<a href="https://colab.research.google.com/github/TheAmirHK/Experiments/blob/main/FederatedLearning/FederatedLearning_with_FedAvg(test).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

def load_mnist_data():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    num_clients = 5
    data_size = len(train_dataset) // num_clients
    client_data_indices = [list(range(i * data_size, (i + 1) * data_size)) for i in range(num_clients)]

    client_datasets = [Subset(train_dataset, indices) for indices in client_data_indices]

    return client_datasets

client_datasets = load_mnist_data()

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
def fed_avg(client_weights):

    global_weights = client_weights[0].copy()
    for key in global_weights.keys():
        global_weights[key] = torch.stack([client_weights[i][key] for i in range(len(client_weights))], dim=0).mean(dim=0)
    return global_weights

In [11]:
def train_local_model(client_id, model, dataset, epochs=1, lr=0.1):

    model.train()
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for images, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    return model.state_dict()

In [12]:
def federated_learning(num_rounds=20, num_clients=5):
    global_model = SimpleNN()
    client_models = [SimpleNN() for _ in range(num_clients)]

    for round_num in range(num_rounds):
        print(f"Round {round_num + 1}")
        client_weights = []

        for client_id in range(num_clients):
            client_model = client_models[client_id]
            client_model.load_state_dict(global_model.state_dict())
            client_dataset = client_datasets[client_id]
            trained_weights = train_local_model(client_id, client_model, client_dataset, epochs=1)
            client_weights.append(trained_weights)

        global_weights = fed_avg(client_weights)
        global_model.load_state_dict(global_weights)

        test_accuracy = evaluate_global_model(global_model)
        print(f"Test Accuracy after Round {round_num + 1}: {test_accuracy:.2f}%")

def evaluate_global_model(model):
    model.eval()
    test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total

federated_learning()

Round 1
Test Accuracy after Round 1: 82.85%
Round 2
Test Accuracy after Round 2: 85.85%
Round 3
Test Accuracy after Round 3: 90.48%
Round 4
Test Accuracy after Round 4: 92.32%
Round 5
Test Accuracy after Round 5: 92.59%
Round 6
Test Accuracy after Round 6: 93.78%
Round 7
Test Accuracy after Round 7: 93.58%
Round 8
Test Accuracy after Round 8: 94.33%
Round 9
Test Accuracy after Round 9: 94.47%
Round 10
Test Accuracy after Round 10: 94.49%
Round 11
Test Accuracy after Round 11: 95.14%
Round 12
Test Accuracy after Round 12: 94.86%
Round 13
Test Accuracy after Round 13: 94.68%
Round 14
Test Accuracy after Round 14: 94.84%
Round 15
Test Accuracy after Round 15: 95.14%
Round 16
Test Accuracy after Round 16: 95.50%
Round 17
Test Accuracy after Round 17: 95.01%
Round 18
Test Accuracy after Round 18: 95.43%
Round 19
Test Accuracy after Round 19: 95.61%
Round 20
Test Accuracy after Round 20: 95.11%


In [16]:
# Here, I added a proximal term to the local training loss to penalize deviations from the global model weights. Well seems to be funny and accurate in long runs but won't be time-effeicient !

def train_local_model_fedprox(client_id, model, global_weights, dataset, epochs=1, lr=0.1, mu=0.01):

    model.train()
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for images, labels in dataloader:
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)

            prox_term = 0.0
            for param, global_param in zip(model.parameters(), global_weights.values()):
                prox_term += torch.norm(param - global_param) ** 2

            loss += (mu / 2) * prox_term

            loss.backward()
            optimizer.step()

    return model.state_dict()

In [17]:
def federated_learning_fedprox(num_rounds=20, num_clients=5, mu=0.01):
    global_model = SimpleNN()
    client_models = [SimpleNN() for _ in range(num_clients)]

    for round_num in range(num_rounds):
        print(f"Round {round_num + 1}")
        client_weights = []

        for client_id in range(num_clients):
            client_model = client_models[client_id]
            client_model.load_state_dict(global_model.state_dict())
            client_dataset = client_datasets[client_id]

            trained_weights = train_local_model_fedprox(
                client_id, client_model, global_model.state_dict(), client_dataset, epochs=1, mu=mu
            )
            client_weights.append(trained_weights)

        global_weights = fed_avg(client_weights)
        global_model.load_state_dict(global_weights)

        test_accuracy = evaluate_global_model(global_model)
        print(f"Test Accuracy after Round {round_num + 1}: {test_accuracy:.2f}%")


In [18]:
federated_learning_fedprox()

Round 1
Test Accuracy after Round 1: 72.90%
Round 2
Test Accuracy after Round 2: 83.00%
Round 3
Test Accuracy after Round 3: 88.88%
Round 4
Test Accuracy after Round 4: 90.66%
Round 5
Test Accuracy after Round 5: 91.87%
Round 6
Test Accuracy after Round 6: 93.39%
Round 7
Test Accuracy after Round 7: 93.89%
Round 8
Test Accuracy after Round 8: 94.05%
Round 9
Test Accuracy after Round 9: 94.71%
Round 10
Test Accuracy after Round 10: 94.93%
Round 11
Test Accuracy after Round 11: 94.88%
Round 12
Test Accuracy after Round 12: 95.07%
Round 13
Test Accuracy after Round 13: 94.63%
Round 14
Test Accuracy after Round 14: 95.41%
Round 15
Test Accuracy after Round 15: 95.04%
Round 16
Test Accuracy after Round 16: 95.53%
Round 17
Test Accuracy after Round 17: 95.70%
Round 18
Test Accuracy after Round 18: 95.67%
Round 19
Test Accuracy after Round 19: 95.46%
Round 20
Test Accuracy after Round 20: 95.81%
