# Federated Optimization Experiment Setup

This notebook loads `config.yaml` to reproduce the Section 3 experiments. We use the Imagenette dataset split 
across 100 simulated clients with 10% participation per round, five local epochs, and mini-batches of size 64. 
Use this notebook alongside the [section overview](README.md#section-3-federated-optimization-algorithms) for theory context.

- **Config keys:** `global_config` (hardware/seed), `data_config` (dataset + partitioning), `model_config` (MobileNetV3 transfer head),
  `algorithms` (per-algorithm hyperparameters), and `attack` (PGD-style backdoor setup).
- **Algorithms available:** FedAvg, SCAFFOLD, FedAdam, FedAdagrad, and FedYogi—switchable via `alg_configs`.


In [None]:
import os
from copy import deepcopy

from util_functions import set_logger, save_plt
import torch
import importlib
import numpy as np
from torch.utils.data import DataLoader
from model import *
from load_data_for_clients import dist_data_per_client
from util_functions import set_seed, evaluate_fn, run_fl
from algos import Client, Server, FedOptClient,FedAdamServer, FedAdagradServer, FedYogiServer, ScaffoldClient, ScaffoldServer
import yaml 



with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
seed = config.get("seed", 42)
torch.manual_seed(seed)
np.random.seed(seed)

ALGORITHM_MAP = {
    "FedAvg": Server,
    "Scaffold": ScaffoldServer,
    "FedAdam": FedAdamServer,
    "FedAdagrad": FedAdagradServer,
    "FedYogi": FedYogiServer,
}

def train_server(alg_name, global_config, data_config, model_config, alg_configs, attack_config):
    alg_conf = alg_configs[alg_name]
    fed_cfg  = alg_conf["fed_config"].copy()
    fed_cfg["algorithm"] = alg_name
    optim_cfg = alg_conf.get("optim_config", {})
    return run_fl(
        ALGORITHM_MAP[alg_name],
        global_config, data_config, fed_cfg,
        model_config, optim_cfg,
        attack_config
    )
def eval_server(server):
    loss, acc = evaluate_fn(
        server.data, server.x,
        server.criterion, server.device
    )
    return loss, acc

In [None]:
def run_one_algorithm(alg_name,
                      global_config,
                      data_config,
                      model_config,
                      alg_configs,
                      attack_config=None):

    if alg_name not in alg_configs:
        raise ValueError(f"Algorithm {alg_name!r} not found in the configuration.")

    if attack_config is None:
        attack_config = {"malicious_fraction": 0}

    alg_conf = alg_configs[alg_name]
    fed_cfg  = alg_conf["fed_config"].copy()
    fed_cfg["algorithm"] = alg_name
    optim_cfg = alg_conf.get("optim_config", {})

    server = train_server(alg_name,
        global_config,
        data_config,
        model_config,
        alg_configs,
        attack_config)

    loss, acc = eval_server(server)
    del server; torch.cuda.empty_cache()
    return loss, acc

In [45]:
def run_all_algorithms(global_config,
                       data_config,
                       model_config,
                       alg_configs,
                       attack_config=None):

    results = {}
    for alg_name in alg_configs:
        loss, acc = run_one_algorithm(
            alg_name,
            global_config,
            data_config,
            model_config,
            alg_configs
        )
        results[alg_name] = {
            "loss": loss,
            "accuracy": acc
        }

    return results

In [None]:
def run_attack(alg_name,
               global_config,
               data_config,
               model_config,
               alg_configs,
               attack_config):

    if alg_name not in alg_configs:
        raise ValueError(f"Algorithm {alg_name!r} not in config")

    alg_conf  = alg_configs[alg_name]
    fed_cfg   = alg_conf["fed_config"].copy()
    fed_cfg["algorithm"] = alg_name
    optim_cfg = alg_conf.get("optim_config", {})


    server = train_server(alg_name,
        global_config,
        data_config,
        model_config,
        alg_configs,
        attack_config)
    

    _, clean_acc = eval_server(server)



    mal_client = next(c for c in server.clients if isinstance(c, MaliciousClient))

    mc = MaliciousClient(
        client_id     = mal_client.id,
        local_data    = mal_client.data,
        device        = server.device,
        num_epochs    = 0,
        criterion     = server.criterion,
        lr            = server.lr_l,
        attack_config = attack_config
    )
    dataset = mal_client.data.dataset if hasattr(mal_client.data, "dataset") else mal_client.data
    dataset.x = dataset.x.cpu()
    dataset.y = dataset.y.cpu()
    surr_loader = DataLoader(
        dataset,
        batch_size=attack_config["surrogate_batch_size"],
        shuffle=True,
        num_workers=2
    )
    mc.train_surrogate(surr_loader)

    x_batch, y_batch = next(iter(mc.data))
    x_adv = mc.perform_attack(
        x_batch.to(server.device),
        y_batch.to(server.device)
    )

    preds = server.x(x_adv).argmax(dim=1)
    asr   = (preds == attack_config["target_label"]).float().mean().item()

    del mc, server
    torch.cuda.empty_cache()

    return clean_acc, asr

global_config = config["global_config"]
data_config = config["data_config"]
model_config = config["model_config"]
alg_configs = config["algorithms"]
attack_config = config["attack"]

### Clean FedAvg baseline
Run `run_one_algorithm` with the FedAvg configuration to obtain the loss and accuracy after the configured rounds.
Use this cell as the reference convergence trace before introducing attacks or alternative optimizers.


In [43]:
loss, acc = run_one_algorithm(
    "FedAvg",
    global_config,
    data_config,
    model_config,
    alg_configs,
    
)
print(f"FedAvg (clean) → Loss: {loss:.4f}, Acc: {acc:.2f}%")

Server is successfully initialized



Preparing Data
Loading cached client data from cache/client_data_b910f99db14aade51a10770cb0305f61.pkl



Clients are successfully initialized

Communication Round:1
	client_update has completed
	server_update has completed
	Loss:1.9446   Accuracy:43.41%

Communication Round:2
	client_update has completed
	server_update has completed
	Loss:1.6128   Accuracy:69.43%

Execution has completed


FedAvg (clean) → Loss: 1.6136, Acc: 6.9e+01%


### FedAvg under PGD backdoor attack
This block reuses the FedAvg setup but enables the `attack` configuration to report clean accuracy and attack success rate.
Expect the ASR to rise when the PGD-crafted trigger is effective while clean accuracy degrades relative to the baseline.


In [44]:
atk_config = config["attack"]
clean_acc, asr = run_attack(
    "FedAvg",
    global_config,
    data_config,
    model_config,
    alg_configs,
    attack_config=atk_config
)
print(f"FedAvg (poisoned) → Clean Acc: {clean_acc:.2f}%, ASR: {asr:.2f}%")

Server is successfully initialized



Preparing Data
Loading cached client data from cache/client_data_b910f99db14aade51a10770cb0305f61.pkl



Clients are successfully initialized

Communication Round:1
	client_update has completed
	server_update has completed
	Loss:1.9627   Accuracy:41.38%

Communication Round:2
	client_update has completed
	server_update has completed
	Loss:1.6331   Accuracy:68.56%

Execution has completed


FedAvg (poisoned) → Clean Acc: 68.56%, ASR: 0.084%


In [47]:
clean_results = run_all_algorithms(
    global_config,
    data_config,
    model_config,
    alg_configs
)

Server is successfully initialized



Preparing Data
Loading cached client data from cache/client_data_b910f99db14aade51a10770cb0305f61.pkl



Clients are successfully initialized

Communication Round:1
	client_update has completed
	server_update has completed
	Loss:1.9446   Accuracy:43.41%

Communication Round:2
	client_update has completed
	server_update has completed


KeyboardInterrupt: 

In [None]:
clean_results