In [None]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.sampling import partition_public, partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit
from models.test import test_img
from models.resnet_client import resnet20, resnet16, resnet8

In [None]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--dataset','fashionmnist', '--momentum','0.9', '--alpha','1',
                             '--epochs','50', '--gpu','0', '--public_data_ratio','0.1', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

In [None]:
# load dataset and split users
if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform = trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform = trans_mnist)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

In [None]:
# Initialize model
model_init = {}
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))

In [None]:
# copy init_model_parameters to model
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
intra_accs_dict = {}
inter_accs_dict = {}
mean_intra_acc_list = []
mean_inter_acc_list = []

def _off_diagonal(x):
     n, m = x.shape
     assert n == m
     return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def _calculate_isd_sim(features):
     sim_q = torch.mm(features, features.T)
     # Create a mask matrix logits_mask to set the diagonal elements of the similarity matrix 
     # (i.e., the similarity between oneself and oneself) to 0, while keeping the remaining elements to 1
     logits_mask = torch.scatter(
          torch.ones_like(sim_q),
          1,
          torch.arange(sim_q.size(0)).view(-1, 1).to(args.device),
          0
     )
     row_size = sim_q.size(0)                                         # The number of rows in the similarity matrix sim_q
     sim_q = sim_q[logits_mask.bool()].view(row_size, -1)             # The similarity matrix after filtering out diagonal elements has a shape of [n, n-1]
     return sim_q / 0.02                                              # temp = 0.02


# copy init_model_parameters to model_inter
model_inter = {}
for uid in dict_users:
     model_inter[uid] = copy.deepcopy(model_init[uid])

for epoch_index in range(args.epochs):       # args.epochs
     acc_all = []
     
     for batch_idx, (images, _) in enumerate(DataLoader(dataset_public, batch_size = 8)):
          '''
            Aggregate the output from participants
          '''
          batch_loss_dict = {}

          linear_output_list = []                      # Store the raw output of each model
          linear_output_target_list = []               # Store a copy of each model outpu
          logitis_sim_list = []
          logits_sim_target_list = []
          images = images.to(args.device)
          
          for uid in dict_users:
               model[uid] = model[uid].to(args.device)
               model[uid].train()
               linear_output  = model[uid](images)     # Calculate linear_output
               linear_output_target_list.append(linear_output.clone().detach())
               linear_output_list.append(linear_output)
               features = model[uid].features(images)
               features = F.normalize(features, dim = 1)
               logits_sim = _calculate_isd_sim(features)
               logits_sim_target_list.append(logits_sim.clone().detach())
               logitis_sim_list.append(logits_sim)

          '''
            Update Participants' Models via Col Loss
          '''
          for uid in dict_users:
               model[uid] = model[uid].to(args.device)
               model[uid].train()
               optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)

               '''
                 FCCM-LOSS
               '''
               linear_output_target_avg_list = []
               for k in range(len(dict_users)):
                    linear_output_target_avg_list.append(linear_output_target_list[k])

               linear_output_target_avg = torch.mean(torch.stack(linear_output_target_avg_list), 0)
               
               linear_output = linear_output_list[uid]
               # Perform normalization (subtract the mean and divide by the standard deviation) on the current model output z_1_bn and the average value of the target output z_2_bn
               z_1_bn = (linear_output-linear_output.mean(0))/linear_output.std(0)
               z_2_bn = (linear_output_target_avg-linear_output_target_avg.mean(0))/linear_output_target_avg.std(0)
               c = z_1_bn.T @ z_2_bn
               c.div_(len(images))

               on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()               # Diagonal element close to 1
               off_diag = _off_diagonal(c).add_(1).pow_(2).sum()                # Non diagonal elements close to 0
               fccl_loss = on_diag + 0.0051 * off_diag

               '''
                 FISL-LOSS
               '''
               # Obtain the feature similarity distribution of the current network model
               logits_sim = logitis_sim_list[uid]
               logits_sim_target_avg_list = []
               for k in range(len(dict_users)):
                    logits_sim_target_avg_list.append(logits_sim_target_list[k])
               # Calculate the average value of the characteristic similarity distribution target of all participants
               logits_sim_target_avg = torch.mean(torch.stack(logits_sim_target_avg_list), 0)

               inputs = F.log_softmax(logits_sim, dim=1)
               targets = F.softmax(logits_sim_target_avg, dim=1)
               loss_distill = F.kl_div(inputs, targets, reduction='batchmean')
               loss_distill = 3 * loss_distill           # Weight of distillation loss: dis_power = 3

               optimizer.zero_grad()
               col_loss = fccl_loss + loss_distill       # Total LOSS
               batch_loss_dict[uid] = {'fccl': round(fccl_loss.item(), 3), 'distill': round(loss_distill.item(), 3)}
               col_loss.backward()
               optimizer.step()

     T = 3
     for uid in dict_users:
          model[uid] = model[uid].to(args.device)
          model_inter[uid] = model_inter[uid].to(args.device)
          model[uid].train()
          model_inter[uid].train()
          optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
          criterionCE = nn.CrossEntropyLoss()
          criterionCE.to(args.device)
          criterionKL = nn.KLDivLoss(reduction='batchmean')
          criterionKL.to(args.device)
          ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[uid]), batch_size = 8)
          for _ in range(5):
               for batch_idx, (images, labels) in enumerate(ldr_train):
                    images = images.to(args.device)
                    labels = labels.to(args.device)
                    outputs = model[uid](images)
                    bs, class_num = outputs.shape
                    soft_outputs = F.softmax(outputs / T, dim=1)
                    non_targets_mask = torch.ones(bs, class_num).to(args.device).scatter_(1, labels.view(-1, 1), 0)
                    non_target_soft_outputs = soft_outputs[non_targets_mask.bool()].view(bs, class_num - 1)
                    non_target_logsoft_outputs = torch.log(non_target_soft_outputs)
                    with torch.no_grad():
                         inter_outputs = model_inter[uid](images)
                         soft_inter_outpus = F.softmax(inter_outputs / T, dim=1)
                         # Federal non-target distillation
                         non_target_soft_inter_outputs = soft_inter_outpus[non_targets_mask.bool()].view(bs, class_num - 1)

                    inter_loss = criterionKL(non_target_logsoft_outputs, non_target_soft_inter_outputs)
                    loss_hard = criterionCE(outputs, labels)
                    inter_loss = inter_loss * (T ** 2)        # T = 3
                    loss = loss_hard + inter_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
          model_inter[uid] = copy.deepcopy(model[uid])
          model[uid].eval()
          acc_fine_test = test_img(model[uid], dataset_test, args)
          acc_all.append(acc_fine_test.item())
     print(epoch_index + 1, ":", "mean Fine_Test/AccTop1 on all clients:", float(np.mean(np.array(acc_all))))
     args.lr = args.lr * (1 - epoch_index / args.epochs * 0.9)