In [None]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.sampling import partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit
from models.test import test_img
from models.resnet_client import resnet20, resnet16, resnet8
from collections import defaultdict

In [None]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--dataset','cinic', '--momentum','0.9', '--alpha','5',
                             '--epochs','50', '--gpu','0', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

In [None]:
# load dataset and split users
# No Public Data Partition

if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform = trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform = trans_mnist)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])                     # 记录标签 array 数组
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])

        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

In [None]:
# Initialize model
model_init = {}
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))

In [None]:
# copy init_model_parameters to model
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
class Trainable_Global_Prototypes(nn.Module):
    def __init__(self, num_classes, server_hidden_dim, feature_dim, device):
        super().__init__()

        self.device = device

        self.embedings = nn.Embedding(num_classes, feature_dim)
        layers = [nn.Sequential(
            nn.Linear(feature_dim, server_hidden_dim), 
            nn.ReLU()
        )]
        self.middle = nn.Sequential(*layers)
        self.fc = nn.Linear(server_hidden_dim, feature_dim)

    def forward(self, class_id):
        class_id = torch.tensor(class_id, device=self.device)

        emb = self.embedings(class_id)
        mid = self.middle(emb)
        out = self.fc(mid)

        return out

TGP = Trainable_Global_Prototypes(args.num_classes, 256, 256, args.device).to(args.device)
print(TGP)

In [None]:
global_protos = []

def agg_func(protos):
    """
    Returns the average of the weights.
    """

    for [label, proto_list] in protos.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            protos[label] = proto / len(proto_list)
        else:
            protos[label] = proto_list[0]

    return protos

def proto_cluster(protos_list):
    proto_clusters = defaultdict(list)
    for protos in protos_list:
        for k in protos.keys():
            proto_clusters[k].append(protos[k])

    for k in proto_clusters.keys():
        protos = torch.stack(proto_clusters[k])
        proto_clusters[k] = torch.mean(protos, dim = 0).detach()

    return proto_clusters


for epoch_index in range(args.epochs):
    # print(f'\n | Global Training Round : {epoch_index + 1} |\n')
    acc_all = []
    local_protos = {}
    for idx in range(len(dict_users)):
        model[idx].train()
        optimizer = torch.optim.SGD(model[idx].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
        ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size = 8, shuffle = True)
        criterionCE = nn.CrossEntropyLoss()
        for iter in range(1):
            for batch_idx, (images, label_g) in enumerate(ldr_train):
                images, labels = images.to(args.device), label_g.to(args.device)
                model[idx].zero_grad()
                log_probs = model[idx](images)
                rep = model[idx].features(images)
                loss = criterionCE(log_probs, labels)
                
                loss_mse = nn.MSELoss()
                if len(global_protos) == 0:
                    loss2 = 0 * loss
                else:
                    proto_new = copy.deepcopy(rep.detach())
                    for i, yy in enumerate(label_g):
                        y_c = yy.item()
                        if type(global_protos[y_c]) != type([]):
                            proto_new[i, :] = global_protos[y_c].data
                    loss += loss_mse(proto_new, rep) * 0.1                        # 1 for mnist/fashionmnist    &&     0.1 for cifar/cinic
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        # collect protos
        model[idx].eval()
        protos = defaultdict(list)
        with torch.no_grad():
            for i, (x, y) in enumerate(ldr_train):
                if type(x) == type([]):
                    x[0] = x[0].to(args.device)
                else:
                    x = x.to(args.device)
                y = y.to(args.device)
                rep = model[idx].features(x)

                for i, yy in enumerate(y):
                    y_c = yy.item()
                    protos[y_c].append(rep[i, :].detach().data)
        local_protos[idx] = agg_func(protos)
        
    uploaded_protos = []
    uploaded_protos_per_client = []
    for client_id in local_protos:
        protos = local_protos[client_id]
        for k in protos.keys():
            uploaded_protos.append((protos[k], k))
        uploaded_protos_per_client.append(protos)
    
    # calculate class-wise minimum distance
    gap = torch.ones(args.num_classes, device = args.device) * 1e9
    avg_protos = proto_cluster(uploaded_protos_per_client)
    for k1 in avg_protos.keys():
        for k2 in avg_protos.keys():
            if k1 > k2:
                dis = torch.norm(avg_protos[k1] - avg_protos[k2], p  = 2)
                gap[k1] = torch.min(gap[k1], dis)
                gap[k2] = torch.min(gap[k2], dis)
    min_gap = torch.min(gap)
    for i in range(len(gap)):
        if gap[i] > torch.tensor(1e8, device = args.device):
            gap[i] = min_gap
    max_gap = torch.max(gap)
    # print('class-wise minimum distance', gap)
    # print('min_gap', min_gap)
    # print('max_gap', max_gap)
    TGP_opt = torch.optim.SGD(TGP.parameters(), lr = 0.005)
    TGP.train()
    server_epochs = 100
    CEloss = nn.CrossEntropyLoss()
    for e in range(server_epochs):
        proto_loader = DataLoader(uploaded_protos, 8, shuffle = True)
        for proto, y in proto_loader:
            y = torch.Tensor(y).type(torch.int64).to(args.device)

            proto_gen = TGP(list(range(args.num_classes)))

            features_square = torch.sum(torch.pow(proto, 2), 1, keepdim = True)
            centers_square = torch.sum(torch.pow(proto_gen, 2), 1, keepdim = True)
            features_into_centers = torch.matmul(proto, proto_gen.T)
            dist = features_square - 2 * features_into_centers + centers_square.T
            dist = torch.sqrt(dist)
                
            one_hot = F.one_hot(y, args.num_classes).to(args.device)
            margin = min(max_gap.item(), 100)           # margin_threthold = 100
            dist = dist + one_hot * margin
            loss = CEloss(-dist, y)

            TGP_opt.zero_grad()
            loss.backward()
            TGP_opt.step()

    uploaded_protos = []

    TGP.eval()
    global_protos = defaultdict(list)
    for class_id in range(args.num_classes):
        global_protos[class_id] = TGP(torch.tensor(class_id, device=args.device)).detach()
    
    for idx in range(len(dict_users)):
        model[idx].eval()
        acc_fine_test = test_img(model[idx], dataset_test, args)
        # print("round:", epoch_index,"idx:", idx, "Train Testing accuracy: {:.2f}".format(acc_fine_test))
        acc_all.append(acc_fine_test.item())
        
    print(epoch_index + 1, ":", "mean Fine_Test/AccTop1 on all clients:", float(np.mean(np.array(acc_all))))
    args.lr = args.lr * (1 - epoch_index / args.epochs * 0.9)