In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms

import random, os
import numpy as np
from math import sqrt
from matplotlib import pyplot as plt
import pandas as pd
import copy

from fedlab.utils.dataset.partition import CIFAR10Partitioner
from fedlab.utils.dataset import FMNISTPartitioner
from fedlab.utils.functional import partition_report, save_dict
    
from args_femnist import args_parser
import server_se1 as server
import model

from utils.global_test import test_on_globaldataset, globalmodel_test_on_localdataset,globalmodel_test_on_specifdataset,verify_feature_consistency
from utils.local_test import test_on_localdataset
from utils.femnist_dataset import *
from utils.training_loss import train_loss_show,train_localacc_show
from utils.sampling import *


args = args_parser()


def seed_torch(seed=args.seed):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) 
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

seed_torch()
GLOBAL_SEED = 1
def worker_init_fn(worker_id):
    global GLOBAL_WORKER_ID
    GLOBAL_WORKER_ID = worker_id
    set_seed(GLOBAL_SEED + worker_id)

cka = False
save_models = False
Train_model = True

In [None]:
#Feature Distribution Skew

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1., net_id=None, total=0):
        self.std = std
        self.mean = mean
        self.net_id = net_id
        self.num = int(sqrt(total))
        if self.num * self.num < total:
            self.num = self.num + 1

    def __call__(self, tensor):
        if self.net_id is None:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean
        else:
            tmp = torch.randn(tensor.size())
            filt = torch.zeros(tensor.size())
            size = int(28 / self.num)
            row = int(self.net_id / size)
            col = self.net_id % size
            for i in range(size):
                for j in range(size):
                    filt[:, row * size + i, col * size + j] = 1
            tmp = tmp * filt
            return tensor + tmp * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
# noise = 0.1
# num_clients = 10
# img_idx = 17

# fig = plt.figure(figsize=(15, 10))

# for cid in range(num_clients):
#     if cid == num_clients - 1:
#         noise_level = 0
#     else:
#         noise_level = noise / num_clients * (cid + 1)  # a little different from original NIID-bench
#     transform = transforms.Compose([transforms.ToTensor(),
#                                     AddGaussianNoise(0., noise_level)])
#     trainset_feature_skew = FashionMNIST(root=root, train=True, download=True, 
#                                          transform=transform)
#     ax = fig.add_subplot(2, num_clients/2, cid + 1, xticks=[], yticks=[])
#     ax.imshow(np.squeeze(trainset_feature_skew[img_idx][0]), cmap='viridis')
#     ax.set_title(f"Client {cid}: noise$\sim$Gau({noise_level:.3f})")
#     ax.patch.set_facecolor('white')
#     fig.tight_layout()
    
# # fig.savefig("../imgs/fmnist_feature_skew_vis.png", dpi=400, bbox_inches = 'tight'

In [None]:
root = "data/femnist/"
root_logger = "data/femnist/logger/"

transform = transforms.Compose([transforms.ToTensor()])
trainset = FEMNIST(root, train=True, transform=transform, download=True)
testset = FEMNIST(root, train=False, transform=transform, download=True)

In [None]:
_, _, _, _, dict_users_train, traindata_cls_counts = partition_data(
    args = args,
    dataset = "femnist", 
    datadir = root, 
    logdir = root_logger, 
    partition = "transfer-from-femnist", 
    n_parties = args.K)

In [None]:
_, _, _, _, dict_users_train_iid, traindata_cls_counts = partition_data(
    args = args,
    dataset = "femnist", 
    datadir = root, 
    logdir = root_logger, 
    partition = "homo", 
    n_parties = args.K)

In [None]:
### class-based  

In [None]:
num_classes = args.num_classes
num_clients = args.K
number_perclass = args.num_perclass
 

col_names = [f"class{i}" for i in range(num_classes)]
print(col_names)
hist_color = '#4169E1'
plt.rcParams['figure.facecolor'] = 'white'

In [None]:
# # perform partition
# noniid_labeldir_part = FMNISTPartitioner(trainset.targets,  
#                                            num_clients=num_clients,
#                                            partition="noniid-#label", 
#                                            major_classes_num=2,
#                                            seed=args.seed)
# # generate partition report
# csv_file = "data/fmnist/fmnist_noniid_labeldir_clients_10.csv"
# partition_report(trainset.targets, noniid_labeldir_part.client_dict, 
#                  class_num=num_classes, 
#                  verbose=False, file=csv_file)

# noniid_labeldir_part_df = pd.read_csv(csv_file,header=1)
# noniid_labeldir_part_df = noniid_labeldir_part_df.set_index('client')
# for col in col_names:
#     noniid_labeldir_part_df[col] = (noniid_labeldir_part_df[col] * noniid_labeldir_part_df['Amount']).astype(int)

# # select first 10 clients for bar plot
# noniid_labeldir_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
# # plt.tight_layout()
# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# plt.xlabel('sample num')
# plt.savefig(f"data/fmnist//fmnist_noniid_labeldir_clients_10.png", 
#             dpi=400, bbox_inches = 'tight')

# # split dataset into training and testing

In [None]:
### Distribution-based (Dirichlet)

In [None]:
# # perform partition
# noniid_labeldir_part = FMNISTPartitioner(trainset.targets, 
#                                         num_clients=num_clients,
#                                         partition="noniid-labeldir", 
#                                         dir_alpha=0.1,
#                                         seed=args.seed)

# # generate partition report
# csv_file = "data/fmnist/fmnist_noniid_labeldir_clients_10.csv"
# partition_report(trainset.targets, noniid_labeldir_part.client_dict, 
#                  class_num=num_classes, 
#                  verbose=False, file=csv_file)

# noniid_labeldir_part_df = pd.read_csv(csv_file,header=1)
# noniid_labeldir_part_df = noniid_labeldir_part_df.set_index('client')
# for col in col_names:
#     noniid_labeldir_part_df[col] = (noniid_labeldir_part_df[col] * noniid_labeldir_part_df['Amount']).astype(int)

# # select first 10 clients for bar plot
# noniid_labeldir_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
# # plt.tight_layout()
# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# plt.xlabel('sample num')
# plt.savefig(f"data/fmnist//fmnist_noniid_labeldir_clients_10.png", 
#             dpi=400, bbox_inches = 'tight')

In [None]:
# Quantity Skew (label)

In [None]:
# # perform partition
# noniid_labeldir_part = FMNISTPartitioner(trainset.targets, 
#                                   num_clients=num_clients,
#                                   partition="unbalance", 
#                                   dir_alpha=0.5,
#                                   seed=args.seed)

# # generate partition report
# csv_file = "data/fmnist//fmnist_unbalance_clients_10.csv"
# partition_report(trainset.targets, noniid_labeldir_part.client_dict, 
#                  class_num=num_classes, 
#                  verbose=False, file=csv_file)

# noniid_labeldir_part_df = pd.read_csv(csv_file,header=1)
# noniid_labeldir_part_df = noniid_labeldir_part_df.set_index('client')
# for col in col_names:
#     noniid_labeldir_part_df[col] = (noniid_labeldir_part_df[col] * noniid_labeldir_part_df['Amount']).astype(int)

# # select first 10 clients for bar plot
# noniid_labeldir_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
# # plt.tight_layout()
# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# plt.xlabel('sample num')
# plt.savefig(f"data/fmnist/fmnist_unbalance_clients_10.png", 
#             dpi=400, bbox_inches = 'tight')

In [None]:
# noniid_labeldir_part_df

In [None]:
clients_labeset_femnist = {i:set(np.array(trainset.targets)[list(dict_users_train[1])]) for i in range(args.K)}

In [None]:
trainset_sample_rate = args.trainset_sample_rate
rare_class_nums = 0
dict_users_train = trainset_sampling_label_femnist(args, trainset, trainset_sample_rate, rare_class_nums, dict_users_train) 
dict_users_test = testset_sampling_femnist(args, testset, number_perclass, clients_labeset_femnist)

In [None]:
training_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    training_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(trainset.targets)[list(dict_users_train[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        training_number[i][k] = list(np.array(trainset.targets)[list(dict_users_train[i])]).count(k)

In [None]:
df_training_number=[]
df_training_number=pd.DataFrame(df_training_number)
for i in range(args.K):
    temp = pd.Series(training_number[i])
    df_training_number[i]= temp
    
df_training_number['Col_sum'] = df_training_number.apply(lambda x: x.sum(), axis=1)
df_training_number.loc['Row_sum'] = df_training_number.apply(lambda x: x.sum())

df_training_number

In [None]:
test_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    test_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(testset.targets)[list(dict_users_test[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        test_number[i][k] = list(np.array(testset.targets)[list(dict_users_test[i])]).count(k)

In [None]:
df_test_number=[]
df_test_number=pd.DataFrame(df_test_number)
for i in range(args.K):
    temp = pd.Series(test_number[i])
    df_test_number[i]= temp
    
df_test_number['Col_sum'] = df_test_number.apply(lambda x: x.sum(), axis=1)
df_test_number.loc['Row_sum'] = df_test_number.apply(lambda x: x.sum())

df_test_number

In [None]:
# # perform partition
# iid_part = FMNISTPartitioner(trainset.targets, 
#                             num_clients=num_clients,
#                             partition="iid",
#                             seed=args.seed)

# # generate partition report
# csv_file = "data/fmnist/fmnist_iid_clients_10.csv"
# partition_report(trainset.targets, iid_part.client_dict, 
#                  class_num=num_classes, 
#                  verbose=False, file=csv_file)

# iid_part_df = pd.read_csv(csv_file,header=1)
# iid_part_df = iid_part_df.set_index('client')
# for col in col_names:
#     iid_part_df[col] = (iid_part_df[col] * iid_part_df['Amount']).astype(int)

# # select first 10 clients for bar plot
# iid_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
# # plt.tight_layout()
# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# plt.xlabel('sample num')
# plt.savefig(f"data/fmnist/fmnist_iid_clients_10.png", 
#             dpi=400, bbox_inches = 'tight')

In [None]:
clients_labeset_femnist_iid = {i:set(np.array(trainset.targets)[list(dict_users_train_iid[i])]) for i in range(args.K)}

In [None]:
dict_users_train_iid = trainset_sampling_label_femnist(args, trainset, trainset_sample_rate,rare_class_nums, dict_users_train_iid) 
dict_users_test_iid = testset_sampling_femnist(args, testset, number_perclass, clients_labeset_femnist_iid)

In [None]:
test_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    test_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(testset.targets)[list(dict_users_test_iid[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        test_number[i][k] = list(np.array(testset.targets)[list(dict_users_test_iid[i])]).count(k)

In [None]:
df_test_number=[]
df_test_number=pd.DataFrame(df_test_number)
for i in range(args.K):
    temp = pd.Series(test_number[i])
    df_test_number[i]= temp
    
df_test_number['Col_sum'] = df_test_number.apply(lambda x: x.sum(), axis=1)
df_test_number.loc['Row_sum'] = df_test_number.apply(lambda x: x.sum())

df_test_number

In [None]:
training_number_iid = {j:{}  for j in range(args.K)}

for i in range(args.K):
    training_number_iid[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(trainset.targets)[list(dict_users_train_iid[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        training_number_iid[i][k] = list(np.array(trainset.targets)[list(dict_users_train_iid[i])]).count(k)

In [None]:
df_training_number_iid=[]
df_training_number_iid=pd.DataFrame(df_training_number_iid)
for i in range(args.K):
    temp = pd.Series(training_number_iid[i])
    df_training_number_iid[i]= temp
    
df_training_number_iid['Col_sum'] = df_training_number_iid.apply(lambda x: x.sum(), axis=1)
df_training_number_iid.loc['Row_sum'] = df_training_number_iid.apply(lambda x: x.sum())

df_training_number_iid

In [None]:
specf_model = model.Client_Model(args, name="fmnist").to(args.device)

In [None]:
serverz = server.Server(args, specf_model, trainset, dict_users_train)#dict_users指的是user的local dataset索引

In [None]:
C = "_2CNN"

In [None]:
#  baseline----> iid setting with fedavg

In [None]:
server_iid = server.Server(args, specf_model, trainset, dict_users_train_iid)

In [None]:
if Train_model:
    global_model_iid, _, client_models_iid, loss_dict_iid, clients_index_iid, acc_list_iid = server_iid.fedavg_joint_update(testset, dict_users_test_iid[0], test_global_model_accuracy = True)
else:
    acc_list_iid = torch.load("results/Test/feature skew/emnist/iid-fedavg/seed{}/acc_list_iid_{}E_{}class.pt".format(args.seed,args.E,C))
    global_model_iid = server_iid.nn
    client_models_iid = server_iid.nns
    path_iid_fedavg = "results/Test/feature skew/emnist/iid-fedavg/seed{}/global_model_iid-fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    global_model_iid.load_state_dict(torch.load(path_iid_fedavg))
    for i in range(args.K):
        path_iid_fedavg = "results/Test/feature skew/emnist/iid-fedavg/seed{}/client{}_model_iid-fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_models_iid[i].load_state_dict(torch.load(path_iid_fedavg))

In [None]:
if save_models:
    torch.save(acc_list_iid,"results/Test/feature skew/emnist/iid-fedavg/seed{}/acc_list_iid_{}E_{}class.pt".format(args.seed,args.E,C))
    path_iid_fedavg = "results/Test/feature skew/emnist/iid-fedavg/seed{}/global_model_iid-fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_model_iid.state_dict(), path_iid_fedavg)

In [None]:
g_iid,_ = test_on_globaldataset(args, global_model_iid, testset)
g_iid

In [None]:
a_iid,_ =globalmodel_test_on_localdataset(args,global_model_iid, testset,dict_users_test_iid)
np.mean(list(a_iid.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dict_iid,clients_index_iid)
    plt.plot(range(args.r), acc_list_iid)

In [None]:
#  baseline---->fedavg

In [None]:
server_fedavg =  copy.deepcopy(serverz)#dict_users指的是user的local dataset索引

In [None]:
if Train_model:
    global_model1, personalized_models1, client_models1, loss_dict1, clients_index1, acc_list1 = server_fedavg.fedavg_joint_update(testset, dict_users_test_iid[0],test_global_model_accuracy = True)
else:
    acc_list1 = torch.load("results/Test/feature skew/emnist/fedavg/seed{}/acc_list1_{}E_{}class.pt".format(args.seed,args.E,C))
    global_model1 = server_fedavg.nn
    client_models1 = server_fedavg.nns
    path_fedavg = "results/Test/feature skew/emnist/fedavg/seed{}/global_model_fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    global_model1.load_state_dict(torch.load(path_fedavg))
    for i in range(args.K):
        path_fedavg = "results/Test/feature skew/emnist/fedavg/seed{}/client{}_model_fedavg_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_models1[i].load_state_dict(torch.load(path_fedavg))

In [None]:
if save_models:
    torch.save(acc_list1,"results/Test/feature skew/emnist/fedavg/seed{}/acc_list1_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedavg = "results/Test/feature skew/emnist/fedavg/seed{}/global_model_fedavg_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_model1.state_dict(), path_fedavg)

In [None]:
g1,_ = test_on_globaldataset(args, global_model1, testset)
g1

In [None]:
a1,_ =globalmodel_test_on_localdataset(args,global_model1, testset,dict_users_test)
np.mean(list(a1.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dict1,clients_index1)

In [None]:
#  baseline---->fedprox

In [None]:
server_fedprox_joint = copy.deepcopy(serverz)

In [None]:
if Train_model:
    global_modelp, _, client_modelsp, loss_dictp, clients_indexp, acc_listp = server_fedprox_joint.fedprox_joint_update(testset, dict_users_test_iid[0],test_global_model_accuracy = True)
else:
    acc_listp = torch.load("results/Test/feature skew/emnist/fedprox/seed{}/acc_listp_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelp = server_fedprox_joint.nn
    client_modelsp = server_fedprox_joint.nns
    path_fedprox = "results/Test/feature skew/emnist/fedprox/seed{}/global_model_fedprox_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelp.load_state_dict(torch.load(path_fedprox))
    for i in range(args.K):
        path_fedprox = "results/Test/feature skew/emnist/fedprox/seed{}/client{}_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsp[i].load_state_dict(torch.load(path_fedprox))

In [None]:
if save_models:
    torch.save(acc_listp,"results/Test/feature skew/emnist/fedprox/seed{}/acc_listp_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedprox = "results/Test/feature skew/emnist/fedprox/seed{}/global_model_fedprox_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelp.state_dict(), path_fedprox)

In [None]:
gp,_ = test_on_globaldataset(args, global_modelp, testset)
gp

In [None]:
ap,_ =globalmodel_test_on_localdataset(args,global_modelp, testset,dict_users_test)
np.mean(list(ap.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dictp,clients_indexp)

In [None]:
#  baseline---->feddyn

In [None]:
server_feddyn = copy.deepcopy(serverz)

In [None]:
if Train_model:
    global_modeldyn, personalized_modeldyn, client_modelsdyn, loss_dictdyn, clients_indexdyn, acc_listdyn = server_feddyn.feddyn(testset, dict_users_test_iid[0],test_global_model_accuracy = True)
else:
    acc_listdyn = torch.load("results/Test/feature skew/emnist/feddyn/seed{}/acc_listdyn_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modeldyn = server_feddyn.nn
    client_modelsdyn = server_feddyn.nns
    path_feddyn = "results/Test/feature skew/emnist/feddyn/seed{}/global_model_feddyn_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modeldyn.load_state_dict(torch.load(path_feddyn))
    for i in range(args.K):
        path_feddyn = "results/Test/feature skew/emnist/feddyn/seed{}/client{}_model_feddyn_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsdyn[i].load_state_dict(torch.load(path_feddyn))

In [None]:
if save_models:
    torch.save(acc_listdyn,"results/Test/feature skew/emnist/feddyn/seed{}/acc_listdyn_{}E_{}class.pt".format(args.seed,args.E,C))
    path_feddyn = "results/Test/feature skew/emnist/feddyn/seed{}/global_model_feddyn_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modeldyn.state_dict(), path_feddyn)

In [None]:
gdyn,_ = test_on_globaldataset(args, global_modeldyn, testset)
gdyn

In [None]:
adyn,_ =globalmodel_test_on_localdataset(args,global_modeldyn, testset,dict_users_test)
np.mean(list(adyn.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dictdyn,clients_indexdyn)

In [None]:
#  baseline---->moon

In [None]:
server_moon = copy.deepcopy(serverz)
if Train_model:
    global_modelm, _, client_modelsm, loss_dictm, clients_indexm, acc_listm = server_moon.moon(testset, dict_users_test_iid[0],test_global_model_accuracy = True)
else:
    acc_listm = torch.load("results/Test/feature skew/emnist/moon/seed{}/acc_listm_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelm = server_moon.nn
    client_modelsm = server_moon.nns
    path_moon = "results/Test/feature skew/emnist/moon/seed{}/global_model_moon_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelm.load_state_dict(torch.load(path_moon))
    for i in range(args.K):
        path_moon = "results/Test/feature skew/emnist/moon/seed{}/client{}_model_moon_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsm[i].load_state_dict(torch.load(path_moon))

In [None]:
if save_models:
    torch.save(acc_listm,"results/Test/feature skew/emnist/moon/seed{}/acc_listm_{}E_{}class.pt".format(args.seed,args.E,C))
    path_moon = "results/Test/feature skew/emnist/moon/seed{}/global_model_moon_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelm.state_dict(), path_moon)

In [None]:
gm,_ = test_on_globaldataset(args, global_modelm, testset)
gm

In [None]:
am,_ =globalmodel_test_on_localdataset(args,global_modelm, testset,dict_users_test)
np.mean(list(am.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dictm,clients_indexm)

In [None]:
#baseline fedproc

In [None]:
server_fedproc =  copy.deepcopy(serverz)

In [None]:
if Train_model:
    global_modelproc, _, client_modelsproc, loss_dictproc, clients_indexproc, acc_listproc= server_fedproc.fedproc(testset, dict_users_test_iid[0],test_global_model_accuracy = True)
else:
    acc_listproc = torch.load("results/Test/feature skew/emnist/fedproc/seed{}/acc_listproc_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelproc = server_fedproc.nn
    client_modelsproc = server_fedproc.nns
    path_fedproc = "results/Test/feature skew/emnist/fedproc/seed{}/global_model_fedproc_{}E_{}class.pt".format(args.seed,args.E,C)
    global_modelproc.load_state_dict(torch.load(path_fedproc))
    for i in range(args.K):
        path_fedproc = "results/Test/feature skew/emnist/fedproc/seed{}/client{}_model_fedproc_{}E_{}class.pt".format(args.seed,i,args.E,C)
        client_modelsproc[i].load_state_dict(torch.load(path_fedproc))

In [None]:
if save_models:
    torch.save(acc_listproc,"results/Test/feature skew/emnist/fedproc/seed{}/acc_listproc_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedproc = "results/Test/feature skew/emnist/fedproc/seed{}/global_model_fedproc_{}E_{}class.pt".format(args.seed,args.E,C)
    torch.save(global_modelproc.state_dict(), path_fedproc)

In [None]:
gproc,_ = test_on_globaldataset(args, global_modelproc, testset)
gproc

In [None]:
aproc,_ =globalmodel_test_on_localdataset(args,global_modelproc, testset,dict_users_test)
np.mean(list(aproc.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dictproc,clients_indexproc)

In [None]:
#  our method---->fedfa

In [None]:
server_feature = copy.deepcopy(serverz)

In [None]:
if Train_model:
    global_modelfa, _, client_modelsfa, loss_dictfa, clients_indexfa, acc_listfa = server_feature.fedfa_anchorloss(testset, dict_users_test,
                                                                                                             test_global_model_accuracy = True)
else:
    acc_listfa = torch.load("results/Test/feature skew/emnist/fedfa/seed{}/acc_listfa_{}E_{}class.pt".format(args.seed,args.E,C))
    global_modelfa = server_feature.nn
    client_modelsfa = server_feature.nns
    path_fedfa = "results/Test/feature skew/emnist/fedfa/seed{}/global_model_fedfa_{}E_{}class".format(args.seed,args.E,C)
    global_modelfa.load_state_dict(torch.load(path_fedfa))
    for i in range(args.K):
        path_fedfa = "results/Test/feature skew/emnist/fedfa/seed{}/client{}_model_fedfa_{}E_{}class".format(args.seed,i,args.E,C)
        client_modelsfa[i].load_state_dict(torch.load(path_fedfa))

In [None]:
if save_models:
    torch.save(acc_listfa,"results/Test/feature skew/emnist/fedfa/seed{}/acc_listfa_{}E_{}class.pt".format(args.seed,args.E,C))
    path_fedfa = "results/Test/feature skew/emnist/fedfa/seed{}/global_model_fedfa_{}E_{}class".format(args.seed,args.E,C)
    torch.save(global_modelfa.state_dict(), path_fedfa)

In [None]:
gfa,_ = test_on_globaldataset(args, global_modelfa, testset)
gfa

In [None]:
afa,_ =globalmodel_test_on_localdataset(args,global_modelfa, testset,dict_users_test)
np.mean(list(afa.values()))

In [None]:
if Train_model:
    train_loss_show(args, loss_dictfa,clients_indexfa)

In [None]:
plt.plot(range(args.r),acc_list_iid, label="iid-fedavg")
plt.plot(range(args.r),acc_listfa, label="fedfa")
plt.plot(range(args.r),acc_list1, label="fedavg")
plt.plot(range(args.r),acc_listp, label="fedprox")
plt.plot(range(args.r),acc_listdyn, label="feddyn")
plt.plot(range(args.r),acc_listm, label="moon")
plt.plot(range(args.r),acc_listproc,label="fedproc")
plt.legend()