In [None]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.sampling import partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit
from models.test import test_img
from models.resnet_client import resnet20, resnet16, resnet8

In [None]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--dataset','cinic', '--momentum','0.9', '--alpha','1',
                             '--epochs','50', '--gpu','0', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

In [None]:
# load dataset and split users
# No Public Data Partition

if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform = trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform = trans_mnist)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])


        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

In [None]:
# Initialize model
model_init = {}
acc_init_test = []
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    acc_init_test.append(acc_test.item())
print("mean AccTop1 on all clients:",float(np.mean(np.array(acc_init_test))))

In [None]:
# copy init_model_parameters to model
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
global_protos = []
idxs_users = np.arange(len(dict_users))
train_loss, train_accuracy = [], []

def agg_func(protos):
    """
    Returns the average of the weights.
    """

    for [label, proto_list] in protos.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            protos[label] = proto / len(proto_list)
        else:
            protos[label] = proto_list[0]

    return protos

def proto_aggregation(local_protos_list):
    agg_protos_label = dict()
    for idx in local_protos_list:
        local_protos = local_protos_list[idx]
        for label in local_protos.keys():
            if label in agg_protos_label:
                agg_protos_label[label].append(local_protos[label])
            else:
                agg_protos_label[label] = [local_protos[label]]

    for [label, proto_list] in agg_protos_label.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            agg_protos_label[label] = [proto / len(proto_list)]
        else:
            agg_protos_label[label] = [proto_list[0].data]

    return agg_protos_label

for round in range(1):
    
     acc_all=[]
    
     local_weights, local_losses, local_protos = [], [], {}
     print(f'\n | Global Training Round : {round + 1} |\n')

     proto_loss = 0
     for idx in range(len(dict_users)):
          model[idx].train()
          criterionCE = nn.CrossEntropyLoss()
          
          optimizer = torch.optim.SGD(model[idx].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
          ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size = 8, shuffle = True)
          for iter in range(1):         # train_ep = 1
               agg_protos_label = {}    # Aggregate local prototypes by category
               for batch_idx, (images, label_g) in enumerate(ldr_train):
                    images, labels = images.to(args.device), label_g.to(args.device)
                    model[idx].zero_grad()
                    log_probs = model[idx](images)
                    protos = model[idx].features(images)
                    loss1 = criterionCE(log_probs, labels)

                    loss_mse = nn.MSELoss()
                    if len(global_protos) == 0:
                         loss2 = 0 * loss1
                    else:
                         proto_new = copy.deepcopy(protos.data)
                         i = 0
                         for label in labels:
                              if label.item() in global_protos.keys():
                                   proto_new[i, :] = global_protos[label.item()][0].data
                              i += 1
                         loss2 = loss_mse(proto_new, protos)
                    loss = loss1 + loss2 * 1
                    loss.backward()
                    optimizer.step()

                    for i in range(len(labels)):
                         if label_g[i].item() in agg_protos_label:
                              agg_protos_label[label_g[i].item()].append(protos[i,:])
                         else:
                              agg_protos_label[label_g[i].item()] = [protos[i,:]]

                    log_probs = log_probs[:, 0:args.num_classes]
                    _, y_hat = log_probs.max(1)
                    acc_val = torch.eq(y_hat, labels.squeeze()).float().mean()
          
          print(agg_protos_label)
          agg_protos = agg_func(agg_protos_label)
          local_protos[idx] = agg_protos

          # update global protos
          global_protos = proto_aggregation(local_protos)
          if idx == 9:
            print(global_protos)
          model[idx].eval()
          acc_fine_test = test_img(model[idx], dataset_test, args)
        #   print("round:",round,"idx:",idx,"Train Testing accuracy: {:.2f}".format(acc_fine_test))
          acc_all.append(acc_fine_test.item())
    
     print("mean Fine_Test/AccTop1 on all clients:",float(np.mean(np.array(acc_all))))
     args.lr = args.lr * (1 - round / args.epochs * 0.9)