<a href="https://colab.research.google.com/github/DhruboDevPramanik/FL_with_MINST_dataset/blob/main/part1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 1: Install required packages
!pip install torch torchvision flwr

Collecting flwr
  Downloading flwr-1.21.0-py3-none-any.whl.metadata (15 kB)
Collecting click<8.2.0 (from flwr)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting cryptography<45.0.0,>=44.0.1 (from flwr)
  Downloading cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl.metadata (5.7 kB)
Collecting grpcio-health-checking<2.0.0,>=1.62.3 (from flwr)
  Downloading grpcio_health_checking-1.75.0-py3-none-any.whl.metadata (1.0 kB)
Collecting iterators<0.0.3,>=0.0.2 (from flwr)
  Downloading iterators-0.0.2-py3-none-any.whl.metadata (2.5 kB)
Collecting pathspec<0.13.0,>=0.12.1 (from flwr)
  Downloading pathspec-0.12.1-py3-none-any.whl.metadata (21 kB)
Collecting protobuf<5.0.0,>=4.21.6 (from flwr)
  Downloading protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes)
Collecting pycryptodome<4.0.0,>=3.18.0 (from flwr)
  Downloading pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting tomli<3.0.0,>=2.0.1 

In [None]:
# Cell 2: Import all required libraries
from collections import OrderedDict
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
import flwr as fl
from flwr.common import NDArrays, Scalar
import numpy as np

  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
# Cell 3: Define the neural network model
class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Cell 4: Training and testing functions
def train(net, trainloader, optimizer, epochs, device: str):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    net.to(device)
    net.train()
    for e in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

def test(net, testloader, device: str):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.to(device)
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

In [None]:
# Cell 5: Dataset preparation functions
def get_mnist(data_path: str = "./data"):
    tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    trainset = MNIST(root=data_path, train=True, download=True, transform=tr)
    testset = MNIST(root=data_path, train=False, download=True, transform=tr)
    return trainset, testset

def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):
    trainset, testset = get_mnist()

    # Split the dataset into num_partitions parts
    num_image = len(trainset) // num_partitions
    partition_len = [num_image] * num_partitions

    trainsets = random_split(trainset, partition_len, torch.Generator().manual_seed(2023))

    # Create dataloaders with train + val split
    trainloaders = []
    validateloaders = []
    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(num_total * val_ratio)
        num_train = num_total - num_val

        for_train, for_val = random_split(trainset_, [num_train, num_val], torch.Generator().manual_seed(2023))

        trainloaders.append(DataLoader(for_train, batch_size=batch_size, shuffle=True))
        validateloaders.append(DataLoader(for_val, batch_size=batch_size, shuffle=False))

    testloader = DataLoader(testset, batch_size=128)
    return trainloaders, validateloaders, testloader

In [None]:
# Cell 6: Flower client class with validation tracking
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainloader, valloader, num_classes) -> None:
        super().__init__()

        self.trainloader = trainloader
        self.valloader = valloader

        self.model = Net(num_classes)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config: Dict[str, Scalar] = None):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)

        lr = config["lr"]
        momentum = config["momentum"]
        epochs = config["local_epochs"]

        optim = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum)

        # Do local training
        train(self.model, self.trainloader, optim, epochs, self.device)

        return self.get_parameters(), len(self.trainloader.dataset), {}

    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        self.set_parameters(parameters)

        loss, accuracy = test(self.model, self.valloader, self.device)

        # Print validation accuracy for this client
        print(f"Client validation - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

        return float(loss), len(self.valloader.dataset), {"accuracy": float(accuracy)}

def generate_client_fn(trainloaders, valloaders, num_classes):
    def client_fn(cid: str):
        return FlowerClient(
            trainloader=trainloaders[int(cid)],
            valloader=valloaders[int(cid)],
            num_classes=num_classes,
        )
    return client_fn

In [None]:
# Cell 7: Server functions with validation tracking
def get_on_fit_config(config):
    def fit_config_fn(server_round: int):
        return {
            'lr': config['lr'],
            'momentum': config['momentum'],
            'local_epochs': config['local_epochs'],
        }
    return fit_config_fn

def get_evaluate_fn(num_classes, testloader):
    def evaluate_fn(server_round: int, parameters, config):
        model = Net(num_classes)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

        loss, accuracy = test(model, testloader, device)

        # Print global test accuracy
        print(f"Global test - Round {server_round}: Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

        return loss, {"accuracy": accuracy}
    return evaluate_fn

In [None]:
# Cell 8: Main federated learning function with validation tracking
class SaveMetricsStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(self, server_round, results, failures):
        """Aggregate evaluation accuracy."""
        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

        # Calculate average accuracy
        accuracies = [r.metrics["accuracy"] for r in results]
        avg_accuracy = sum(accuracies) / len(accuracies)

        print(f"Round {server_round} - Average validation accuracy: {avg_accuracy:.4f}")

        return aggregated_loss, aggregated_metrics

def run_federated_learning(server_address="[::]:8080"):
    # Configuration parameters
    config = {
        'num_clients': 10,
        'batch_size': 32,
        'num_clients_per_round_fit': 3,
        'num_clients_per_round_eval': 3,
        'num_rounds': 5,
        'num_classes': 10,
        'config_fit': {
            'lr': 0.01,
            'momentum': 0.9,
            'local_epochs': 1
        }
    }

    print("Configuration:")
    for key, value in config.items():
        print(f"{key}: {value}")

    # Prepare dataset
    print("Preparing dataset...")
    trainloaders, validateloaders, testloader = prepare_dataset(
        config['num_clients'], config['batch_size']
    )

    # Define clients
    client_fn = generate_client_fn(trainloaders, validateloaders, num_classes=10)

    # Define strategy with validation tracking
    strategy = SaveMetricsStrategy(
        fraction_fit=1.0,
        min_fit_clients=config['num_clients_per_round_fit'],
        fraction_evaluate=1.0,
        min_evaluate_clients=config['num_clients_per_round_eval'],
        min_available_clients=config['num_clients'],
        on_fit_config_fn=get_on_fit_config(config['config_fit']),
        evaluate_fn=get_evaluate_fn(config['num_classes'], testloader)
    )

    # Start Flower server
    print("Starting Flower server...")

    # Create client manager
    client_manager = fl.server.SimpleClientManager()

    # Initialize server
    server = fl.server.Server(client_manager=client_manager, strategy=strategy)

    # Start server
    fl.server.start_server(
        server_address=server_address,
        config=fl.server.ServerConfig(num_rounds=config['num_rounds']),
        strategy=strategy,
        client_manager=client_manager
    )

In [None]:
# Cell 9: Alternative simple approach with validation tracking
def run_simple_federated_learning():
    print("Running simplified federated learning...")

    # Configuration
    num_clients = 3
    batch_size = 32
    num_rounds = 2

    # Prepare dataset
    trainloaders, validateloaders, testloader = prepare_dataset(num_clients, batch_size)

    # Create a simple federated learning loop
    global_model = Net(10)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for round in range(num_rounds):
        print(f"\nRound {round + 1}/{num_rounds}")
        print("-" * 50)

        # Train on each client
        client_models = []
        client_accuracies = []

        for i in range(num_clients):
            print(f"  Training client {i + 1}/{num_clients}")
            client = FlowerClient(trainloaders[i], validateloaders[i], 10)

            # Get global parameters
            global_params = [val.cpu().numpy() for _, val in global_model.state_dict().items()]

            # Train client
            client_params, _, _ = client.fit(global_params, {"lr": 0.01, "momentum": 0.9, "local_epochs": 1})
            client_models.append(client_params)

            # Test client on validation set
            client_loss, client_accuracy = test(client.model, validateloaders[i], device)
            client_accuracies.append(client_accuracy)
            print(f"  Client {i + 1} validation accuracy: {client_accuracy:.4f}")

        # Print average client validation accuracy
        avg_client_accuracy = sum(client_accuracies) / len(client_accuracies)
        print(f"  Average client validation accuracy: {avg_client_accuracy:.4f}")

        # Average client models (simple FedAvg)
        averaged_params = []
        for i in range(len(client_models[0])):
            layer_params = []
            for client_params in client_models:
                layer_params.append(client_params[i])
            averaged_layer = np.mean(layer_params, axis=0)
            averaged_params.append(averaged_layer)

        # Update global model
        params_dict = zip(global_model.state_dict().keys(), averaged_params)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        global_model.load_state_dict(state_dict, strict=True)

        # Test global model on test set
        loss, accuracy = test(global_model, testloader, device)
        print(f"  Global model test accuracy: {accuracy:.4f}")

        # Test global model on validation sets (average across all clients)
        global_val_accuracies = []
        for i in range(num_clients):
            val_loss, val_accuracy = test(global_model, validateloaders[i], device)
            global_val_accuracies.append(val_accuracy)

        avg_global_val_accuracy = sum(global_val_accuracies) / len(global_val_accuracies)
        print(f"  Global model average validation accuracy: {avg_global_val_accuracy:.4f}")

In [None]:
# Cell 10: Main function
def main():
    try:
        # Try the simulation approach first
        print("Starting federated learning with Flower...")
        run_federated_learning()
    except SystemExit as e:
        print(f"Simulation approach failed: {e}")
        print("Trying alternative approach...")
        run_simple_federated_learning()
    except ImportError as e:
        print(f"Simulation approach failed: {e}")
        print("Trying alternative approach...")
        run_simple_federated_learning()

if __name__ == "__main__":
    main()

Starting federated learning with Flower...
Configuration:
num_clients: 10
batch_size: 32
num_clients_per_round_fit: 3
num_clients_per_round_eval: 3
num_rounds: 5
num_classes: 10
config_fit: {'lr': 0.01, 'momentum': 0.9, 'local_epochs': 1}
Preparing dataset...


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 335kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.71MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.39MB/s]
	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=5, no rou

Starting Flower server...
Simulation approach failed: Port in server address [::]:8080 is already in use.
Trying alternative approach...
Running simplified federated learning...

Round 1/2
--------------------------------------------------
  Training client 1/3


  return datetime.utcnow().replace(tzinfo=utc)


  Client 1 validation accuracy: 0.9665
  Training client 2/3
  Client 2 validation accuracy: 0.9640
  Training client 3/3
  Client 3 validation accuracy: 0.9635
  Average client validation accuracy: 0.9647
  Global model test accuracy: 0.9720
  Global model average validation accuracy: 0.9677

Round 2/2
--------------------------------------------------
  Training client 1/3
  Client 1 validation accuracy: 0.9720
  Training client 2/3
  Client 2 validation accuracy: 0.9805
  Training client 3/3
  Client 3 validation accuracy: 0.9815
  Average client validation accuracy: 0.9780
  Global model test accuracy: 0.9846
  Global model average validation accuracy: 0.9822
