In [None]:
''''
This experiment seting is following FedBN (this paper url={https://openreview.net/pdf?id=6YEQUn0QICG}).
Code souce of data process: https://github.com/med-air/FedBN

Before running the this file, due to the requirement of maximum file size of Supplementary Material,
please download the pre-processed datasets from following url (sorry for this inconvenience):
    https://drive.google.com/uc?export=download&id=1moBE_ASD5vIOaU8ZHm_Nsj0KAfX5T0Sf
    
and unzip it under 'data/mixed_digit_dataset' directory,
then you can start following experiments on mixed-digit dataset.
''''

In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import random, os
import numpy as np
from math import sqrt
from matplotlib import pyplot as plt
import pandas as pd
import copy


from fedlab.utils.dataset.partition import CIFAR10Partitioner
from fedlab.utils.dataset import FMNISTPartitioner
from fedlab.utils.functional import partition_report, save_dict
    
from args_mixed_digit_c2 import args_parser
import server_se1 as server
import model

from utils.global_test import *
from utils.local_test import test_on_localdataset
from utils.femnist_dataset import *
from utils.training_loss import train_loss_show,train_localacc_show
from utils.sampling import *
from utils.mixed_digits_data_preprocess import *
from utils.clusteror import cluster_clients
from utils.compute_histogram import compute_histogram, compute_histogram_mixed_digit


args = args_parser()

def seed_torch(seed=args.seed):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)  
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

seed_torch()
GLOBAL_SEED = 1
def worker_init_fn(worker_id):
    global GLOBAL_WORKER_ID
    GLOBAL_WORKER_ID = worker_id
    set_seed(GLOBAL_SEED + worker_id)

similarity = False
save_models = False
Train_model = True
C = "3CNN"

In [None]:
num_classes = args.num_classes
num_clients = args.K
number_perclass = args.num_perclass
 

col_names = [f"class{i}" for i in range(num_classes)]
print(col_names)
hist_color = '#4169E1'
plt.rcParams['figure.facecolor'] = 'white'

In [None]:
from torch.utils.data import Dataset

class DigitsDataset(Dataset):
    def __init__(self, data_path, channels, percent=0.1, filename=None, train=True, transform=None):
        if filename is None:
            if train:
                if percent >= 0.1:
                    for part in range(int(percent*10)):
                        if part == 0:
                            self.images, self.targets = np.load(os.path.join(data_path, 'partitions/train_part{}.pkl'.format(part)), allow_pickle=True)
                        else:
                            images, targets = np.load(os.path.join(data_path, 'partitions/train_part{}.pkl'.format(part)), allow_pickle=True)
                            self.images = np.concatenate([self.images,images], axis=0)
                            self.targets = np.concatenate([self.targets,targets], axis=0)
                else:
                    self.images, self.targets = np.load(os.path.join(data_path, 'partitions/train_part0.pkl'), allow_pickle=True)
                    data_len = int(self.images.shape[0] * percent*10)
                    self.images = self.images[:data_len]
                    self.targets = self.targets[:data_len]
            else:
                self.images, self.targets = np.load(os.path.join(data_path, 'test.pkl'), allow_pickle=True)
        else:
            self.images, self.targets = np.load(os.path.join(data_path, filename), allow_pickle=True)

        self.transform = transform
        self.channels = channels
        self.targets = self.targets.astype(np.compat.long).squeeze()

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        if self.channels == 1:
            image = Image.fromarray(image, mode='L')
        elif self.channels == 3:
            image = Image.fromarray(image, mode='RGB')
        else:
            raise ValueError("{} channel is not allowed.".format(self.channels))

        if self.transform is not None:
            image = self.transform(image)

        return image, target

In [None]:
class DigitsDataset_IID(Dataset):
    def __init__(self, data_path_list, channels_list, filename=None, transform_list=None, train =True):
        
        self.transform_list = transform_list
        self.channels_list = channels_list
        self.images = []
        
        for index, data_path in enumerate(data_path_list):
            if train:
                for part in range(10):
                    images, targets = np.load(os.path.join(data_path, 'partitions/train_part{}.pkl'.format(part)), allow_pickle=True)
                    for k, image in enumerate(images):
                        if self.channels_list[index] == 1:
                            a = Image.fromarray(image, mode='L')
                            a = self.transform_list[index](a)
                            self.images.extend(torch.unsqueeze(a, dim=0))
                        elif self.channels_list[index] == 3:
                            a = Image.fromarray(image, mode='RGB')
                            a = self.transform_list[index](a)
                            self.images.extend(torch.unsqueeze(a, dim=0))
                        else:
                            raise ValueError("{} channel is not allowed.".format(self.channels_list[index]))
                    if part == 0 and index ==0:
                        self.targets = targets
                    else:
                        self.targets = np.concatenate([self.targets,targets], axis=0)
            else:
                images, targets = np.load(os.path.join(data_path, 'test.pkl'), allow_pickle=True)
                for k, image in enumerate(images):
                    if self.channels_list[index] == 1:
                        a = Image.fromarray(image, mode='L')
                        a = self.transform_list[index](a)
                        self.images.extend(torch.unsqueeze(a, dim=0))
                    elif self.channels_list[index] == 3:
                        a = Image.fromarray(image, mode='RGB')
                        a = self.transform_list[index](a)
                        self.images.extend(torch.unsqueeze(a, dim=0))
                    else:
                        raise ValueError("{} channel is not allowed.".format(self.channels_list[index]))
                if index ==0:
                    self.targets = targets
                else:
                    self.targets = np.concatenate([self.targets,targets], axis=0)
                    
        self.targets = self.targets.astype(np.compat.long).squeeze()


    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        # if self.channels == 1:
        #     image = Image.fromarray(image, mode='L')
        # elif self.channels == 3:
        #     image = Image.fromarray(image, mode='RGB')
        # else:
        #     raise ValueError("{} channel is not allowed.".format(self.channels))

        # if self.transform is not None:
        #     image = self.transform(image)

        return image, target

In [None]:
transform_mnist = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_svhn = transforms.Compose([
        transforms.Resize([28,28]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_usps = transforms.Compose([
        transforms.Resize([28,28]),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_synth = transforms.Compose([
        transforms.Resize([28,28]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_mnistm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


percent = 1
# MNIST
mnist_trainset  = DigitsDataset(data_path="data/mixed_digit_dataset/MNIST", channels=1, percent=percent, train=True,  transform=transform_mnist)
mnist_testset   = DigitsDataset(data_path="data/mixed_digit_dataset/MNIST", channels=1, percent=0.1, train=False, transform=transform_mnist)

# SVHN
svhn_trainset  = DigitsDataset(data_path='data/mixed_digit_dataset/SVHN', channels=3, percent=percent,  train=True,  transform=transform_svhn)
svhn_testset   = DigitsDataset(data_path='data/mixed_digit_dataset/SVHN', channels=3, percent=0.1,  train=False, transform=transform_svhn)

# USPS
usps_trainset  = DigitsDataset(data_path='data/mixed_digit_dataset/USPS', channels=1, percent=percent,  train=True,  transform=transform_usps)
usps_testset   = DigitsDataset(data_path='data/mixed_digit_dataset/USPS', channels=1, percent=0.1,  train=False, transform=transform_usps)

# Synth Digits
synth_trainset = DigitsDataset(data_path='data/mixed_digit_dataset/SynthDigits/', channels=3, percent=percent,  train=True,  transform=transform_synth)
synth_testset = DigitsDataset(data_path='data/mixed_digit_dataset/SynthDigits/', channels=3, percent=0.1,  train=False, transform=transform_synth)

# MNIST-M
mnistm_trainset = DigitsDataset(data_path='data/mixed_digit_dataset/MNIST_M/', channels=3, percent=percent,  train=True,  transform=transform_mnistm)
mnistm_testset  = DigitsDataset(data_path='data/mixed_digit_dataset/MNIST_M/', channels=3, percent=0.1,  train=False, transform=transform_mnistm)


trainsets = [mnist_trainset, svhn_trainset, usps_trainset, synth_trainset, mnistm_trainset]
testsets  = [mnist_testset, svhn_testset, usps_testset, synth_testset, mnistm_testset]

datasets_name = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST-M']
datasets_client_index = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}


In [None]:
clients_indexset = [ i for i in range(args.K)]
clientnumbers_per_dataset = int(args.K/len(datasets_name))
np.random.seed(args.seed)
for i in range(len(datasets_name)):
    datasets_client_index[datasets_name[i]] = list(np.random.choice(clients_indexset, clientnumbers_per_dataset, replace=False))
    clients_indexset = list(set(clients_indexset) - set(datasets_client_index[datasets_name[i]]))


In [None]:
# perform partition
labeldir_parts = []
labeldir_part_dfs = []
for dataset_name, trainset in zip(datasets_name, trainsets):
    dataset_client_num = len(datasets_client_index[dataset_name])
    # labeldir_part = FMNISTPartitioner(trainset.targets, 
    #                                             num_clients=dataset_client_num,
    #                                             partition="iid",
    #                                             seed=1)
    labeldir_part = FMNISTPartitioner(trainset.targets,  
                                               num_clients=dataset_client_num,
                                               partition="noniid-#label", 
                                               major_classes_num=2,
                                               seed=1)
    
    # labeldir_part = FMNISTPartitioner(trainset.targets, 
    #                                     num_clients=dataset_client_num,
    #                                     partition="noniid-labeldir", 
    #                                     dir_alpha=0.5,
    #                                     seed=3)
    
    # generate partition report
    csv_file = "data/fmnist/fmnist_noniid_labeldir_clients_10.csv"
    partition_report(trainset.targets, labeldir_part.client_dict, 
                     class_num=num_classes, 
                     verbose=False, file=csv_file)

    labeldir_part_df = pd.read_csv(csv_file,header=1)
    labeldir_part_df = labeldir_part_df.set_index('client')
    for col in col_names:
        labeldir_part_df[col] = (labeldir_part_df[col] * labeldir_part_df['Amount']).astype(int)
        
    labeldir_parts.append(labeldir_part)
    labeldir_part_dfs.append(labeldir_part_df)

labeldir_part_df

In [None]:
transform_list =  [transform_mnist, transform_svhn, transform_usps, transform_synth, transform_mnistm]
channels_list = [1, 3, 1, 3, 3]
path_list = ["data/mixed_digit_dataset/MNIST", 
            'data/mixed_digit_dataset/SVHN',
            'data/mixed_digit_dataset/USPS',
            'data/mixed_digit_dataset/SynthDigits/',
            'data/mixed_digit_dataset/MNIST_M/']

In [None]:
datasets_client_index.items()

In [None]:
clients_dataset_index = {i:[] for i in range(args.K)}
for i in range(args.K):
    for dataset_name in datasets_client_index.keys():
        if i in datasets_client_index[dataset_name]:
            clients_dataset_index[i] = datasets_name.index(dataset_name)

In [None]:
clients_dataset_index.items()

In [None]:
trainset_sample_rate = args.trainset_sample_rate
rare_class_nums = 0
dict_users_train = {i: [] for i in range(args.K)}
for index, trainset in enumerate(trainsets):
    dict_users_train_part = trainset_sampling_mixed_digit(args, datasets_client_index[datasets_name[index]], trainset, trainset_sample_rate, rare_class_nums, labeldir_parts[index])
    for key in dict_users_train_part.keys():
        dict_users_train[key] = dict_users_train_part[key]

In [None]:
training_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    training_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(trainset.targets)[list(dict_users_train[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        training_number[i][k] = list(np.array(trainset.targets)[list(dict_users_train[i])]).count(k)

In [None]:
df_training_number=[]
df_training_number=pd.DataFrame(df_training_number)
for i in range(args.K):
    temp = pd.Series(training_number[i])
    df_training_number[i]= temp
    
df_training_number['Col_sum'] = df_training_number.apply(lambda x: x.sum(), axis=1)
df_training_number.loc['Row_sum'] = df_training_number.apply(lambda x: x.sum())

df_training_number

In [None]:
dict_test = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
for index, testset in enumerate(testsets):
        dict_test[datasets_name[index]] = testset_sampling_mixed_digit(args, testset, 10)

In [None]:
dict_datasets_varify = {'MNIST':{i: [] for i in range(args.num_classes)}, 
                        'SVHN':{i: [] for i in range(args.num_classes)}, 
                        'USPS':{i: [] for i in range(args.num_classes)}, 
                        'SynthDigits':{i: [] for i in range(args.num_classes)}, 
                        'MNIST-M':{i: [] for i in range(args.num_classes)}}
dict_varify = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
for index, testset in enumerate(testsets):
    dict_varify[datasets_name[index]] = testset_sampling_mixed_digit(args, testset, 10)
    
for index, testset in enumerate(testsets):
    for i in dict_varify[datasets_name[index]]: 
        for c in range(args.num_classes):
            if np.array(testset.targets)[i] == c: 
                dict_datasets_varify[datasets_name[index]][c].append(i)

In [None]:
len(dict_test['MNIST'])

In [None]:
a_list = []
for i in range(args.K):
    if df_training_number.loc['Row_sum'][i] % args.B == 1:
        a_list.extend([i])
print(a_list)
for k in a_list:
    dict_users_train[k] = dict_users_train[k] - {list(dict_users_train[k])[-1]}

In [None]:
#  baseline----> iid setting with fedavg

In [None]:
trainset_iid = DigitsDataset_IID(path_list, channels_list, transform_list= transform_list, train=True)
#testset_iid = DigitsDataset_IID(path_list, channels_list, transform_list= transform_list, train=False)

In [None]:
# perform partition
labeldir_parts_iid = []
labeldir_part_dfs_iid = []

labeldir_part_iid = FMNISTPartitioner(trainset_iid.targets, 
                                            num_clients=args.K,
                                            partition="iid",
                                            seed=1)
# labeldir_part = FMNISTPartitioner(trainset.targets,  
#                                            num_clients=dataset_client_num,
#                                            partition="noniid-#label", 
#                                            major_classes_num=5,
#                                            seed=args.seed)
# generate partition report
csv_file = "data/fmnist/fmnist_noniid_labeldir_clients_10.csv"
partition_report(trainset_iid.targets, labeldir_part_iid.client_dict, 
                 class_num=num_classes, 
                 verbose=False, file=csv_file)

labeldir_part_df_iid = pd.read_csv(csv_file,header=1)
labeldir_part_df_iid = labeldir_part_df_iid.set_index('client')
for col in col_names:
    labeldir_part_df_iid[col] = (labeldir_part_df_iid[col] * labeldir_part_df_iid['Amount']).astype(int)

# labeldir_parts_iid.append(labeldir_part_iid)
# labeldir_part_dfs_iid.append(labeldir_part_df_iid)
labeldir_part_df_iid

In [None]:
dict_users_train_iid = trainset_sampling_label(args, trainset_iid, trainset_sample_rate, rare_class_nums, labeldir_part_iid)

In [None]:
len(trainset_iid.targets[list(dict_users_train_iid[1])])

In [None]:
dict_users_train_iid = trainset_sampling_label(args, trainset_iid, trainset_sample_rate, rare_class_nums, labeldir_part_iid) 

In [None]:
clients_dataset_index_iid = {i:[] for i in range(args.K)}
for i in range(args.K):
        clients_dataset_index_iid[i] = [k for k in range(5)]

In [None]:
specf_model = model.DigitModel().to(args.device)

In [None]:
#iid-fedavg

In [None]:
server_iid = server.Server(args, specf_model, trainset_iid, dict_users_train_iid)

In [None]:
if Train_model:
    global_model_iid, similarity_dict_iid, client_models_iid, loss_dict_iid, clients_index_iid, acc_list_iid = server_iid.fedavg_joint_update(testsets, dict_test, fedbn=True, similarity=True, test_global_model_accuracy = True)
else:
    acc_list_iid = torch.load("results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/acc_list_iid_{}E_{}class.pt".format(args.seed,args.E,C))
    global_model_iid = server_iid.nn
    client_models_iid = server_iid.nns
    path_iid_fedavg = "results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/global_model_iid-fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    global_model_iid.load_state_dict(torch.load(path_iid_fedavg))
    for i in range(args.K):
        path_iid_fedavg = "results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/client{}_model_iid-fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_models_iid[i].load_state_dict(torch.load(path_iid_fedavg))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dict_iid,"results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/similarity_dict_iid_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_list_iid,"results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/acc_list_iid_{}E_{}class.pt".format(args.seed,args.E,C))
    path_iid_fedavg = "results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/global_model_iid-fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_model_iid.state_dict(), path_iid_fedavg)
    # for i in range(args.K):
    #     path_iid_fedavg = "results/Test/feature skew/mixed_digit/iid-fedavg/seed{}/client{}_model_iid-fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_models_iid[i].state_dict(), path_iid_fedavg)

In [None]:
d_iid = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_giid = 0.0
for index, testset in enumerate(testsets):
    giid,_ = test_on_globaldataset_mixed_digit(args, global_model_iid, testset, dict_test[datasets_name[index]])
    d_iid[datasets_name[index]] = giid
    mean_giid += giid/5
d_iid

In [None]:
if Train_model:
    train_loss_show(args, loss_dict_iid,clients_index_iid)

In [None]:
del server_iid
torch.cuda.empty_cache()

In [None]:
#fedavg

In [None]:
server_fedavg =  server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_model1, similarity_dict1, client_models1, loss_dict1, clients_index1, acc_list1 = server_fedavg.fedbn(testsets, dict_test,clients_dataset_index, similarity=True,fe_optimizer_name = "fedavg", test_global_model_accuracy = True)
else:
    acc_list1 = torch.load("results/Test/feature skew/mixed_digit/fedavg/seed{}/acc_list1_{}E_{}class.pt".format(args.seed,args.E,C))
    global_model1 = server_fedavg.nn
    client_models1 = server_fedavg.nns
    path_fedavg = "results/Test/feature skew/mixed_digit/fedavg/seed{}/global_model_fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    global_model1.load_state_dict(torch.load(path_fedavg))
    for i in range(args.K):
        path_fedavg = "results/Test/feature skew/mixed_digit/fedavg/seed{}/client{}_model_fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_models1[i].load_state_dict(torch.load(path_fedavg))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dict1,"results/Test/feature skew/mixed_digit/fedavg/seed{}/similarity_dict1_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_list1,"results/Test/feature skew/mixed_digit/fedavg/seed{}/acc_list1_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedavg = "results/Test/feature skew/mixed_digit/fedavg/seed{}/global_model_fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_model1.state_dict(), path_fedavg)
    # for i in range(args.K):
    #     path_fedavg = "results/Test/feature skew/mixed_digit/fedavg/seed{}/client{}_model_fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_models1[i].state_dict(), path_fedavg)

In [None]:
d_g1 = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_g1 = 0.0
for index, testset in enumerate(testsets):
    #print(datasets_name[index])
    g1,_ = test_on_globaldataset_mixed_digit(args, global_model1, testset, dict_test[datasets_name[index]])
    d_g1[datasets_name[index]]  = g1 
    mean_g1 += g1/len(testsets)
d_g1

In [None]:
mean_g1

In [None]:
if Train_model:
    train_loss_show(args, loss_dict1,clients_index1)

In [None]:
del server_fedavg
torch.cuda.empty_cache()

In [None]:
#fedprox

In [None]:
server_fedprox_joint = server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_modelp, similarity_dictp, client_modelsp, loss_dictp, clients_indexp, acc_listp = server_fedprox_joint.fedbn(testsets, dict_test,clients_dataset_index, similarity=True,fe_optimizer_name = "fedprox", test_global_model_accuracy = True)
else:
    acc_listp = torch.load("results/Test/feature skew/mixed_digit/fedprox/seed{}/acc_listp_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelp = server_fedprox_joint.nn
    client_modelsp = server_fedprox_joint.nns
    path_fedprox = "results/Test/feature skew/mixed_digit/fedprox/seed{}/global_model_fedprox_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelp.load_state_dict(torch.load(path_fedprox))
    for i in range(args.K):
        path_fedprox = "results/Test/feature skew/mixed_digit/fedprox/seed{}/client{}_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsp[i].load_state_dict(torch.load(path_fedprox))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dictp,"results/Test/feature skew/mixed_digit/fedprox/seed{}/similarity_dictp_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_listp,"results/Test/feature skew/mixed_digit/fedprox/seed{}/acc_listp_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedprox = "results/Test/feature skew/mixed_digit/fedprox/seed{}/global_model_fedprox_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelp.state_dict(), path_fedprox)
    # for i in range(args.K):
    #     path_fedprox = "results/Test/feature skew/mixed_digit/fedprox/seed{}/client{}_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_modelsp[i].state_dict(), path_fedprox)

In [None]:
d_gp = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_gp= 0.0
for index, testset in enumerate(testsets):
    gp,_ = test_on_globaldataset_mixed_digit(args, global_modelp, testset, dict_test[datasets_name[index]])
    d_gp[datasets_name[index]] = gp 
    mean_gp += gp/5
d_gp

In [None]:
mean_gp

In [None]:
if Train_model:
    train_loss_show(args, loss_dictp,clients_indexp)

In [None]:
del server_fedprox_joint
torch.cuda.empty_cache()

In [None]:
#feddyn

In [None]:
server_feddyn = server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_modeldyn, similarity_dictdyn, client_modelsdyn, loss_dictdyn, clients_indexdyn, acc_listdyn = server_feddyn.fedbn(testsets, dict_test,clients_dataset_index, similarity=True,fe_optimizer_name = "feddyn", test_global_model_accuracy = True)
else:
    acc_listdyn = torch.load("results/Test/feature skew/mixed_digit/feddyn/seed{}/acc_listdyn_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modeldyn = server_feddyn.nn
    client_modelsdyn = server_feddyn.nns
    path_feddyn = "results/Test/feature skew/mixed_digit/feddyn/seed{}/global_model_feddyn_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modeldyn.load_state_dict(torch.load(path_feddyn))
    for i in range(args.K):
        path_feddyn = "results/Test/feature skew/mixed_digit/feddyn/seed{}/client{}_model_feddyn_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsdyn[i].load_state_dict(torch.load(path_feddyn))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dictdyn,"results/Test/feature skew/mixed_digit/feddyn/seed{}/similarity_dictdyn_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_listdyn,"results/Test/feature skew/mixed_digit/feddyn/seed{}/acc_listdyn_{}E_{}class.pt".format(args.seed,args.E,C))
    path_feddyn = "results/Test/feature skew/mixed_digit/feddyn/seed{}/global_model_feddyn_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modeldyn.state_dict(), path_feddyn)
    # for i in range(args.K):
    #     path_feddyn = "results/Test/feature skew/mixed_digit/feddyn/seed{}/client{}_model_feddyn_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_modelsdyn[i].state_dict(), path_feddyn)

In [None]:
d_gdyn ={'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_gdyn = 0.0
for index, testset in enumerate(testsets):
    gdyn,_ = test_on_globaldataset_mixed_digit(args, global_modeldyn, testset, dict_test[datasets_name[index]])
    d_gdyn[datasets_name[index]] = gdyn
    mean_gdyn += gdyn/len(testsets)
d_gdyn

In [None]:
mean_gdyn

In [None]:
if Train_model:
    train_loss_show(args, loss_dictdyn,clients_indexdyn)

In [None]:
del server_feddyn
torch.cuda.empty_cache()

In [None]:
#moon

In [None]:
server_moon = server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_modelm, similarity_dictm, client_modelsm, loss_dictm, clients_indexm, acc_listm = server_moon.fedbn(testsets, dict_test,clients_dataset_index,similarity=True, fe_optimizer_name = "moon", test_global_model_accuracy = True)
else:
    acc_listm = torch.load("results/Test/feature skew/mixed_digit/moon/seed{}/acc_listm_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelm = server_moon.nn
    client_modelsm = server_moon.nns
    path_moon = "results/Test/feature skew/mixed_digit/moon/seed{}/global_model_moon_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelm.load_state_dict(torch.load(path_moon))
    for i in range(args.K):
        path_moon = "results/Test/feature skew/mixed_digit/moon/seed{}/client{}_model_moon_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsm[i].load_state_dict(torch.load(path_moon))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dictm,"results/Test/feature skew/mixed_digit/moon/seed{}/similarity_dictm_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_listm,"results/Test/feature skew/mixed_digit/moon/seed{}/acc_listm_{}E_{}class.pt".format(args.seed,args.E,C))
    path_moon = "results/Test/feature skew/mixed_digit/moon/seed{}/global_model_moon_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelm.state_dict(), path_moon)
    # for i in range(args.K):
    #     path_moon = "results/Test/feature skew/mixed_digit/moon/seed{}/client{}_model_moon_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_modelsm[i].state_dict(), path_moon)

In [None]:
d_gm = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_gm = 0.0
for index, testset in enumerate(testsets):
    gm,_ = test_on_globaldataset_mixed_digit(args, global_modelm, testset, dict_test[datasets_name[index]])
    d_gm[datasets_name[index]] = gm
    mean_gm += gm/5
d_gm

In [None]:
mean_gm

In [None]:
if Train_model:
    train_loss_show(args, loss_dictm,clients_indexm)

In [None]:
del server_moon
torch.cuda.empty_cache()

In [None]:
#fedproc

In [None]:
server_fedproc = server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_modelproc, similarity_dictproc, client_modelsproc, loss_dictproc, clients_indexproc, acc_listproc= server_fedproc.fedbn(testsets, dict_test,clients_dataset_index, similarity=True,fe_optimizer_name = "fedproc", test_global_model_accuracy = True)
else:
    acc_listproc = torch.load("results/Test/feature skew/mixed_digit/fedproc/seed{}/acc_listproc_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelproc = server_fedproc.nn
    client_modelsproc = server_fedproc.nns
    path_fedproc = "results/Test/feature skew/mixed_digit/fedproc/seed{}/global_model_fedproc_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelproc.load_state_dict(torch.load(path_fedproc))
    for i in range(args.K):
        path_fedproc = "results/Test/feature skew/mixed_digit/fedproc/seed{}/client{}_model_fedproc_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsproc[i].load_state_dict(torch.load(path_fedproc))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dictproc,"results/Test/feature skew/mixed_digit/fedproc/seed{}/similarity_dictproc_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_listproc,"results/Test/feature skew/mixed_digit/fedproc/seed{}/acc_listproc_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedproc = "results/Test/feature skew/mixed_digit/fedproc/seed{}/global_model_fedproc_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelproc.state_dict(), path_fedproc)
    # for i in range(args.K):
    #     path_fedproc = "results/Test/feature skew/mixed_digit/fedproc/seed{}/client{}_model_fedproc_{}E_{}class.pt".format(args.seed,i,args.E,C)
    #     torch.save(client_modelsproc[i].state_dict(), path_fedproc)

In [None]:
d_gproc = {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_gproc = 0.0
for index, testset in enumerate(testsets):
    gproc,_ = test_on_globaldataset_mixed_digit(args, global_modelproc, testset, dict_test[datasets_name[index]])
    d_gproc[datasets_name[index]] = gproc
    mean_gproc += gproc/5
d_gproc

In [None]:
mean_gproc

In [None]:
if Train_model:
    train_loss_show(args, loss_dictproc,clients_indexproc)

In [None]:
del server_fedproc
torch.cuda.empty_cache()

In [None]:
#fedfa

In [None]:
server_fedfa =  server.Server(args, specf_model, trainsets, dict_users_train)

In [None]:
if Train_model:
    global_modelfa, similarity_dictfa, client_modelsfa, loss_dictfa, clients_indexfa, acc_listfa = server_fedfa.fedbn(testsets, dict_test,clients_dataset_index, similarity=True,fe_optimizer_name = "fedfa", test_global_model_accuracy = True)
else:
    acc_listfa = torch.load("results/Test/feature skew/mixed_digit/fedfa/seed{}/acc_listfa_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelfa = server_feature.nn
    client_modelsfa = server_feature.nns
    path_fedfa = "results/Test/feature skew/mixed_digit/fedfa/seed{}/global_model_fedfa_{}E_{}class".format(args.seed,args.E,C)
    global_modelfa.load_state_dict(torch.load(path_fedfa))
    for i in range(args.K):
        path_fedfa = "results/Test/feature skew/mixed_digit/fedfa/seed{}/client{}_model_fedfa_{}E_{}class".format(args.seed,i,args.E,C)
        client_modelsfa[i].load_state_dict(torch.load(path_fedfa))

In [None]:
if save_models:
    if similarity:
        torch.save(similarity_dictfa,"results/Test/feature skew/mixed_digit/fedfa/seed{}/similarity_dictfa_{}E_{}class.pt".format(args.seed,args.E,C))
    torch.save(acc_listfa,"results/Test/feature skew/mixed_digit/fedfa/seed{}/acc_listfa_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedfa = "results/Test/feature skew/mixed_digit/fedfa/seed{}/global_model_fedfa_{}E_{}class".format(args.seed,args.E,C)
    torch.save(global_modelfa.state_dict(), path_fedfa)
    # for i in range(args.K):
    #     path_fedfa = "results/Test/feature skew/mixed_digit/fedfa/seed{}/client{}_model_fedfa_{}E_{}class".format(args.seed,i,args.E,C)
    #     torch.save(client_modelsfa[i].state_dict(), path_fedfa)

In [None]:
d_gfa =  {'MNIST':[], 'SVHN':[], 'USPS':[], 'SynthDigits':[], 'MNIST-M':[]}
mean_gfa = 0.0
for index, testset in enumerate(testsets):
    gfa,_ = test_on_globaldataset_mixed_digit(args, global_modelfa, testset, dict_test[datasets_name[index]])
    d_gfa[datasets_name[index]]  = gfa 
    mean_gfa += gfa/5
d_gfa

In [None]:
mean_gfa

In [None]:
if Train_model:
    train_loss_show(args, loss_dictfa,clients_indexfa)

In [None]:
del server_fedfa
torch.cuda.empty_cache()