# Federated Learning

This demo shows two clients training a small model on local data. After each round the weights are averaged to update a global model.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

def generate_data(n=100):
    x = torch.randn(n, 2)
    y = (x.sum(dim=1) > 0).float().unsqueeze(1)
    return x, y

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

def train(model, data):
    x, y = data
    criterion = nn.BCELoss()
    opt = optim.SGD(model.parameters(), lr=0.1)
    for _ in range(5):
        opt.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        opt.step()

def get_weights(model):
    return [p.data.clone() for p in model.parameters()]

def set_weights(model, weights):
    for p, w in zip(model.parameters(), weights):
        p.data = w.clone()

def average(w1, w2):
    return [(a + b) / 2 for a, b in zip(w1, w2)]

client1_data = generate_data()
client2_data = generate_data()

global_model = Model()
for rnd in range(3):
    c1, c2 = Model(), Model()
    set_weights(c1, get_weights(global_model))
    set_weights(c2, get_weights(global_model))
    train(c1, client1_data)
    train(c2, client2_data)
    new_weights = average(get_weights(c1), get_weights(c2))
    set_weights(global_model, new_weights)
    print(f'Round {rnd}: {global_model.fc.weight.data.tolist()}')