# Module 4 Adversarial FL – Outline

## 1. Federated Baseline Imports

In [None]:
from copy import deepcopy
from pathlib import Path

import yaml
import numpy as np
import torch
import torch.nn as nn

from util_functions import set_seed, evaluate_fn, run_fl
from load_data_for_clients import dist_data_per_client
from algos import (
    Server,
    ScaffoldServer,
    FedAdamServer,
    FedAdagradServer,
    FedYogiServer,
)


## 2. Federated Baseline Paths & Config

In [None]:
CONFIG_PATH = Path("config.yaml")
if not CONFIG_PATH.exists():
    raise FileNotFoundError("Could not locate config.yaml in the working directory")

with CONFIG_PATH.open() as f:
    CONFIG = yaml.safe_load(f)

global_config = CONFIG.get("global_config", {})
data_config = CONFIG.get("data_config", {})
model_config = CONFIG.get("model_config", {})
alg_configs = CONFIG.get("algorithms", {})
attack_defaults = CONFIG.get("attack", {})

set_seed(global_config.get("seed", 42))
AVAILABLE_ALGORITHMS = list(alg_configs)
print("Loaded config from", CONFIG_PATH.resolve())
print("Available algorithms:", AVAILABLE_ALGORITHMS)


## 3. Federated Baseline Helpers

In [None]:
ALGORITHM_MAP = {
    "FedAvg": Server,
    "Scaffold": ScaffoldServer,
    "FedAdam": FedAdamServer,
    "FedAdagrad": FedAdagradServer,
    "FedYogi": FedYogiServer,
}

missing = sorted(set(AVAILABLE_ALGORITHMS) - set(ALGORITHM_MAP))
if missing:
    raise KeyError(f"No server mapping registered for: {missing}")


def train_server(alg_name: str, attack_cfg: dict | None = None):
    if alg_name not in alg_configs:
        raise ValueError(f"Algorithm {alg_name!r} not found in configuration.")

    alg_conf = alg_configs[alg_name]
    fed_cfg = deepcopy(alg_conf["fed_config"])
    fed_cfg["algorithm"] = alg_name
    optim_cfg = deepcopy(alg_conf.get("optim_config", {}))
    attack_cfg = deepcopy(attack_cfg or {"malicious_fraction": 0.0})

    return run_fl(
        ALGORITHM_MAP[alg_name],
        global_config,
        data_config,
        fed_cfg,
        model_config,
        optim_cfg,
        attack_cfg,
    )


def summarise_server(server) -> dict:
    loss, acc = evaluate_fn(server.data, server.x, server.criterion, server.device)
    history = server.results if hasattr(server, "results") else {}
    return {
        "final_loss": float(loss),
        "final_accuracy": float(acc),
        "history": {
            "loss": list(history.get("loss", [])),
            "accuracy": list(history.get("accuracy", [])),
        },
    }


def run_one_algorithm(alg_name: str, attack_cfg: dict | None = None) -> dict:
    server = train_server(alg_name, attack_cfg=attack_cfg)
    summary = summarise_server(server)
    del server
    torch.cuda.empty_cache()
    return summary


def run_all_algorithms(
    algorithms: list[str] | None = None,
    attack_cfg: dict | None = None,
) -> dict:
    algorithms = algorithms or AVAILABLE_ALGORITHMS
    results: dict[str, dict] = {}
    for name in algorithms:
        results[name] = run_one_algorithm(name, attack_cfg=attack_cfg)
    return results


## 4. Federated Baseline Runs

In [None]:
BASELINE_ALGORITHMS = ["FedAvg"]  

baseline_results = run_all_algorithms(BASELINE_ALGORITHMS)
baseline_results


## 5. Surrogate Imports

In [None]:
from attacks import get_attack
from malicious_client import MaliciousClient
from model import MobileNetV2Transfer


## 6. Surrogate Paths & Config

In [None]:
SURROGATE_CFG = CONFIG.get("surrogate", {})
SURROGATE_CLIENT_ID = int(SURROGATE_CFG.get("client_id", 0))
SURROGATE_SEED = SURROGATE_CFG.get("seed", global_config.get("seed", 42))

_SURROGATE_CLIENT_LOADERS = None
_SURROGATE_TEST_LOADER = None


def _ensure_surrogate_data():
    global _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER
    if _SURROGATE_CLIENT_LOADERS is not None and _SURROGATE_TEST_LOADER is not None:
        return _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER

    loaders, test_loader = dist_data_per_client(
        data_path=SURROGATE_CFG.get("dataset_path", data_config.get("dataset_path")),
        dataset_name=SURROGATE_CFG.get("dataset_name", data_config.get("dataset_name")),
        num_clients=SURROGATE_CFG.get("num_clients", data_config.get("num_clients", 50)),
        batch_size=SURROGATE_CFG.get("batch_size", data_config.get("batch_size", 96)),
        non_iid_per=SURROGATE_CFG.get("non_iid_per", data_config.get("non_iid_per", 0.0)),
        device=DEVICE,
    )
    _SURROGATE_CLIENT_LOADERS = loaders
    _SURROGATE_TEST_LOADER = test_loader
    return _SURROGATE_CLIENT_LOADERS, _SURROGATE_TEST_LOADER


def get_surrogate_train_loader():
    client_loaders, _ = _ensure_surrogate_data()
    if SURROGATE_CLIENT_ID >= len(client_loaders):
        raise IndexError(
            f"Surrogate client id {SURROGATE_CLIENT_ID} out of range for {len(client_loaders)} clients"
        )
    return client_loaders[SURROGATE_CLIENT_ID]


def get_surrogate_test_loader():
    _, test_loader = _ensure_surrogate_data()
    return test_loader


## 7. Surrogate Baseline

In [None]:
def build_surrogate_model(num_classes: int = 10, pretrained: bool | None = None) -> torch.nn.Module:
    if pretrained is None:
        pretrained = SURROGATE_CFG.get("pretrained", True)
    return MobileNetV2Transfer(pretrained=pretrained, num_classes=num_classes)


def train_surrogate_baseline(num_epochs: int | None = None):
    set_seed(SURROGATE_SEED)

    train_loader = get_surrogate_train_loader()
    model = build_surrogate_model(num_classes=SURROGATE_CFG.get("num_classes", 10)).to(DEVICE)

    criterion =torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=SURROGATE_CFG.get("local_lr", 0.003),
        weight_decay=SURROGATE_CFG.get("weight_decay", 0.0),
    )
    epochs = num_epochs or SURROGATE_CFG.get("num_epochs", 5)

    history = {"loss": [], "accuracy": []}
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total = 0
        correct = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE).float()
            labels = labels.to(DEVICE).long()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        epoch_loss = running_loss / max(len(train_loader), 1)
        epoch_acc = 100.0 * correct / max(total, 1)
        history["loss"].append(epoch_loss)
        history["accuracy"].append(epoch_acc)
        print(f"Epoch {epoch + 1}/{epochs}: loss={epoch_loss:.4f}, acc={epoch_acc:.2f}%")

    test_loader = get_surrogate_test_loader()
    test_loss, test_acc = evaluate_fn(test_loader, model, criterion, DEVICE)

    summary = {
        "history": history,
        "test_loss": test_loss,
        "test_accuracy": test_acc,
    }

    return model, summary


## 8. Baseline Comparison

In [None]:
surrogate_model, surrogate_summary = train_surrogate_baseline()

fedavg_summary = baseline_results.get("FedAvg", {})
print("Federated baseline:", fedavg_summary)
print(
    f"Surrogate test metrics → loss: {surrogate_summary['test_loss']:.4f}, accuracy: {surrogate_summary['test_accuracy']:.2f}%"
)

surrogate_summary


## 9. Attack Imports

In [None]:
from attacks import get_attack, pgd_attack, fgsm_attack, random_noise_attack


## 10. Attack Paths & Config

In [None]:
ATTACK_RAW = attack_defaults
ATTACK_SEED = ATTACK_RAW.get("seed", global_config.get("seed", 42))
BASE_MALICIOUS_FRACTION = ATTACK_RAW.get("malicious_fraction", 0.0)


def _extract_attack_params(overrides: dict | None = None) -> dict:
    params = {
        "type": ATTACK_RAW.get("type", "pgd"),
        "poison_rate": ATTACK_RAW.get("poison_rate", 0.0),
        "target_label": ATTACK_RAW.get("target_label", 0),
        "epsilon": ATTACK_RAW.get("epsilon", 0.03137255),
        "step_size": ATTACK_RAW.get("step_size", 0.00784314),
        "iters": ATTACK_RAW.get("iters", 10),
        "criterion": ATTACK_RAW.get("criterion", "torch.nn.CrossEntropyLoss"),
    }
    schedule = ATTACK_RAW.get("poison_rate_schedule")
    if schedule:
        params["poison_rate_schedule"] = schedule
    if overrides:
        params.update(overrides)
    return params


def _extract_surrogate_params(overrides: dict | None = None) -> dict:
    params = {
        "pretrained": ATTACK_RAW.get("surrogate_pretrained", True),
        "finetune_epochs": ATTACK_RAW.get("surrogate_finetune_epochs", 0),
        "lr": ATTACK_RAW.get("surrogate_lr", 1e-3),
        "batch_size": ATTACK_RAW.get("surrogate_batch_size", SURROGATE_CFG.get("batch_size", data_config.get("batch_size", 96))),
        "client_id": SURROGATE_CFG.get("client_id", 0),
        "num_classes": SURROGATE_CFG.get("num_classes", model_config.get("kwargs", {}).get("num_classes", 10)),
        "criterion": ATTACK_RAW.get("criterion", "torch.nn.CrossEntropyLoss"),
    }
    if overrides:
        params.update(overrides)
    return params


def build_attack_config(*, attack_overrides: dict | None = None, surrogate_overrides: dict | None = None, malicious_fraction: float | None = None, seed: int | None = None) -> dict:
    cfg = {
        "seed": seed if seed is not None else ATTACK_SEED,
        "malicious_fraction": BASE_MALICIOUS_FRACTION if malicious_fraction is None else malicious_fraction,
        "attack": _extract_attack_params(attack_overrides),
        "surrogate": _extract_surrogate_params(surrogate_overrides),
    }
    for key in ("start_round",):
        if key in ATTACK_RAW:
            cfg[key] = ATTACK_RAW[key]
    return cfg


ATTACK_RECIPES = {
    "clean": build_attack_config(malicious_fraction=0.0, attack_overrides={"poison_rate": 0.0}),
    "pgd_default": build_attack_config(),
    "fgsm_default": build_attack_config(attack_overrides={"type": "fgsm", "iters": 1}),
    "random_noise": build_attack_config(attack_overrides={"type": "random_noise"}),
}


## 11. Attack Implementations

In [None]:
from typing import Callable, Dict

AttackFn = Callable[..., torch.Tensor]


def pgd_attack(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    images: torch.Tensor,
    labels: torch.Tensor,
    eps: float,
    step_size: float,
    iters: int,
    targeted: bool = False,
) -> torch.Tensor:
    """Projected Gradient Descent with an L-infinity constraint."""
    ori = images.clone().detach()
    adv = ori.clone().detach()

    for _ in range(iters):
        adv.requires_grad_(True)
        outputs = model(adv)
        loss = criterion(outputs, labels)

        model.zero_grad(set_to_none=True)
        loss.backward()

        direction = -1 if targeted else 1
        adv = adv + direction * step_size * adv.grad.sign()
        eta = torch.clamp(adv - ori, min=-eps, max=eps)
        adv = torch.clamp(ori + eta, 0, 1).detach()

    return adv


def fgsm_attack(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    images: torch.Tensor,
    labels: torch.Tensor,
    step_size: float,
    targeted: bool = False,
) -> torch.Tensor:
    """Single gradient-sign step."""
    adv = images.clone().detach().requires_grad_(True)
    loss = criterion(model(adv), labels)

    model.zero_grad(set_to_none=True)
    loss.backward()

    direction = -1 if targeted else 1
    adv = adv + direction * step_size * adv.grad.sign()
    return torch.clamp(adv, 0, 1).detach()


def random_noise_attack(images: torch.Tensor, step_size: float) -> torch.Tensor:
    """Add random signed noise to the batch and clamp to valid bounds."""
    perturb = torch.randn_like(images).sign()
    adv = images + step_size * perturb
    return torch.clamp(adv, 0, 1).detach()


ATTACK_FUNCTIONS: Dict[str, AttackFn] = {
    "pgd": pgd_attack,
    "fgsm": fgsm_attack,
    "random": random_noise_attack,
    "random_noise": random_noise_attack,
}


## 12. Malicious Client Definition

In [None]:
def make_malicious_client(
    attack_config: dict,
    *,
    local_loader=None,
    surrogate=None,
    num_epochs: int = 0,
    lr: float | None = None,
):
    """Helper to instantiate a MaliciousClient bound to the trained surrogate."""
    if local_loader is None:
        local_loader = get_surrogate_train_loader()
    if surrogate is None:
        surrogate = surrogate_model

    client = MaliciousClient(
        client_id=SURROGATE_CLIENT_ID,
        local_data=local_loader,
        device=DEVICE,
        num_epochs=num_epochs,
        criterion=nn.CrossEntropyLoss().to(DEVICE),
        lr=lr if lr is not None else SURROGATE_CFG.get("local_lr", 0.003),
        attack_config=attack_config,
    )
    client.surrogate = surrogate.to(DEVICE)
    client.surrogate.eval()
    return client


## 13. Attack Execution Helpers

In [None]:

def select_attack_fn(name: str) -> AttackFn:
    key = name.lower()
    if key not in ATTACK_FUNCTIONS:
        raise KeyError(f"Unknown attack function '{name}'. Available: {sorted(ATTACK_FUNCTIONS)}")
    return ATTACK_FUNCTIONS[key]


def _attach_attack_callable(cfg: dict) -> dict:
    cfg_copy = deepcopy(cfg)
    attack_params = cfg_copy.setdefault("attack", {})
    attack_type = attack_params.get("type", "pgd")
    if "callable" not in attack_params:
        attack_params["callable"] = select_attack_fn(attack_type)
    return cfg_copy


def craft_adversarial_batch(client: MaliciousClient, inputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    attack_params = client.attack_params
    attack_type = attack_params.get("type", "pgd").lower()
    target_label = int(attack_params.get("target_label", 0))
    target_labels = torch.full_like(labels, target_label)
    targeted = bool(attack_params.get("targeted", attack_params.get("target_label") is not None))

    if attack_type in {"pgd", "fgsm"}:
        attack_fn = select_attack_fn(attack_type)
        kwargs = {
            "model": client.surrogate,
            "criterion": client.attack_criterion,
            "images": inputs,
            "labels": target_labels,
            "step_size": float(attack_params.get("step_size", 0.00784314)),
        }
        kwargs["targeted"] = targeted
        if attack_type == "pgd":
            kwargs["eps"] = float(attack_params.get("epsilon", 0.03137255))
            kwargs["iters"] = int(attack_params.get("iters", 10))
        adv = attack_fn(**kwargs)
    else:
        attack_fn = select_attack_fn("random_noise")
        adv = attack_fn(inputs, step_size=float(attack_params.get("step_size", 0.00784314)))

    return adv, target_labels


def evaluate_surrogate_attack(recipe_name: str, *, batch_size: int = 32) -> dict:
    attack_cfg = _attach_attack_callable(ATTACK_RECIPES[recipe_name])

    client = make_malicious_client(
        attack_cfg,
        num_epochs=attack_cfg.get("surrogate", {}).get("finetune_epochs", 0),
    )
    client.surrogate = surrogate_model.to(DEVICE)
    client.surrogate.eval()

    loader = get_surrogate_test_loader()
    inputs, labels = next(iter(loader))
    inputs = inputs[:batch_size].to(DEVICE).float()
    labels = labels[:batch_size].to(DEVICE).long()

    with torch.no_grad():
        clean_logits = surrogate_model(inputs)
        clean_preds = clean_logits.argmax(dim=1)
        clean_acc = (clean_preds == labels).float().mean().item() * 100.0

    adv_inputs, adv_labels = craft_adversarial_batch(client, inputs, labels)

    with torch.no_grad():
        adv_logits = surrogate_model(adv_inputs)
        adv_preds = adv_logits.argmax(dim=1)
        asr = (adv_preds == adv_labels).float().mean().item() * 100.0

    return {
        "recipe": recipe_name,
        "clean_accuracy": clean_acc,
        "attack_success_rate": asr,
    }


def run_attack_recipe_on_server(recipe_name: str, alg_name: str = "FedAvg") -> dict:
    attack_cfg = _attach_attack_callable(ATTACK_RECIPES[recipe_name])
    summary = run_one_algorithm(alg_name, attack_cfg=attack_cfg)
    summary.update({"recipe": recipe_name, "algorithm": alg_name})
    return summary


def sweep_attacks_on_server(alg_name: str = "FedAvg", recipes: list[str] | None = None) -> dict:
    recipes = recipes or list(ATTACK_RECIPES)
    results = {}
    for recipe in recipes:
        if recipe == "clean":
            results[recipe] = baseline_results.get(alg_name, {})
            continue
        results[recipe] = run_attack_recipe_on_server(recipe, alg_name=alg_name)
    return results


## 14. Surrogate Attack Experiments

In [None]:
SURROGATE_ATTACK_RECIPES = ["pgd_default", "fgsm_default", "random_noise"]

surrogate_attack_results = {
    recipe: evaluate_surrogate_attack(recipe)
    for recipe in SURROGATE_ATTACK_RECIPES
}

surrogate_attack_results


## 15. Federated Attack Sweeps

In [None]:
FED_ATTACK_RECIPES = ["clean", "pgd_default"]

federated_attack_results = sweep_attacks_on_server(
    alg_name="FedAvg",
    recipes=FED_ATTACK_RECIPES,
)

federated_attack_results
