In [None]:
import medmnist
import torch
import numpy as np

from configuration import config_jup
from utils.data_loader import get_loader_with_assignment
from utils.train_utils import get_logger, initialize_model, compute_avg_acc_for, save_results
from utils.utils_memory import Memory
from utils.cl_utils import Client

In [None]:
args = config_jup.base_parser()
client_id = 0

if torch.cuda.is_available():
    args.cuda = True
    args.device = f'cuda:0'
else:
    args.device = 'cpu' 

In [None]:
# change the default arguments if needed (see example below)
args.dataset_name = 'cifar10'
args.memory_size = 200
logger = get_logger(args)
print(args)

In [None]:
for run in range(args.n_runs):
    if args.dataset_name in medmnist.INFO.keys():
        cls_assignment = None
        loader_client, cls_assignment = get_loader_with_assignment(args, None, None)
        print(cls_assignment)
    else:
        np.random.seed(run)
        cls_assignment = np.arange(args.n_classes)
        np.random.shuffle(cls_assignment)
        loader_client, _ = get_loader_with_assignment(args, cls_assignment.tolist(), run)

    # for reproducibility purposes
    np.random.seed(run)
    torch.manual_seed(run)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model, optimizer, criterion = initialize_model(args)
    memory_client = Memory(args)
    client = Client(args, loader_client, model, optimizer, criterion, memory_client, client_id, cls_assignment)

    while not client.train_completed:
        samples, labels = client.get_next_batch()

        if samples is not None:
            if args.with_memory:
                if client.task_id == 0:
                    client.train_with_update(samples, labels)
                else:
                    client.train_with_memory(samples, labels)
            else:
                client.train(samples, labels)

        else:
            print(f'Run {run} - Client {client.client_id} - Task {client.task_id} completed - {client.get_current_task()}')
            # compute loss train
            logger = client.compute_loss(logger, run)
            print(f'Run {run} - Client {client.client_id} - Test time - Task {client.task_id}')
            if args.model_name == 'resnetmc':
                logger = client.testMC(logger, run)
                logger = client.validationMC(logger, run)
            else:
                logger = client.test(logger, run)
                logger = client.validation(logger, run)
            logger = client.forgetting(logger, run)

            if client.task_id + 1 >= args.n_tasks:
                client.train_completed = True
                print(f'Run {run} - Client {client.client_id} - Train completed')
            else:
                client.task_id += 1

    if args.model_name == 'resnetmc':
        logger = client.balanced_accuracyMC(logger, run)
    else:
        logger = client.balanced_accuracy(logger, run)
    print()

    print(logger['test']['acc'][client_id][run])
    print()

In [None]:
mean_acc, std_acc, mean_for, std_for = compute_avg_acc_for(args, logger)
print(f'Final accuracy: {np.round(mean_acc*100, 2):5} (+-) {np.round(std_acc*100, 2)}')
print(f'Final forgetting: {np.round(mean_for*100, 2):5} (+-) {np.round(std_for*100, 2)}')
print()

In [None]:
# save training results
save_results(args, logger)

### Show images in the memory

In [None]:
import os
import matplotlib.pyplot as plt
from torchvision import transforms
from utils.data_loader import get_statistics


mean, std, n_classes, inp_size, in_channels = get_statistics(args)

invTrans = transforms.Compose([ transforms.Normalize(mean = np.dot(0, mean),
                                                     std = np.divide(1, std)),
                                transforms.Normalize(mean = np.dot(-1, mean),
                                                     std = np.divide(std, std)),
                               ])

def show_images(args, imgs, class_id):
    dir_plot = f'./images/{args.dataset_name}/{args.memory_size}/{args.uncertainty_score}/{args.balanced_step}/{class_id}'
    if not os.path.exists(dir_plot):
        os.makedirs(dir_plot)

    n_rows = len(imgs) // 10

    if n_rows > 1:
        fix, axs = plt.subplots(nrows=n_rows, ncols=10, squeeze=False, figsize=(5, n_rows/2))
        for n_row in range(n_rows+1):
            for n_col in range(10):
                img_idx = n_col + n_row * 10
                if img_idx == len(imgs): break
                img = transforms.ToPILImage()(invTrans(imgs[img_idx]).to('cpu'))
                axs[n_row, n_col].imshow(np.asarray(img))
                axs[n_row, n_col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    else:
        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = transforms.ToPILImage()(invTrans(img).to('cpu'))
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    plt.subplots_adjust(hspace=0, wspace=0)
    plt.show()


In [None]:
class_id = 2 # choose any class_id
mem_class = client.memory.x[client.memory.y == class_id]
show_images(args, mem_class, class_id)

### Create LT-version of CIFAR10

In [None]:
import pickle
from torchvision import datasets, transforms
from utils.data_loader import get_data_per_class

data_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

train = datasets.CIFAR10('./data/raw/', train=True,  download=True, transform=data_transforms)
max_num = len(train) / args.n_classes
imb_factor = 0.1
w_per_cls = []
for idx in range(args.n_classes):
    num = max_num * (imb_factor**(idx/(args.n_classes-1)))
    w = num / max_num
    w_per_cls.append(w)

In [None]:
ds_dict = get_data_per_class(args)
skip = args.n_classes_per_task

for run in range(args.n_runs):
    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}/run{run}/'
    loader_fn = f'{dir_output}/{args.dataset_name}_split.pkl'
    cls_assignment_fn = f'{dir_output}/{args.dataset_name}_cls_assignment.pkl'
    cls_assignment = pickle.load(open(cls_assignment_fn, 'rb'))
    print(cls_assignment)

    # for each data split (i.e., train/val/test)
    ds_out = {}
    for name_ds, ds in ds_dict.items():
        split_ds = []
        for i in range(0, args.n_classes, skip):
            w_list = w_per_cls[i:i+skip]
            t_list = cls_assignment[i:i+skip]
            task_ds_tmp_x = []
            task_ds_tmp_y = []
            for idx, class_id in enumerate(t_list):
                class_x, class_y = ds[class_id]
                num_per_class = int(w_list[idx]*len(class_y))
                task_ds_tmp_x.append(class_x[:num_per_class])
                task_ds_tmp_y.append(class_y[:num_per_class])

            task_ds_x = torch.cat(task_ds_tmp_x)
            task_ds_y = torch.cat(task_ds_tmp_y)
            split_ds += [(task_ds_x, task_ds_y)]
        ds_out[name_ds] = split_ds

    ds_list = [ds_out['train'], ds_out['val'], ds_out['test']]
    loader_list = []
    for ds in ds_list:
        loader_tmp = []
        for task_data in ds:
            images, label = task_data
            indices = torch.from_numpy(np.random.choice(images.size(0), images.size(0), replace=False))
            images = images[indices]
            label = label[indices]
            task_ds = torch.utils.data.TensorDataset(images, label)
            task_loader = torch.utils.data.DataLoader(task_ds, batch_size=args.batch_size, drop_last=True)
            loader_tmp.append(task_loader)
        loader_list.append(loader_tmp)

    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}LT/run{run}/'
    loader_fn = f'{dir_output}/{args.dataset_name}LT_split.pkl'
    cls_assignment_fn = f'{dir_output}/{args.dataset_name}LT_cls_assignment.pkl'
    if not os.path.exists(loader_fn):
        os.makedirs(dir_output)

    # save data splits and cls_assignment
    with open(loader_fn, 'wb') as outfile:
        pickle.dump(loader_list, outfile)
        outfile.close()
    with open(cls_assignment_fn, 'wb') as outfile:
        pickle.dump(cls_assignment, outfile)
        outfile.close()