In [1]:
import torch
import numpy as np

from configuration import config_jup
from utils.data_loader import get_loader_all_clients
from utils.train_utils import get_logger, initialize_clients, FedAvg, weightedFedAvg, test_global_model, save_results

In [2]:
args = config_jup.base_parser()
if torch.cuda.is_available():
    args.cuda = True
    args.device = f'cuda:0'
else:
    args.device = 'cpu' 


In [3]:
args.n_runs = 1
logger = get_logger(args)
print(args)

Namespace(framework='FCL', dir_data='./data/', dir_output='./output/', dataset_name='cifar10', model_name='resnet', batch_size=10, lr=0.1, optimizer='sgd', local_epochs=1, n_runs=1, n_tasks=5, with_memory=1, memory_size=500, update_strategy='balanced', sampling_strategy='random', balanced_update='random', uncertainty_score='bregman', subsample_size=50, balanced_step='bottomk', n_clients=5, overlap='non-overlap', burnin=30, jump=5, fl_update='w_favg', cuda=True, device='cuda:0', input_size=(3, 32, 32), n_classes=10, n_classes_per_task=2, dir_results='./output//FCL/cifar10/w_favg/non-overlap/30/5/resnet/sgd/01/500/10/1/random/balanced_random/')


In [4]:
for run in range(args.n_runs):
    loader_clients, cls_assignment_list, global_test_loader = get_loader_all_clients(args, run)
    clients = initialize_clients(args, loader_clients, cls_assignment_list, run)

    while not all([client.train_completed for client in clients]):
        for client in clients:
            if not client.train_completed:
                samples, labels = client.get_next_batch()

                if samples is not None:
                    if args.with_memory:
                        if client.task_id == 0:
                            client.train_with_update(samples, labels)
                        else:
                            client.train_with_memory(samples, labels)
                    else:
                        client.train(samples, labels)

                else:
                    print(f'Run {run} - Client {client.client_id} - Task {client.task_id} completed - {client.get_current_task()}')
                    # compute loss train
                    logger = client.compute_loss(logger, run)
                    print(f'Run {run} - Client {client.client_id} - Test time - Task {client.task_id}')
                    logger = client.test(logger, run)
                    logger = client.validation(logger, run)
                    logger = client.forgetting(logger, run)

                    if client.task_id + 1 >= args.n_tasks:
                        client.train_completed = True
                        print(f'Run {run} - Client {client.client_id} - Train completed')
                        logger = client.balanced_accuracy(logger, run)
                    else:
                        client.task_id += 1

        # COMMUNICATION ROUND PART
        selected_clients = [client.client_id for client in clients if (client.num_batches >= args.burnin and client.num_batches % args.jump == 0 and client.train_completed == False)]
        if len(selected_clients) > 1:
            # communication round when all clients process a mini-batch
            if args.fl_update == 'favg':
                global_model = FedAvg(args, selected_clients, clients)
            if args.fl_update == 'w_favg':
                global_model = weightedFedAvg(args, selected_clients, clients)

            global_parameters = global_model.state_dict()
            # local models update with averaged global parameters
            for client_id in selected_clients:
                clients[client_id].update_parameters(global_parameters)
                clients[client_id].save_last_global_model(global_model)

    # global model accuracy when all clients finish their training on all tasks (FedCIL ICLR2023)
    logger = test_global_model(args, global_test_loader, global_model, logger, run)

Run 0 - Client 0 - Task 0 completed - (2, 9)
Run 0 - Client 0 - Test time - Task 0
Run 0 - Client 1 - Task 0 completed - (4, 1)
Run 0 - Client 1 - Test time - Task 0
Run 0 - Client 2 - Task 0 completed - (5, 4)
Run 0 - Client 2 - Test time - Task 0
Run 0 - Client 3 - Task 0 completed - (3, 8)
Run 0 - Client 3 - Test time - Task 0
Run 0 - Client 4 - Task 0 completed - (9, 5)
Run 0 - Client 4 - Test time - Task 0
Run 0 - Client 0 - Task 1 completed - (6, 4)
Run 0 - Client 0 - Test time - Task 1
Run 0 - Client 1 - Task 1 completed - (5, 0)
Run 0 - Client 1 - Test time - Task 1
Run 0 - Client 2 - Task 1 completed - (1, 2)
Run 0 - Client 2 - Test time - Task 1
Run 0 - Client 3 - Task 1 completed - (4, 9)
Run 0 - Client 3 - Test time - Task 1
Run 0 - Client 4 - Task 1 completed - (2, 4)
Run 0 - Client 4 - Test time - Task 1
Run 0 - Client 0 - Task 2 completed - (0, 3)
Run 0 - Client 0 - Test time - Task 2
Run 0 - Client 1 - Task 2 completed - (7, 2)
Run 0 - Client 1 - Test time - Task 2
Run 

In [9]:
clients[client_id].task_list

[(9, 5), (2, 4), (7, 1), (0, 8), (6, 3)]

In [10]:
for client_id in range(args.n_clients):
    print(f'Client {client_id}: {clients[client_id].task_list}')
    print(np.mean(logger['test']['acc'][client_id], 0))
    print(f'Final client accuracy: {np.mean(np.mean(logger["test"]["acc"][client_id], 0)[args.n_tasks-1,:], 0)}')
    print(f'Final client forgetting: {np.mean(logger["test"]["forget"][client_id])}')
    print(f'Final client balanced accuracy: {np.mean(logger["test"]["bal_acc"][client_id])}')
    print()

Client 0: [(2, 9), (6, 4), (0, 3), (1, 7), (8, 5)]
[[0.5        0.         0.         0.         0.        ]
 [0.005      0.57249999 0.         0.         0.        ]
 [0.         0.65249997 0.         0.         0.        ]
 [0.         0.50749999 0.         0.29249999 0.        ]
 [0.0825     0.66749996 0.         0.0875     0.55250001]]
Final client accuracy: 0.2779999926686287
Final client forgetting: 0.15187500230967999
Final client balanced accuracy: 0.278

Client 1: [(4, 1), (5, 0), (7, 2), (3, 6), (9, 8)]
[[0.85249996 0.         0.         0.         0.        ]
 [0.69499999 0.         0.         0.         0.        ]
 [0.4025     0.         0.32999998 0.         0.        ]
 [0.11       0.         0.         0.5025     0.        ]
 [0.26249999 0.         0.02       0.47749999 0.27250001]]
Final client accuracy: 0.20649999752640724
Final client forgetting: 0.23124999087303877
Final client balanced accuracy: 0.20650000000000004

Client 2: [(5, 4), (1, 2), (9, 6), (7, 0), (3, 8)

In [None]:
# save training results
save_results(args, logger)

### Show images in the memory

In [None]:
import os
import matplotlib.pyplot as plt
from torchvision import transforms
from utils.data_loader import get_statistics


mean, std, n_classes, inp_size, in_channels = get_statistics(args)

invTrans = transforms.Compose([ transforms.Normalize(mean = np.dot(0, mean),
                                                     std = np.divide(1, std)),
                                transforms.Normalize(mean = np.dot(-1, mean),
                                                     std = np.divide(std, std)),
                               ])

def show_images(args, imgs, class_id):
    dir_plot = f'./images/{args.dataset_name}/{args.memory_size}/{args.uncertainty_score}/{args.balanced_step}/{class_id}'
    if not os.path.exists(dir_plot):
        os.makedirs(dir_plot)

    n_rows = len(imgs) // 10

    if n_rows > 1:
        fix, axs = plt.subplots(nrows=n_rows, ncols=10, squeeze=False, figsize=(5, n_rows/2))
        for n_row in range(n_rows+1):
            for n_col in range(10):
                img_idx = n_col + n_row * 10
                if img_idx == len(imgs): break
                img = transforms.ToPILImage()(invTrans(imgs[img_idx]).to('cpu'))
                axs[n_row, n_col].imshow(np.asarray(img))
                axs[n_row, n_col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    else:
        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = transforms.ToPILImage()(invTrans(img).to('cpu'))
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    plt.subplots_adjust(hspace=0, wspace=0)
    plt.show()


In [None]:
class_id = 2 # choose any class_id
mem_class = client.memory.x[client.memory.y == class_id]
show_images(args, mem_class, class_id)

### Create LT-version of CIFAR10

In [None]:
import pickle
from torchvision import datasets, transforms
from utils.data_loader import get_data_per_class

data_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

train = datasets.CIFAR10('./data/raw/', train=True,  download=True, transform=data_transforms)
max_num = len(train) / args.n_classes
imb_factor = 0.1
w_per_cls = []
for idx in range(args.n_classes):
    num = max_num * (imb_factor**(idx/(args.n_classes-1)))
    w = num / max_num
    w_per_cls.append(w)

In [None]:
ds_dict = get_data_per_class(args)
skip = args.n_classes_per_task

for run in range(args.n_runs):
    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}/run{run}/'
    loader_fn = f'{dir_output}/{args.dataset_name}_split.pkl'
    cls_assignment_fn = f'{dir_output}/{args.dataset_name}_cls_assignment.pkl'
    cls_assignment = pickle.load(open(cls_assignment_fn, 'rb'))
    print(cls_assignment)

    # for each data split (i.e., train/val/test)
    ds_out = {}
    for name_ds, ds in ds_dict.items():
        split_ds = []
        for i in range(0, args.n_classes, skip):
            w_list = w_per_cls[i:i+skip]
            t_list = cls_assignment[i:i+skip]
            task_ds_tmp_x = []
            task_ds_tmp_y = []
            for idx, class_id in enumerate(t_list):
                class_x, class_y = ds[class_id]
                num_per_class = int(w_list[idx]*len(class_y))
                task_ds_tmp_x.append(class_x[:num_per_class])
                task_ds_tmp_y.append(class_y[:num_per_class])

            task_ds_x = torch.cat(task_ds_tmp_x)
            task_ds_y = torch.cat(task_ds_tmp_y)
            split_ds += [(task_ds_x, task_ds_y)]
        ds_out[name_ds] = split_ds

    ds_list = [ds_out['train'], ds_out['val'], ds_out['test']]
    loader_list = []
    for ds in ds_list:
        loader_tmp = []
        for task_data in ds:
            images, label = task_data
            indices = torch.from_numpy(np.random.choice(images.size(0), images.size(0), replace=False))
            images = images[indices]
            label = label[indices]
            task_ds = torch.utils.data.TensorDataset(images, label)
            task_loader = torch.utils.data.DataLoader(task_ds, batch_size=args.batch_size, drop_last=True)
            loader_tmp.append(task_loader)
        loader_list.append(loader_tmp)

    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}LT/run{run}/'
    loader_fn = f'{dir_output}/{args.dataset_name}LT_split.pkl'
    cls_assignment_fn = f'{dir_output}/{args.dataset_name}LT_cls_assignment.pkl'
    if not os.path.exists(loader_fn):
        os.makedirs(dir_output)

    # save data splits and cls_assignment
    with open(loader_fn, 'wb') as outfile:
        pickle.dump(loader_list, outfile)
        outfile.close()
    with open(cls_assignment_fn, 'wb') as outfile:
        pickle.dump(cls_assignment, outfile)
        outfile.close()

In [4]:
a = 0
for i in range(0):
    a += 2

In [5]:
a

0