In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# --------- Minimal MLP Model ----------
class MinimalMLPModel(nn.Module):
    def __init__(self):
        super(MinimalMLPModel, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

# --------- Sparsification Utility ----------
def sparsify_state_dict(state_dict, compression_ratio=0.1):
    sparse_state = {}
    for k, v in state_dict.items():
        flat = v.view(-1)
        k_val = int(len(flat) * compression_ratio)
        if k_val == 0:
            sparse_state[k] = torch.zeros_like(v)
            continue
        _, idxs = torch.topk(flat.abs(), k_val)
        mask = torch.zeros_like(flat)
        mask[idxs] = 1.0
        sparse_flat = flat * mask
        sparse_state[k] = sparse_flat.view(v.shape)
    return sparse_state

# --------- Federated Training Utilities ----------
def train(model, loader, optimizer, criterion, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

def similarity_score(state1, state2):
    norm_diff = 0.0
    norm_ref = 0.0
    for k in state1:
        diff = state1[k] - state2[k]
        norm_diff += torch.norm(diff).item()
        norm_ref += torch.norm(state2[k]).item()
    return 1 - (norm_diff / norm_ref)

# --------- Main Federated Learning Process ----------
def federated_training(
        optimizer_name='adam',
        num_clients=10,
        max_batch_size=100,
        rounds=10,
        bandwidth_mb_per_s=20,
        compression_ratio=0.4
    ):
    client_powers = torch.rand(num_clients)
    max_power = client_powers.max()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    client_data = random_split(full_dataset, [len(full_dataset)//num_clients]*num_clients)

    global_model = MinimalMLPModel()
    criterion = nn.CrossEntropyLoss()

    global_accs, comm_volumes, agg_times, net_times, conv_speeds = [], [], [], [], []
    prev_global_acc = 0.0

    total_params = sum(p.numel() for p in global_model.parameters())

    print(f"\n=== Running Federated Training with {optimizer_name.upper()} ===")
    for rnd in range(1, rounds+1):
        print(f"\n=== Round {rnd} ===")
        selected_states = []
        inclusion_threshold = 0.65 if rnd > 1 else 0.20

        for i in range(num_clients):
            ratio = (client_powers[i] / max_power).item()
            batch_size = max(1, int(ratio * max_batch_size))
            print(f"Client {i}: power={client_powers[i]:.3f} → batch_size={batch_size}")

            local_model = MinimalMLPModel()
            local_model.load_state_dict(global_model.state_dict())

            if optimizer_name == 'adam':
                optimizer = optim.Adam(local_model.parameters(), lr=0.001)
            elif optimizer_name == 'sgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.0)
            elif optimizer_name == 'nag':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.9, nesterov=True)
            elif optimizer_name == 'mgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.99)
            else:
                raise ValueError(f"Unknown optimizer: {optimizer_name}")

            train_loader = DataLoader(client_data[i], batch_size=batch_size, shuffle=True)
            train(local_model, train_loader, optimizer, criterion, epochs=1)

            sim = similarity_score(local_model.state_dict(), global_model.state_dict())
            print(f"  → similarity with global model: {sim:.2%}", end='')

            if sim > inclusion_threshold:
                print(" ✅ included")
                sparse_state = sparsify_state_dict(local_model.state_dict(), compression_ratio)
                selected_states.append(sparse_state)
            else:
                print(" ❌ excluded")

        sent_clients = len(selected_states)
        comm_mb = sent_clients * total_params * compression_ratio * 4 / 1e6
        net_time = comm_mb / bandwidth_mb_per_s

        start = time.perf_counter()
        new_state = {}
        for key in global_model.state_dict().keys():
            stack = torch.stack([state[key] for state in selected_states], dim=0)
            new_state[key] = stack.mean(dim=0)
        global_model.load_state_dict(new_state)
        agg_time = time.perf_counter() - start

        test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
        test_loader = DataLoader(test_ds, batch_size=64)
        global_acc = evaluate(global_model, test_loader)

        global_accs.append(global_acc)
        comm_volumes.append(comm_mb)
        net_times.append(net_time)
        agg_times.append(agg_time)
        if rnd > 1:
            conv_speeds.append(global_acc - prev_global_acc)
        prev_global_acc = global_acc

        print(f"\nGlobal Accuracy         : {global_acc:.2%}")
        print(f"Selected clients        : {sent_clients}/{num_clients}")
        print(f"Comm Volume             : {comm_mb:.2f} MB")
        print(f"Simulated Net Time      : {net_time:.2f} s")
        print(f"Server Aggregation Time : {agg_time:.4f} s")

    rounds_idx = list(range(1, rounds+1))


if __name__ == '__main__':
    for opt in ['sgd', 'nag', 'mgd', 'adam']:
        federated_training(
            optimizer_name=opt,
            num_clients=10,
            max_batch_size=100,
            rounds=10,
            bandwidth_mb_per_s=20,
            compression_ratio=0.4  # 10% sparsity
        )



=== Running Federated Training with SGD ===

=== Round 1 ===
Client 0: power=0.341 → batch_size=36
  → similarity with global model: 96.07% ✅ included
Client 1: power=0.004 → batch_size=1
  → similarity with global model: 48.82% ✅ included
Client 2: power=0.008 → batch_size=1
  → similarity with global model: 48.54% ✅ included
Client 3: power=0.117 → batch_size=12
  → similarity with global model: 88.45% ✅ included
Client 4: power=0.935 → batch_size=100
  → similarity with global model: 98.55% ✅ included
Client 5: power=0.635 → batch_size=67
  → similarity with global model: 97.83% ✅ included
Client 6: power=0.443 → batch_size=47
  → similarity with global model: 96.97% ✅ included
Client 7: power=0.245 → batch_size=26
  → similarity with global model: 94.66% ✅ included
Client 8: power=0.232 → batch_size=24
  → similarity with global model: 94.17% ✅ included
Client 9: power=0.331 → batch_size=35
  → similarity with global model: 95.94% ✅ included

Global Accuracy         : 75.02%
Selec