In [1]:
import torch

from federated import data_utils
from federated.methods import FEDERATION_METHODS
from federated.visualization import plot_accuracy_histories, plot_drift_histories
from federated.models import SimpleCNN

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


Using device: cuda


In [2]:
def run_federated_experiment(
    method_name: str,
    *,
    alpha: float,
    num_rounds: int,
    local_epochs: int,
    num_clients: int = 10,
    batch_size: int = 32,
    lr: float = 0.01,
    client_fraction: float = 1.0,
    seed: int = 42,
):
    """Dispatch to the requested federated method and return its logs."""

    method_key = method_name.lower()
    if method_key not in FEDERATION_METHODS:
        raise ValueError(f"Unknown federated method: {method_name}")

    train_dataset, test_dataset = data_utils.build_cifar10_datasets()
    test_loader = data_utils.get_test_loader(test_dataset)

    method = FEDERATION_METHODS[method_key]
    return method(
        model_class=SimpleCNN,
        train_dataset=train_dataset,
        test_loader=test_loader,
        num_clients=num_clients,
        alpha=alpha,
        batch_size=batch_size,
        num_rounds=num_rounds,
        local_epochs=local_epochs,
        lr=lr,
        client_fraction=client_fraction,
        device=DEVICE,
        seed=seed,
    )


In [None]:
from base_params import *

def run_multi_alpha_experiment(
    alphas = ALPHAS,
    num_rounds: int = NUM_ROUNDS,
    local_epochs: int = K_VALUE,
    num_clients: int = NUM_CLIENTS,
    batch_size: int = LOCAL_BATCH_SIZE,
    lr: float = LEARNING_RATE,
    client_fraction: float = K_FRAC_EXP,
    seed: int = 42,
):
    """Run FedAvg for each alpha and plot accuracy/drift histories."""

    accuracy_histories = {}
    drift_histories = {}
    for alpha in alphas:
        print(f"=== FedAvg run (alpha={alpha}) ===")
        result = run_federated_experiment(
            method_name="fedavg",
            alpha=alpha,
            num_rounds=num_rounds,
            local_epochs=local_epochs,
            num_clients=num_clients,
            batch_size=batch_size,
            lr=lr,
            client_fraction=client_fraction,
            seed=seed,
        )
        accuracy_histories[f"α={alpha}"] = result["accuracy"]
        drift_histories[f"α={alpha}"] = result["drift"]

    plot_accuracy_histories(
        accuracy_histories,
        title=f"FedAvg Accuracy vs α (K={local_epochs}, R={num_rounds})",
        save_path="fedavg_dirichlet_acc.png",
    )
    plot_drift_histories(
        drift_histories,
        title=f"Client Drift vs α (K={local_epochs}, R={num_rounds})",
        save_path="fedavg_dirichlet_drift.png",
    )

    return accuracy_histories, drift_histories


accuracy_histories, drift_histories = run_multi_alpha_experiment()


Using device: cuda
=== FedAvg run (alpha=100.0) ===


  2%|▏         | 4.26M/170M [00:40<1:26:13, 32.1kB/s]