# Module 4 Adversarial FL – Outline

## 1. Federated Baseline Imports

Import core utilities (config loader, model helpers, FL runners) plus PyTorch, NumPy, and logging so the notebook can spin up federated experiments reproducibly.

In [None]:
from copy import deepcopy
from pathlib import Path

import json
import yaml
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as tv_transforms
import matplotlib.pyplot as plt
from torch.utils.data import ConcatDataset, DataLoader, Dataset

from util_functions import set_seed, evaluate_fn, run_fl
from load_data_for_clients import dist_data_per_client
from algos import (
    Server,
    ScaffoldServer,
    FedAdamServer,
    FedAdagradServer,
    FedYogiServer,
)


## 2. Federated Baseline Paths & Config

Load `config.yaml`, seed all RNGs, and capture the global/data/model sections that drive every experiment. We also build the algorithm map so we can instantiate any server class by name.

In [None]:
CONFIG_PATH = Path("config.yaml")
if not CONFIG_PATH.exists():
    raise FileNotFoundError("Could not locate config.yaml in the working directory")
with CONFIG_PATH.open() as f:
    CONFIG = yaml.safe_load(f)
global_config = CONFIG.get("global_config", {})
data_config = CONFIG.get("data_config", {})
model_config = CONFIG.get("model_config", {})
alg_configs = CONFIG.get("algorithms", {})
attack_defaults = CONFIG.get("attack", {})
set_seed(global_config.get("seed", 42))


def get_device(preferred: str | None = None) -> torch.device:
    choice = preferred if preferred is not None else global_config.get("device")
    if isinstance(choice, str):
        if choice.startswith("cuda") and not torch.cuda.is_available():
            choice = "cpu"
        return torch.device(choice)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE = get_device()
AVAILABLE_ALGORITHMS = list(alg_configs)
print("Loaded config from", CONFIG_PATH.resolve())
print("Available algorithms:", AVAILABLE_ALGORITHMS)

ARTIFACT_DIR = Path('artifacts')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)


## 3. Federated Baseline Helpers

Utility wrappers for running a single algorithm, looping over all configured algorithms, and summarising server outputs (loss/accuracy plus training history). These ensure every later experiment shares the same execution pipeline.

In [None]:
ALGORITHM_MAP = {
    "FedAvg": Server,
    "Scaffold": ScaffoldServer,
    "FedAdam": FedAdamServer,
    "FedAdagrad": FedAdagradServer,
    "FedYogi": FedYogiServer,
}

missing = sorted(set(AVAILABLE_ALGORITHMS) - set(ALGORITHM_MAP))
if missing:
    raise KeyError(f"No server mapping registered for: {missing}")


def train_server(alg_name: str, attack_cfg: dict | None = None):
    if alg_name not in alg_configs:
        raise ValueError(f"Algorithm {alg_name!r} not found in configuration.")

    alg_conf = alg_configs[alg_name]
    fed_cfg = deepcopy(alg_conf["fed_config"])
    fed_cfg["algorithm"] = alg_name
    optim_cfg = deepcopy(alg_conf.get("optim_config", {}))
    attack_cfg = deepcopy(attack_cfg or {"malicious_fraction": 0.0})

    return run_fl(
        ALGORITHM_MAP[alg_name],
        global_config,
        data_config,
        fed_cfg,
        model_config,
        optim_cfg,
        attack_cfg,
    )


def summarise_server(server) -> dict:
    loss, acc = evaluate_fn(server.data, server.x, server.criterion, server.device)
    history = server.results if hasattr(server, "results") else {}
    return {
        "final_loss": float(loss),
        "final_accuracy": float(acc),
        "history": {
            "loss": list(history.get("loss", [])),
            "accuracy": list(history.get("accuracy", [])),
        },
    }


def run_one_algorithm(alg_name: str, attack_cfg: dict | None = None) -> dict:
    server = train_server(alg_name, attack_cfg=attack_cfg)
    summary = summarise_server(server)
    del server
    torch.cuda.empty_cache()
    return summary


def run_all_algorithms(
    algorithms: list[str] | None = None,
    attack_cfg: dict | None = None,
) -> dict:
    algorithms = algorithms or AVAILABLE_ALGORITHMS
    results: dict[str, dict] = {}
    for name in algorithms:
        results[name] = run_one_algorithm(name, attack_cfg=attack_cfg)
    return results


## 4. Federated Baseline Runs

Kick off the clean (non-adversarial) benchmark: run FedAvg once to check the pipeline, then iterate across all algorithms to produce a dictionary of baseline metrics we can revisit later.

In [None]:
BASELINE_ALGORITHMS = ["FedAvg"]  

baseline_results = run_all_algorithms(BASELINE_ALGORITHMS)
baseline_results


### Persist federated baselines

Store the clean FedAvg baseline so you can compare future attack runs without rerunning the entire notebook.


In [None]:
baseline_path = ARTIFACT_DIR / 'module4_federated_baseline.json'
with baseline_path.open('w') as f:
    json.dump(baseline_results, f, indent=2)
print(f'Saved baseline metrics to {baseline_path.resolve()}')


### Baseline validation

Check that every baseline algorithm logged the expected number of communication rounds.


In [None]:
def validate_baseline(results, algorithm_config):
    issues = []
    for name, summary in results.items():
        expected = algorithm_config[name]['fed_config']['num_rounds']
        actual = len(summary.get('history', {}).get('accuracy', []))
        if actual < expected:
            issues.append(f"{name}: expected {expected} rounds, saw {actual}")
    if issues:
        raise ValueError('Baseline validation failed:
' + '
'.join(issues))
    print('Baseline validation passed for', ', '.join(sorted(results)))

validate_baseline(baseline_results, alg_configs)


In [None]:
def plot_baseline_results(results: dict[str, dict]) -> None:
    if not results:
        print("Baseline run has not been executed yet.")
        return
    algs = sorted(results.keys())
    accuracies = []
    for alg in algs:
        summary = results.get(alg, {})
        acc = summary.get("final_accuracy")
        accuracies.append(float(acc) if acc is not None else 0.0)
    plt.figure(figsize=(6, 4))
    plt.bar(algs, accuracies, color="#4c72b0", alpha=0.85)
    plt.ylabel('Final accuracy (%)')
    plt.xlabel('Algorithm')
    plt.title('Baseline federated accuracy')
    plt.xticks(rotation=30)
    plt.tight_layout()

plot_baseline_results(baseline_results)


## 5. Surrogate Imports

Add the extra build blocks needed for the attacker: surrogate helpers, malicious client hooks, registry lookups, and MobileNet backbones for black-box transfer.

In [None]:

from malicious_client import MaliciousClient
from model import MobileNetV2Transfer


## 6. Surrogate Paths & Config

Construct deterministic data loaders for the surrogate. We pool multiple client shards, optionally apply simple augmentations, and cache the test loader so both notebook and malicious clients share the same view.

In [None]:
SURROGATE_CFG = CONFIG.get("surrogate", {})
SURROGATE_CLIENT_ID = int(SURROGATE_CFG.get("client_id", 0))
SURROGATE_SEED = SURROGATE_CFG.get("seed", global_config.get("seed", 42))
SURROGATE_POOL_SIZE = max(1, int(SURROGATE_CFG.get("pool_size", 1)))

_SURROGATE_CLIENT_LOADERS = None
_SURROGATE_TEST_LOADER = None
_SURROGATE_POOLED_LOADER = None


class AugmentedDataset(Dataset):
    def __init__(self, base_dataset, transform=None):
        self.base_dataset = base_dataset
        self.transform = transform

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        image, label = self.base_dataset[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, label


def _surrogate_transform():
    if not SURROGATE_CFG.get("augment", False):
        return None
    return tv_transforms.Compose([
        tv_transforms.RandomHorizontalFlip(),
        tv_transforms.RandomRotation(degrees=10),
    ])


def _ensure_surrogate_data():
    global _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER
    if _SURROGATE_CLIENT_LOADERS is not None and _SURROGATE_TEST_LOADER is not None:
        return _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER

    loaders, test_loader = dist_data_per_client(
        data_path=SURROGATE_CFG.get("dataset_path", data_config.get("dataset_path")),
        dataset_name=SURROGATE_CFG.get("dataset_name", data_config.get("dataset_name")),
        num_clients=SURROGATE_CFG.get("num_clients", data_config.get("num_clients", 50)),
        batch_size=SURROGATE_CFG.get("batch_size", data_config.get("batch_size", 96)),
        non_iid_per=SURROGATE_CFG.get("non_iid_per", data_config.get("non_iid_per", 0.0)),
        device=get_device(),
    )
    _SURROGATE_CLIENT_LOADERS = loaders
    _SURROGATE_TEST_LOADER = test_loader
    return _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER


def _build_surrogate_pool_loader():
    global _SURROGATE_POOLED_LOADER
    if _SURROGATE_POOLED_LOADER is not None:
        return _SURROGATE_POOLED_LOADER

    client_loaders, _ = _ensure_surrogate_data()
    pool_size = min(SURROGATE_POOL_SIZE, len(client_loaders))
    datasets = []
    for idx in range(pool_size):
        dataset = getattr(client_loaders[idx], "dataset", None)
        if dataset is None:
            raise ValueError("Client loader is missing a dataset attribute; cannot build surrogate pool")
        datasets.append(dataset)

    pooled_dataset = ConcatDataset(datasets)
    transform = _surrogate_transform()
    if transform is not None:
        pooled_dataset = AugmentedDataset(pooled_dataset, transform=transform)

    batch_size = int(SURROGATE_CFG.get("batch_size", data_config.get("batch_size", 96)))
    _SURROGATE_POOLED_LOADER = DataLoader(
        pooled_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
    )
    return _SURROGATE_POOLED_LOADER


def get_surrogate_train_loader(pool: bool = True):
    client_loaders, _ = _ensure_surrogate_data()
    if not pool:
        if SURROGATE_CLIENT_ID >= len(client_loaders):
            raise IndexError("Surrogate client id {SURROGATE_CLIENT_ID} out of range for {len(client_loaders)} clients")
        return client_loaders[SURROGATE_CLIENT_ID]
    return _build_surrogate_pool_loader()


def get_surrogate_test_loader():
    _, test_loader = _ensure_surrogate_data()
    return test_loader


## 7. Surrogate Baseline

Fine-tune the surrogate MobileNetV2 against the pooled client data. We support freezing the backbone, early stopping, and weight decay so the surrogate’s test accuracy stays in the same ballpark as the federated baseline before attacks begin.

In [None]:
def build_surrogate_model(num_classes: int = 10, pretrained: bool | None = None) -> torch.nn.Module:
    if pretrained is None:
        pretrained = SURROGATE_CFG.get("pretrained", True)
    return MobileNetV2Transfer(pretrained=pretrained, num_classes=num_classes)
def train_surrogate_baseline(num_epochs: int | None = None):
    set_seed(global_config.get("seed", 42))
    train_loader = get_surrogate_train_loader(pool=True)
    model = build_surrogate_model(num_classes=SURROGATE_CFG.get("num_classes", model_config.get("kwargs", {}).get("num_classes", 10))).to(get_device())
    if SURROGATE_CFG.get("freeze_backbone", False) and hasattr(model, "v2model"):
        for param in model.v2model.features.parameters():
            param.requires_grad = False
    criterion = torch.nn.CrossEntropyLoss().to(get_device())
    lr = SURROGATE_CFG.get("learning_rate", SURROGATE_CFG.get("lr", 1e-3))
    weight_decay = SURROGATE_CFG.get("weight_decay", 0.0)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    if not trainable_params:
        raise RuntimeError("No trainable parameters available for surrogate optimisation.")
    optimizer = torch.optim.Adam(
        trainable_params,
        lr=lr,
        weight_decay=weight_decay,
    )
    epochs = num_epochs or SURROGATE_CFG.get("num_epochs", 5)
    patience = int(SURROGATE_CFG.get("early_stop_patience", 0))
    history = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": []}
    best_state = None
    best_val_loss = float("inf")
    epochs_since_improved = 0
    test_loader = get_surrogate_test_loader()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total = 0
        correct = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(get_device()).float()
            labels = labels.to(get_device()).long()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
        epoch_loss = running_loss / max(len(train_loader), 1)
        epoch_acc = 100.0 * correct / max(total, 1)
        val_loss, val_acc = evaluate_fn(test_loader, model, criterion, get_device())
        history["loss"].append(epoch_loss)
        history["accuracy"].append(epoch_acc)
        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(val_acc)
        print(f"Epoch {epoch + 1}/{epochs}: train_loss={epoch_loss:.4f}, train_acc={epoch_acc:.2f}%, val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%")
        if val_loss + 1e-5 < best_val_loss:
            best_val_loss = val_loss
            best_state = deepcopy(model.state_dict())
            epochs_since_improved = 0
        else:
            epochs_since_improved += 1
            if patience and epochs_since_improved >= patience:
                print(f"Stopping early at epoch {epoch + 1} after {patience} epochs without improvement.")
                break
    if best_state is not None:
        model.load_state_dict(best_state)
    test_loss, test_acc = evaluate_fn(test_loader, model, criterion, get_device())
    summary = {
        "history": history,
        "test_loss": test_loss,
        "test_accuracy": test_acc,
    }
    return model, summary


## 8. Baseline Comparison

Display the federated-reference metrics alongside the surrogate’s evaluation to confirm the two models are comparable prior to introducing malicious behaviour.

In [None]:
surrogate_model, surrogate_summary = train_surrogate_baseline()

fedavg_summary = baseline_results.get("FedAvg", {})
print("Federated baseline:", fedavg_summary)
print(
    f"Surrogate test metrics → loss: {surrogate_summary['test_loss']:.4f}, accuracy: {surrogate_summary['test_accuracy']:.2f}%"
)

surrogate_summary


### Surrogate sanity check

Save the surrogate training history and ensure the held-out accuracy is reasonable before crafting attacks.


In [None]:
surrogate_path = ARTIFACT_DIR / 'module4_surrogate.json'
with surrogate_path.open('w') as f:
    json.dump(surrogate_summary, f, indent=2)
acc = surrogate_summary.get('test_accuracy', 0.0)
if acc < 5.0:
    raise ValueError(f'Surrogate accuracy {acc:.2f}% is suspiciously low; revisit training settings.')
print(f'Surrogate test accuracy: {acc:.2f}% (saved details to {surrogate_path.resolve()})')


### Surrogate Training Curves

Track training vs. validation performance of the surrogate to confirm regularisation keeps it aligned with the federated baseline.

In [None]:
def plot_surrogate_history(summary: dict) -> None:
    history = summary.get("history", {})
    epochs = range(1, len(history.get("loss", [])) + 1)
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history.get("loss", []), label="train")
    plt.plot(epochs, history.get("val_loss", []), label="val")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Surrogate loss")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history.get("accuracy", []), label="train")
    plt.plot(epochs, history.get("val_accuracy", []), label="val")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Surrogate accuracy")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

plot_surrogate_history(surrogate_summary)


## 10. Attack Paths & Config

Extract the attack defaults (malicious fraction, PGD/FGSM parameters, poison schedule) and build lightweight config builders that let each recipe override pieces of that base.

In [None]:
ATTACK_RAW = attack_defaults
ATTACK_BASE = ATTACK_RAW.get("attack", {})
SURROGATE_BASE = ATTACK_RAW.get("surrogate", {})
ATTACK_SEED = ATTACK_RAW.get("seed", global_config.get("seed", 42))
BASE_MALICIOUS_FRACTION = ATTACK_RAW.get("malicious_fraction", 0.0)


def _extract_attack_params(overrides: dict | None = None) -> dict:
    params = {
        "type": ATTACK_BASE.get("type", "pgd"),
        "poison_rate": ATTACK_BASE.get("poison_rate", 0.0),
        "target_label": ATTACK_BASE.get("target_label", 0),
        "epsilon": ATTACK_BASE.get("epsilon", 0.03137255),
        "step_size": ATTACK_BASE.get("step_size", 0.00784314),
        "iters": ATTACK_BASE.get("iters", 10),
        "criterion": ATTACK_BASE.get("criterion", "torch.nn.CrossEntropyLoss"),
    }
    schedule = ATTACK_BASE.get("poison_rate_schedule")
    if schedule:
        params["poison_rate_schedule"] = schedule
    if overrides:
        params.update(overrides)
    return params


def _extract_surrogate_params(overrides: dict | None = None) -> dict:
    params = {
        "pretrained": SURROGATE_CFG.get("pretrained", SURROGATE_BASE.get("pretrained", True)),
        "finetune_epochs": SURROGATE_CFG.get("num_epochs", SURROGATE_BASE.get("finetune_epochs", 0)),
        "lr": SURROGATE_CFG.get("learning_rate", SURROGATE_BASE.get("lr", 1e-3)),
        "weight_decay": SURROGATE_CFG.get("weight_decay", SURROGATE_BASE.get("weight_decay", 0.0)),
        "batch_size": SURROGATE_CFG.get("batch_size", SURROGATE_BASE.get("batch_size", data_config.get("batch_size", 96))),
        "client_id": SURROGATE_CLIENT_ID,
        "pool_size": SURROGATE_CFG.get("pool_size", SURROGATE_BASE.get("pool_size", 1)),
        "freeze_backbone": SURROGATE_CFG.get("freeze_backbone", SURROGATE_BASE.get("freeze_backbone", False)),
        "augment": SURROGATE_CFG.get("augment", SURROGATE_BASE.get("augment", False)),
        "early_stop_patience": SURROGATE_CFG.get("early_stop_patience", SURROGATE_BASE.get("early_stop_patience", 0)),
        "num_classes": SURROGATE_CFG.get("num_classes", SURROGATE_BASE.get("num_classes", model_config.get("kwargs", {}).get("num_classes", 10))),
    }
    if overrides:
        params.update(overrides)
    return params


def build_attack_config(*, attack_overrides: dict | None = None, surrogate_overrides: dict | None = None, malicious_fraction: float | None = None, seed: int | None = None) -> dict:
    cfg = {
        "seed": seed if seed is not None else ATTACK_SEED,
        "malicious_fraction": BASE_MALICIOUS_FRACTION if malicious_fraction is None else malicious_fraction,
        "attack": _extract_attack_params(attack_overrides),
        "surrogate": _extract_surrogate_params(surrogate_overrides),
    }
    for key in ("start_round",):
        if key in ATTACK_RAW:
            cfg[key] = ATTACK_RAW[key]
    return cfg


ATTACK_RECIPES = {
    "clean": build_attack_config(malicious_fraction=0.0, attack_overrides={"poison_rate": 0.0}),
    "pgd_default": build_attack_config(),
    "fgsm_default": build_attack_config(attack_overrides={"type": "fgsm", "iters": 1}),
    "random_noise": build_attack_config(attack_overrides={"type": "random_noise"}),
}


## 11. Attack Implementations

Notebook-native implementations of PGD, FGSM, and random noise that mirror the malicious client’s expectations. Each function now supports a `targeted` flag so we can craft targeted poisons directly from the notebook.

In [None]:
from typing import Callable, Dict

AttackFn = Callable[..., torch.Tensor]


def pgd_attack(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    images: torch.Tensor,
    labels: torch.Tensor,
    eps: float,
    step_size: float,
    iters: int,
    targeted: bool = False,
) -> torch.Tensor:
    """Projected Gradient Descent with an L-infinity constraint."""
    ori = images.clone().detach()
    adv = ori.clone().detach()

    for _ in range(iters):
        adv.requires_grad_(True)
        outputs = model(adv)
        loss = criterion(outputs, labels)

        model.zero_grad(set_to_none=True)
        loss.backward()

        direction = -1 if targeted else 1
        adv = adv + direction * step_size * adv.grad.sign()
        eta = torch.clamp(adv - ori, min=-eps, max=eps)
        adv = torch.clamp(ori + eta, 0, 1).detach()

    return adv


def fgsm_attack(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    images: torch.Tensor,
    labels: torch.Tensor,
    step_size: float,
    targeted: bool = False,
) -> torch.Tensor:
    """Single gradient-sign step."""
    adv = images.clone().detach().requires_grad_(True)
    loss = criterion(model(adv), labels)

    model.zero_grad(set_to_none=True)
    loss.backward()

    direction = -1 if targeted else 1
    adv = adv + direction * step_size * adv.grad.sign()
    return torch.clamp(adv, 0, 1).detach()


def random_noise_attack(images: torch.Tensor, step_size: float) -> torch.Tensor:
    """Add random signed noise to the batch and clamp to valid bounds."""
    perturb = torch.randn_like(images).sign()
    adv = images + step_size * perturb
    return torch.clamp(adv, 0, 1).detach()


ATTACK_FUNCTIONS: Dict[str, AttackFn] = {
    "pgd": pgd_attack,
    "fgsm": fgsm_attack,
    "random": random_noise_attack,
    "random_noise": random_noise_attack,
}


## 12. Malicious Client Definition

Helper to instantiate a `MaliciousClient` bound to our trained surrogate, including frozen backbone support, pooled loaders, and any extra hyper-parameter overrides pulled from the attack config.

In [None]:
def make_malicious_client(
    attack_config: dict,
    *,
    local_loader=None,
    surrogate=None,
    num_epochs: int = 0,
    lr: float | None = None,
):
    """Helper to instantiate a MaliciousClient bound to the trained surrogate."""
    if local_loader is None:
        local_loader = get_surrogate_train_loader()
    if surrogate is None:
        surrogate = surrogate_model

    client = MaliciousClient(
        client_id=SURROGATE_CLIENT_ID,
        local_data=local_loader,
        device=get_device(),
        num_epochs=num_epochs,
        criterion=nn.CrossEntropyLoss().to(get_device()),
        lr=lr if lr is not None else SURROGATE_CFG.get("local_lr", 0.003),
        attack_config=attack_config,
    )
    client.surrogate = surrogate.to(get_device())
    client.surrogate.eval()
    return client


## 13. Attack Execution Helpers

Utilities for selecting attack callables, crafting adversarial batches, and running both surrogate-level and federated-level attack sweeps. The helpers ensure each recipe is self-contained and reusable across experiments.

In [None]:

def select_attack_fn(name: str) -> AttackFn:
    key = name.lower()
    if key not in ATTACK_FUNCTIONS:
        raise KeyError(f"Unknown attack function '{name}'. Available: {sorted(ATTACK_FUNCTIONS)}")
    return ATTACK_FUNCTIONS[key]


def _attach_attack_callable(cfg: dict) -> dict:
    cfg_copy = deepcopy(cfg)
    attack_params = cfg_copy.setdefault("attack", {})
    attack_type = attack_params.get("type", "pgd")
    if "callable" not in attack_params:
        attack_params["callable"] = select_attack_fn(attack_type)
    return cfg_copy


def craft_adversarial_batch(client: MaliciousClient, inputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    attack_params = client.attack_params
    attack_type = attack_params.get("type", "pgd").lower()
    target_label = int(attack_params.get("target_label", 0))
    target_labels = torch.full_like(labels, target_label)
    targeted = bool(attack_params.get("targeted", attack_params.get("target_label") is not None))

    if attack_type in {"pgd", "fgsm"}:
        attack_fn = select_attack_fn(attack_type)
        kwargs = {
            "model": client.surrogate,
            "criterion": client.attack_criterion,
            "images": inputs,
            "labels": target_labels,
            "step_size": float(attack_params.get("step_size", 0.00784314)),
        }
        kwargs["targeted"] = targeted
        if attack_type == "pgd":
            kwargs["eps"] = float(attack_params.get("epsilon", 0.03137255))
            kwargs["iters"] = int(attack_params.get("iters", 10))
        adv = attack_fn(**kwargs)
    else:
        attack_fn = select_attack_fn("random_noise")
        adv = attack_fn(inputs, step_size=float(attack_params.get("step_size", 0.00784314)))

    return adv, target_labels


def evaluate_surrogate_attack(recipe_name: str, *, batch_size: int = 32) -> dict:
    attack_cfg = _attach_attack_callable(ATTACK_RECIPES[recipe_name])

    client = make_malicious_client(
        attack_cfg,
        num_epochs=attack_cfg.get("surrogate", {}).get("finetune_epochs", 0),
    )
    client.surrogate = surrogate_model.to(get_device())
    client.surrogate.eval()

    loader = get_surrogate_test_loader()
    inputs, labels = next(iter(loader))
    inputs = inputs[:batch_size].to(get_device()).float()
    labels = labels[:batch_size].to(get_device()).long()

    with torch.no_grad():
        clean_logits = surrogate_model(inputs)
        clean_preds = clean_logits.argmax(dim=1)
        clean_acc = (clean_preds == labels).float().mean().item() * 100.0

    adv_inputs, adv_labels = craft_adversarial_batch(client, inputs, labels)

    with torch.no_grad():
        adv_logits = surrogate_model(adv_inputs)
        adv_preds = adv_logits.argmax(dim=1)
        asr = (adv_preds == adv_labels).float().mean().item() * 100.0

    return {
        "recipe": recipe_name,
        "clean_accuracy": clean_acc,
        "attack_success_rate": asr,
    }


def run_attack_recipe_on_server(recipe_name: str, alg_name: str = "FedAvg", malicious_fraction: float | None = None) -> dict:
    attack_cfg = _attach_attack_callable(ATTACK_RECIPES[recipe_name])
    if malicious_fraction is not None:
        attack_cfg = deepcopy(attack_cfg)
        attack_cfg["malicious_fraction"] = malicious_fraction
    summary = run_one_algorithm(alg_name, attack_cfg=attack_cfg)
    summary.update({"recipe": recipe_name, "algorithm": alg_name})
    return summary


def sweep_attacks_on_server(alg_name: str = "FedAvg", recipes: list[str] | None = None, malicious_fraction: float | None = None) -> dict:
    recipes = recipes or list(ATTACK_RECIPES)
    results = {}
    for recipe in recipes:
        if recipe == "clean":
            results[recipe] = baseline_results.get(alg_name, {})
            continue
        results[recipe] = run_attack_recipe_on_server(
            recipe,
            alg_name=alg_name,
            malicious_fraction=malicious_fraction,
        )
    return results


## 14. Surrogate Attack Experiments

Evaluate how each attack recipe transfers on the surrogate alone: craft a batch of adversarial examples, measure clean vs. adversarial accuracy, and log the attack success rate before federated training enters the picture.

In [None]:
SURROGATE_ATTACK_RECIPES = ["pgd_default", "fgsm_default", "random_noise"]

surrogate_attack_results = {
    recipe: evaluate_surrogate_attack(recipe)
    for recipe in SURROGATE_ATTACK_RECIPES
}

surrogate_attack_results


### Persist surrogate attack metrics

Capture clean vs. adversarial accuracy for each recipe.


In [None]:
sur_attack_path = ARTIFACT_DIR / 'module4_surrogate_attacks.json'
with sur_attack_path.open('w') as f:
    json.dump(surrogate_attack_results, f, indent=2)
print(f'Saved surrogate attack metrics to {sur_attack_path.resolve()}')


### Surrogate Attack Summary

Visualise clean accuracy vs. attack success rate for each recipe to gauge transferability before federated training.

In [None]:
import pandas as pd

def plot_surrogate_attack_results(results: dict[str, dict]) -> None:
    df = pd.DataFrame(results).T
    plt.figure(figsize=(6, 4))
    width = 0.35
    x = range(len(df))
    plt.bar([i - width / 2 for i in x], df['clean_accuracy'], width=width, label='Clean Acc')
    plt.bar([i + width / 2 for i in x], df['attack_success_rate'], width=width, label='ASR')
    plt.xticks(list(x), df.index, rotation=30)
    plt.ylabel('Percentage')
    plt.title('Surrogate attack outcomes')
    plt.legend()
    plt.tight_layout()

plot_surrogate_attack_results(surrogate_attack_results)


## 15. Federated Attack Sweeps

Deploy poisoned clients inside the federated loop by reusing the same recipes, then compare clean baselines with attacked runs across algorithms to see how poisoning degrades global performance.

In [None]:
FED_ATTACK_RECIPES = ["clean", "pgd_default"]

federated_attack_results = sweep_attacks_on_server(
    alg_name="FedAvg",
    recipes=FED_ATTACK_RECIPES,
)

federated_attack_results


### Persist federated attack metrics

Record clean vs. PGD-poisoned outcomes for reproducibility.


In [None]:
fed_attack_path = ARTIFACT_DIR / 'module4_federated_attacks.json'
with fed_attack_path.open('w') as f:
    json.dump(federated_attack_results, f, indent=2)
print(f'Saved federated attack metrics to {fed_attack_path.resolve()}')


### Federated attack sanity check

Confirm the attack run differs from the clean baseline so we notice configuration mistakes early.


In [None]:
clean_acc = baseline_results.get('FedAvg', {}).get('final_accuracy')
attack_acc = federated_attack_results.get('pgd_default', {}).get('final_accuracy')
if clean_acc is not None and attack_acc is not None:
    delta = clean_acc - attack_acc
    print(f'FedAvg clean accuracy: {clean_acc:.2f}%  |  attacked: {attack_acc:.2f}%  |  drop: {delta:.2f} pts')
else:
    print('Run the baseline and attack cells before executing this check.')


### Federated Attack Impact

Compare final accuracies between clean and attacked runs for the chosen algorithms.

In [None]:
def _resolve_accuracy(summary: dict | None) -> float:
    if not summary:
        return 0.0
    value = summary.get('final_accuracy')
    return float(value) if value is not None else 0.0

def plot_federated_attack_results(clean: dict, attacked: dict) -> None:
    algs = sorted({*clean.keys(), *attacked.keys()})
    if not algs:
        print('No federated results available to plot yet.')
        return
    clean_acc = [_resolve_accuracy(clean.get(alg)) for alg in algs]
    attack_acc = [_resolve_accuracy(attacked.get(alg)) for alg in algs]
    plt.figure(figsize=(6, 4))
    width = 0.35
    positions = list(range(len(algs)))
    plt.bar([i - width / 2 for i in positions], clean_acc, width=width, label='Clean')
    plt.bar([i + width / 2 for i in positions], attack_acc, width=width, label='Attacked')
    plt.xticks(positions, algs, rotation=30)
    plt.ylabel('Final accuracy (%)')
    plt.title('Federated accuracy with and without attack')
    plt.legend()
    plt.tight_layout()

plot_federated_attack_results(baseline_results, federated_attack_results)


## 16. Malicious Fraction Sweep

Explore how varying the proportion of malicious clients changes the global outcome. We reuse the sweep helper while overriding the malicious fraction for each run.

In [None]:
malicious_grid = [0.0, 0.05, 0.1, 0.2]
fraction_sweep_results = {}
for frac in malicious_grid:
    mf = frac if frac > 0 else None
    fraction_sweep_results[frac] = sweep_attacks_on_server(
        alg_name="FedAvg",
        recipes=["clean", "pgd_default"],
        malicious_fraction=mf,
    )

comparison_rows = []
for frac, results in fraction_sweep_results.items():
    attack_summary = results.get("pgd_default", {})
    comparison_rows.append({
        "malicious_fraction": frac,
        "final_loss": attack_summary.get("final_loss"),
        "final_accuracy": attack_summary.get("final_accuracy"),
    })

comparison_rows


### Persist malicious-fraction sweep

Store the summary table for different malicious client ratios.


In [None]:
fraction_path = ARTIFACT_DIR / 'module4_fraction_sweep.json'
with fraction_path.open('w') as f:
    json.dump(comparison_rows, f, indent=2)
print(f'Saved sweep results to {fraction_path.resolve()}')


### Malicious Fraction vs. Accuracy

Plot the final accuracy achieved by PGD attacks as we vary the proportion of malicious clients.

In [None]:
def plot_fraction_sweep(results: dict[float, dict] | None) -> None:
    if not results:
        raise ValueError("No fraction sweep results available to plot.")
    fractions = sorted(results.keys())
    accuracies = [results[f].get('pgd_default', {}).get('final_accuracy') for f in fractions]
    plt.figure(figsize=(6, 4))
    plt.plot(fractions, accuracies, marker='o')
    plt.xlabel('Malicious fraction')
    plt.ylabel('Final accuracy (%)')
    plt.title('Impact of malicious fraction on FedAvg (PGD)')
    plt.grid(True, alpha=0.3)

_fraction_results = globals().get('fraction_sweep_results')
if not _fraction_results:
    print("fraction_sweep_results missing; run the sweep cell above to populate it before plotting.")
else:
    plot_fraction_sweep(_fraction_results)
