In [None]:
%matplotlib inline
from IPython.display import clear_output
import os
from copy import deepcopy

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from helper import ExperimentLogger, display_train_stats
from fl_devices import Server, Client

In [None]:
# set seed
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
# download and concate dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import Dataset
from PIL import Image

train_dataset = MNIST('raw_data', download=True, train=True)
test_dataset = MNIST('raw_data', download=True, train=False)

data =  torch.cat([
    torch.tensor(train_dataset.data),
    torch.tensor(test_dataset.data)
])

targets = torch.cat([
    torch.tensor(train_dataset.targets),
    torch.tensor(test_dataset.targets)
])

class SubMNIST(Dataset):

    def __init__(self, indices, data, targets) -> None:

        self.indices = indices
        self.transform = Compose([
            ToTensor(),
            Normalize((0.1307,), (0.3081,))
        ])

        self.data = data[self.indices]
        self.targets = targets[self.indices]
    
    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])

        img = Image.fromarray(img.numpy(), mode='L')

        img = self.transform(img)
        
        return img, target

In [None]:
# build client datasets
N_CLIENTS = 100
import pickle


client_train_dataset = []
test_indices = []
base_path = '../data/mnist/all_data/train'
for task_id, task_dir in enumerate(os.listdir(base_path)):
    data_path = os.path.join(base_path, task_dir, 'train.pkl')
    with open(data_path, 'rb') as f:
        indices = pickle.load(f)
    client_dataset = SubMNIST(indices, data, targets)
    client_train_dataset.append(client_dataset)

    data_path = os.path.join(base_path, task_dir, 'test.pkl')
    with open(data_path, 'rb') as f:
        indices = pickle.load(f)
    test_indices.extend(indices)
test_dataset = SubMNIST(test_indices, data, targets)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MnistPerceptron(nn.Module):
    def __init__(self, num_classes) -> None:
        super(MnistPerceptron, self).__init__()
        self.fc = nn.Linear(28*28, 128)
        self.classifier = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc(x))
        output = self.classifier(x)
        return output

In [None]:
clients = [Client(MnistPerceptron, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9, weight_decay=5e-4), dat, idnum=i) 
           for i, dat in enumerate(client_train_dataset)]
server = Server(MnistPerceptron, test_dataset)

In [None]:
COMMUNICATION_ROUNDS = 200
EPS_1 = 0.03  # 0.05 in seed 42
EPS_2 = 0.5 # 0.5 in seed 42, 0.7 in seed 43

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    
    
cfl_stats = ExperimentLogger()
    
cluster_indices = [np.arange(len(clients)).astype("int")]
client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]


for c_round in range(1, COMMUNICATION_ROUNDS+1):

    if c_round == 1:
        for client in clients:
            client.synchronize_with_server(server)
            
    participating_clients = server.select_clients(clients, frac=1.0)

    for client in participating_clients:
        train_stats = client.compute_weight_update(epochs=1)
        client.reset()

    similarities = server.compute_pairwise_similarities(clients)

    cluster_indices_new = []
    for idc in cluster_indices:
        max_norm = server.compute_max_update_norm([clients[i] for i in idc])
        mean_norm = server.compute_mean_update_norm([clients[i] for i in idc])
             
        if mean_norm<EPS_1 and max_norm>EPS_2 and len(idc)>2 and c_round>20:
            
            server.cache_model(idc, clients[idc[0]].W, acc_clients)
            
            c1, c2 = server.cluster_clients(similarities[idc][:,idc]) 
            cluster_indices_new += [c1, c2]
             
            cfl_stats.log({"split" : c_round})

        else:
            cluster_indices_new += [idc]
        
        
    cluster_indices = cluster_indices_new
    client_clusters = [[clients[i] for i in idcs] for idcs in cluster_indices]

    server.aggregate_clusterwise(client_clusters)

    acc_clients = [client.evaluate() for client in clients]
    
    cfl_stats.log({"acc_clients" : acc_clients, "mean_norm" : mean_norm, "max_norm" : max_norm,
                  "rounds" : c_round, "clusters" : cluster_indices})
    logging.info(f"the mean accuracy is {np.mean(acc_clients, axis=0)} at rount {c_round}")
    
    
    display_train_stats(cfl_stats, EPS_1, EPS_2, COMMUNICATION_ROUNDS)

    
for idc in cluster_indices:    
    server.cache_model(idc, clients[idc[0]].W, acc_clients)

In [None]:
import os
import pickle

path = "../plot/cfl_result/mnist"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"seed_{SEED}.pkl"), 'wb') as f:
    pickle.dump(cfl_stats, f)