In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from pathlib import Path
import tarfile
from typing import Any
from logging import INFO, DEBUG
from collections import defaultdict, OrderedDict
from collections.abc import Sequence, Callable
import numbers
import json

import numpy as np
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from enum import IntEnum
import flwr
from flwr.server import History, ServerConfig
from flwr.server.strategy import FedAvgM as FedAvg, Strategy
from flwr.common import log, NDArrays, Scalar, Parameters, ndarrays_to_parameters
from flwr.client.client import Client

import matplotlib.pyplot as plt


from common.client_utils import (
    Net,
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    train_femnist,
    test_femnist,
    save_history,
)


# Add new seeds here for easy autocomplete
class Seeds(IntEnum):
    """Seeds for reproducibility."""

    DEFAULT = 1337


np.random.seed(Seeds.DEFAULT)
random.seed(Seeds.DEFAULT)
torch.manual_seed(Seeds.DEFAULT)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


PathType = Path | str | None


def get_device() -> str:
    """Get the device (cuda, mps, cpu)."""
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = "mps"
    return device

In [14]:
home_dir = Path.cwd()
dataset_dir: Path = home_dir / "femnist"
data_dir: Path = dataset_dir / "data"
centralized_partition: Path = dataset_dir / "client_data_mappings" / "centralized"
centralized_mapping: Path = dataset_dir / "client_data_mappings" / "centralized" / "0"
federated_partition: Path = dataset_dir / "client_data_mappings" / "fed_natural"

# Decompress dataset
if not dataset_dir.exists():
    with tarfile.open(home_dir / "femnist.tar.gz", "r:gz") as tar:
        tar.extractall(path=home_dir)
    log(INFO, "Dataset extracted in %s", dataset_dir)

In [15]:
def set_model_parameters(net: Module, parameters: NDArrays) -> Module:
    """Put a set of parameters into the model object."""
    weights = parameters
    params_dict = zip(net.state_dict().keys(), weights, strict=False)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)
    return net


def get_model_parameters(net: Module) -> NDArrays:
    """Get the current model parameters as NDArrays."""
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [16]:
def compute_noise_scale_from_gradients(grad_list, eps=1e-6):
    """
    Compute the noise scale (Bsimple) from a list of gradient vectors.
    
    Parameters:
        grad_list (list[Tensor]): List of gradient vectors.
        eps (float): Small constant for numerical stability.
    
    Returns:
        float: Estimated noise scale.
    """
    try:
        if not grad_list:
            log(DEBUG, "Grad list empty")
            return None

        # Stack gradients: shape (num_batches, num_params)
        grad_stack = torch.stack(grad_list)
        mean_grad = grad_stack.mean(dim=0)
        # Compute average variance per parameter element.
        var_grad = grad_stack.var(dim=0, unbiased=False).mean()
        denom = mean_grad.norm()**2 + eps
        noise_scale = var_grad / denom
        return noise_scale.item()
    except Exception as e:
        log(DEBUG, "Error in compute_noise_scale_from_gradients: %s", e)
        return None

def get_gradient_vector(model, data, target, loss_fn, device):
    model.zero_grad()
    data, target = data.to(device), target.to(device)
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    
    grads = []
    for p in model.parameters():
        if p.grad is not None:
            grads.append(p.grad.view(-1))
    if grads:
        return torch.cat(grads)
    log(DEBUG, "No gradients found")
    return None

def collect_gradients(model, train_loader, device, criterion, num_mini_batches):
    grad_vectors = []
    for i, (data, target) in enumerate(train_loader):
        if i >= num_mini_batches:
            break
        grad_vector = get_gradient_vector(model, data, target, criterion, device)
        if grad_vector is not None:
            grad_vectors.append(grad_vector)
    return grad_vectors

In [17]:
class FlowerRayClient(flwr.client.NumPyClient):
    """Flower client for the FEMNIST dataset."""

    def __init__(
        self,
        cid: int,
        partition_dir: Path,
        model_generator: Callable[[], Module],
    ) -> None:
        """Init the client with its unique id and the folder to load data from.

        Parameters:
            cid (int): Unique client id for a client used to map it to its data
                partition
            partition_dir (Path): The directory containing data for each
                client/client id
            model_generator (Callable[[], Module]): The model generator function
        
        """
        self.cid = cid
        log(INFO, "cid: %s", self.cid)
        self.partition_dir = partition_dir
        self.device = str(
            torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        )
        self.model_generator: Callable[[], Module] = model_generator
        self.properties: dict[str, Scalar] = {"tensor_type": "numpy.ndarray"}

    def set_parameters(self, parameters: NDArrays) -> Module:
        """Load weights inside the network."""
        net = self.model_generator()
        return set_model_parameters(net, parameters)

    def get_parameters(self, config: dict[str, Scalar]) -> NDArrays:
        """Return weights from a given model.

        If no model is passed, then a local model is created.
        This can be used to initialise a model in the
        server.
        The config param is not used but is mandatory in Flower.

        """
        net = self.model_generator()
        return get_model_parameters(net)

    def fit(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[NDArrays, int, dict]:
        """Receive and train a model on the local client data."""
        # Only create model right before training/testing
        # To lower memory usage when idle
        try:
            net = self.set_parameters(parameters)
            net.to(self.device)


            train_loader: DataLoader = self._create_data_loader(config, name="train")
            train_loss = self._train(net, train_loader=train_loader, config=config)

            # Compute gradients
            # Collect gradients for noise scale estimation.
            grad_vectors = collect_gradients(net, train_loader, self.device, torch.nn.CrossEntropyLoss(), 5)
            # Compute local noise scale (Bsimple) on this client.
            local_noise_scale = compute_noise_scale_from_gradients(grad_vectors)
            return get_model_parameters(net), len(train_loader), {"train_loss": train_loss, "noise_scale": local_noise_scale}
        except Exception as e:
            log(DEBUG, f"---------------------- A client raised error: {e}: {self.cid}")

    def evaluate(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[float, int, dict]:
        """Receive and test a model on the local client data."""
        net = self.set_parameters(parameters)
        net.to(self.device)

        test_loader: DataLoader = self._create_data_loader(config, name="test")
        loss, accuracy = self._test(net, test_loader=test_loader, config=config)
        return loss, len(test_loader), {"local_accuracy": accuracy}

    def _create_data_loader(self, config: dict[str, Scalar], name: str) -> DataLoader:
        """Create the data loader using the specified config parameters."""
        batch_size = int(config["batch_size"])
        num_workers = int(config["num_workers"])
        dataset = self._load_dataset(name)
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            drop_last=(name == "train"),
        )

    def _load_dataset(self, name: str) -> Dataset:
        full_file: Path = self.partition_dir / str(self.cid)
        return load_femnist_dataset(
            mapping=full_file,
            name=name,
            data_dir=data_dir,
        )

    def _train(
        self, net: Module, train_loader: DataLoader, config: dict[str, Scalar]
    ) -> float:
        return train_femnist(
            net=net,
            train_loader=train_loader,
            epochs=int(config["epochs"]),
            device=self.device,
            optimizer=torch.optim.AdamW(
                net.parameters(),
                lr=float(config["client_learning_rate"]),
                weight_decay=float(config["weight_decay"]),
            ),
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
            cid=self.cid,
        )

    def _test(
        self, net: Module, test_loader: DataLoader, config: dict[str, Scalar]
    ) -> tuple[float, float]:
        return test_femnist(
            net=net,
            test_loader=test_loader,
            device=self.device,
            criterion=torch.nn.CrossEntropyLoss(),
            max_batches=int(config["max_batches"]),
        )

    def get_properties(self, config: dict[str, Scalar]) -> dict[str, Scalar]:
        """Return properties for this client."""
        return self.properties

    def get_train_set_size(self) -> int:
        """Return the client train set size."""
        return len(self._load_dataset("train"))  # type: ignore[reportArgumentType]

    def get_test_set_size(self) -> int:
        """Return the client test set size."""
        return len(self._load_dataset("test"))  # type: ignore[reportArgumentType]


def fit_client_seeded(
    client: FlowerRayClient,
    params: NDArrays,
    conf: dict[str, Any],
    seed: Seeds = Seeds.DEFAULT,
    **kwargs: Any,
) -> tuple[NDArrays, int, dict]:
    """Wrap to always seed client training."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    return client.fit(params, conf, **kwargs)

In [18]:
def get_flower_client_generator(
    model_generator: Callable[[], Module],
    partition_dir: Path,
    mapping_fn: Callable[[int], int] | None = None,
) -> Callable[[str], FlowerRayClient]:
    """Wrap the client instance generator.

    A mapping function could be used for filtering/ordering clients.

    Parameters
    ----------
        model_generator (Callable[[], Module]): model generator function.
        partition_dir (Path): directory containing the partition.
        mapping_fn (Optional[Callable[[int], int]]): function mapping sorted/filtered
            ids to real cid.
    """

    def client_fn(cid: str) -> FlowerRayClient:
        """Create a single client instance given the client id `cid`."""
        return FlowerRayClient(
            cid=mapping_fn(int(cid)) if mapping_fn is not None else int(cid),
            partition_dir=partition_dir,
            model_generator=model_generator,
        )

    return client_fn

In [19]:
def compute_critical_batch(noise_scales: list, constant: float = 1.0) -> float:
    # simple avg of noise scales
    avg_noise_scale = np.mean(noise_scales)
    eps = 1e-8
    
    # Computing an estimated critical batch size (Bcrit) using a simple heuristic.
    critical_batch_size = constant / (avg_noise_scale + eps)
    return critical_batch_size

## Experiment code

In [20]:
def centralized_experiment(centralized_train_cfg, centralized_test_cfg, train_loader, test_loader, device):
    model = network_generator().to(device)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=centralized_train_cfg["client_learning_rate"],
        weight_decay=centralized_train_cfg["weight_decay"]
        )
    criterion = nn.CrossEntropyLoss()

    epoch_accuracies = []
    epoch_losses = []
    epoch_noise_scales = []

    for epoch in range(centralized_train_cfg["epochs"]):
        model.train()
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            if batch_idx >= centralized_train_cfg["max_batches"]:
                break
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * data.size(0)
        running_loss /= len(train_loader.dataset)
        epoch_losses.append(running_loss)

        # collect gradients over a few mini-batches
        grad_vectors = collect_gradients(model, train_loader, device, criterion, 5) 
        noise_scale = compute_noise_scale_from_gradients(grad_vectors)
        epoch_noise_scales.append(noise_scale)

        # Evaluate the trained model
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                if batch_idx >= centralized_test_cfg["max_batches"]:
                    break
                data, target = data.to(device), target.to(device)
                output = model(data)
                preds = output.argmax(dim=1)
                correct += (preds == target).sum().item()
                total += target.size(0)
        accuracy = correct / total
        epoch_accuracies.append(accuracy)

        log(INFO, f"Epoch {epoch+1}/{centralized_train_cfg['epochs']}, Loss: {running_loss:.4f}, "
              f"Noise scale: {noise_scale:.4e}, Accuracy: {accuracy*100:.2f}%")
    
    return {
        "accuracies": epoch_accuracies,
        "losses": epoch_losses,
        "noise_scales": epoch_noise_scales,
    }


In [23]:
network_generator = get_network_generator()

# Load the centralized dataset using the same function as in FL.
# The centralized mapping folder should be the one used in your FL centralized experiment.
centralized_train_dataset = load_femnist_dataset(data_dir=data_dir, mapping=centralized_mapping, name="train")
centralized_test_dataset = load_femnist_dataset(data_dir=data_dir, mapping=centralized_mapping, name="test")

# Use the same configuration parameters as in your FL config.
centralized_train_config: dict[str, Any] = {
    "epochs": 1, # we have 5 epochs * 10 rounds in FL
    "batch_size": 32,
    "client_learning_rate": 0.01,
    "weight_decay": 0.001,
    "num_workers": 0,
    "max_batches": 100,
}

centralized_test_config: dict[str, Any] = {
    "batch_size": 32,
    "num_workers": 0,
    "max_batches": 100,
}


In [None]:
experiment_batch_sizes = [8, 16, 32, 64, 128]

centralized_results = []

for batch_size in experiment_batch_sizes:

    train_cfg = centralized_train_config.copy()
    train_cfg["batch_size"] = batch_size

    test_cfg = centralized_test_config.copy()
    test_cfg["batch_size"] = batch_size

    # Create DataLoaders with the same settings.
    centralized_train_loader = DataLoader(
        dataset=centralized_train_dataset,
        batch_size=train_cfg["batch_size"],
        shuffle=True,                # Shuffle for training
        num_workers=train_cfg["num_workers"],
        drop_last=True,              # If FL training drops last batch, do the same here.
    )

    centralized_test_loader = DataLoader(
        dataset=centralized_test_dataset,
        batch_size=test_cfg["batch_size"],
        shuffle=False,               # No shuffling during evaluation
        num_workers=test_cfg["num_workers"],
        drop_last=False,
    )

    no_fl_results = centralized_experiment(centralized_train_config, centralized_test_config, centralized_train_loader, centralized_test_loader, get_device())

    centralized_results.append((batch_size, no_fl_results))

In [None]:
critical_batches = []

# Create side-by-side subplots
fig, axes = plt.subplots(1, 2, figsize=(20, 6))

# Left subplot: Accuracy vs Epoch for each batch size configuration
for batch_size, results in centralized_results:
    axes[0].plot(results["accuracies"], label=f"Batch size: {batch_size}")
    bcrit = compute_critical_batch(results["noise_scales"], constant=0.01)
    critical_batches.append((batch_size, bcrit))
    # Logging information (optional)
    log(INFO, f"Batch size: {batch_size}")
    log(INFO, f"Accuracies: {results['accuracies']}")
    log(INFO, f"Losses: {results['losses']}")
    log(INFO, f"Noise scales: {results['noise_scales']}")
    log(INFO, f"Critical batch size: {bcrit}")

axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].set_title("Centralized Training: Accuracy vs Epoch")
axes[0].legend()
axes[0].grid(True)

# Right subplot: Critical Batch Size vs Batch Size
batch_sizes = [bs for bs, _ in critical_batches]
bcrit_values = [bcrit for _, bcrit in critical_batches]

axes[1].plot(batch_sizes, bcrit_values, marker='o')
axes[1].set_xlabel("Batch Size")
axes[1].set_ylabel("Critical Batch Size")
axes[1].set_title("Critical Batch Size vs Batch Size")
axes[1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
with open('centralized_results.json', 'w') as f:
    json.dump(centralized_results, f)

In [None]:
with open('centralized_results.json', 'r') as f:
    centralized_results = json.load(f)