In [None]:
import copy
import numpy as np   
import torch
import torch.nn.functional as F
import os

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from timm.scheduler.cosine_lr import CosineLRScheduler
from utils.sampling import partition_public, partition_data_dataset
from utils.options import args_parser
from models.Update import DatasetSplit
from models.test import test_img
from models.resnet_client import resnet20, resnet16, resnet8

In [None]:
if __name__ == '__main__':
    # parse args
    args = args_parser(args=['--dataset','mnist', '--momentum','0.9', '--alpha','1',
                             '--epochs','50', '--gpu','0', '--public_data_ratio','0.1', '--lr','0.01'])

    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print('torch.cuda:',torch.cuda.is_available())
    print(args)

In [None]:
# load dataset and split users
if __name__ == '__main__':
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.MNIST('data/mnist/', train = True, download = False, transform = trans_mnist)
        dataset_test = datasets.MNIST('data/mnist/', train = False, download = False, transform = trans_mnist)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'fashionmnist':
        trans_fashionmnist = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda img: img.expand(3, -1, -1)), 
                                                    transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081))])
        dataset_train = datasets.FashionMNIST('data/fashionmnist/', train = True, download = False, transform = trans_fashionmnist)
        dataset_test = datasets.FashionMNIST('data/fashionmnist/', train = False, download = False, transform = trans_fashionmnist)
    
        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('data/cifar', train = True, download = False, transform = trans_cifar)
        dataset_test = datasets.CIFAR10('data/cifar', train = False, download = False, transform = trans_cifar)

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    elif args.dataset == 'cinic':
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        transform_cinic = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cinic_mean, std=cinic_std)
        ])
        cinic_directory = 'data/cinic'
        dataset_train = datasets.ImageFolder(
            os.path.join(cinic_directory, 'train'),
            transform=transform_cinic
        )
        dataset_valid = datasets.ImageFolder(
            os.path.join(cinic_directory, 'valid'),
            transform=transform_cinic
        )
        dataset_test = datasets.ImageFolder(
            os.path.join(cinic_directory, 'test'),
            transform=transform_cinic
        )
        dataset_train = torch.utils.data.ConcatDataset([dataset_train, dataset_valid])

        dataset_train, dataset_public = partition_public(dataset_train, args.public_data_ratio)
        print('len(dataset_public): ', len(dataset_public))
        print('len(dataset_train): ', len(dataset_train))
        print('len(dataset_test): ', len(dataset_test))
        
        dataset_train_labels = np.array([])
        for i,(x, y) in enumerate(dataset_train):
            dataset_train_labels = np.append(dataset_train_labels, y)
        dataset_train_labels = dataset_train_labels.astype(int)

        dict_users = partition_data_dataset(dataset_train_labels, 10, alpha = args.alpha)


    print("num_users:", len(dict_users))
    img_size = dataset_train[0][0].shape
    print(img_size)

In [None]:
# Initialize model
model_init = {}
acc_init_test = []
for x in range(10):
    if x % 3 == 0:
        model_init[x] = resnet8(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    elif x % 3 == 1:
        model_init[x] = resnet16(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    else:
        model_init[x] = resnet20(10).to(args.device)
        model_init[x].eval()
        acc_test = test_img(model_init[x], dataset_test, args)
        print("user-uid:", x, "init_Local_Training_accuracy: {:.2f}".format(acc_test))
    acc_init_test.append(acc_test.item())
print("mean AccTop1 on all clients:",float(np.mean(np.array(acc_init_test))))

In [None]:
# copy init_model_parameters to model
model = {}
for i in range(10):
    model[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model[", i, "]---------------------------------")
    print(model[i])

In [None]:
# copy init_model_parameters to model_intra
model_intra = {}
for i in range(10):
    model_intra[i] = copy.deepcopy(model_init[i])
    print("---------------------------------model_intra[", i, "]---------------------------------")
    print(model_intra[i])

In [None]:
'''
  pretrain-intra_net
'''
for uid in dict_users:
    model_intra[uid] = model_intra[uid].to(args.device)
    model_intra[uid].train()
    optimizer = torch.optim.Adam(model_intra[uid].parameters(), lr = 0.01)
    scheduler = CosineLRScheduler(optimizer, t_initial = 50, lr_min = 1e-6)
    criterion = nn.CrossEntropyLoss()
    criterion.to(args.device)
    iterator = tqdm(range(50))
    ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[uid]), batch_size = 128)
    for epoch_index in iterator:
        for batch_idx, (images, labels) in enumerate(ldr_train):
            images = images.to(args.device)
            labels = labels.to(args.device)
            outputs = model_intra[uid](images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            iterator.desc = "Local Pariticipant %d loss = %0.3f" % (uid,loss)
            optimizer.step()
        model_intra[uid].eval()
        acc_test = test_img(model_intra[uid], dataset_test, args)
        if epoch_index == 49:
          print("round:",epoch_index,"uid:",uid,"Pretrain-intra_net Testing Accuracy: {:.2f}".format(acc_test))
        scheduler.step(epoch_index)

In [None]:
intra_accs_dict = {}
inter_accs_dict = {}
mean_intra_acc_list = []
mean_inter_acc_list = []

def _off_diagonal(x):
     n, m = x.shape
     assert n == m
     return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

# copy init_model_parameters to model_inter
model_inter = {}
for uid in dict_users:
     model_inter[uid] = copy.deepcopy(model_init[uid])

for epoch_index in range(args.epochs):       # args.epochs
     for batch_idx, (images, _) in enumerate(DataLoader(dataset_public, batch_size = 8)):
          '''
            Aggregate the output from participants
          '''
          linear_output_list = []                         # Store the raw output of each model
          linear_output_target_list = []                  # Store a copy of each model output
          images = images.to(args.device)
          
          for uid in dict_users:
               model[uid] = model[uid].to(args.device)
               model[uid].train()
               linear_output  = model[uid](images)
               linear_output_target_list.append(linear_output.clone().detach())
               linear_output_list.append(linear_output)

          '''
            Update Participants' Models via Col Loss
          '''
          for uid in dict_users:
               model[uid] = model[uid].to(args.device)
               model[uid].train()
               optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
               linear_output_target_avg_list = []         # Store output copies for all participants
               for k in range(len(dict_users)):
                    linear_output_target_avg_list.append(linear_output_target_list[k])         # Add each participant's output copy to the list

               linear_output_target_avg = torch.mean(torch.stack(linear_output_target_avg_list), 0)
               linear_output = linear_output_list[uid]
               z_1_bn = (linear_output-linear_output.mean(0)) / linear_output.std(0)
               # Standardize the output z_1-bn of the current model and the average value z_2-bn of the target output (subtract the mean and divide by the standard deviation)
               z_2_bn = (linear_output_target_avg - linear_output_target_avg.mean(0)) / linear_output_target_avg.std(0)
               # Calculate the comparison matrix c between the standardized outputs and implement it through matrix multiplication
               c = z_1_bn.T @ z_2_bn
               # normalization
               c.div_(len(images))

               on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()        # Loss of diagonal elements
               off_diag = _off_diagonal(c).add_(1).pow_(2).sum()         # Loss of non-diagonal elements
               optimizer.zero_grad()
               col_loss = on_diag + 0.0051 * off_diag
               col_loss.backward()
               optimizer.step()


     for uid in dict_users:
          model[uid] = model[uid].to(args.device)
          model_inter[uid] = model_inter[uid].to(args.device)
          model_intra[uid] = model_intra[uid].to(args.device)
          optimizer = torch.optim.SGD(model[uid].parameters(), lr = args.lr, momentum = args.momentum, weight_decay = 5e-4)
          criterionCE = nn.CrossEntropyLoss()
          criterionCE.to(args.device)
          criterionKL = nn.KLDivLoss(reduction='batchmean')
          criterionKL.to(args.device)
          ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[uid]), batch_size = 8)
          for _ in range(5):
               for batch_idx, (images, labels) in enumerate(ldr_train):
                    images = images.to(args.device)
                    labels = labels.to(args.device)
                    outputs = model[uid](images)
                    logsoft_outputs = F.log_softmax(outputs, dim=1)
                    with torch.no_grad():
                         intra_soft_outpus = F.softmax(model_intra[uid](images), dim=1)
                         inter_soft_outpus = F.softmax(model_inter[uid](images), dim=1)
                    intra_loss = criterionKL(logsoft_outputs, intra_soft_outpus)
                    inter_loss = criterionKL(logsoft_outputs, inter_soft_outpus)
                    loss_hard = criterionCE(outputs, labels)
                    loss = loss_hard + inter_loss + intra_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
          model_inter[uid] = copy.deepcopy(model[uid])
          model[uid].eval()
          acc_fine_test = test_img(model[uid], dataset_test, args)
          print("round:",epoch_index,"uid:",uid,"Train Testing accuracy: {:.2f}".format(acc_fine_test))
     args.lr = args.lr * (1 - epoch_index / args.epochs * 0.9)