In [None]:
%matplotlib inline
from IPython.display import clear_output
import os
from copy import deepcopy

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from helper import ExperimentLogger, display_train_stats
from fl_devices import Server, Client

In [None]:
# set seed
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# prepare dataset

In [None]:
# define subdataset
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Normalize, Compose
from PIL import Image

class SubCIFAR10(Dataset):
    """
    Constructs a subset of CIFAR10 dataset from a pickle file;
    expects pickle file to store list of indices

    Attributes
    ----------
    indices: iterable of integers
    transform
    data
    targets

    Methods
    -------
    __init__
    __len__
    __getitem__
    """
    def __init__(self, indices, cifar10_data=None, cifar10_targets=None, transform=None):
        """
        :param path: path to .pkl file; expected to store list of indices
        :param cifar10_data: Cifar-10 dataset inputs stored as torch.tensor
        :param cifar10_targets: Cifar-10 dataset labels stored as torch.tensor
        :param transform:
        """
        # with open(path, "rb") as f:
        self.indices = indices

        if transform is None:
            self.transform = \
                Compose([
                    ToTensor(),
                    Normalize(
                        (0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010)
                    )
                ])


        self.data, self.targets = cifar10_data, cifar10_targets

        self.data = self.data[self.indices]
        self.targets = self.targets[self.indices]

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        img = Image.fromarray(img.numpy())

        if self.transform is not None:
            img = self.transform(img)

        target = target

        return img, target

In [None]:
# download and concate dataset
from torchvision.datasets import CIFAR10
train_dataset = CIFAR10('raw_data', download=True, train=True)
test_dataset = CIFAR10('raw_data', download=True, train=False)

data =  torch.cat([
    torch.tensor(train_dataset.data),
    torch.tensor(test_dataset.data)
])

targets = torch.cat([
    torch.tensor(train_dataset.targets),
    torch.tensor(test_dataset.targets)
])

In [None]:
# build client datasets
N_CLIENTS = 80
import pickle


client_train_dataset = []
test_indices = []
base_path = '../data/cifar10/all_data/train'
for task_id, task_dir in enumerate(os.listdir(base_path)):
    data_path = os.path.join(base_path, task_dir, 'train.pkl')
    with open(data_path, 'rb') as f:
        indices = pickle.load(f)
    client_dataset = SubCIFAR10(indices, data, targets)
    client_train_dataset.append(client_dataset)

    data_path = os.path.join(base_path, task_dir, 'test.pkl')
    with open(data_path, 'rb') as f:
        indices = pickle.load(f)
    test_indices.extend(indices)
test_dataset = SubCIFAR10(test_indices, data, targets)

# Define model

In [None]:
import torch.nn as nn
import torchvision.models as tvmodels

class Cifar10_Net(nn.Module):
    def __init__(self, num_classes) -> None:
        super(Cifar10_Net, self).__init__()
        self.model = tvmodels.mobilenet_v2(pretrained=True)
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# FL process

In [None]:
clients = [Client(Cifar10_Net, lambda x : torch.optim.SGD(x, lr=0.01, momentum=0.9, weight_decay=5e-4), dat, idnum=i) 
           for i, dat in enumerate(client_train_dataset)]
server = Server(Cifar10_Net, test_dataset)

In [None]:
COMMUNICATION_ROUNDS = 200
EPS_1 = 0.15
EPS_2 = 7 # most case
    
    
cfl_stats = ExperimentLogger()
    
cluster_indices = [np.arange(len(clients)).astype("int")]
client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]


for c_round in range(1, COMMUNICATION_ROUNDS+1):

    if c_round == 1:
        for client in clients:
            client.synchronize_with_server(server)
            
    participating_clients = server.select_clients(clients, frac=1.0)

    for client in participating_clients:
        train_stats = client.compute_weight_update(epochs=1)
        client.reset()

    similarities = server.compute_pairwise_similarities(clients)

    cluster_indices_new = []
    for idc in cluster_indices:
        max_norm = server.compute_max_update_norm([clients[i] for i in idc])
        mean_norm = server.compute_mean_update_norm([clients[i] for i in idc])
             
        if mean_norm<EPS_1 and max_norm>EPS_2 and len(idc)>2 and c_round>20:
            
            server.cache_model(idc, clients[idc[0]].W, acc_clients)
            
            c1, c2 = server.cluster_clients(similarities[idc][:,idc]) 
            cluster_indices_new += [c1, c2]
             
            cfl_stats.log({"split" : c_round})

        else:
            cluster_indices_new += [idc]
        
        
    cluster_indices = cluster_indices_new
    client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]

    server.aggregate_clusterwise(client_clusters)

    acc_clients = [client.evaluate() for client in clients]
    
    cfl_stats.log({"acc_clients" : acc_clients, "mean_norm" : mean_norm, "max_norm" : max_norm,
                  "rounds" : c_round, "clusters" : cluster_indices})
    
    
    display_train_stats(cfl_stats, EPS_1, EPS_2, COMMUNICATION_ROUNDS)

    
for idc in cluster_indices:    
    server.cache_model(idc, clients[idc[0]].W, acc_clients)

# save result

In [None]:
import os
import pickle

path = "../plot/cfl_result/cifar10"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"seed_{SEED}.pkl"), 'wb') as f:
    pickle.dump(cfl_stats, f)