# Federated Learning

2016年，Google 提出了聯邦學習方法，聯邦學習意旨於結合各個用戶的資訊共同訓練一個模型，同時能夠保障用戶的資料隱私不會外洩

![](https://blog.ml.cmu.edu/wp-content/uploads/2019/11/Screen-Shot-2019-11-12-at-10.41.38-AM-970x377.png)

其主要流程如下:

1. 中央伺服器選定一個客戶端模型
2. 中央伺服器將模型參數以及超參數 (epochs、batch size, learning rate) 分配給客戶端
3. 客戶端根據中央伺服器分配的參數訓練自己的資料
4. 客戶端回報訓練結果到中央伺服器，中央伺服器將會進行客戶端模型統整

以上 2.~4. 步驟將會重複 $t$ round，直到模型收斂

## 分配與整合

聯邦學習的概念就像分組報告，一開始組長會分配任務到每個組員身上，在組員完成組長分配的任務之後，會進行小組討論以整合各個組員的貢獻，當然每個組員貢獻不一致，所以組長會參考任務完成度更高的組員較多一點，之後再重新分配新的任務，周而復始

在分配任務時，我們希望客戶端模型都做到一定程度時在回報給中央伺服器 (epochs)，以確保客戶的模型是有一定參考性的，至於因為每個客戶訓練資料不同，中央伺服器或可採用加權平均等方式整合各個客戶端的模型權重，以資料量較多的客戶模型站較大比重

## 問題與方向

聯邦學習希望客戶端的資料是相近的 (因為要整合) 且資料之間是 NON I.I.D.，在醫學領域中，聯邦學習是一個常用的方法，醫院之間都有病人的資料，但是醫院不會公開病人的資料，且會面臨訓練資料過少的情況，所以醫院之間希望大家可以透過聯邦學習共同提升模型表現，聯邦學習的理念是很完整的，但是也會遇到一些問題

1. 客戶端可以很容易攻擊中央伺服器 (send 一個很爛的模型給中央伺服器)
2. 為了避免客戶端離開影響伺服器，伺服器通常會採用 drop out
3. 我們可以透過客戶端回傳的梯度推出用戶資料的相關信息



### Library

In [1]:
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn
import copy
import random
import numpy as np
from collections import Counter
import tqdm

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar-10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting cifar-10/cifar-10-python.tar.gz to cifar-10
Files already downloaded and verified


### Model Architecture

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        x = x.view(x.size(0), -1) # flatten all dimensions except batch
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x

### Loss function

In [3]:
def criterion():
    return nn.CrossEntropyLoss()

### Client

In [20]:
class Clinet:

    def __init__(
        self,
        model,
        optimizer,
        criterion,
        lr_scheduler,
        local_config: dict
      ):
        self.local_config = local_config
        self.model = model
        self.optimizer = optimizer(self.model.parameters(), lr=local_config['lr'])
        self.criterion = criterion
        self.lr_scheduler = lr_scheduler

    def create_data_loader(self, data):
        assert self.local_config['batch_size'] is not None
        dl = torch.utils.data.DataLoader(data, shuffle=True, batch_size = self.local_config['batch_size'])
        self.data_loader = dl
        
    def local_step(self):

        assert self.local_config['device'] is not None

        local_acc = 0
        local_loss = 0
        num_images = 0

        self.model.train()
        for images, labels in self.data_loader:
            images = images.to(self.local_config['device'])
            labels = labels.to(self.local_config['device'])

            outputs = self.model(images)
            preds = torch.argmax(outputs, dim=1)

            local_acc += (preds == labels).sum()

            loss = self.criterion(outputs, labels)
            local_loss += loss.item() / self.data_loader.__len__()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.lr_scheduler is not None:
                try:
                  self.lr_scheduler.step()
                except:
                  self.lr_scheduler.step(loss)

            num_images += images.size(0)

        return local_loss, local_acc / num_images

    def local_update(self):

        self.local_loss = 0
        self.local_acc = 0

        for i in range(self.local_config['epochs']):
            epoch_loss, epoch_acc = self.local_step()

            self.local_loss += epoch_loss / self.local_config['epochs']
            self.local_acc += epoch_acc / self.local_config['epochs']

### Server

In [24]:
class Server:

    def __init__(self,
        model,
        optimizer,
        criterion,
        lr_scheduler,
        drop_rate,
        n_nodes,
        local_config
      ):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.lr_scheduler = lr_scheduler
        self.drop_rate = drop_rate
        self.n_nodes = n_nodes
        self.local_config = local_config

        self.create_clients()

    def create_clients(self):

        self.local_nodes = []
        self.total_data = 0

        for i in range(self.n_nodes):
            model = copy.deepcopy(self.model)
            local_node = Clinet(model=model,
                      optimizer=self.optimizer,
                      criterion=self.criterion,
                      lr_scheduler=self.lr_scheduler,
                      local_config=self.local_config)
            self.local_nodes.append(local_node)

    def communication(self):
        num_nodes = int(self.drop_rate * self.n_nodes)
        nodes_to_train = random.sample(self.local_nodes, num_nodes)
        return nodes_to_train

    def broadcast(self):
        avg_weight = copy.deepcopy(self.model.state_dict())
        for client in self.local_nodes:
            client.model.load_state_dict(avg_weight)

    def aggregate(self):
        self.avg_loss = 0
        self.avg_acc = 0
        nodes_to_train = self.communication()

        avg_weight = copy.deepcopy(self.model.state_dict())
        for key in avg_weight:
            avg_weight[key] = torch.zeros_like(avg_weight[key])

        for node in nodes_to_train:
            node.local_update()
            self.avg_loss += node.local_loss / len(nodes_to_train) 
            self.avg_acc += node.local_acc / len(nodes_to_train)
            node_weight = copy.deepcopy(node.model.state_dict())
            print(f"Client :Acc {node.local_acc}, Loss: {node.local_loss}")

            for key in avg_weight:
                avg_weight[key] += node_weight[key] / len(nodes_to_train)

        print(f"Aggregated training Acc: {self.avg_acc}, Aggregated training Loss: {self.avg_loss}")

        self.model.load_state_dict(avg_weight)
  

### Configuration

In [22]:
def create_config():
    config = {
        'local_config':  { 
            'device': 'cuda',
            'batch_size': 32,
            'epochs': 10,
            'lr': 0.001,
        },
        'global_config': {
            'drop_rate': 0.8,
            'epochs': 200,
            'n_nodes': 100,
        }
    }
    return config

### Create Client

In [23]:
def trainer(server):
    server.aggregate()
    server.broadcast()
    return server

def evaluate(model, data_loader, config):

    acc = 0
    num_images = 0

    model.eval()
    with torch.no_grad():
        for images, labels, in data_loader:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            num_images +=  images.size(0)

            acc += (preds == labels).sum()
    return acc / num_images

def main():

    mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

    train_transforms = transforms.Compose([
        # TO DO: add anthor transformation and see what happened
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_data = torchvision.datasets.CIFAR10(root='cifar-10', train=True, download=True, transform=train_transforms)
    test_data = torchvision.datasets.CIFAR10(root='cifar-10', train=False, download=True, transform=test_transforms)

    config = create_config()

    local_config = config['local_config']
    global_config = config['global_config']

    lr_scheduler = None

    model = Net().to(local_config['device'])

    test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=True)

    optimizer = torch.optim.SGD

    server = Server(
        model=model,
        optimizer=optimizer,
        criterion=criterion(),
        lr_scheduler=lr_scheduler,
        drop_rate=global_config['drop_rate'],
        n_nodes=global_config['n_nodes'],
        local_config=local_config
    )

    classes_pair = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]

    data_label = np.array(train_data.targets)

    chosen_counter = Counter()
    for client in server.local_nodes:
        # sample until we have a pair of class with insufficient client owning
        class_pair = random.choice(classes_pair)
        while chosen_counter[class_pair] == 100:
            class_pair = random.choice(classes_pair)
            
        chosen_counter[class_pair] += 1
            
        first_class, second_class = class_pair
        first_class_sample_idx = list(np.where(data_label == first_class)[0])
        second_class_sample_idx = list(np.where(data_label == second_class)[0])
        
        client_first_class_sample_idx = random.sample(first_class_sample_idx, k=500)
        client_second_class_sample_idx = random.sample(second_class_sample_idx, k=500)
        
        client_data = []
        
        for i in range(500):
            client_data.append(train_data[client_first_class_sample_idx[i]])
            client_data.append(train_data[client_second_class_sample_idx[i]])
        
        client.create_data_loader(client_data)

    for i in range(global_config['epochs']):
        server = trainer(server)
        acc = evaluate(server.model, test_loader, local_config)
        print(f"Epoch [{i}] Aggregate ACC: {acc}")


if __name__ == '__main__':
    main()

Client :Acc 0.27880001068115234, Loss: 2.2454799473285676
Client :Acc 0.4717000424861908, Loss: 2.1829131744801997
Client :Acc 0.44940003752708435, Loss: 2.1770574316382407
Client :Acc 0.225600004196167, Loss: 2.284136677533388
Client :Acc 0.2785000205039978, Loss: 2.2457356810569764
Client :Acc 0.44950002431869507, Loss: 2.1771710231900214
Client :Acc 0.44950002431869507, Loss: 2.1769544683396815
Client :Acc 0.47040003538131714, Loss: 2.18313050866127
Client :Acc 0.5000000596046448, Loss: 2.1874490469694137
Client :Acc 0.47120004892349243, Loss: 2.1832672104239466
Client :Acc 0.4702000617980957, Loss: 2.1832287460565567
Client :Acc 0.27650001645088196, Loss: 2.2456260688602927
Client :Acc 0.27630001306533813, Loss: 2.245902583748102
Client :Acc 0.4701000452041626, Loss: 2.1832996264100077
Client :Acc 0.4501000642776489, Loss: 2.1770354114472865
Client :Acc 0.4499000310897827, Loss: 2.1771534912288186
Client :Acc 0.4711000621318817, Loss: 2.1828420534729958
Client :Acc 0.50000005960464

KeyboardInterrupt: ignored