# Adversarial Federated Learning Lab

This lab mirrors the structure of earlier modules: you will first train a clean MobileNetV3 FedAvg model on CIFAR-10, then escalate into surrogate-driven poisoning attacks and compare the outcomes.

> Duplicate this notebook if you want to keep a personal record of experiments or notes.

## 1. Environment Setup

Locate the repository root, register it on `sys.path`, and import the Module 4 utilities used throughout the lab.

In [None]:
from pathlib import Path
import sys
from importlib import import_module

PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "4_Adversarial_FL").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
PACKAGE_ROOT = PROJECT_ROOT / "4_Adversarial_FL"

if not PACKAGE_ROOT.exists():
    raise RuntimeError("Run this notebook from the repo root or inside 4_Adversarial_FL.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

load_data_module = import_module("4_Adversarial_FL.load_data_for_clients")
client_module = import_module("4_Adversarial_FL.client")
model_module = import_module("4_Adversarial_FL.model")
utils_module = import_module("4_Adversarial_FL.util_functions")
attacks_module = import_module("4_Adversarial_FL.attacks")

Client = client_module.Client
MobileNetV3Transfer = model_module.MobileNetV3Transfer
MobileNetV2Transfer = model_module.MobileNetV2Transfer
set_seed = utils_module.set_seed
evaluate_fn = utils_module.evaluate_fn
resolve_callable = utils_module.resolve_callable
dist_data_per_client = load_data_module.dist_data_per_client
get_attack = attacks_module.get_attack

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

def denormalize_batch(batch: torch.Tensor) -> torch.Tensor:
    return batch * IMAGENET_STD.to(batch.device) + IMAGENET_MEAN.to(batch.device)


def normalize_batch(batch: torch.Tensor) -> torch.Tensor:
    return (batch - IMAGENET_MEAN.to(batch.device)) / IMAGENET_STD.to(batch.device)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using project root: {PROJECT_ROOT}")
print(f"Device detected: {DEVICE}")

## 2. Clean FedAvg Baseline

Replicate the honest training workflow from the earlier modules: run MobileNetV3 with realistic client counts, rounds, and epochs on CIFAR-10. This establishes the reference trajectory before any adversarial behaviour is introduced.

### 2.1 Configure baseline hyperparameters

In [None]:
baseline_config = {
    "seed": 27,
    "data": {
        "dataset_path": "./data",
        "dataset_name": "CIFAR10",
        "non_iid_per": 0.2,
    },
    "federated": {
        "num_clients": 50,
        "fraction_clients": 0.2,
        "num_rounds": 10,
        "num_epochs": 2,
        "batch_size": 64,
        "local_lr": 0.05,
        "criterion": "torch.nn.CrossEntropyLoss",
    },
    "model": {
        "num_classes": 10,
    },
}
baseline_config

Bump `num_rounds` or `num_epochs` after you validate the workflow; the defaults provide a realistic yet tractable run for CPU-only environments.

### 2.2 Prepare CIFAR-10 client loaders

In [None]:
set_seed(baseline_config["seed"])

CLIENT_LOADERS, TEST_LOADER = dist_data_per_client(
    data_path=baseline_config["data"]["dataset_path"],
    dataset_name=baseline_config["data"]["dataset_name"],
    num_clients=baseline_config["federated"]["num_clients"],
    batch_size=baseline_config["federated"]["batch_size"],
    non_iid_per=baseline_config["data"]["non_iid_per"],
    device=DEVICE,
)

len(CLIENT_LOADERS), len(TEST_LOADER)

_(Optional)_ Inspect a batch to sanity-check shapes and label balance.

In [None]:
# sample_images, sample_labels = next(iter(CLIENT_LOADERS[0]))
# sample_images.shape, sample_labels[:8]

### 2.3 Build honest clients

In [None]:
from copy import deepcopy

criterion_path = baseline_config["federated"]["criterion"]
honest_criterion = resolve_callable(criterion_path)()


def build_honest_clients() -> list[Client]:
    clients = []
    for idx, loader in enumerate(CLIENT_LOADERS):
        clients.append(
            Client(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=baseline_config["federated"]["num_epochs"],
                criterion=honest_criterion,
                lr=baseline_config["federated"]["local_lr"],
            )
        )
    return clients

### 2.4 FedAvg training loop

In [None]:
import math


def run_fedavg_rounds(clients, config, label: str, *, verbose: bool = True, malicious_ids: list[int] | None = None):
    num_clients = len(clients)
    num_rounds = config["federated"]["num_rounds"]
    fraction = config["federated"]["fraction_clients"]
    model_kwargs = {
        "num_classes": config["model"]["num_classes"],
    }
    eval_criterion = resolve_callable(config["federated"]["criterion"])()

    malicious_set = set(malicious_ids or [])
    history = []
    metrics = {"loss": [], "accuracy": []}
    global_model = MobileNetV3Transfer(**model_kwargs).to(DEVICE)

    for round_idx in range(num_rounds):
        set_seed(config["seed"] + round_idx)
        num_sampled = max(1, int(math.ceil(fraction * num_clients)))
        sampled = sorted(np.random.choice(num_clients, size=num_sampled, replace=False).tolist())
        active_malicious = [idx for idx in sampled if idx in malicious_set]
        history.append(
            {
                "round": round_idx + 1,
                "sampled": sampled,
                "malicious": active_malicious,
            }
        )

        for idx in sampled:
            client_model = MobileNetV3Transfer(**model_kwargs).to(DEVICE)
            client_model.load_state_dict(global_model.state_dict())
            clients[idx].x = client_model

        for idx in sampled:
            clients[idx].client_update()

        avg_params = [torch.zeros_like(param, device=DEVICE) for param in global_model.parameters()]
        with torch.no_grad():
            for idx in sampled:
                for avg_param, client_param in zip(avg_params, clients[idx].y.parameters()):
                    avg_param.add_(client_param.data / len(sampled))
            for param, avg_param in zip(global_model.parameters(), avg_params):
                param.data.copy_(avg_param.data)

        loss, acc = evaluate_fn(TEST_LOADER, global_model, eval_criterion, DEVICE)
        metrics["loss"].append(loss)
        metrics["accuracy"].append(acc)

        if verbose:
            print(
                f"[{label}] Round {round_idx + 1}: sampled={sampled} malicious={active_malicious} "
                f"loss={loss:.4f} acc={acc:.2f}%"
            )

    return {
        "label": label,
        "global_model": global_model,
        "metrics": metrics,
        "history": history,
        "config": config,
    }

### 2.5 Execute the clean baseline

In [None]:
print("Training clean MobileNetV3 FedAvg baseline on CIFAR-10...")
honest_clients = build_honest_clients()
baseline_run = run_fedavg_rounds(honest_clients, baseline_config, label="Clean FedAvg")
print("Baseline training complete.")

### 2.6 Visualise clean accuracy and loss

In [None]:
if baseline_run:
    rounds = range(1, len(baseline_run["metrics"]["accuracy"]) + 1)
    plt.figure(figsize=(8, 4))
    plt.plot(rounds, baseline_run["metrics"]["accuracy"], marker="o", label=baseline_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Accuracy (%)")
    plt.title("Clean FedAvg test accuracy")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()

    plt.figure(figsize=(8, 4))
    plt.plot(rounds, baseline_run["metrics"]["loss"], marker="o", label=baseline_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Cross-entropy loss")
    plt.title("Clean FedAvg test loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()
else:
    print("Train the clean baseline first to unlock these plots.")

## 3. Attack Primer

Before coding, recall the main threat families encountered in adversarial federated learning:

- **Data poisoning**: tamper with local batches so the aggregated update nudges the global model toward attacker objectives.
- **Model poisoning**: forge parameter updates directly (e.g., model replacement).
- **Evasion attacks**: craft examples at inference time without modifying training.

This lab focuses on *surrogate-driven data poisoning*, which blends the ideas above.

## 4. White-box vs. Black-box Surrogates

| Setting | Attacker knowledge | Typical strategy |
| --- | --- | --- |
| White-box | Full access to the victim architecture and weights | Optimise adversarial updates directly on the victim. |
| Black-box | Only input/output queries or periodic snapshots | Train a *surrogate* model and transfer crafted examples. |

Federated settings often land in between—clients observe the architecture and snapshots but rarely the entire training trace. You will emulate a black-box attacker by swapping MobileNetV3 for a closely related MobileNetV2 surrogate.

## 5. Train a Surrogate Model (MobileNetV2)

Assume the attacker controls one client shard. Fine-tune a MobileNetV2 surrogate on that shard so it approximates the victim’s behaviour despite the architectural mismatch.

In [None]:
surrogate_config = {
    "client_id": 0,
    "epochs": 3,
    "lr": 1e-3,
    "batch_size": baseline_config["federated"]["batch_size"],
    "criterion": "torch.nn.CrossEntropyLoss",
}

surrogate_dataset = CLIENT_LOADERS[surrogate_config["client_id"]].dataset
surrogate_loader = DataLoader(
    surrogate_dataset,
    batch_size=surrogate_config["batch_size"],
    shuffle=True,
    num_workers=0,
)

surrogate_model = MobileNetV2Transfer(
    pretrained=True,
    num_classes=baseline_config["model"]["num_classes"],
).to(DEVICE)

surrogate_optimizer = torch.optim.Adam(surrogate_model.parameters(), lr=surrogate_config["lr"])
surrogate_criterion = resolve_callable(surrogate_config["criterion"])()


In [None]:
RUN_SURROGATE_TRAINING = False

if RUN_SURROGATE_TRAINING:
    set_seed(baseline_config["seed"])
    surrogate_model.train()
    for epoch in range(surrogate_config["epochs"]):
        epoch_loss = 0.0
        total = 0
        correct = 0
        for images, labels in surrogate_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            surrogate_optimizer.zero_grad()
            logits = surrogate_model(images)
            loss = surrogate_criterion(logits, labels)
            loss.backward()
            surrogate_optimizer.step()

            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (logits.argmax(dim=1) == labels).sum().item()
        avg_loss = epoch_loss / max(total, 1)
        acc = 100 * correct / max(total, 1)
        print(f"Epoch {epoch + 1}: loss={avg_loss:.4f} acc={acc:.2f}%")
else:
    print("Set RUN_SURROGATE_TRAINING = True to fine-tune the surrogate model.")

In [None]:
def evaluate_model(loader, model, description: str):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            preds = model(images).argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    acc = 100 * correct / max(total, 1)
    print(f"{description} accuracy: {acc:.2f}% ({correct}/{total})")


evaluate_model(surrogate_loader, surrogate_model, "Surrogate (local shard)")

## 6. Craft Attacks on the Surrogate

Configure an attack recipe—PGD, FGSM, or random noise—then craft adversarial batches the malicious clients will replay during training.

In [None]:
attack_recipe = {
    "type": "pgd",           # options: "pgd", "fgsm", "rand_noise"
    "poison_rate": 0.2,       # fraction of a malicious minibatch to replace
    "target_label": 0,        # targeted misclassification (set to None for untargeted)
    "epsilon": 0.03137255,    # 8/255 L_inf budget
    "step_size": 0.00784314,  # 2/255 per step
    "iters": 10,              # PGD iterations
    "criterion": "torch.nn.CrossEntropyLoss",
}

attack_fn = get_attack(attack_recipe["type"])
attack_criterion = resolve_callable(attack_recipe["criterion"])()

print("Configured attack:")
attack_recipe

In [None]:
def prepare_attack_labels(labels: torch.Tensor, recipe: dict) -> torch.Tensor:
    target_label = recipe.get("target_label")
    if target_label is None:
        return labels
    return torch.full_like(labels, int(target_label))


def craft_adversarial_examples(images: torch.Tensor, labels: torch.Tensor, *, surrogate, attack_fn, attack_criterion, recipe: dict) -> torch.Tensor:
    surrogate.eval()
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    attack_name = recipe["type"].lower()
    denorm = denormalize_batch(images)

    if attack_name == "pgd":
        adv_denorm = attack_fn(
            model=surrogate,
            criterion=attack_criterion,
            images=denorm,
            labels=labels,
            eps=recipe.get("epsilon", 0.03),
            step_size=recipe.get("step_size", 0.007),
            iters=recipe.get("iters", 5),
        )
    elif attack_name == "fgsm":
        adv_denorm = attack_fn(
            model=surrogate,
            criterion=attack_criterion,
            images=denorm,
            labels=labels,
            step_size=recipe.get("step_size", 0.003),
        )
    else:
        adv_denorm = attack_fn(
            denorm,
            step_size=recipe.get("step_size", 0.003),
        )

    return normalize_batch(adv_denorm)


def plot_clean_vs_adv(clean_batch: torch.Tensor, adv_batch: torch.Tensor, index: int = 0) -> None:
    clean_img = denormalize_batch(clean_batch[index:index + 1]).squeeze(0).permute(1, 2, 0).cpu().numpy()
    adv_img = denormalize_batch(adv_batch[index:index + 1]).squeeze(0).permute(1, 2, 0).cpu().numpy()
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(np.clip(clean_img, 0, 1))
    axes[0].set_title("Clean")
    axes[0].axis("off")
    axes[1].imshow(np.clip(adv_img, 0, 1))
    axes[1].set_title("Adversarial")
    axes[1].axis("off")
    plt.tight_layout()


In [None]:
sample_batch = next(iter(surrogate_loader))
clean_images, clean_labels = sample_batch
attack_labels = prepare_attack_labels(clean_labels, attack_recipe)
adversarial_images = craft_adversarial_examples(
    clean_images,
    attack_labels,
    surrogate=surrogate_model,
    attack_fn=attack_fn,
    attack_criterion=attack_criterion,
    recipe=attack_recipe,
)

plot_clean_vs_adv(clean_images, adversarial_images, index=0)


## 7. Deploy the Attack Against the Clean Model

Wrap the surrogate logic in a malicious client that poisons a fraction of its minibatches on the fly, then rerun FedAvg to observe how the global MobileNetV3 degrades relative to the clean baseline.

In [None]:
class SurrogateAttackClient(Client):
    def __init__(
        self,
        *,
        client_id: int,
        local_data,
        device: torch.device,
        num_epochs: int,
        criterion,
        lr: float,
        surrogate_state,
        attack_fn,
        attack_criterion,
        recipe: dict,
        num_classes: int,
    ) -> None:
        super().__init__(
            client_id=client_id,
            local_data=local_data,
            device=device,
            num_epochs=num_epochs,
            criterion=criterion,
            lr=lr,
        )
        self.recipe = recipe
        self.attack_fn = attack_fn
        self.attack_criterion = attack_criterion
        self.poison_rate = float(recipe.get("poison_rate", 0.0))
        self.target_label = recipe.get("target_label")
        self.surrogate = MobileNetV2Transfer(pretrained=True, num_classes=num_classes).to(self.device)
        self.surrogate.load_state_dict(surrogate_state)
        self.surrogate.eval()

    def client_update(self) -> None:
        if self.x is None:
            raise ValueError("Client model `x` has not been initialised by the server.")

        self.y = deepcopy(self.x).to(self.device)
        self.y.train()

        for _ in range(self.num_epochs):
            for inputs, labels in self.data:
                inputs = inputs.float().to(self.device)
                labels = labels.long().to(self.device)

                if self.poison_rate > 0.0:
                    mask = torch.rand(labels.size(0), device=self.device) < self.poison_rate
                    if mask.any():
                        target_labels = labels[mask]
                        if self.target_label is not None:
                            target_labels = torch.full_like(target_labels, int(self.target_label))
                        poisoned = craft_adversarial_examples(
                            inputs[mask],
                            target_labels,
                            surrogate=self.surrogate,
                            attack_fn=self.attack_fn,
                            attack_criterion=self.attack_criterion,
                            recipe=self.recipe,
                        )
                        inputs = inputs.clone()
                        labels = labels.clone()
                        inputs[mask] = poisoned
                        labels[mask] = target_labels

                outputs = self.y(inputs)
                loss = self.criterion(outputs, labels)
                grads = torch.autograd.grad(loss, self.y.parameters())

                with torch.no_grad():
                    for param, grad in zip(self.y.parameters(), grads):
                        param.data -= self.lr * grad.data

            if self.device.type == "cuda":
                torch.cuda.empty_cache()


In [None]:
def build_attack_clients(recipe: dict, malicious_fraction: float, *, seed: int, surrogate_state) -> tuple[list[Client], list[int]]:
    num_clients = len(CLIENT_LOADERS)
    malicious_fraction = max(0.0, min(1.0, malicious_fraction))
    num_malicious = int(np.floor(num_clients * malicious_fraction))
    rng = np.random.default_rng(seed)
    malicious_ids = []
    if num_malicious > 0:
        malicious_ids = sorted(rng.choice(num_clients, size=num_malicious, replace=False).tolist())

    clients: list[Client] = []
    attack_fn_local = get_attack(recipe["type"])
    attack_criterion_local = resolve_callable(recipe.get("criterion", "torch.nn.CrossEntropyLoss"))()
    for idx, loader in enumerate(CLIENT_LOADERS):
        if idx in malicious_ids:
            client = SurrogateAttackClient(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=baseline_config["federated"]["num_epochs"],
                criterion=honest_criterion,
                lr=baseline_config["federated"]["local_lr"],
                surrogate_state=surrogate_state,
                attack_fn=attack_fn_local,
                attack_criterion=attack_criterion_local,
                recipe=recipe,
                num_classes=baseline_config["model"]["num_classes"],
            )
        else:
            client = Client(
                client_id=idx,
                local_data=loader,
                device=DEVICE,
                num_epochs=baseline_config["federated"]["num_epochs"],
                criterion=honest_criterion,
                lr=baseline_config["federated"]["local_lr"],
            )
        clients.append(client)

    return clients, malicious_ids

In [None]:
surrogate_state = surrogate_model.state_dict()
attack_run_config = {
    "malicious_fraction": 0.2,
    "seed": 2024,
}


In [None]:
RUN_POISONED_TRAINING = False
poisoned_run = None
malicious_ids = []

if RUN_POISONED_TRAINING:
    attack_clients, malicious_ids = build_attack_clients(
        attack_recipe,
        attack_run_config["malicious_fraction"],
        seed=attack_run_config["seed"],
        surrogate_state=surrogate_state,
    )
    poisoned_run = run_fedavg_rounds(
        attack_clients,
        baseline_config,
        label="Surrogate attack",
        malicious_ids=malicious_ids,
    )
    print("Poisoned training complete.")
else:
    print("Enable RUN_POISONED_TRAINING to launch the adversarial run.")

### 7.1 Compare clean and poisoned trajectories

In [None]:
if baseline_run and poisoned_run:
    rounds = range(1, len(poisoned_run["metrics"]["accuracy"]) + 1)
    plt.figure(figsize=(8, 4))
    plt.plot(rounds, baseline_run["metrics"]["accuracy"], marker="o", label=baseline_run["label"])
    plt.plot(rounds, poisoned_run["metrics"]["accuracy"], marker="o", label=poisoned_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Accuracy (%)")
    plt.title("Clean vs. poisoned global accuracy")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()

    plt.figure(figsize=(8, 4))
    plt.plot(rounds, baseline_run["metrics"]["loss"], marker="o", label=baseline_run["label"])
    plt.plot(rounds, poisoned_run["metrics"]["loss"], marker="o", label=poisoned_run["label"])
    plt.xlabel("Communication round")
    plt.ylabel("Cross-entropy loss")
    plt.title("Clean vs. poisoned global loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()
else:
    print("Train both the clean and poisoned runs to compare metrics.")

### 7.2 Inspect malicious participation

In [None]:
if poisoned_run:
    print("Malicious client ids:", malicious_ids)
    print("Round-by-round participation:")
    for record in poisoned_run["history"]:
        print(record)
else:
    print("No poisoned run to summarise yet.")

## 8. Reflection Prompts

- How quickly do the poisoned metrics diverge from the clean baseline?
- What happens if you vary the malicious fraction, poison rate, or attack iterations?
- Try alternative attacks (`fgsm`, `rand_noise`)—which transfer best to the MobileNetV3 victim?
- Modify the surrogate training schedule or backbone to gauge how sensitive the attack is to architectural mismatch.