<a href="https://colab.research.google.com/github/alansshots/federated-learning-institute-project/blob/main/Straggler_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torchmetrics


Collecting torchmetrics
  Downloading torchmetrics-1.7.3-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

In [3]:
!pip install torch torchvision



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, F1Score
from torchvision import datasets, transforms
from copy import deepcopy
import numpy as np
import torchvision.models as models
from tqdm.autonotebook import tqdm
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import time
import random

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def resnet18(num_classes, **kwargs):
    original_model = models.resnet18(**kwargs)
    original_model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)  # For MNIST: grayscale input
    original_model.maxpool = nn.Identity()
    original_model.fc = nn.Linear(512, num_classes)
    return original_model

def uniform_allocation(Y, num_clients):
    indices = np.arange(len(Y))
    np.random.shuffle(indices)
    indices_split = np.array_split(indices, num_clients)
    return [list(idx) for idx in indices_split]

# Hyperparameters
num_clients = 5
batch_size = 32
global_epochs = 2
local_epochs = 1
learning_rate = 1e-2
loss_fn = nn.CrossEntropyLoss()
timeout_seconds = 3

# Model
model = resnet18(10)
client_models = [deepcopy(model).to(device) for _ in range(num_clients)]
client_optims = [optim.SGD(cm.parameters(), lr=learning_rate) for cm in client_models]

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_dataset.data = train_dataset.data[:5000]
train_dataset.targets = train_dataset.targets[:5000]


train_subsets = uniform_allocation(train_dataset.targets, num_clients)
train_subsets = [torch.utils.data.Subset(train_dataset, indices) for indices in train_subsets]
train_subset_dataloaders = [DataLoader(ds, batch_size=batch_size, shuffle=True) for ds in train_subsets]
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def fed_avg(models):
    avg_model = deepcopy(models[0]).to(device)
    avg_state_dict = avg_model.state_dict()
    with torch.no_grad():
        for key in avg_state_dict.keys():
            if not avg_state_dict[key].dtype.is_floating_point:
                avg_state_dict[key].copy_(models[0].state_dict()[key])
                continue
            tmp = torch.zeros_like(avg_state_dict[key])
            for m in models:
                tmp += m.state_dict()[key]
            avg_state_dict[key].copy_(tmp / len(models))
    return avg_model

global_model = deepcopy(model).to(device)
acc_list, f1_list, loss_list = [], [], []

for round in range(global_epochs):
    print(f"\n--- Round {round+1}/{global_epochs} ---")
    participating_clients = []

    for i in range(num_clients):
        client_models[i].load_state_dict(global_model.state_dict())
        client_models[i].train()
        start_time = time.time()

        # --- simulated straggler setup ---
        straggler_delay = random.uniform(1.5, 3) if random.random() < 0.4 else 0

        try:
            for epoch in range(local_epochs):
                for x, y in train_subset_dataloaders[i]:

                    # --- simulate slow clients ---
                    if straggler_delay:
                        time.sleep(straggler_delay / (local_epochs * len(train_subset_dataloaders[i])))

                    # --- timeout check ---
                    if time.time() - start_time > timeout_seconds:
                        raise TimeoutError(f"Client {i} timed out")

                    x, y = x.to(device), y.to(device)
                    client_optims[i].zero_grad()
                    y_pred = client_models[i](x)
                    loss = loss_fn(y_pred, y)
                    loss.backward()
                    client_optims[i].step()

            participating_clients.append(client_models[i])
        except TimeoutError as e:
            print(e)

        end_time = time.time()
        print(f"Client {i} training time: {end_time - start_time:.2f} seconds")

    print(f"{len(participating_clients)} out of {num_clients} clients participated.")

    if participating_clients:
        global_model.load_state_dict(fed_avg(participating_clients).state_dict())
    else:
        print("No clients completed training in this round.")

    global_model.eval()
    correct, total, test_loss = 0, 0, 0
    y_true, y_pred_all = [], []

    with torch.no_grad():
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            logits = global_model(x)
            loss = loss_fn(logits, y)
            test_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred_all.extend(preds.cpu().numpy())

    acc = correct / total
    f1 = f1_score(y_true, y_pred_all, average='macro')
    acc_list.append(acc)
    f1_list.append(f1)
    loss_list.append(test_loss / len(test_dataloader))

    print(f"Accuracy: {acc:.4f}, F1 Score: {f1:.4f}, Loss: {loss_list[-1]:.4f}")

plt.plot(acc_list, label="Accuracy")
plt.plot(f1_list, label="F1 Score")
plt.plot(loss_list, label="Loss")
plt.legend()
plt.title("Federated Learning on MNIST")
plt.show()



--- Round 1/3 ---
Client 0 timed out
Client 0 training time: 30.04 seconds
Client 1 timed out
Client 1 training time: 30.06 seconds
Client 2 timed out
Client 2 training time: 30.06 seconds


KeyboardInterrupt: 