In [1]:
from collections import OrderedDict
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision import datasets
from torchvision import models
import pandas as pd
import json
import random
from matplotlib import pyplot as plt
import copy
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os
import time
import warnings
import math

from clientBase import clientBase
from serverBase import serverBase
from models     import CNN, BaseHeadSplit

In [3]:
class FedAvg(serverBase):
    def __init__(self, args, times):
        super().__init__(args, times)
        self.set_clients(args, clientAvg)
        self.Budget = []

    def train(self):
        for i in range(self.global_rounds+1):
            s_t = time.time()
            self.selected_clients = self.select_clients()
            self.send_models()
            
            if i%self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate global model")
                self.evaluate()
            
            for client in self.selected_clients:
                client.train()
            
            self.receive_models()
            self.aggregate_parameters()
            
            self.Budget.append(time.time() - s_t)
            print('-'*50, self.Budget[-1])

        print("\nBest global accuracy.")
        print(max(self.rs_test_acc))
        print(sum(self.Budget[1:])/len(self.Budget[1:]))
        print(f"acc:  {self.rs_test_acc}")
        print(f"loss: {self.rs_train_loss}")

In [4]:
class clientAvg(clientBase):
    def __init__(self, args, id, train_samples, test_samples):
        super().__init__(args, id, train_samples, test_samples)
    
    def train(self):
        print(f"[Client: {self.id:3d}] train.")
        trainloader = self.load_train_data()
        self.model.train()
        for step in range(self.local_steps):
            for i, (x, y) in enumerate(trainloader):
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)

                output = self.model(x)
                loss = self.loss(output, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

In [5]:
vocab_size = 87915
max_len=200
emb_dim=64

def run(args):
    time_list = []
    
    for i in range(args.prev, args.times):
        torch.cuda.empty_cache()
        print(f"\n============= Running time: {i}th =============")
        print("Creating server and clients ...")

        if args.dataset[:8] == "cifar100":
            args.num_classes = 100
        elif args.dataset[:9] == "pathmnist":
            args.num_classes = 9
        elif args.dataset[:11] == "organamnist":
            args.num_classes = 11
        elif args.dataset[:6] == "agnews":
            args.num_classes = 4
            vocab_size = 87915
        elif args.dataset[:9] == "sogounews":
            args.num_classes = 5
            vocab_size = 145835
        else:
            args.num_classes = 10
        
        # Generate args.model
        model_str = args.model
        if model_str == "cnn":
            if args.dataset == "organamnist" or args.dataset == "fmnist" or args.dataset == "fmnist_pat":
                args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
                
            elif args.dataset == "pathmnist" or args.dataset == "pathmnist_pat" or args.dataset == "pathmnist_N20_alpha03_64":
                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1024).to(args.device)

            elif args.dataset == "cifar10_alpha03" or args.dataset == "cifar10_alpha08" or args.dataset == "cifar10" or args.dataset == "cifar10_pat" or args.dataset == "cifar10_N20_alpha03_64":
                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
                
            elif args.dataset == "cifar100" or args.dataset == "cifar100_pat":
                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600, dim1=1024).to(args.device)
        
        elif model_str == "Transformer":
            args.model = TransformerModel(ntoken=vocab_size, d_model=64, nhead=2, d_hid=64, nlayers=2, fc_hid=64,
                            num_classes=args.num_classes).to(args.device)
               
        elif model_str == "fastText":
            args.model = fastText(hidden_dim=256, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device)
        
        elif model_str == "TextCNN":
            args.model = TextCNN(hidden_dim=512, max_len=200, vocab_size=vocab_size,
                                 num_classes=args.num_classes).to(args.device)
        elif model_str == "resnet":
            args.model = ResNet10(num_classes=args.num_classes).to(args.device)
        
        elif model_str == "alexnet":
            args.model = AlexNet(num_classes=args.num_classes).to(args.device)

        else:
            raise NotImplementedError
        
        head = copy.deepcopy(args.model.fc)
        args.model.fc = nn.Identity()
        args.model = BaseHeadSplit(args.model, head)

        for key, value in vars(args).items():
            print(f"{key}: {value}")

        if args.algorithm == "FedMR":
            server = FedMR(args, i)
        elif args.algorithm == "FedAvg":
            server = FedAvg(args, i)
        elif args.algorithm == "SCAFFOLD":
            server = SCAFFOLD(args, i)
        elif args.algorithm == "FedNTD":
            server = FedNTD(args, i)
        else:
            raise NotImplementedError
        
        start = time.time()
        server.train()
        time_list.append(time.time()-start)
        
    print(f"\nAverage time cost: {round(np.average(time_list), 2)}s.")
    print("All done!")
    

In [6]:
class config():
    def __init__(self):
        self.algorithm = "FedAvg"
        self.task = "-"
        self.model = "cnn"
        self.dataset = "cifar10"
        self.batch_size = 32
        self.local_learning_rate = 0.1
        self.global_rounds = 150
        self.local_steps = 5
        self.join_ratio = 0.1
        self.num_clients = 50
        self.num_classes = 10

        self.head = None
        self.device = "cuda"
        self.device_id = "0"
        
        self.random_join_ratio = False
        self.prev = 0
        self.times = 8
        self.eval_gap = 1
        

In [None]:
if __name__ == "__main__":
    total_start = time.time()

    args = config()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
    # torch.cuda.set_device(int(args.device_id))
    
    if args.device == "cuda" and not torch.cuda.is_available():
        print("\ncuda is not avaiable.\n")
        args.device = "cpu"
    torch.cuda.empty_cache()
    run(args)
