In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms 
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms

import random, os
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import copy


from fedlab.utils.dataset import FMNISTPartitioner,CIFAR10Partitioner,CIFAR100Partitioner
from fedlab.utils.functional import partition_report
    
from args import args_parser
import server_se1 as server
import model

from utils.global_test import test_on_globaldataset
from utils.sampling import testset_sampling,  trainset_sampling_label

args = args_parser()


def seed_torch(seed=args.seed):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed) 
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

seed_torch()
GLOBAL_SEED = 1
def worker_init_fn(worker_id):
    global GLOBAL_WORKER_ID
    GLOBAL_WORKER_ID = worker_id
    set_seed(GLOBAL_SEED + worker_id)

In [None]:
model_name = args.model_name
similarity = False
training_loss_show =True
init_model = model.Client_Model(args, name=model_name).to(args.device)
dict_users_test_iid = [[]]

C = f"{args.split}"
print(C)
D = f"{args.r}r_lr{args.lr}_decay{round(1-args.weight_decay, 4)}_M{args.momentum}_B{args.B}_C{args.C}_fima{args.r_ima}_W{args.window_size}_lrdecay{round(1-args.lr_ima_decay, 4)}_adap_ima_{args.dataset}_{args.K}"
print(D)
args.setup = D

In [None]:
if args.dataset == 'fmnist':
    #trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    trans_mnist = transforms.Compose([transforms.ToTensor()])
    root = "data/fmnist/"
    trainset = FashionMNIST(root=root, train=True, download=True, transform=trans_mnist)
    testset = FashionMNIST(root=root, train=False, download=True, transform=trans_mnist)

elif args.dataset == 'mnist':
    #for alexnet on fashion mnist
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])
    root = "data/mnist/"
    trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=trans_mnist)
    testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=trans_mnist)
elif args.dataset == 'cifar10':
    trans_cifar10 =  transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.491, 0.482, 0.447], 
                                                        std=[0.247, 0.243, 0.262])])
    # trans_cifar10 =  transforms.Compose([transforms.ToTensor()])   
    root = "data/CIFAR10/"
    trainset = torchvision.datasets.CIFAR10(root=root,train=True, download=True, transform=trans_cifar10)
    testset = torchvision.datasets.CIFAR10(root=root,train=False, download=True, transform=trans_cifar10)
elif args.dataset == 'cifar100':
    trans_cifar100 =  transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.491, 0.482, 0.447], 
                                                        std=[0.247, 0.243, 0.262])])
    # trans_cifar10 =  transforms.Compose([transforms.ToTensor()])   
    root = "data/CIFAR100/"
    trainset = torchvision.datasets.CIFAR100(root=root,train=True, download=True, transform=trans_cifar100)
    testset = torchvision.datasets.CIFAR100(root=root,train=False, download=True, transform=trans_cifar100)

else:
    exit('Error: unrecognized dataset')


In [None]:
num_classes = max(list(trainset.targets))+1
args.num_classes = num_classes
num_clients = args.K
number_perclass = args.num_perclass
 

col_names = [f"class{i}" for i in range(num_classes)]
print(col_names)
hist_color = '#4169E1'
plt.rcParams['figure.facecolor'] = 'white'

In [None]:
### Distribution-based (class)

In [None]:
# perform partition
if args.dataset == 'cifar100':
    noniid_labeldir_part = CIFAR100Partitioner(trainset.targets, 
                                            num_clients=num_clients,
                                            partition="dirichlet", 
                                            dir_alpha=0.5,
                                            seed=1)
elif '_2' in args.split:
    noniid_labeldir_part = CIFAR10Partitioner(trainset.targets, 
                                    num_clients=num_clients,
                                    balance=None, 
                                    partition="shards",
                                    num_shards=2*num_clients,
                                    seed=1)
elif 'dir' in args.split:
    noniid_labeldir_part = FMNISTPartitioner(trainset.targets, 
                                            num_clients=num_clients,
                                            partition="noniid-labeldir", 
                                            dir_alpha=float(args.split[-3:]),
                                            seed=1)
elif 'unbalance' in args.split:
    noniid_labeldir_part = FMNISTPartitioner(trainset.targets, 
                                    num_clients=num_clients,
                                    partition="unbalance", 
                                    dir_alpha=0.5,
                                    seed=args.seed)
elif 'iid' in args.split:
    noniid_labeldir_part = FMNISTPartitioner(trainset.targets, 
                            num_clients=num_clients,
                            partition="iid",
                            seed=1)
 
    
 
# generate partition report
csv_file = "data/fmnist/fmnist_noniid_labeldir_clients_10.csv"
partition_report(trainset.targets, noniid_labeldir_part.client_dict, 
                class_num=num_classes, 
                verbose=False, file=csv_file)

noniid_labeldir_part_df = pd.read_csv(csv_file,header=1)
noniid_labeldir_part_df = noniid_labeldir_part_df.set_index('client')
for col in col_names:
    noniid_labeldir_part_df[col] = (noniid_labeldir_part_df[col] * noniid_labeldir_part_df['Amount']).astype(int)

# select first 10 clients for bar plot
noniid_labeldir_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
# plt.tight_layout()
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.xlabel('sample num')
plt.savefig(f"data/fmnist//fmnist_noniid_labeldir_clients_10.png", 
            dpi=400, bbox_inches = 'tight')

In [None]:
trainset_sample_rate = args.trainset_sample_rate
rare_class_nums = 0
dict_users_train = trainset_sampling_label(args, trainset, trainset_sample_rate,rare_class_nums, noniid_labeldir_part) 
dict_users_test = testset_sampling(args, testset, number_perclass, noniid_labeldir_part_df)

In [None]:
training_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    training_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(trainset.targets)[list(dict_users_train[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        training_number[i][k] = list(np.array(trainset.targets)[list(dict_users_train[i])]).count(k)

In [None]:
df_training_number=[]
df_training_number=pd.DataFrame(df_training_number)
for i in range(args.K):
    temp = pd.Series(training_number[i])
    df_training_number[i]= temp
    
df_training_number['Col_sum'] = df_training_number.apply(lambda x: x.sum(), axis=1)
df_training_number.loc['Row_sum'] = df_training_number.apply(lambda x: x.sum())

df_training_number

In [None]:
test_number = {j:{}  for j in range(args.K)}

for i in range(args.K):
    test_number[i] = {j: 0 for  j in range(num_classes)}
    label_class = set (np.array(testset.targets)[list(dict_users_test[i])].tolist())
    #print(list(label_class))
    for k in label_class:
        test_number[i][k] = list(np.array(testset.targets)[list(dict_users_test[i])]).count(k)

In [None]:
df_test_number=[]
df_test_number=pd.DataFrame(df_test_number)
for i in range(args.K):
    temp = pd.Series(test_number[i])
    df_test_number[i]= temp
    
df_test_number['Col_sum'] = df_test_number.apply(lambda x: x.sum(), axis=1)
df_test_number.loc['Row_sum'] = df_test_number.apply(lambda x: x.sum())

df_test_number

In [None]:
#  baseline---->fedavg

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedavg = server.Server(args, specf_model, trainset, dict_users_train)
global_model, _, _, loss_dict, clients_index, acc_list = server_fedavg.fedavg(testset, dict_users_test,
                                                                                    agg_mode='ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fedavg adaptive-ima start round:", server_fedavg.args.r_ima)

In [None]:
g1,_ = test_on_globaldataset(args, global_model, testset)
print(g1)

In [None]:
del server_fedavg

In [None]:
#  baseline---->fedprox 

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedprox = server.Server(args, specf_model, trainset, dict_users_train)
# server_fedprox_joint = copy.deepcopy(serverz)

global_modelp, _, _, loss_dictp, clients_indexp, acc_listp = server_fedprox.fedprox(testset, dict_users_test,  
                                                                                    agg_mode='ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fedprox adaptive-ima start round:", server_fedprox.args.r_ima)

In [None]:
gp,_ = test_on_globaldataset(args, global_modelp, testset)
gp

In [None]:
del server_fedprox

In [None]:
#  baseline---->fedasam

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedasam = server.Server(args, specf_model, trainset, dict_users_train)
# server_fedprox_joint = copy.deepcopy(serverz)

global_modelasam, _, _, loss_dictasam, clients_indexasam, acc_listasam = server_fedasam.fedsam(testset, dict_users_test,  
                                                                                    agg_mode='ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fedasam adaptive-ima start round:", server_fedasam.args.r_ima)

In [None]:
gasam,_ = test_on_globaldataset(args, global_modelasam, testset)
gasam

In [None]:
del server_fedasam

In [None]:
#  baseline---->fednova

In [None]:
specf_model = copy.deepcopy(init_model)
server_fednova = server.Server(args, specf_model, trainset, dict_users_train)

global_modelnova, _, _, loss_dictnova, clients_indexnova, acc_listnova = server_fednova.fednova(testset, dict_users_test,  
                                                                                    agg_mode='fednova+ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fednova adaptive-ima start round:", server_fednova.args.r_ima)

In [None]:
gnova,_ = test_on_globaldataset(args, global_modelnova, testset)
gnova

In [None]:
del server_fednova

In [None]:
#  baseline---->fedadam

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedadam = server.Server(args, specf_model, trainset, dict_users_train)

global_modeladam, _, _, loss_dictadam, clients_indexadam, acc_listadam = server_fedadam.fedavg(testset, dict_users_test,  
                                                                                    agg_mode='fedadam+ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fedadam adaptive-ima start round:", server_fedadam.args.r_ima)

In [None]:
gadam,_ = test_on_globaldataset(args, global_modeladam, testset)
gadam

In [None]:
del server_fedadam

In [None]:
#  baseline---->fedyogi

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedyogi = server.Server(args, specf_model, trainset, dict_users_train)

global_modelyogi, _, _, loss_dictyogi, clients_indexyogi, acc_listyogi = server_fedyogi.fedavg(testset, dict_users_test,  
                                                                                    agg_mode='fedyogi+ima',
                                                                                    test_global_model_accuracy = True)


In [None]:
print("fedyogi adaptive-ima start round:", server_fedyogi.args.r_ima)

In [None]:
gyogi,_ = test_on_globaldataset(args, global_modelyogi, testset)
gyogi

In [None]:
del server_fedyogi

In [None]:
#  baseline---->fedgma

In [None]:
specf_model = copy.deepcopy(init_model)
server_fedgma = server.Server(args, specf_model, trainset, dict_users_train)

global_modelgma, _, _, loss_dictgma, clients_indexgma, acc_listgma = server_fedgma.fedavg(testset, dict_users_test,
                                                                                          agg_mode='gma+ima',
                                                                                          test_global_model_accuracy = True)

                                                                                                                               

In [None]:
print("fedgma adaptive-ima start round:", server_fedgma.args.r_ima)

In [None]:
ggma,_ = test_on_globaldataset(args, global_modelgma, testset)
ggma

In [None]:
del server_fedgma

In [None]:
#fedfa

In [None]:
specf_model = copy.deepcopy(init_model)
server_feature = server.Server(args, specf_model, trainset, dict_users_train)
# server_feature = copy.deepcopy(serverz)

global_modelfa, _, _, loss_dictfa, clients_indexfa, acc_listfa = server_feature.fedfa_anchorloss(testset, 
                                                                                                 dict_users_test,
                                                                                                 agg_mode='ima',
                                                                                                 test_global_model_accuracy = True)



In [None]:
print("fedfa adaptive-ima start round:", server_feature.args.r_ima)

In [None]:
gfa,_ = test_on_globaldataset(args, global_modelfa, testset)
gfa

In [None]:
del server_feature

In [None]:
plt.plot(acc_list,label='FedAvg+IMA')
plt.plot(acc_listp,label='FedProx+IMA')
plt.plot(acc_listasam,label='FedASAM+IMA')
plt.plot(acc_listnova,label='FedNova+IMA')
plt.plot(acc_listadam,label='FedAdam+IMA')
plt.plot(acc_listyogi,label='FedYogi+IMA')
plt.plot(acc_listgma,label='FedGMA+IMA')
plt.plot(acc_listfa,label='FedFA+IMA')
plt.legend()
# plt.savefig(f'acc_comparison_allmethods_{args.dataset}_ima_{args.split}.pdf')