# Module 4 Adversarial FL â€“ Outline

## 1. Federated Baseline Imports

In [None]:
from copy import deepcopy
from pathlib import Path

import yaml
import numpy as np
import torch

from util_functions import set_seed, evaluate_fn, run_fl
from load_data_for_clients import dist_data_per_client
from algos import (
    Server,
    ScaffoldServer,
    FedAdamServer,
    FedAdagradServer,
    FedYogiServer,
)


## 2. Federated Baseline Paths & Config

In [None]:
CONFIG_PATH = Path("config.yaml")
if not CONFIG_PATH.exists():
    raise FileNotFoundError("Could not locate config.yaml in the working directory")

with CONFIG_PATH.open() as f:
    CONFIG = yaml.safe_load(f)

global_config = CONFIG.get("global_config", {})
data_config = CONFIG.get("data_config", {})
model_config = CONFIG.get("model_config", {})
alg_configs = CONFIG.get("algorithms", {})
attack_defaults = CONFIG.get("attack", {})

set_seed(global_config.get("seed", 42))
AVAILABLE_ALGORITHMS = list(alg_configs)
print("Loaded config from", CONFIG_PATH.resolve())
print("Available algorithms:", AVAILABLE_ALGORITHMS)


## 3. Federated Baseline Helpers

In [None]:
ALGORITHM_MAP = {
    "FedAvg": Server,
    "Scaffold": ScaffoldServer,
    "FedAdam": FedAdamServer,
    "FedAdagrad": FedAdagradServer,
    "FedYogi": FedYogiServer,
}

missing = sorted(set(AVAILABLE_ALGORITHMS) - set(ALGORITHM_MAP))
if missing:
    raise KeyError(f"No server mapping registered for: {missing}")


def train_server(alg_name: str, attack_cfg: dict | None = None):
    if alg_name not in alg_configs:
        raise ValueError(f"Algorithm {alg_name!r} not found in configuration.")

    alg_conf = alg_configs[alg_name]
    fed_cfg = deepcopy(alg_conf["fed_config"])
    fed_cfg["algorithm"] = alg_name
    optim_cfg = deepcopy(alg_conf.get("optim_config", {}))
    attack_cfg = deepcopy(attack_cfg or {"malicious_fraction": 0.0})

    return run_fl(
        ALGORITHM_MAP[alg_name],
        global_config,
        data_config,
        fed_cfg,
        model_config,
        optim_cfg,
        attack_cfg,
    )


def summarise_server(server) -> dict:
    loss, acc = evaluate_fn(server.data, server.x, server.criterion, server.device)
    history = server.results if hasattr(server, "results") else {}
    return {
        "final_loss": float(loss),
        "final_accuracy": float(acc),
        "history": {
            "loss": list(history.get("loss", [])),
            "accuracy": list(history.get("accuracy", [])),
        },
    }


def run_one_algorithm(alg_name: str, attack_cfg: dict | None = None) -> dict:
    server = train_server(alg_name, attack_cfg=attack_cfg)
    summary = summarise_server(server)
    del server
    torch.cuda.empty_cache()
    return summary


def run_all_algorithms(
    algorithms: list[str] | None = None,
    attack_cfg: dict | None = None,
) -> dict:
    algorithms = algorithms or AVAILABLE_ALGORITHMS
    results: dict[str, dict] = {}
    for name in algorithms:
        results[name] = run_one_algorithm(name, attack_cfg=attack_cfg)
    return results


## 4. Federated Baseline Runs

In [None]:
BASELINE_ALGORITHMS = ["FedAvg"]  # adjust this list to sweep additional algorithms

baseline_results = run_all_algorithms(BASELINE_ALGORITHMS)
baseline_results


## 5. Surrogate Imports

In [None]:
from attacks import get_attack
from malicious_client import MaliciousClient
from model import MobileNetV2Transfer


## 6. Surrogate Paths & Config

In [None]:
SURROGATE_CFG = CONFIG.get("surrogate", {})
SURROGATE_CLIENT_ID = SURROGATE_CFG.get("client_id", 0)
SURROGATE_SEED = SURROGATE_CFG.get("seed", global_config.get("seed", 42))

def ensure_surrogate_loader():
    loaders, _ = dist_data_per_client(
        data_path=SURROGATE_CFG.get("dataset_path", data_config.get("dataset_path")),
        dataset_name=SURROGATE_CFG.get("dataset_name", data_config.get("dataset_name")),
        num_clients=SURROGATE_CFG.get("num_clients", data_config.get("num_clients", 50)),
        batch_size=SURROGATE_CFG.get("batch_size", data_config.get("batch_size", 96)),
        non_iid_per=SURROGATE_CFG.get("non_iid_per", data_config.get("non_iid_per", 0.0)),
        device=torch.device(global_config.get("device", "cpu")),
    )
    return loaders[SURROGATE_CLIENT_ID]


## 7. Surrogate Baseline

In [None]:
def build_surrogate_model(num_classes: int = 10, pretrained: bool | None = None):
    if pretrained is None:
        pretrained = SURROGATE_CFG.get("pretrained", True)
    return MobileNetV2Transfer(pretrained=pretrained, num_classes=num_classes)


def train_surrogate_baseline(num_epochs: int | None = None):
    set_seed(SURROGATE_SEED)

    loader = ensure_surrogate_loader()
    model = build_surrogate_model(num_classes=SURROGATE_CFG.get("num_classes", 10))
    model.to(torch.device(global_config.get("device", "cpu")))

    criterion = evaluate_fn.__defaults__[1] if isinstance(evaluate_fn.__defaults__, tuple) else torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=SURROGATE_CFG.get("local_lr", 0.003),
    )
    epochs = num_epochs or SURROGATE_CFG.get("num_epochs", 5)

    history = {"loss": [], "accuracy": []}
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total = 0
        correct = 0
        for inputs, labels in loader:
            inputs = inputs.float().to(model.device if hasattr(model, 'device') else inputs.device)
            labels = labels.long().to(inputs.device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        history["loss"].append(running_loss / max(len(loader), 1))
        history["accuracy"].append(100 * correct / max(total, 1))
        print(f"Epoch {epoch+1}/{epochs}: loss {history['loss'][-1]:.4f}, acc {history['accuracy'][-1]:.2f}%")

    return model, history


## 8. Baseline Comparison

In [None]:
surrogate_model, surrogate_history = train_surrogate_baseline()
surrogate_history


## 9. Attack Imports

In [None]:
baseline_summary = baseline_results.get("FedAvg", {})
print("Federated baseline:", baseline_summary)
print("Surrogate final metrics:", surrogate_history['loss'][-1], surrogate_history['accuracy'][-1])


## 10. Attack Paths & Config

In [None]:
# TODO: load attack-related configuration blocks


## 11. Attack Implementations

In [None]:
# TODO: define PGD/FGSM/random-noise routines in-notebook


## 12. Malicious Client Definition

In [None]:
# TODO: implement or override malicious client behaviour


## 13. Attack Execution Helpers

In [None]:
# TODO: add helpers to run one attack or sweep attacks


## 14. Surrogate Attack Experiments

In [None]:
# TODO: craft adversarial batches and report surrogate metrics


## 15. Federated Attack Sweeps

In [None]:
# TODO: run attacks against the FL server under different settings
