# Adversarial Federated Learning Lab

Welcome to Module 4's hands-on lab. You'll orchestrate a surrogate-driven poisoning attack inside this notebook by wiring together the same building blocks the adversary would use in code.

> Use this lab as a guided worksheet: read each section, execute the code, and jot down observations. Feel free to duplicate the notebook if you want to keep notes or alternative attack settings.

## 1. Environment Setup

We start by locating the project root so we can import the Module 4 package. Because the package name begins with a number we rely on `importlib.import_module` instead of the usual `from ... import ...` form.

In [None]:
from pathlib import Path
import sys
from importlib import import_module

PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "4_Adversarial_FL").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
PACKAGE_ROOT = PROJECT_ROOT / "4_Adversarial_FL"

if not PACKAGE_ROOT.exists():
    raise RuntimeError("Run this notebook from the repo root or the 4_Adversarial_FL directory.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt

load_data_module = import_module("4_Adversarial_FL.load_data_for_clients")
client_module = import_module("4_Adversarial_FL.client")
malicious_module = import_module("4_Adversarial_FL.malicious_client")
model_module = import_module("4_Adversarial_FL.model")
utils_module = import_module("4_Adversarial_FL.util_functions")

Client = client_module.Client
MaliciousClient = malicious_module.MaliciousClient
MobileNetV3Transfer = model_module.MobileNetV3Transfer

set_seed = utils_module.set_seed
evaluate_fn = utils_module.evaluate_fn
resolve_callable = utils_module.resolve_callable
dist_data_per_client = load_data_module.dist_data_per_client

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using project root: {PROJECT_ROOT}")
print(f"Device detected: {DEVICE}")

## 2. Lab Parameters

Tweak these knobs to explore different scenarios. The defaults keep the run lightweight so you can iterate quickly on CPU.

In [None]:
lab_config = {
    "seed": 27,
    "data": {
        "dataset_path": "./data",
        "dataset_name": "CIFAR10",
        "non_iid_per": 0.2,
    },
    "federated": {
        "num_clients": 4,
        "fraction_clients": 0.75,
        "num_rounds": 2,
        "num_epochs": 1,
        "batch_size": 16,
        "local_lr": 0.05,
        "criterion": "torch.nn.CrossEntropyLoss",
    },
    "model": {
        "num_classes": 10,
        "pretrained": False,
    },
    "malicious": {
        "fraction": 0.5,
        "seed": 2024,
        "attack": {
            "type": "pgd",
            "poison_rate": 0.4,
            "target_label": 0,
            "epsilon": 0.03,
            "step_size": 0.007,
            "iters": 5,
            "criterion": "torch.nn.CrossEntropyLoss",
        },
        "surrogate": {
            "pretrained": False,
            "lr": 5e-4,
            "finetune_epochs": 1,
            "batch_size": 16,
            "num_classes": 10,
        },
    },
}
lab_config

Adjust parameters above as you go. Increasing `num_rounds` or enabling pretrained weights will make runs slower but can highlight longer-term effects.

In [None]:
from pprint import pprint

print("Current configuration:")
pprint(lab_config)

## 3. Prepare Client Datasets

We reuse Module 3's data pipeline to split CIFAR-10 across clients. The helper caches splits on disk, so subsequent runs are fast.

In [None]:
set_seed(lab_config["seed"])

client_loaders, test_loader = dist_data_per_client(
    data_path=lab_config["data"]["dataset_path"],
    dataset_name=lab_config["data"]["dataset_name"],
    num_clients=lab_config["federated"]["num_clients"],
    batch_size=lab_config["federated"]["batch_size"],
    non_iid_per=lab_config["data"]["non_iid_per"],
    device=DEVICE,
)

print(f"Prepared {len(client_loaders)} client loaders")
print(f"Test loader batches: {len(test_loader)}")

_Optional:_ peek at a batch to understand the local data. Uncomment the block below to inspect shapes or label counts.

In [None]:
# batch_images, batch_labels = next(iter(client_loaders[0]))
# batch_images.shape, batch_labels[:8]

## 4. Build Honest and Malicious Clients

Instead of hiding setup logic in a runner, we construct the client pool step by step. This makes it clear how malicious participants differ from honest ones.

In [None]:
from pprint import pprint

from copy import deepcopy

criterion_path = lab_config["federated"]["criterion"]
client_criterion = resolve_callable(criterion_path)()

attack_payload = {
    key: deepcopy(value)
    for key, value in lab_config["malicious"].items()
    if key not in {"fraction"}
}

print("Attack payload passed to malicious clients:")
pprint(attack_payload)

In [None]:
def create_clients(malicious_fraction: float):
    '''Instantiate honest and malicious clients for the experiment.'''
    malicious_fraction = max(0.0, min(1.0, malicious_fraction))
    num_clients = len(client_loaders)
    num_malicious = int(np.floor(num_clients * malicious_fraction))

    malicious_ids = []
    if num_malicious:
        malicious_ids = sorted(np.random.choice(num_clients, size=num_malicious, replace=False).tolist())

    clients = []
    for idx, loader in enumerate(client_loaders):
        if idx in malicious_ids:
            client = MaliciousClient(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=lab_config["federated"]["num_epochs"],
                criterion=client_criterion,
                lr=lab_config["federated"]["local_lr"],
                attack_config=deepcopy(attack_payload),
            )
        else:
            client = Client(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=lab_config["federated"]["num_epochs"],
                criterion=client_criterion,
                lr=lab_config["federated"]["local_lr"],
            )
        clients.append(client)

    return clients, malicious_ids

## 5. Federated Training Loop

The function below mirrors the FedAvg routine in plain view: sample clients, broadcast the global model, perform local updates, then average the weights. We log which malicious ids participate each round.

In [None]:
def run_fedavg(malicious_fraction: float, *, label: str, verbose: bool = True):
    set_seed(lab_config["seed"])

    clients, malicious_ids = create_clients(malicious_fraction)
    num_clients = len(clients)
    sample_fraction = lab_config["federated"]["fraction_clients"]
    num_rounds = lab_config["federated"]["num_rounds"]

    model_kwargs = {
        "num_classes": lab_config["model"]["num_classes"],
        "pretrained": lab_config["model"]["pretrained"],
    }
    global_model = MobileNetV3Transfer(**model_kwargs).to(DEVICE)
    eval_criterion = resolve_callable(lab_config["federated"]["criterion"])()

    history = []
    metrics = {"loss": [], "accuracy": []}

    for round_idx in range(num_rounds):
        num_sampled = max(1, int(np.floor(sample_fraction * num_clients)))
        sampled = sorted(np.random.choice(num_clients, size=num_sampled, replace=False).tolist())
        active_malicious = [idx for idx in sampled if idx in malicious_ids]
        history.append(
            {
                "round": round_idx + 1,
                "sampled": sampled,
                "malicious": active_malicious,
            }
        )

        for idx in sampled:
            client_model = MobileNetV3Transfer(**model_kwargs).to(DEVICE)
            client_model.load_state_dict(global_model.state_dict())
            clients[idx].x = client_model

        for idx in sampled:
            clients[idx].client_update()

        avg_params = [torch.zeros_like(param, device=DEVICE) for param in global_model.parameters()]
        with torch.no_grad():
            for idx in sampled:
                for avg_param, client_param in zip(avg_params, clients[idx].y.parameters()):
                    avg_param.add_(client_param.data / len(sampled))
            for param, avg_param in zip(global_model.parameters(), avg_params):
                param.data.copy_(avg_param.data)

        loss, acc = evaluate_fn(test_loader, global_model, eval_criterion, DEVICE)
        metrics["loss"].append(loss)
        metrics["accuracy"].append(acc)

        if verbose:
            print(
                f"[{label}] Round {round_idx + 1}: sampled={sampled} malicious={active_malicious} "
                f"loss={loss:.4f} acc={acc:.2f}%"
            )

    return {
        "label": label,
        "global_model": global_model,
        "metrics": metrics,
        "history": history,
        "malicious_ids": malicious_ids,
    }

## 6. Run Clean vs. Poisoned Experiments

Set `RUN_TRAINING` to `True` when you are ready. Start with the defaults, then iterate on the configuration to see how the attack strength shifts the curves.

In [None]:
RUN_TRAINING = False
clean_run = None
attack_run = None

if RUN_TRAINING:
    clean_run = run_fedavg(0.0, label="Clean baseline")
    attack_run = run_fedavg(lab_config["malicious"]["fraction"], label="Poisoned run")

## 7. Visualise Metrics

Overlay the clean and poisoned trajectories to spot divergence. Rerun this cell each time you change the training configuration.

In [None]:
if clean_run and attack_run:
    rounds = range(1, len(clean_run["metrics"]["accuracy"]) + 1)
    plt.figure(figsize=(8, 4))
    plt.plot(rounds, clean_run["metrics"]["accuracy"], marker="o", label=clean_run["label"])
    plt.plot(rounds, attack_run["metrics"]["accuracy"], marker="o", label=attack_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Accuracy (%)")
    plt.title("Clean vs. poisoned global accuracy")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    plt.figure(figsize=(8, 4))
    plt.plot(rounds, clean_run["metrics"]["loss"], marker="o", label=clean_run["label"])
    plt.plot(rounds, attack_run["metrics"]["loss"], marker="o", label=attack_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Cross-entropy loss")
    plt.title("Clean vs. poisoned global loss")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("Train both runs first by setting RUN_TRAINING = True.")

### Round Participation Log

Inspect which clients (and especially which malicious ids) were active each round.

In [None]:
if attack_run:
    from pprint import pprint

    print("Malicious client ids:", attack_run["malicious_ids"])
    print("Round activity (poisoned run):")
    pprint(attack_run["history"])
else:
    print("No poisoned run recorded yet.")

## 8. Reflection Prompts

- When do the poisoned and clean curves start to diverge? How does that relate to the sampled malicious clients above?
- How does changing `poison_rate` or the number of rounds affect the attack's stealthiness?
- Try swapping the attack type (e.g. FGSM or random noise) and observe which metrics are most diagnostic.