In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy

In [8]:
class SampleNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_layer = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3) # 28*28*1 -> 26*26*4
        self.pooling = nn.MaxPool2d(kernel_size=2) 
        self.activation = nn.ReLU()
        self.conv = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=3)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_features=11*11*16, out_features=512)
        self.dropout = nn.Dropout(0.2)
        self.output = nn.Linear(in_features=512, out_features=10)
    
    def forward(self, x):
        x = self.pooling(self.activation(self.input_layer(x))) # 28 * 28 * 1 -> 26 * 26 * 4 -> 13 * 13 * 4
        x = self.activation(self.conv(x)) # 13 * 13 * 4 -> 11 * 11 * 16
        return self.output(self.dropout(self.linear(self.flatten(x)))) # 11 * 11 * 16 -> 512 -> 10
        

In [9]:
def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    accuracy = 100 * correct / total
    return accuracy

def client_update(model, train_loader, epochs=1, lr=0.01):
    model = copy.deepcopy(model)
    model.train()
    
    optimizer = optim.SGD(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    
    for _ in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

    return model.state_dict()


In [10]:
def average_weights(w_list):
    avg_w = copy.deepcopy(w_list[0])
    for key in avg_w.keys():
        for i in range(1, len(w_list)):
            avg_w[key] += w_list[i][key]
        avg_w[key] = torch.div(avg_w[key], len(w_list))
    return avg_w


In [11]:
global_model = SampleNet()
num_rounds = 5
num_clients = 3

In [12]:
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

client_data_indices = np.array_split(np.arange(len(mnist_data)), 3)
client_loaders = [DataLoader(Subset(mnist_data, idx), batch_size=32, shuffle=True) for idx in client_data_indices]
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)


In [13]:
for round in range(num_rounds):
    print(f"\n--- Round {round+1} ---")
    local_weights = []

    for client_id in range(num_clients):
        client_model = copy.deepcopy(global_model)
        local_w = client_update(client_model, client_loaders[client_id], epochs=1)
        local_weights.append(local_w)

    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)

    acc = evaluate_model(global_model, test_loader)
    print(f"Global Test Accuracy: {acc:.2f}%")


--- Round 1 ---
Global Test Accuracy: 89.59%

--- Round 2 ---
Global Test Accuracy: 91.68%

--- Round 3 ---
Global Test Accuracy: 93.19%

--- Round 4 ---
Global Test Accuracy: 94.26%

--- Round 5 ---
Global Test Accuracy: 95.39%
