In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# --------- Minimal MLP Model ----------
class ShallowMLP(nn.Module):
    def __init__(self):
        super(ShallowMLP, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

# --------- Federated Training Utilities ----------
def train(model, loader, optimizer, criterion, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# --------- Main Federated Learning Process ----------
def federated_training(optimizer_name,
                       num_clients=10,
                       rounds=10,
                       batch_size=64,
                       bandwidth_mb_s=20):

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    full_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    client_ds = random_split(full_ds, [len(full_ds) // num_clients] * num_clients)

    global_model = ShallowMLP()
    criterion = nn.CrossEntropyLoss()

    print(f"\n=== Running Federated Training with {optimizer_name.upper()} ===")
    for rnd in range(1, rounds + 1):
        print(f"\n--- Round {rnd} ---")
        local_states = []

        for i in range(num_clients):
            local_model = ShallowMLP()
            local_model.load_state_dict(global_model.state_dict())

            if optimizer_name == 'sgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.0)
            elif optimizer_name == 'mgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.99)
            elif optimizer_name == 'nag':
                optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
            elif optimizer_name == 'adam':
                optimizer = optim.Adam(local_model.parameters(), lr=0.001)
            else:
                raise ValueError(f"Unknown optimizer: {optimizer_name}")

            loader = DataLoader(client_ds[i], batch_size=batch_size, shuffle=True)
            train(local_model, loader, optimizer, criterion, epochs=1)
            local_states.append(local_model.state_dict())

        # Simulated Communication
        total_params = sum(p.numel() for p in ShallowMLP().parameters())
        total_bytes = num_clients * total_params * 4
        total_mb = total_bytes / 1e6
        sim_comm_time = total_mb / bandwidth_mb_s

        # Aggregation (FedAvg)
        start = time.perf_counter()
        new_state = {}
        for key in global_model.state_dict().keys():
            stacked = torch.stack([st[key] for st in local_states], dim=0)
            new_state[key] = stacked.mean(dim=0)
        global_model.load_state_dict(new_state)
        agg_time = time.perf_counter() - start

        # Evaluation
        test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
        test_loader = DataLoader(test_ds, batch_size=64)
        acc = evaluate(global_model, test_loader)

        print(f"Accuracy               : {acc:.2%}")
        print(f"Comm. Volume           : {total_mb:.2f} MB")
        print(f"Simulated Comm. Time   : {sim_comm_time:.2f} s")
        print(f"Aggregation Time       : {agg_time:.4f} s")

# --------- Run Experiment ----------
if __name__ == '__main__':
    for opt in ['sgd', 'mgd', 'nag', 'adam']:
        federated_training(optimizer_name=opt,
                           num_clients=10,
                           rounds=10,
                           batch_size=64,
                           bandwidth_mb_s=20)


100%|██████████| 9.91M/9.91M [00:00<00:00, 58.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.65MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 12.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.59MB/s]



=== Running Federated Training with SGD ===

--- Round 1 ---
Accuracy               : 80.33%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0021 s

--- Round 2 ---
Accuracy               : 86.08%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0012 s

--- Round 3 ---
Accuracy               : 87.93%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0014 s

--- Round 4 ---
Accuracy               : 89.04%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0010 s

--- Round 5 ---
Accuracy               : 89.63%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0011 s

--- Round 6 ---
Accuracy               : 90.02%
Comm. Volume           : 2.04 MB
Simulated Comm. Time   : 0.10 s
Aggregation Time       : 0.0010 s

--- Round 7 ---
Accuracy               : 90.41%
Comm. Volume      