In [1]:
from collections import OrderedDict
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
import pandas as pd
import json
import random
from matplotlib import pyplot as plt
import copy
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import os
import time
import warnings
import math

from clientBase import clientBase
from serverBase import serverBase
from models     import CNN,  BaseHeadSplit
from readData   import read_client_data_text, read_client_data

In [2]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

In [3]:
class FedMR(serverBase):
    def __init__(self, args, times):
        super().__init__(args, times)
        self.aggregate_weights = []
        self.set_clients(args, clientMR)
        
        self.Budget = []
        
        self.accs = [[] for _ in range(self.num_clients)]
        self.repaired_head = None
        
    def train(self):
        for i in range(self.global_rounds+1):
            s_t = time.time()
            self.selected_clients = self.select_clients()
            self.send_models()
            
            # Ensure that the client executes only after receiving the latest global model.
            if i > 10 :
                self.aggregate_revised_head(i)
            
            if i % self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate global model")
                self.evaluate()
            
            for client in self.selected_clients: 
                client.train(i)
            
            self.receive_models()
            self.aggregate_parameters()
            
            self.Budget.append(time.time() - s_t)
            print('-'*50, self.Budget[-1])
            
        print("aggregated_var", self.aggregated_var)
        print("balanced_var", self.balanced_var)

        print("\nBest global accuracy.")
        print(max(self.rs_test_acc))
        print(sum(self.Budget[1:])/len(self.Budget[1:]))
        
        print(f"acc:  {self.rs_test_acc}")
        print(f"loss: {self.rs_train_loss}")

    def set_clients(self, args, clientObj):
        for i in range(self.num_clients):
            if args.task == "NLP":
                train_data = read_client_data_text(self.dataset, i, is_train=True)
                test_data  = read_client_data_text(self.dataset, i, is_train=False)
            else:
                train_data = read_client_data(self.dataset, i, is_train=True)
                test_data  = read_client_data(self.dataset, i, is_train=False)

            client = clientObj(args, 
                            id=i, 
                            train_samples=len(train_data), 
                            test_samples=len(test_data))
            self.clients.append(client)
            self.aggregate_weights.append(((client.client_classes * 1.0) / client.num_classes) * client.train_samples)

    
    # Collect and aggregate the repaired heads.
    def aggregate_revised_head(self, r):
        uploaded_revised_heads = []
        # Select the clients that need to be uploaded.
        selected_ids = [self.clients.index(client)  for client in self.selected_clients]
        selected_weights = [self.aggregate_weights[i] for i in selected_ids]
        selected_weights = torch.tensor(selected_weights)
        selected_weights, top_k_indices = torch.topk(selected_weights, k=4)
        selected_clients = [self.selected_clients[i] for i in top_k_indices]
        # Collect
        for client in selected_clients:
            revised_haeds = client.get_local_repaired_head(flag=True)
            uploaded_revised_heads.append(revised_haeds)
        
        selected_weights = selected_weights / torch.sum(selected_weights)
        
        # Aggregate
        self.repaired_head = copy.deepcopy(uploaded_revised_heads[0])
        for param in self.repaired_head.parameters():
            param.data.zero_()
        for local_head, w in zip(uploaded_revised_heads, selected_weights):
            for global_param, local_param in zip(self.repaired_head.parameters(), local_head.parameters()):
                global_param.data += local_param.data.clone() * w
        
        #  Send to all clients, but only use their test sets to evaluate performance.
        #  The unselected clients will not be optimized in this round, nor will their models be collected.
        for client in self.clients:
            client.set_local_repaired_head(self.repaired_head)        
    

In [4]:
def dataset_repeat(data, labels, batch_size=10):
    # Resample to expand to at least one batch size.
    dims = data.shape
    repeats = (batch_size + dims[0] - 1) // dims[0]
    if repeats == 1:
        return data, labels
    else:
        if len(dims) == 4:
            expanded_data = data.repeat(repeats, 1, 1, 1)
        elif len(dims) == 3:
            expanded_data = data.repeat(repeats, 1, 1)
        elif len(dims) == 2:
            expanded_data = data.repeat(repeats, 1)
        expanded_labels = labels.repeat(repeats)
        
        return expanded_data, expanded_labels
    

## Client

In [5]:
class clientMR(clientBase):
    def __init__(self, args, id, train_samples, test_samples):
        super().__init__(args, id, train_samples, test_samples)
        
        # Settings related to the FedMR algorithm.
        trainloader = self.load_train_data(batch_size=1)
        for x, y in trainloader:
            if type(x) == type([]):
                x[0] = x[0].to(self.device)
            else:
                x = x.to(self.device)
            y = y.to(self.device)
            with torch.no_grad():
                rep = self.model.base(x).detach()
            break
        self.feature_shape = torch.zeros_like(rep.squeeze())
        
        self.sample_per_class = torch.zeros(self.num_classes).to(self.device)
        for x, y in trainloader:
            for yy in y:
                self.sample_per_class[yy.item()] += 1
        self.client_classes = torch.count_nonzero(self.sample_per_class).item()
        
        self.ft_learning_rate = args.ft_learning_rate
        self.repaired_head = None
        self.revise_steps = args.revise_steps
        self.revise_weights = None
        self.mu = args.mu
        self.lamda = 1.0
        self.prototypes = None

    def train(self, r):
        print(f"[Client: {self.id:3d}] train.")
        trainloader = self.load_train_data()
        self.model.train()

        for step in range(self.local_steps):
            for i, (x, y) in enumerate(trainloader):
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                
                output = self.model(x)
                loss = self.loss(output, y)

                if self.repaired_head != None:
                    proximal_term = 0.
                    for local_param, standart_param in zip(self.model.head.parameters(), self.repaired_head.parameters()):
                        proximal_term += torch.sum(torch.square(local_param - standart_param))
                    loss += self.lamda * (self.mu / 2) * proximal_term
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        
        self.lamda = (math.cos(r * math.pi / self.global_rounds) + 1) / 2

    # Set the repaired global head to the local model.
    def set_local_repaired_head(self, repaired_head): 
        if self.repaired_head == None:
            self.repaired_head = copy.deepcopy(self.model.head)
        
        self.model.head.load_state_dict(repaired_head.state_dict())
        self.repaired_head.load_state_dict(repaired_head.state_dict())

    def get_local_repaired_head(self, flag=True):
        # Create the prototype training set.
        self.prototypes = self.get_local_prototypes(self.model)
        
        prototype_inputs = []
        prototype_labels = []
        for i, prototype in enumerate(self.prototypes):
            if self.sample_per_class[i] != 0 :
                prototype_inputs.append(prototype.tolist())
                prototype_labels.append(i)
        
        prototype_inputs = torch.tensor(prototype_inputs).to(self.device)
        prototype_labels = torch.tensor(prototype_labels).to(self.device)

        prototype_inputs, prototype_labels = dataset_repeat(prototype_inputs, prototype_labels, batch_size=32)
        
        prototype_dataset = CustomDataset(prototype_inputs, prototype_labels)
        prototype_dataloader = DataLoader(prototype_dataset, batch_size=32, drop_last=False, shuffle=True)
        
        # This step is performed after the global model is sent to the local model, so repaired_head_temp is initialized as the global head.
        repaired_head_temp = copy.deepcopy(self.model.head)
        optimizer = torch.optim.SGD(repaired_head_temp.parameters(), lr=self.ft_learning_rate)
        
        # Fine-tune the global head.
        repaired_head_temp.train()
        for step in range(self.revise_steps):
            for i, (x, y) in enumerate(prototype_dataloader):
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
            
                optimizer.zero_grad()
                output = repaired_head_temp(x)
                loss   = self.loss(output, y)
                loss.backward()
                optimizer.step()

        return repaired_head_temp
    
    # Extract prototypes.
    def get_local_prototypes(self, model):
        features = torch.tensor([torch.zeros_like(self.feature_shape).tolist() for _ in range(self.num_classes)]).to(self.device)
        
        trainloader = self.load_train_data(batch_size=300)

        with torch.no_grad():
            for i , (x, y) in enumerate(trainloader):
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                
                output_base = model.base(x)
                for output_feature, yy in zip(output_base.detach(), y):
                    features[yy.item()] += output_feature
        
        
        prototypes = torch.zeros_like(features).to(self.device)
        for i, (prototype, num) in enumerate(zip(features, self.sample_per_class)):
            if num == 0:
                prototypes[i] = torch.zeros_like(prototype) 
            else:
                prototypes[i] = prototype / num

        return prototypes

In [6]:
vocab_size = 87915
max_len=200
emb_dim=64

def run(args):
    time_list = []
    
    for i in range(args.prev, args.times):
        torch.cuda.empty_cache()
        print(f"\n============= Running time: {i}th =============")
        print("Creating server and clients ...")
        
        if args.dataset[:8] == "cifar100":
            args.num_classes = 100
        elif args.dataset[:9] == "pathmnist":
            args.num_classes = 9
        elif args.dataset[:11] == "organamnist":
            args.num_classes = 11
        elif args.dataset[:6] == "agnews":
            args.num_classes = 4
            vocab_size = 87915
        elif args.dataset[:9] == "sogounews":
            args.num_classes = 5
            vocab_size = 145835
        else:
            args.num_classes = 10
        
        # Generate args.model
        model_str = args.model
        if model_str == "cnn":
            if args.dataset == "organamnist" or args.dataset == "fmnist":
                args.model = CNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
                
            elif args.dataset == "pathmnist":
                args.model = CNN(in_features=3, num_classes=args.num_classes, dim=1024).to(args.device)

            elif args.dataset == "cifar10":
                args.model = CNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
                
            elif args.dataset == "cifar100":
                args.model = CNN(in_features=3, num_classes=args.num_classes, dim=1600, dim1=1024).to(args.device)
        else:
            raise NotImplementedError
        
        head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, head)

        for key, value in vars(args).items():
            print(f"{key}: {value}")

        if args.algorithm == "FedMR":
            server = FedMR(args, i)
        elif args.algorithm == "FedAvg":
            server = FedAvg(args, i)
        elif args.algorithm == "SCAFFOLD":
            server = SCAFFOLD(args, i)
        elif args.algorithm == "FedNTD":
            server = FedNTD(args, i)
        elif args.algorithm == "FedGEN":
            server = FedGEN(args, i)
        elif args.algorithm == "FedProto":
            server = FedProto(args, i)
        elif args.algorithm == "FedDyn":
            server = FedDyn(args, i)
        elif args.algorithm == "MOON":
            server = MOON(args, i)
        else:
            raise NotImplementedError
        
        start = time.time()
        server.train()
        time_list.append(time.time()-start)
        
    print(f"\nAverage time cost: {round(np.average(time_list), 2)}s.")
    print("All done!")
    

In [7]:
class config():
    def __init__(self):
        self.algorithm = "FedMR"
        self.task = "-"
        self.model = "cnn"
        self.dataset = "cifar10"
        self.batch_size = 32
        self.local_learning_rate = 0.01
        self.global_rounds = 150
        self.local_steps = 5
        self.join_ratio = 0.1
        self.num_clients = 50                                                                        
        self.num_classes = 10

        self.revise_steps = 100
        self.mu = 0.01
        self.ft_learning_rate = 0.01
        
        self.head = None
        self.device = "cuda"
        self.device_id = "0"
        
        self.random_join_ratio = False
        self.prev = 0
        self.times = 8
        self.eval_gap = 1


In [None]:
if __name__ == "__main__":
    total_start = time.time()
    
    args = config()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
    # torch.cuda.set_device(int(args.device_id))
    
    if args.device == "cuda" and not torch.cuda.is_available():
        print("\ncuda is not avaiable.\n")
        args.device = "cpu" 
    torch.cuda.empty_cache()
    run(args)