<a href="https://colab.research.google.com/github/Rujengelal/MSCThesis/blob/main/MSCThesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Installing dependencies

In [2]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

In [3]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

from pathlib import Path
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cuda
Flower 1.19.0 / PyTorch 2.6.0+cu124


It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device("cpu")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`.

### Data loading

Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`.

In [4]:
def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="uoft-cs/cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

### Model training/evaluation

Let's continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:

In [5]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

### Flower client

To implement the Flower client, we (again) create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`. Here, we also pass the `partition_id` to the client and use it log additional details. We then create an instance of `ClientApp` and pass it the `client_fn`.

In [6]:
# class FlowerClient(NumPyClient):
#     def __init__(self, partition_id, net, trainloader, valloader):
#         self.partition_id = partition_id
#         self.net = net
#         self.trainloader = trainloader
#         self.valloader = valloader

#     def get_parameters(self, config):
#         print(f"[Client {self.partition_id}] get_parameters")
#         return get_parameters(self.net)

#     def fit(self, parameters, config):
#         print(f"[Client {self.partition_id}] fit, config: {config}")
#         set_parameters(self.net, parameters)
#         train(self.net, self.trainloader, epochs=1)
#         return get_parameters(self.net), len(self.trainloader), {}

#     def evaluate(self, parameters, config):
#         print(f"[Client {self.partition_id}] evaluate, config: {config}")
#         set_parameters(self.net, parameters)
#         loss, accuracy = test(self.net, self.valloader)
#         return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


# def client_fn(context: Context) -> Client:
#     net = Net().to(DEVICE)
#     partition_id = context.node_config["partition-id"]
#     num_partitions = context.node_config["num-partitions"]
#     trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
#     return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# # Create the ClientApp
# client = ClientApp(client_fn=client_fn)

In [7]:
def distillation(teacher,student,data_loader,distillation_epochs: int,L,T):
    criterion = torch.nn.CrossEntropyLoss()
    KLLoss=nn.KLDivLoss()
    optimizer = torch.optim.Adam(student.parameters())
    student.train()
    for epoch in range(distillation_epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in data_loader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()

            teacher_output=teacher(images)
            teacher_soft_labels=F.softmax(teacher_output/T,dim=1)

            student_output=student(images)
            student_soft_labels=F.softmax(student_output/T,dim=1)


            loss = KLLoss(student_soft_labels,teacher_soft_labels)*L*T*T +criterion(student_output, labels)*(1-L)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(student_output.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(data_loader.dataset)
        epoch_acc = correct / total
        print(f"Distillation Epoch {epoch+1}: T {T}, L {L}, loss {epoch_loss}, accuracy {epoch_acc}")

In [8]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.teacher_model=None

        self.load_teacher()



    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)

        #Set the teacher model based on the global model on first run
        if self.teacher_model is None:
          torch.save(self.net.state_dict(), self.teacher_model_path)
          self.load_teacher()

        train(self.net, self.trainloader, epochs=1)

        temperatures = [1, 5, 25]
        imitation_parameters = [0, 0.25, 0.5, 0.75, 1.0]

        for T in temperatures:
          for L in imitation_parameters:
            student=deepcopy(self.net).to(DEVICE)
            distillation(self.teacher_model,student,self.trainloader,distillation_epochs=1,L=L,T=T)

        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

    def load_teacher(self):
        #Make directory for storing teacher models
        #Load teacher model if it exists
        teacher_models_dir=Path("teacher_models")
        teacher_models_dir.mkdir(parents=True,exist_ok=True)
        self.teacher_model_path=teacher_models_dir/f"teacher_model_{self.partition_id}.pt"
        if self.teacher_model_path.exists():
          self.teacher_model=deepcopy(self.net)
          self.teacher_model.load_state_dict(torch.load(self.teacher_model_path))
          self.teacher_model.eval()
          self.teacher_model=self.teacher_model.to(DEVICE)


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
NUM_PARTITIONS = 10


def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    # If no strategy is provided, by default, ServerAppComponents will use FedAvg
    return ServerAppComponents(config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 0.1,"num_cpus": 1}}

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=3151)[0m 2025-07-15 20:06:54.249499: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=3151)[0m E0000 00:00:1752610014.283715    3151 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=3151)[0m E0000 00:00:1752610014.294423    3151 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=3152)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppAct

[36m(ClientAppActor pid=3152)[0m [Client 6] get_parameters
