In [1]:
import os
import datetime 
import time
import copy
from copy import deepcopy
import csv
import random
import math
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.autograd import Variable
from torch.distributions import Normal
import torch.multiprocessing as mp

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

mp.set_sharing_strategy('file_system')
# ## mp.set_start_method('spawn')        
# ## For Colab : https://colab.research.google.com/github/pnavaro/python-notebooks/blob/master/notebooks/10-Multiprocessing.ipynb#scrollTo=a8yyJ5xMi6lU
# ##             https://stackoverflow.com/questions/61939952/mp-set-start-methodspawn-triggered-an-error-saying-the-context-is-already-be
# try:
#    mp.set_start_method('spawn', force=True)
#    print("spawned")
# except RuntimeError:
#    pass
# torch.set_num_threads(1)

In [4]:
#@title FedLearners


def train_handler(args, net, net_id, dataidxs, reduction = "mean"):
    train_dataloader, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=args.rho, weight_decay=args.reg)    
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, weight_decay=args.reg)

    criterion = torch.nn.CrossEntropyLoss(reduction=reduction).to(args.device)
    return train_dataloader, optimizer, criterion 


def update_global(args, networks, selected, freqs):
    if args.arch == "cnn":
        global_para = networks["global_model"].state_dict()
        for idx, net_id  in enumerate(selected):
            net_para = networks["nets"][net_id].cpu().state_dict()
            if idx == 0:
                for key in net_para:
                    global_para[key] = net_para[key] * freqs[idx]
            else:
                for key in net_para:
                    global_para[key] += net_para[key] * freqs[idx]
        networks["global_model"].load_state_dict(global_para)

    else:
        raise ValueError("Wrong arch!")


class BaseAlgorithm():
    def local_update(self, args, net, global_net, net_id, dataidxs):
        raise NotImplementedError() 
    
    def global_update(self, args, networks, selected, net_dataidx_map):
        raise NotImplementedError()  


class FED(BaseAlgorithm):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
        net.to("cpu")
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")
        
    def global_update(self, args, networks, selected, net_dataidx_map):
        fed_freqs = [1 / len(selected) for r in selected]
        update_global(args, networks, selected, fed_freqs)

    def __str__(self):
        return "Bayesian Federated Learning algorithm"


class FEDAvg(FED):     
    def global_update(self, args, networks, selected, net_dataidx_map):
        total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
        fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
        update_global(args, networks, selected, fed_avg_freqs)

    def __str__(self):
        return "Federated Learning algorithm with FedAVG global update rule"


class FEDProx(FEDAvg):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        global_net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        global_weight_collector = list(global_net.to(args.device).parameters())
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                fed_prox_reg = 0.0
                for param_index, param in enumerate(net.parameters()):
                    fed_prox_reg += ((args.mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
                loss += fed_prox_reg
                loss.backward()
                optimizer.step()
        net.to("cpu")
        global_net.to("cpu")
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")

    def __str__(self):
        return "Federated Learning algorithm with FedProx global update rule"


class FEDNova(BaseAlgorithm):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
        net.to("cpu")
        tau = len(train_dataloader) * args.epochs
        a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
        global_net_para = global_net.state_dict()
        net_para = net.state_dict()
        norm_grad = deepcopy(global_net.state_dict())
        for key in norm_grad:
            norm_grad[key] = torch.true_divide(global_net_para[key]-net_para[key], a_i)
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")
        torch.save(norm_grad, f"{args.logdir}/clients/norm_grad_{net_id}.pt")

    def global_update(self, args, networks, selected, net_dataidx_map):
        total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
        freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
        norm_grad_total = deepcopy(networks["global_model"].state_dict())
        for key in norm_grad_total:
            norm_grad_total[key] = 0.0
        for i in enumerate(selected):
            norm_grad = torch.load(f"{args.logdir}/clients/norm_grad_{r}.pt")
            for key in norm_grad_total:
                norm_grad_total[key] += norm_grad[key] * freqs[i]
        coeff = 0.0
        for i, r in enumerate(selected):
            tau = math.ceil(len(net_dataidx_map[r])/args.batch_size) * args.epochs
            a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
            coeff = coeff + a_i * freqs[i]
        global_para = networks["global_model"].state_dict()
        for key in global_para:
            if global_para[key].type() == 'torch.LongTensor':
                global_para[key] -= (coeff * norm_grad_total[key]).type(torch.LongTensor)
            elif global_para[key].type() == 'torch.cuda.LongTensor':
                global_para[key] -= (coeff * norm_grad_total[key]).type(torch.cuda.LongTensor)
            else:
                global_para[key] -= coeff * norm_grad_total[key]
        networks["global_model"].load_state_dict(global_para)

    def __str__(self):
        return "Federated Learning algorithm with FedNOVA global update rule"


