# Adversarial Federated Learning Lab

This notebook complements the Module 4 overview. It mirrors the Module 3 workflow: load configuration from `lab_config.yaml`, run a clean FedAvg baseline, then layer on surrogate-driven poisoning attacks to compare outcomes.

- **Config keys:** `seed`, `baseline` (data + federated hyperparameters), `surrogate` (attacker training schedule), `attack` (crafting recipe), and `attack_run` (malicious participation).
- **Baseline algorithm:** FedAvg with MobileNetV3 backbone, pretrained weights assumed.
- **Attack path:** train a MobileNetV2 surrogate, craft PGD/FGSM/random-noise batches, and deploy them during federated rounds.


> Duplicate this notebook if you want to keep personalised notes or custom configurations.

## 1. Environment Setup

In [None]:
from pathlib import Path
import sys
from importlib import import_module

PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "4_Adversarial_FL").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
PACKAGE_ROOT = PROJECT_ROOT / "4_Adversarial_FL"

if not PACKAGE_ROOT.exists():
    raise RuntimeError("Run this notebook from the repo root or inside 4_Adversarial_FL.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

load_data_module = import_module("4_Adversarial_FL.load_data_for_clients")
client_module = import_module("4_Adversarial_FL.client")
model_module = import_module("4_Adversarial_FL.model")
utils_module = import_module("4_Adversarial_FL.util_functions")
attacks_module = import_module("4_Adversarial_FL.attacks")

Client = client_module.Client
MobileNetV3Transfer = model_module.MobileNetV3Transfer
MobileNetV2Transfer = model_module.MobileNetV2Transfer
set_seed = utils_module.set_seed
evaluate_fn = utils_module.evaluate_fn
resolve_callable = utils_module.resolve_callable
dist_data_per_client = load_data_module.dist_data_per_client
get_attack = attacks_module.get_attack

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

def denormalize_batch(batch: torch.Tensor) -> torch.Tensor:
    return batch * IMAGENET_STD.to(batch.device) + IMAGENET_MEAN.to(batch.device)


def normalize_batch(batch: torch.Tensor) -> torch.Tensor:
    return (batch - IMAGENET_MEAN.to(batch.device)) / IMAGENET_STD.to(batch.device)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using project root: {PROJECT_ROOT}")
print(f"Device detected: {DEVICE}")

## 2. Load Lab Configuration

In [None]:
import yaml

CONFIG_PATH = PACKAGE_ROOT / "lab_config.yaml"
with CONFIG_PATH.open() as f:
    CONFIG = yaml.safe_load(f)

SEED = CONFIG.get("seed", 27)
BASELINE_CFG = CONFIG["baseline"]
SURROGATE_CFG = CONFIG["surrogate"]
ATTACK_RECIPE = CONFIG["attack"]
ATTACK_RUN_CFG = CONFIG["attack_run"]

print("Loaded configuration from", CONFIG_PATH)
CONFIG

## 3. Clean FedAvg Baseline

### 3.1 Prepare data and clients

In [None]:
set_seed(SEED)

CLIENT_LOADERS, TEST_LOADER = dist_data_per_client(
    data_path=BASELINE_CFG["dataset_path"],
    dataset_name=BASELINE_CFG["dataset_name"],
    num_clients=BASELINE_CFG["num_clients"],
    batch_size=BASELINE_CFG["batch_size"],
    non_iid_per=BASELINE_CFG["non_iid_per"],
    device=DEVICE,
)

criterion_path = BASELINE_CFG["criterion"]
honest_criterion = resolve_callable(criterion_path)()

print(f"Prepared {len(CLIENT_LOADERS)} client loaders")

In [None]:
def build_honest_clients() -> list[Client]:
    clients = []
    for idx, loader in enumerate(CLIENT_LOADERS):
        clients.append(
            Client(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=BASELINE_CFG["num_epochs"],
                criterion=honest_criterion,
                lr=BASELINE_CFG["local_lr"],
            )
        )
    return clients

In [None]:
import math


def run_fedavg_rounds(clients, config, label: str, *, verbose: bool = True, malicious_ids: list[int] | None = None):
    num_clients = len(clients)
    num_rounds = config["num_rounds"]
    fraction = config["fraction_clients"]
    model_kwargs = {
        "num_classes": BASELINE_CFG["num_classes"] if "num_classes" in BASELINE_CFG else 10,
    }
    global_lr = config.get("global_lr", 1.0)
    eval_criterion = resolve_callable(config["criterion"])()

    malicious_set = set(malicious_ids or [])
    history = []
    metrics = {"loss": [], "accuracy": []}
    global_model = MobileNetV3Transfer(num_classes=model_kwargs["num_classes"]).to(DEVICE)

    for round_idx in range(num_rounds):
        set_seed(SEED + round_idx)
        num_sampled = max(1, int(math.ceil(fraction * num_clients)))
        sampled = sorted(np.random.choice(num_clients, size=num_sampled, replace=False).tolist())
        active_malicious = [idx for idx in sampled if idx in malicious_set]
        history.append(
            {
                "round": round_idx + 1,
                "sampled": sampled,
                "malicious": active_malicious,
            }
        )

        for idx in sampled:
            client_model = MobileNetV3Transfer(num_classes=model_kwargs["num_classes"]).to(DEVICE)
            client_model.load_state_dict(global_model.state_dict())
            clients[idx].x = client_model

        for idx in sampled:
            clients[idx].client_update()

        avg_params = [torch.zeros_like(param, device=DEVICE) for param in global_model.parameters()]
        with torch.no_grad():
            for idx in sampled:
                for avg_param, client_param in zip(avg_params, clients[idx].y.parameters()):
                    avg_param.add_(client_param.data / len(sampled))
            for param, avg_param in zip(global_model.parameters(), avg_params):
                param.data.add_(global_lr * (avg_param.data - param.data))

        loss, acc = evaluate_fn(TEST_LOADER, global_model, eval_criterion, DEVICE)
        metrics["loss"].append(loss)
        metrics["accuracy"].append(acc)

        if verbose:
            print(
                f"[{label}] Round {round_idx + 1}: sampled={sampled} malicious={active_malicious} "
                f"loss={loss:.4f} acc={acc:.2f}%"
            )

    return {
        "label": label,
        "global_model": global_model,
        "metrics": metrics,
        "history": history,
    }

### 3.2 Train baseline

In [None]:
print("Training clean MobileNetV3 FedAvg baseline on CIFAR-10...")
honest_clients = build_honest_clients()
baseline_run = run_fedavg_rounds(honest_clients, BASELINE_CFG, label="Clean FedAvg")
print("Baseline training complete.")

In [None]:
def summarise_run(run: dict) -> dict:
    metrics = run["metrics"]
    return {
        "label": run["label"],
        "final_loss": metrics["loss"][-1],
        "final_accuracy": metrics["accuracy"][-1],
    }

baseline_summary = summarise_run(baseline_run)
baseline_summary

## 4. Attack Primer

Before coding, recap adversary styles encountered so far:

- **Data poisoning**: malicious clients inject adversarial samples into local training batches.
- **Model poisoning**: adversaries forge updates directly in parameter space (e.g., model replacement).
- **Evasion attacks**: perturb inputs at inference time only.

We will implement surrogate-driven data poisoning, where the attacker trains a proxy network to approximate the victim and transfers adversarial examples across architectures.

## 5. White-box vs. Black-box Surrogates

| Setting | Attacker knowledge | Typical strategy |
| --- | --- | --- |
| White-box | Full access to architecture + weights | Craft gradients directly on the victim model. |
| Black-box | Limited to queries or snapshots | Train a surrogate model and transfer attacks via crafted batches. |

In federated learning, clients often know the architecture (shared by the server) but only see periodic model states. We approximate a black-box attacker by fine-tuning MobileNetV2 while the server trains MobileNetV3.

## 6. Train Surrogate Model

In [None]:
surrogate_dataset = CLIENT_LOADERS[SURROGATE_CFG["client_id"]].dataset
surrogate_loader = DataLoader(
    surrogate_dataset,
    batch_size=SURROGATE_CFG["batch_size"],
    shuffle=True,
    num_workers=0,
)

surrogate_model = MobileNetV2Transfer(
    pretrained=True,
    num_classes=BASELINE_CFG.get("num_classes", 10),
).to(DEVICE)

surrogate_optimizer = torch.optim.Adam(surrogate_model.parameters(), lr=SURROGATE_CFG["lr"])
surrogate_criterion = resolve_callable(SURROGATE_CFG["criterion"])()


In [None]:
print("Fine-tuning surrogate MobileNetV2 on attacker shard...")
set_seed(SEED)
surrogate_model.train()
for epoch in range(SURROGATE_CFG["epochs"]):
    epoch_loss = 0.0
    total = 0
    correct = 0
    for images, labels in surrogate_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        surrogate_optimizer.zero_grad()
        logits = surrogate_model(images)
        loss = surrogate_criterion(logits, labels)
        loss.backward()
        surrogate_optimizer.step()

        epoch_loss += loss.item() * labels.size(0)
        total += labels.size(0)
        correct += (logits.argmax(dim=1) == labels).sum().item()
    avg_loss = epoch_loss / max(total, 1)
    acc = 100 * correct / max(total, 1)
    print(f"Epoch {epoch + 1}: loss={avg_loss:.4f} acc={acc:.2f}%")

In [None]:
def evaluate_model(loader, model, description: str):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            preds = model(images).argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    acc = 100 * correct / max(total, 1)
    print(f"{description} accuracy: {acc:.2f}% ({correct}/{total})")


evaluate_model(surrogate_loader, surrogate_model, "Surrogate (local shard)")

## 7. Craft Adversarial Batches

In [None]:
attack_fn = get_attack(ATTACK_RECIPE["type"])
attack_criterion = resolve_callable(ATTACK_RECIPE["criterion"])()

print("Configured attack recipe:")
ATTACK_RECIPE

In [None]:
def prepare_attack_labels(labels: torch.Tensor, recipe: dict) -> torch.Tensor:
    target_label = recipe.get("target_label")
    if target_label is None:
        return labels
    return torch.full_like(labels, int(target_label))


def craft_adversarial_examples(images: torch.Tensor, labels: torch.Tensor, *, surrogate, attack_fn, attack_criterion, recipe: dict) -> torch.Tensor:
    surrogate.eval()
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    attack_name = recipe["type"].lower()
    denorm = denormalize_batch(images)

    if attack_name == "pgd":
        adv_denorm = attack_fn(
            model=surrogate,
            criterion=attack_criterion,
            images=denorm,
            labels=labels,
            eps=recipe.get("epsilon", 0.03),
            step_size=recipe.get("step_size", 0.007),
            iters=recipe.get("iters", 5),
        )
    elif attack_name == "fgsm":
        adv_denorm = attack_fn(
            model=surrogate,
            criterion=attack_criterion,
            images=denorm,
            labels=labels,
            step_size=recipe.get("step_size", 0.003),
        )
    else:
        adv_denorm = attack_fn(
            denorm,
            step_size=recipe.get("step_size", 0.003),
        )

    return normalize_batch(adv_denorm)


def plot_clean_vs_adv(clean_batch: torch.Tensor, adv_batch: torch.Tensor, index: int = 0) -> None:
    clean_img = denormalize_batch(clean_batch[index:index + 1]).squeeze(0).permute(1, 2, 0).cpu().numpy()
    adv_img = denormalize_batch(adv_batch[index:index + 1]).squeeze(0).permute(1, 2, 0).cpu().numpy()
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(np.clip(clean_img, 0, 1))
    axes[0].set_title("Clean")
    axes[0].axis("off")
    axes[1].imshow(np.clip(adv_img, 0, 1))
    axes[1].set_title("Adversarial")
    axes[1].axis("off")
    plt.tight_layout()


In [None]:
sample_images, sample_labels = next(iter(surrogate_loader))
attack_labels = prepare_attack_labels(sample_labels, ATTACK_RECIPE)
adversarial_images = craft_adversarial_examples(
    sample_images,
    attack_labels,
    surrogate=surrogate_model,
    attack_fn=attack_fn,
    attack_criterion=attack_criterion,
    recipe=ATTACK_RECIPE,
)

plot_clean_vs_adv(sample_images, adversarial_images, index=0)


## 8. Deploy Surrogate Attack

In [None]:
class SurrogateAttackClient(Client):
    def __init__(
        self,
        *,
        client_id: int,
        local_data,
        device: torch.device,
        num_epochs: int,
        criterion,
        lr: float,
        surrogate_state,
        attack_fn,
        attack_criterion,
        recipe: dict,
        num_classes: int,
    ) -> None:
        super().__init__(
            client_id=client_id,
            local_data=local_data,
            device=device,
            num_epochs=num_epochs,
            criterion=criterion,
            lr=lr,
        )
        self.recipe = recipe
        self.attack_fn = attack_fn
        self.attack_criterion = attack_criterion
        self.poison_rate = float(recipe.get("poison_rate", 0.0))
        self.target_label = recipe.get("target_label")
        self.surrogate = MobileNetV2Transfer(pretrained=True, num_classes=num_classes).to(self.device)
        self.surrogate.load_state_dict(surrogate_state)
        self.surrogate.eval()

    def client_update(self) -> None:
        if self.x is None:
            raise ValueError("Client model `x` has not been initialised by the server.")

        self.y = deepcopy(self.x).to(self.device)
        self.y.train()

        for _ in range(self.num_epochs):
            for inputs, labels in self.data:
                inputs = inputs.float().to(self.device)
                labels = labels.long().to(self.device)

                if self.poison_rate > 0.0:
                    mask = torch.rand(labels.size(0), device=self.device) < self.poison_rate
                    if mask.any():
                        target_labels = labels[mask]
                        if self.target_label is not None:
                            target_labels = torch.full_like(target_labels, int(self.target_label))
                        poisoned = craft_adversarial_examples(
                            inputs[mask],
                            target_labels,
                            surrogate=self.surrogate,
                            attack_fn=self.attack_fn,
                            attack_criterion=self.attack_criterion,
                            recipe=self.recipe,
                        )
                        inputs = inputs.clone()
                        labels = labels.clone()
                        inputs[mask] = poisoned
                        labels[mask] = target_labels

                outputs = self.y(inputs)
                loss = self.criterion(outputs, labels)
                grads = torch.autograd.grad(loss, self.y.parameters())

                with torch.no_grad():
                    for param, grad in zip(self.y.parameters(), grads):
                        param.data -= self.lr * grad.data

            if self.device.type == "cuda":
                torch.cuda.empty_cache()


In [None]:
def build_attack_clients(recipe: dict, malicious_fraction: float, *, seed: int, surrogate_state) -> tuple[list[Client], list[int]]:
    num_clients = len(CLIENT_LOADERS)
    malicious_fraction = max(0.0, min(1.0, malicious_fraction))
    num_malicious = int(np.floor(num_clients * malicious_fraction))
    rng = np.random.default_rng(seed)
    malicious_ids = []
    if num_malicious > 0:
        malicious_ids = sorted(rng.choice(num_clients, size=num_malicious, replace=False).tolist())

    clients: list[Client] = []
    attack_fn_local = get_attack(recipe["type"])
    attack_criterion_local = resolve_callable(recipe.get("criterion", "torch.nn.CrossEntropyLoss"))()
    for idx, loader in enumerate(CLIENT_LOADERS):
        if idx in malicious_ids:
            client = SurrogateAttackClient(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=BASELINE_CFG["num_epochs"],
                criterion=honest_criterion,
                lr=BASELINE_CFG["local_lr"],
                surrogate_state=surrogate_state,
                attack_fn=attack_fn_local,
                attack_criterion=attack_criterion_local,
                recipe=recipe,
                num_classes=BASELINE_CFG.get("num_classes", 10),
            )
        else:
            client = Client(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=BASELINE_CFG["num_epochs"],
                criterion=honest_criterion,
                lr=BASELINE_CFG["local_lr"],
            )
        clients.append(client)

    return clients, malicious_ids

In [None]:
surrogate_state = surrogate_model.state_dict()
print("Running surrogate-driven attack with malicious fraction =", ATTACK_RUN_CFG["malicious_fraction"])
attack_clients, malicious_ids = build_attack_clients(
    ATTACK_RECIPE,
    ATTACK_RUN_CFG["malicious_fraction"],
    seed=ATTACK_RUN_CFG["seed"],
    surrogate_state=surrogate_state,
)
poisoned_run = run_fedavg_rounds(
    attack_clients,
    BASELINE_CFG,
    label="Surrogate attack",
    malicious_ids=malicious_ids,
)
print("Attack training complete.")

In [None]:
poisoned_summary = summarise_run(poisoned_run)
poisoned_summary

In [None]:
import pandas as pd
summary_df = pd.DataFrame([baseline_summary, poisoned_summary])
summary_df

### 8.1 Visualise trajectories

In [None]:
rounds = range(1, len(poisoned_run["metrics"]["accuracy"]) + 1)
plt.figure(figsize=(8, 4))
plt.plot(rounds, baseline_run["metrics"]["accuracy"], marker="o", label=baseline_run["label"])
plt.plot(rounds, poisoned_run["metrics"]["accuracy"], marker="o", label=poisoned_run["label"])
plt.xlabel("Communication round")
plt.ylabel("Accuracy (%)")
plt.title("Clean vs. poisoned accuracy")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

plt.figure(figsize=(8, 4))
plt.plot(rounds, baseline_run["metrics"]["loss"], marker="o", label=baseline_run["label"])
plt.plot(rounds, poisoned_run["metrics"]["loss"], marker="o", label=poisoned_run["label"])
plt.xlabel("Communication round")
plt.ylabel("Cross-entropy loss")
plt.title("Clean vs. poisoned loss")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

### 8.2 Inspect malicious participation

In [None]:
print("Malicious client ids:", malicious_ids)
for record in poisoned_run["history"]:
    print(record)


## 9. Reflection

- How quickly does the poisoned run diverge from the clean baseline?
- Which hyperparameters (poison rate, attack iterations, surrogate epochs) most influence attack success?
- Try switching `attack.type` to `fgsm` or `rand_noise`. How transferable are those perturbations?
- Experiment with defensive ideas (trimmed mean, robust aggregation) using the same structure to benchmark mitigations.