In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# --------- Minimal MLP Model ----------
class MinimalMLPModel(nn.Module):
    def __init__(self):
        super(MinimalMLPModel, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

# --------- Federated Training Utilities ----------
def train(model, loader, optimizer, criterion, epochs=1):
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

def similarity_score(state1, state2):
    norm_diff = 0.0
    norm_ref = 0.0
    for k in state1:
        diff = state1[k] - state2[k]
        norm_diff += torch.norm(diff).item()
        norm_ref += torch.norm(state2[k]).item()
    return 1 - (norm_diff / norm_ref)

# --------- INT8 Quantization Utilities ----------
def quantize_to_int8(tensor):
    scale = tensor.abs().max() / 127.0
    if scale == 0:
        return torch.zeros_like(tensor, dtype=torch.int8), scale
    q_tensor = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
    return q_tensor, scale

def dequantize_from_int8(q_tensor, scale):
    return q_tensor.to(torch.float32) * scale

def quantize_state(state_dict):
    return {k: quantize_to_int8(v) for k, v in state_dict.items()}

def dequantize_state(quantized_state):
    return {k: dequantize_from_int8(v[0], v[1]) for k, v in quantized_state.items()}

# --------- Main Federated Learning Process ----------
def federated_training(
        optimizer_name='adam',
        num_clients=10,
        max_batch_size=100,
        rounds=10,
        bandwidth_mb_per_s=20
    ):
    client_powers = torch.rand(num_clients)
    max_power = client_powers.max()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    client_data = random_split(full_dataset, [len(full_dataset)//num_clients]*num_clients)

    global_model = MinimalMLPModel()
    criterion = nn.CrossEntropyLoss()

    global_accs, comm_volumes, agg_times, net_times, conv_speeds = [], [], [], [], []
    prev_global_acc = 0.0

    total_params = sum(p.numel() for p in global_model.parameters())
    print(f"\n=== Running Federated Training with {optimizer_name.upper()} (INT8 Quantization) ===")
    for rnd in range(1, rounds+1):
        print(f"\n=== Round {rnd} ===")
        selected_states = []

        inclusion_threshold = 0.20 if rnd == 1 else 0.65

        for i in range(num_clients):
            ratio = (client_powers[i] / max_power).item()
            batch_size = max(1, int(ratio * max_batch_size))
            print(f"Client {i}: power={client_powers[i]:.3f} → batch_size={batch_size}")

            local_model = MinimalMLPModel()
            local_model.load_state_dict(global_model.state_dict())

            if optimizer_name == 'adam':
                optimizer = optim.Adam(local_model.parameters(), lr=0.001)
            elif optimizer_name == 'sgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.0)
            elif optimizer_name == 'nag':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.9, nesterov=True)
            elif optimizer_name == 'mgd':
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.99)
            else:
                raise ValueError(f"Unknown optimizer: {optimizer_name}")

            train_loader = DataLoader(client_data[i], batch_size=batch_size, shuffle=True)
            train(local_model, train_loader, optimizer, criterion, epochs=1)

            sim = similarity_score(local_model.state_dict(), global_model.state_dict())
            print(f"  → similarity with global model: {sim:.2%}", end='')

            if sim > inclusion_threshold:
                print(" ✅ included")
                q_state = quantize_state(local_model.state_dict())
                selected_states.append(q_state)
            else:
                print(" ❌ excluded")

        sent_clients = len(selected_states)
        comm_mb = sent_clients * total_params * 1 / 1e6  # 1 byte per parameter (int8)
        net_time = comm_mb / bandwidth_mb_per_s

        start = time.perf_counter()
        new_state = {}
        for key in global_model.state_dict().keys():
            # Aggregate dequantized tensors
            deq_tensors = [dequantize_state(qs)[key] for qs in selected_states]
            stack = torch.stack(deq_tensors, dim=0)
            new_state[key] = stack.mean(dim=0)
        global_model.load_state_dict(new_state)
        agg_time = time.perf_counter() - start

        test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
        test_loader = DataLoader(test_ds, batch_size=64)
        global_acc = evaluate(global_model, test_loader)

        global_accs.append(global_acc)
        comm_volumes.append(comm_mb)
        net_times.append(net_time)
        agg_times.append(agg_time)
        if rnd > 1:
            conv_speeds.append(global_acc - prev_global_acc)
        prev_global_acc = global_acc

        print(f"\nGlobal Accuracy         : {global_acc:.2%}")
        print(f"Selected clients        : {sent_clients}/{num_clients}")
        print(f"Comm Volume             : {comm_mb:.2f} MB")
        print(f"Simulated Net Time      : {net_time:.2f} s")
        print(f"Server Aggregation Time : {agg_time:.4f} s")

    rounds_idx = list(range(1, rounds+1))



if __name__ == '__main__':
    for opt in ['sgd', 'mgd', 'nag', 'adam']:
        federated_training(
            optimizer_name=opt,
            num_clients=10,
            max_batch_size=100,
            rounds=10,
            bandwidth_mb_per_s=20
        )



=== Running Federated Training with SGD (INT8 Quantization) ===

=== Round 1 ===
Client 0: power=0.054 → batch_size=6
  → similarity with global model: 78.15% ✅ included
Client 1: power=0.549 → batch_size=61
  → similarity with global model: 97.72% ✅ included
Client 2: power=0.434 → batch_size=48
  → similarity with global model: 97.14% ✅ included
Client 3: power=0.370 → batch_size=41
  → similarity with global model: 96.65% ✅ included
Client 4: power=0.552 → batch_size=62
  → similarity with global model: 97.76% ✅ included
Client 5: power=0.270 → batch_size=30
  → similarity with global model: 95.47% ✅ included
Client 6: power=0.545 → batch_size=61
  → similarity with global model: 97.71% ✅ included
Client 7: power=0.292 → batch_size=32
  → similarity with global model: 95.80% ✅ included
Client 8: power=0.268 → batch_size=30
  → similarity with global model: 95.49% ✅ included
Client 9: power=0.887 → batch_size=100
  → similarity with global model: 98.61% ✅ included

Global Accuracy  