In [1]:
!git clone https://github.com/Soobiwan/atml-pa4.git
import sys
sys.path.append("/kaggle/working/atml-pa4")

Cloning into 'atml-pa4'...
remote: Enumerating objects: 67, done.[K
remote: Counting objects: 100% (67/67), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 67 (delta 30), reused 53 (delta 16), pack-reused 0 (from 0)[K
Receiving objects: 100% (67/67), 11.62 MiB | 51.07 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [2]:
import json

# def save_log(log: dict, filename: str = "log.json") -> None:
#     """
#     Save a federated learning log dictionary to a JSON file.
    
#     Args:
#         log (dict): The log dictionary, e.g., {"accuracy": [...], "drift": [...]}.
#         filename (str): Name of the JSON file to save.
#     """
#     with open(filename, "w") as f:
#         json.dump(log, f)
#     print(f"Log saved to {filename}")

# import json

def save_log(log, filename: str):
    """
    Save only the accuracy and drift from a federated learning log.
    """
    filtered_log = {
        "accuracy": log.get("accuracy", []),
        "drift": log.get("drift", [])
    }
    with open(filename, "w") as f:
        json.dump(filtered_log, f)
    print(f"Filtered log saved to {filename}")

def load_log(filename: str):
    """
    Load a filtered log saved by save_log_filtered.
    """
    with open(filename, "r") as f:
        return json.load(f)


def load_log(filename: str) -> dict:
    """
    Load a federated learning log dictionary from a JSON file.
    
    Args:
        filename (str): Name of the JSON file to load.
    
    Returns:
        dict: The log dictionary.
    """
    with open(filename, "r") as f:
        log = json.load(f)
    return log

load_log("")

In [3]:
import torch

from federated import data_utils
from federated.methods import FEDERATION_METHODS
from federated.visualization import plot_accuracy_histories, plot_drift_histories
from federated.models import SimpleCNN

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

[INFO] Cloning SAM into external/sam ...


Cloning into 'external/sam'...


Using device: cuda


In [5]:
def run_federated_experiment(
    method_name: str,
    *,
    alpha: float,
    num_rounds: int,
    local_epochs: int,
    num_clients: int = 10,
    batch_size: int = 32,
    lr: float = 0.01,
    client_fraction: float = 1.0,
    seed: int = 42,
):
    """
    Run a federated learning experiment using any supported method:
    fedavg, fedprox, scaffold, gh, fedsam
    """

    method_key = method_name.lower()
    if method_key not in FEDERATION_METHODS:
        raise ValueError(
            f"Unknown federated method: {method_name}. "
            f"Choose from: {list(FEDERATION_METHODS.keys())}"
        )

    # Build dataset
    train_dataset, test_dataset = data_utils.build_cifar10_datasets()
    test_loader = data_utils.get_test_loader(test_dataset)

    # Look up the correct federated method runner
    method = FEDERATION_METHODS[method_key]

    print(f"\n=== Running {method_key.upper()} ===")
    print(f"Clients: {num_clients}, α={alpha}, rounds={num_rounds}, "
          f"K={local_epochs}, lr={lr}, frac={client_fraction}")

    return method(
        model_class=SimpleCNN,
        train_dataset=train_dataset,
        test_loader=test_loader,
        num_clients=num_clients,
        alpha=alpha,
        batch_size=batch_size,
        num_rounds=num_rounds,
        local_epochs=local_epochs,
        lr=lr,
        client_fraction=client_fraction,
        device=DEVICE,
        seed=seed,
    )

# hyperparams

### Alpha
Highly skewed: Alpha = 0.1
Medium skew: Alpha = 1
No skew: Alpha = 100

### Learning Rate = 0.01

### Num rounds = 15 
We chose this as in our runs, we found that accuracy for most models platues around this range, training further leads to overfitting

### Local_epochs = 5
This allows the models to run enough times to highlight any client drift issues, while keeping computation low enough to run locally

# Fed average

Optimizer: SGD, momentum = 0.9, weight_decay = 5e-4

### Highly skewed

In [None]:
fedavg_high_logs = run_federated_experiment("fedavg", alpha=0.1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedavg": fedavg_high_logs["accuracy"]},
    title=f"FedAvg Accuracy vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/fedavg_high_acc.png",
)

plot_drift_histories(
    {"fedavg": fedavg_high_logs["drift"]},
    title=f"Client Drift vs α={0.1} (K={5}, R={15})",
    save_path="fedavg_high_drift.png",
)

save_log(fedavg_high_logs, "fedavg_high_logs.json")


=== Running FEDAVG ===
Clients: 10, α=0.1, rounds=15, K=5, lr=0.01, frac=1.0
Round 1/15 | Acc 38.13% | Drift 0.47%
Round 2/15 | Acc 44.02% | Drift 0.46%
Round 3/15 | Acc 49.77% | Drift 0.38%
Round 4/15 | Acc 52.85% | Drift 0.30%


### Moderately Skewed

In [None]:
fedavg_medium_logs = run_federated_experiment("fedavg", alpha=1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedavg": fedavg_medium_logs["accuracy"]},
    title=f"FedAvg Accuracy vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedavg_medium_acc.png",
)

plot_drift_histories(
    {"fedavg": fedavg_medium_logs["drift"]},
    title=f"Client Drift vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedavg_medium_drift.png",
)

save_log(fedavg_medium_logs, "fedavg_medium_logs.json")

### Not skewed

In [None]:
fedavg_low_logs = run_federated_experiment("fedavg", alpha=100, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedavg": fedavg_low_logs["accuracy"]},
    title=f"FedAvg Accuracy vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedavg_low_acc.png",
)

plot_drift_histories(
    {"fedavg": fedavg_low_logs["drift"]},
    title=f"Client Drift vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedavg_low_drift.png",
)

save_log(fedavg_low_logs, "fedavg_low_logs.json")

# Fed prox

mu = 0.1 - the most commonly reported setting, performing best for cifar-10

### Highly Skewed

In [None]:
fedprox_high_logs = run_federated_experiment("fedprox", alpha=0.1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedprox": fedprox_high_logs["accuracy"]},
    title=f"FedProx Accuracy vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_high_acc.png",
)

plot_drift_histories(
    {"fedprox": fedprox_high_logs["drift"]},
    title=f"Client Drift vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_high_drift.png",
)

save_log(fedprox_high_logs, "fedprox_high_logs.json")

### Moderately Skewed

In [None]:
fedprox_medium_logs = run_federated_experiment("fedprox", alpha=1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedprox": fedprox_medium_logs["accuracy"]},
    title=f"FedProx Accuracy vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_medium_acc.png",
)

plot_drift_histories(
    {"fedprox": fedprox_medium_logs["drift"]},
    title=f"Client Drift vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_medium_drift.png",
)

save_log(fedprox_medium_logs, "fedprox_medium_logs.json")

### Not skewed

In [None]:
fedprox_low_logs = run_federated_experiment("fedprox", alpha=100, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedprox": fedprox_low_logs["accuracy"]},
    title=f"FedProx Accuracy vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_low_acc.png",
)

plot_drift_histories(
    {"fedprox": fedprox_low_logs["drift"]},
    title=f"Client Drift vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedprox_low_drift.png",
)

save_log(fedprox_low_logs, "fedprox_low_logs.json")

# Fed Scaffold

### Highly skewed

In [None]:
fedscaffold_high_logs = run_federated_experiment("scaffold", alpha=0.1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"scaffold": fedscaffold_high_logs["accuracy"]},
    title=f"SCAFFOLD Accuracy vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_high_acc.png",
)

plot_drift_histories(
    {"scaffold": fedscaffold_high_logs["drift"]},
    title=f"Client Drift vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_high_drift.png",
)

save_log(fedscaffold_high_logs, "fedscaffold_high_logs.json")

### Moderately skewed

In [None]:
fedscaffold_medium_logs = run_federated_experiment("scaffold", alpha=1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"scaffold": fedscaffold_medium_logs["accuracy"]},
    title=f"SCAFFOLD Accuracy vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_medium_acc.png",
)

plot_drift_histories(
    {"scaffold": fedscaffold_medium_logs["drift"]},
    title=f"Client Drift vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_medium_drift.png",
)

save_log(fedscaffold_medium_logs, "fedscaffold_medium_logs.json")

### Not skewed

In [None]:
fedscaffold_low_logs = run_federated_experiment("scaffold", alpha=100, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"scaffold": fedscaffold_low_logs["accuracy"]},
    title=f"SCAFFOLD Accuracy vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_low_acc.png",
)

plot_drift_histories(
    {"scaffold": fedscaffold_low_logs["drift"]},
    title=f"Client Drift vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/scaffold_low_drift.png",
)

save_log(fedscaffold_low_logs, "fedscaffold_low_logs.json")

# Fed GH

### Highly skewed

In [None]:
fedgh_high_logs = run_federated_experiment("gh", alpha=0.1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"gh": fedgh_high_logs["accuracy"]},
    title=f"FedGH Accuracy vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/gh_high_acc.png",
)

plot_drift_histories(
    {"gh": fedgh_high_logs["drift"]},
    title=f"Client Drift vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/gh_high_drift.png",
)

save_log(fedgh_high_logs, "fedgh_high_logs.json")

### Moderately skewed

In [None]:
fedgh_medium_logs = run_federated_experiment("gh", alpha=1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"gh": fedgh_medium_logs["accuracy"]},
    title=f"FedGH Accuracy vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/gh_medium_acc.png",
)

plot_drift_histories(
    {"gh": fedgh_medium_logs["drift"]},
    title=f"Client Drift vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/gh_medium_drift.png",
)

save_log(fedgh_medium_logs, "fedgh_medium_logs.json")

### Not skewed

In [None]:
fedgh_low_logs = run_federated_experiment("gh", alpha=100, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"gh": fedgh_low_logs["accuracy"]},
    title=f"FedGH Accuracy vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/gh_low_acc.png",
)

plot_drift_histories(
    {"gh": fedgh_low_logs["drift"]},
    title=f"Client Drift vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/gh_low_drift.png",
)

save_log(fedgh_low_logs, "fedgh_low_logs.json")

# FedSAM

### Highly skewed

In [None]:
fedsam_high_logs = run_federated_experiment("fedsam", alpha=0.1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedsam": fedsam_high_logs["accuracy"]},
    title=f"FedSAM Accuracy vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_high_acc.png",
)

plot_drift_histories(
    {"fedsam": fedsam_high_logs["drift"]},
    title=f"Client Drift vs α={0.1} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_high_drift.png",
)

save_log(fedsam_high_logs, "fedsam_high_logs.json")

### Moderately Skewed

In [None]:
fedsam_medium_logs = run_federated_experiment("fedsam", alpha=1, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedsam": fedsam_medium_logs["accuracy"]},
    title=f"FedSAM Accuracy vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_medium_acc.png",
)

plot_drift_histories(
    {"fedsam": fedsam_medium_logs["drift"]},
    title=f"Client Drift vs α={1} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_medium_drift.png",
)

save_log(fedsam_medium_logs, "fedsam_medium_logs.json")

### Not skewed

In [None]:
fedsam_low_logs = run_federated_experiment("fedsam", alpha=100, num_rounds=15, local_epochs=5)

plot_accuracy_histories(
    {"fedsam": fedsam_low_logs["accuracy"]},
    title=f"FedSAM Accuracy vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_low_acc.png",
)

plot_drift_histories(
    {"fedsam": fedsam_low_logs["drift"]},
    title=f"Client Drift vs α={100} (K={5}, R={15})",
    save_path="/kaggle/working/fedsam_low_drift.png",
)

save_log(fedsam_low_logs, "fedsam_low_logs.json")

In [None]:
# Load logs
loaded_logs = {
    # FedAvg
    "fedavg_high": load_log("/kaggle/working/fedavg_high_logs.json"),
    "fedavg_medium": load_log("/kaggle/working/fedavg_medium_logs.json"),
    "fedavg_low": load_log("/kaggle/working/fedavg_low_logs.json"),

    # FedProx
    "fedprox_high": load_log("/kaggle/working/fedprox_high_logs.json"),
    "fedprox_medium": load_log("/kaggle/working/fedprox_medium_logs.json"),
    "fedprox_low": load_log("/kaggle/working/fedprox_low_logs.json"),

    # FedScaffold
    "fedscaffold_high": load_log("/kaggle/working/fedscaffold_high_logs.json"),
    "fedscaffold_medium": load_log("/kaggle/working/fedscaffold_medium_logs.json"),
    "fedscaffold_low": load_log("/kaggle/working/fedscaffold_low_logs.json"),

    # FedGH
    "fedgh_high": load_log("/kaggle/working/fedgh_high_logs.json"),
    "fedgh_medium": load_log("/kaggle/working/fedgh_medium_logs.json"),
    "fedgh_low": load_log("/kaggle/working/fedgh_low_logs.json"),

    # FedSAM
    "fedsam_high": load_log("/kaggle/working/fedsam_high_logs.json"),
    "fedsam_medium": load_log("/kaggle/working/fedsam_medium_logs.json"),
    "fedsam_low": load_log("/kaggle/working/fedsam_low_logs.json"),
}

import matplotlib.pyplot as plt

def plot_all_fed_logs(logs: dict, save_path_acc="/kaggle/working/all_acc", save_path_drift="/kaggle/working/all_drift"):
    """
    Plot accuracy and drift for multiple federated learning logs.
    
    logs: dict of the form { "label": {"accuracy": [...], "drift": [...]}, ... }
    """
    rounds_max = max(len(log["accuracy"]) for log in logs.values())
    rounds = list(range(1, rounds_max + 1))

    plt.style.use("ggplot")

    # Plot accuracies
    fig, ax = plt.subplots(figsize=(10, 6))
    for label, log in logs.items():
        ax.plot(range(1, len(log["accuracy"]) + 1), log["accuracy"], marker="o", markersize=3, label=label)
    ax.set_xlabel("Communication Round")
    ax.set_ylabel("Accuracy (%)")
    ax.set_title("Federated Learning Accuracies")
    ax.legend()
    fig.tight_layout()
    if save_path_acc:
        fig.savefig(save_path_acc)
    plt.show()
    plt.close(fig)

    # Plot drifts
    fig, ax = plt.subplots(figsize=(10, 6))
    for label, log in logs.items():
        ax.plot(range(1, len(log["drift"]) + 1), log["drift"], marker="o", markersize=3, label=label)
    ax.set_xlabel("Communication Round")
    ax.set_ylabel("Mean L2 Weight Divergence")
    ax.set_title("Federated Learning Client Drift")
    ax.legend()
    fig.tight_layout()
    if save_path_drift:
        fig.savefig(save_path_drift)
    plt.show()
    plt.close(fig)
