In [1]:
import os
import datetime 
import time
import copy
from copy import deepcopy
import csv
import random
import math
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.autograd import Variable
from torch.distributions import Normal
import torch.multiprocessing as mp

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

mp.set_sharing_strategy('file_system')
# ## mp.set_start_method('spawn')        
# ## For Colab : https://colab.research.google.com/github/pnavaro/python-notebooks/blob/master/notebooks/10-Multiprocessing.ipynb#scrollTo=a8yyJ5xMi6lU
# ##             https://stackoverflow.com/questions/61939952/mp-set-start-methodspawn-triggered-an-error-saying-the-context-is-already-be
# try:
#    mp.set_start_method('spawn', force=True)
#    print("spawned")
# except RuntimeError:
#    pass
# torch.set_num_threads(1)

In [None]:
#@title Helper Snippets


###----------------------------- model snippets ------------------------------

class CNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=10, input_channel=3):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], output_dim)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, self.input_dim)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def init_nets(args):
    nets = {net_i: None for net_i in range(args.n_parties)}
    if args.dataset == 'cifar10':
        input_channel = 3
        input_dim = (16 * 5 * 5)
        hidden_dims=[120, 84]
        output_dim = 10
    ### elif ...... add here if needed

    for net_i in range(args.n_parties):        
        if args.arch.lower() == "cnn":
            net = CNN(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim, input_channel=input_channel)
        ### elif .... add here if needed
        else:
            raise ValueError("Unknown architecture: {}".format(args.arch))
        nets[net_i] = net


    if args.arch.lower() == "cnn":
        global_net = CNN(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=output_dim, input_channel=input_channel)
    ### elif .... add here if needed
    else:
        raise ValueError("Unknown architecture: {}".format(args.arch))

    if args.is_same_initial:
        global_para = global_net.state_dict() 
        for net_id, net in nets.items():
            net.load_state_dict(global_para)

    return global_net, nets


### --------------------------- Scorrer Snippets -------------------------------

## Expected Calibration Error (ECE) Naeini et al., 2015 
   ##--> approximate difference in expectation between confidence and accuracy of machine learning models
def ece(preds, target, device, minibatch=True):
    confidences, predictions = torch.max(preds, 1)
    _, target_cls = torch.max(target, 1)
    accuracies = predictions.eq(target_cls)
    n_bins = 100 
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = torch.zeros(1, device=device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin * 100

    return ece.item()

## Negative log-likelihood (NLL) Quinonero-Candela et al., 2005 --> metric for evaluating predictive uncertainty
def nll(preds, target, minibatch=True):
    logpred = torch.log(preds + 1e-8)
    if minibatch:
        return -(logpred * target).sum(1).item()
    else:
        return -(logpred * target).sum(1).mean().item()

## Accuracy 
def acc(preds, target, minibatch=True):
    preds = preds.argmax(1)
    target = target.argmax(1)
    if minibatch:
        return (((preds == target) * 1.0).sum() * 100).item()
    else:
        return (((preds == target) * 1.0).mean() * 100).item()

def compute_scores(model, dataloader, args, device="cpu", n_sample=1):
    was_training = False
    if model.training:
        model.eval()
        was_training = True

    if type(dataloader) == type([1]):
        pass
    else:
        dataloader = [dataloader]
        
    preds = []
    targets = []
    model.to(device)
    with torch.no_grad():
        for tmp in dataloader:
            for batch_idx, (x, target) in enumerate(tmp):
                x, target = x.to(device), target.to(device,dtype=torch.int64)
                
                outs = []
                for _ in range(n_sample):
                    out = model(x)
                    out = F.softmax(out, 1)
                    outs.append(out)

                preds.append(torch.stack(outs).mean(0))
                targets.append(F.one_hot(target, model.output_dim))

    targets = torch.cat(targets)
    preds = torch.cat(preds)

    _acc = acc(preds, targets, minibatch=False)
    _ece = ece(preds, targets, device, minibatch=False)
    _nll = nll(preds, targets, minibatch=False)

    if was_training:
        model.train()
    return _acc, _ece, _nll

In [4]:
#@title FedLearners


def train_handler(args, net, net_id, dataidxs, reduction = "mean"):
    train_dataloader, _ = get_dataloader(args.dataset, args.datadir, args.batch_size, 32, dataidxs)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=args.rho, weight_decay=args.reg)    
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, weight_decay=args.reg)

    criterion = torch.nn.CrossEntropyLoss(reduction=reduction).to(args.device)
    return train_dataloader, optimizer, criterion 


def update_global(args, networks, selected, freqs):
    if args.arch == "cnn":
        global_para = networks["global_model"].state_dict()
        for idx, net_id  in enumerate(selected):
            net_para = networks["nets"][net_id].cpu().state_dict()
            if idx == 0:
                for key in net_para:
                    global_para[key] = net_para[key] * freqs[idx]
            else:
                for key in net_para:
                    global_para[key] += net_para[key] * freqs[idx]
        networks["global_model"].load_state_dict(global_para)

    else:
        raise ValueError("Wrong arch!")


class BaseAlgorithm():
    def local_update(self, args, net, global_net, net_id, dataidxs):
        raise NotImplementedError() 
    
    def global_update(self, args, networks, selected, net_dataidx_map):
        raise NotImplementedError()  


class FED(BaseAlgorithm):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
        net.to("cpu")
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")
        
    def global_update(self, args, networks, selected, net_dataidx_map):
        fed_freqs = [1 / len(selected) for r in selected]
        update_global(args, networks, selected, fed_freqs)

    def __str__(self):
        return "Bayesian Federated Learning algorithm"


class FEDAvg(FED):     
    def global_update(self, args, networks, selected, net_dataidx_map):
        total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
        fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
        update_global(args, networks, selected, fed_avg_freqs)

    def __str__(self):
        return "Federated Learning algorithm with FedAVG global update rule"


class FEDProx(FEDAvg):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        global_net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        global_weight_collector = list(global_net.to(args.device).parameters())
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                fed_prox_reg = 0.0
                for param_index, param in enumerate(net.parameters()):
                    fed_prox_reg += ((args.mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
                loss += fed_prox_reg
                loss.backward()
                optimizer.step()
        net.to("cpu")
        global_net.to("cpu")
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")

    def __str__(self):
        return "Federated Learning algorithm with FedProx global update rule"


class FEDNova(BaseAlgorithm):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
        net.to("cpu")
        tau = len(train_dataloader) * args.epochs
        a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
        global_net_para = global_net.state_dict()
        net_para = net.state_dict()
        norm_grad = deepcopy(global_net.state_dict())
        for key in norm_grad:
            norm_grad[key] = torch.true_divide(global_net_para[key]-net_para[key], a_i)
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")
        torch.save(norm_grad, f"{args.logdir}/clients/norm_grad_{net_id}.pt")

    def global_update(self, args, networks, selected, net_dataidx_map):
        total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
        freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
        norm_grad_total = deepcopy(networks["global_model"].state_dict())
        for key in norm_grad_total:
            norm_grad_total[key] = 0.0
        for i in enumerate(selected):
            norm_grad = torch.load(f"{args.logdir}/clients/norm_grad_{r}.pt")
            for key in norm_grad_total:
                norm_grad_total[key] += norm_grad[key] * freqs[i]
        coeff = 0.0
        for i, r in enumerate(selected):
            tau = math.ceil(len(net_dataidx_map[r])/args.batch_size) * args.epochs
            a_i = (tau - args.rho * (1 - pow(args.rho, tau)) / (1 - args.rho)) / (1 - args.rho)
            coeff = coeff + a_i * freqs[i]
        global_para = networks["global_model"].state_dict()
        for key in global_para:
            if global_para[key].type() == 'torch.LongTensor':
                global_para[key] -= (coeff * norm_grad_total[key]).type(torch.LongTensor)
            elif global_para[key].type() == 'torch.cuda.LongTensor':
                global_para[key] -= (coeff * norm_grad_total[key]).type(torch.cuda.LongTensor)
            else:
                global_para[key] -= coeff * norm_grad_total[key]
        networks["global_model"].load_state_dict(global_para)

    def __str__(self):
        return "Federated Learning algorithm with FedNOVA global update rule"


class Scaffold(BaseAlgorithm):
    def local_update(self, args, net, global_net, net_id, dataidxs):
        c_global_para = torch.load(f"{args.logdir}/clients/c_global.pt", map_location=args.device)
        c_local_para = torch.load(f"{args.logdir}/clients/c_{net_id}.pt", map_location=args.device)
        net.to(args.device)
        global_net.to(args.device)
        train_dataloader, optimizer, criterion = train_handler(args, net, net_id, dataidxs)
        cnt = 0
        for epoch in range(args.epochs):
            for x, target in train_dataloader:
                x, target = x.to(args.device), target.to(args.device)
                optimizer.zero_grad()
                x.requires_grad = True
                target.requires_grad = False
                target = target.long()
                out = net(x)
                loss = criterion(out, target)
                loss.backward()
                optimizer.step()
                net_para = net.state_dict()
                for key in net_para:
                    net_para[key] = net_para[key] - args.lr * (c_global_para[key] - c_local_para[key])
                net.load_state_dict(net_para)
                cnt += 1
        net.to("cpu")
        c_new_para = torch.load(f"{args.logdir}/clients/c_{net_id}.pt")
        c_delta_para = torch.load(f"{args.logdir}/clients/c_{net_id}.pt")
        c_global_para = torch.load(f"{args.logdir}/clients/c_global.pt")
        c_local_para = torch.load(f"{args.logdir}/clients/c_{net_id}.pt")
        global_model_para = global_net.state_dict()
        net_para = net.state_dict()
        for key in net_para:
            c_new_para[key] = c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (cnt * args.lr)
            c_delta_para[key] = c_new_para[key] - c_local_para[key]
        torch.save(net.state_dict(), f"{args.logdir}/clients/client_{net_id}.pt")
        torch.save(c_new_para, f"{args.logdir}/clients/c_{net_id}.pt")
        torch.save(c_delta_para, f"{args.logdir}/clients/c_delta_{net_id}.pt")
        
    def global_update(self, args, networks, selected, net_dataidx_map):
        total_delta = deepcopy(networks["global_model"].state_dict())
        for key in total_delta:
            total_delta[key] = 0.0
        for r in selected:
            c_delta_para = torch.load(f"{args.logdir}/clients/c_delta_{r}.pt")
            for key in total_delta:
                total_delta[key] += c_delta_para[key] / len(selected)
        c_global_para = torch.load(f"{args.logdir}/clients/c_global.pt")
        for key in c_global_para:
            if c_global_para[key].type() == 'torch.LongTensor':
                c_global_para[key] += total_delta[key].type(torch.LongTensor)
            elif c_global_para[key].type() == 'torch.cuda.LongTensor':
                c_global_para[key] += total_delta[key].type(torch.cuda.LongTensor)
            else:
                c_global_para[key] += total_delta[key]
        torch.save(c_global_para, f"{args.logdir}/clients/c_global.pt")
        total_data_points = sum([len(net_dataidx_map[r]) for r in selected])
        fed_avg_freqs = [len(net_dataidx_map[r]) / total_data_points for r in selected]
        global_para = networks["global_model"].state_dict()
        for i, r in enumerate(selected):
            net_para = networks["nets"][r].cpu().state_dict()
            if i == 0:
                for key in net_para:
                    global_para[key] = net_para[key] * fed_avg_freqs[i]
            else:
                for key in net_para:
                    global_para[key] += net_para[key] * fed_avg_freqs[i]
        networks["global_model"].load_state_dict(global_para)
    
    def __str__(self):
        return "Federated Learning algorithm with Scaffold global update rule"


def get_algorithm(args):
    if args.alg.lower() == "fed":
        return FED
    elif args.alg.lower() == "fedavg":
        return FEDAvg
    elif args.alg.lower() == "fedprox":
        return FEDProx
    elif args.alg.lower() == "fednova":
        return FEDNova
    elif args.alg.lower() == "scaffold":
        return Scaffold
    else:
        raise NotImplementedError(f"{args.alg} is not implemented!") 