# Module 3 â€” Federated Optimization Algorithms (FedAvg, FedOpt, SCAFFOLD)

This notebook compares multiple **federated server update rules** under the same training setup.

We will run:
- **FedAvg** (baseline)
- **FedAdagrad**, **FedAdam**, **FedYogi** (FedOpt family: server-side adaptive optimization)
- **SCAFFOLD** (control variates to reduce client drift)

**Outputs:** comparison plots (accuracy/loss vs rounds) and a small summary of final metrics.


In [None]:
from copy import deepcopy
from pathlib import Path

import math
import json
import matplotlib.pyplot as plt
import pandas as pd
import torch

from federated_core import BaseClient, BaseServer
from util_functions import set_seed, evaluate_fn
import yaml


### Load configuration and data

Read the federated config, apply the global seed, and prepare shared data/model settings for every algorithm.

In [None]:
CONFIG_PATH = Path('config.yaml')
if not CONFIG_PATH.exists():
    raise FileNotFoundError('Could not locate config.yaml for Section 3')

config = yaml.safe_load(CONFIG_PATH.read_text())
global_config = config['global_config']
data_config = config['data_config']
model_config = config['model_config']
alg_configs = config['algorithms']

DEVICE = torch.device(global_config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu'))
set_seed(global_config.get('seed', 42))

ARTIFACT_DIR = Path('artifacts')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)


### Algorithm definitions

Define the algorithm-specific server subclasses (and any specialised clients) that override the base aggregation behaviour.

In [None]:
class FedAvgServer(BaseServer):
    """Vanilla FedAvg uses the base aggregation (simple average)."""
    pass

class FedAdamServer(BaseServer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.beta1 = self.optim_config.get('beta1', 0.9)
        self.beta2 = self.optim_config.get('beta2', 0.99)
        self.epsilon = self.optim_config.get('epsilon', 1e-6)
        param_state = self._parameter_state_dict()
        self.m = {name: torch.zeros_like(tensor) for name, tensor in param_state.items()}
        self.v = {name: torch.zeros_like(tensor) for name, tensor in param_state.items()}
    def aggregate(self, local_states):
        if not local_states:
            return
        global_params = self._parameter_state_dict()
        delta_params = []
        for state in local_states:
            delta_params.append({name: state[name] - global_params[name] for name in global_params})
        mean_delta = self._average_state_dicts(delta_params)
        updated = {}
        for name, param in global_params.items():
            delta = mean_delta[name]
            self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * delta
            self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * torch.square(delta)
            # FIX: Add safety clamping to prevent numerical issues
            self.v[name] = torch.clamp(self.v[name], min=1e-10)
            updated[name] = param + self.global_lr * self.m[name] / (torch.sqrt(self.v[name]) + self.epsilon)
        full_state = self._global_state_cpu()
        for name in updated:
            full_state[name] = updated[name]
        self.global_model.load_state_dict(full_state)

In [None]:
class FedAdagradServer(BaseServer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = self.optim_config.get('epsilon', 1e-6)
        param_state = self._parameter_state_dict()
        # FIX: Initialize with zeros (standard Adagrad)
        self.s = {name: torch.zeros_like(tensor) for name, tensor in param_state.items()}
    def aggregate(self, local_states):
        if not local_states:
            return
        global_params = self._parameter_state_dict()
        delta_params = []
        for state in local_states:
            delta_params.append({name: state[name] - global_params[name] for name in global_params})
        mean_delta = self._average_state_dicts(delta_params)
        updated = {}
        for name, param in global_params.items():
            delta = mean_delta[name]
            self.s[name] = self.s[name] + torch.square(delta)
            # FIX: Remove double damping - use only epsilon
            updated[name] = param + self.global_lr * delta / (torch.sqrt(self.s[name]) + self.epsilon)
        full_state = self._global_state_cpu()
        for name in updated:
            full_state[name] = updated[name]
        self.global_model.load_state_dict(full_state)

In [None]:
class FedYogiServer(BaseServer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.beta1 = self.optim_config.get('beta1', 0.9)
        self.beta2 = self.optim_config.get('beta2', 0.99)
        self.epsilon = self.optim_config.get('epsilon', 1e-6)
        param_state = self._parameter_state_dict()
        self.m = {name: torch.zeros_like(tensor) for name, tensor in param_state.items()}
        self.v = {name: torch.zeros_like(tensor) for name, tensor in param_state.items()}
        self.timestep = 1
    def aggregate(self, local_states):
        if not local_states:
            return
        global_params = self._parameter_state_dict()
        delta_params = []
        for state in local_states:
            delta_params.append({name: state[name] - global_params[name] for name in global_params})
        mean_delta = self._average_state_dicts(delta_params)
        updated = {}
        for name, param in global_params.items():
            delta = mean_delta[name]
            self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * delta
            # FIX: Clamp v to prevent negative values that cause sqrt to fail
            self.v[name] = torch.clamp(
                self.v[name] - (1 - self.beta2) * torch.sign(self.v[name] - torch.square(delta)) * torch.square(delta),
                min=self.epsilon
            )
            updated[name] = param + self.global_lr * self.m[name] / (torch.sqrt(self.v[name]) + self.epsilon)
        self.timestep += 1
        full_state = self._global_state_cpu()
        for name in updated:
            full_state[name] = updated[name]
        self.global_model.load_state_dict(full_state)

In [None]:
class ScaffoldClient(BaseClient):
    def __init__(self, *args, control_init: dict[str, torch.Tensor] | None = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.client_c = control_init or {}
        self.server_c: dict[str, torch.Tensor] | None = None
    def set_server_controls(self, server_c: dict[str, torch.Tensor]):
        self.server_c = {k: v.clone() for k, v in server_c.items()}
    def train_with_controls(self, global_model: torch.nn.Module):
        if self.server_c is None:
            raise RuntimeError('Server controls must be provided before training.')
        local_model = deepcopy(global_model).to(self.device)
        local_model.train()
        param_names = [name for name, _ in local_model.named_parameters()]
        params = [param for _, param in local_model.named_parameters()]
        server_controls = [self.server_c[name].to(self.device) for name in param_names]
        client_controls = [self.client_c[name].to(self.device) for name in param_names]
        for _ in range(self.num_epochs):
            for inputs, labels in self.data:
                inputs = inputs.float().to(self.device)
                labels = labels.long().to(self.device)
                outputs = local_model(inputs)
                loss = self.criterion(outputs, labels)
                grads = torch.autograd.grad(loss, params)
                for param, grad, s_c, c_c in zip(params, grads, server_controls, client_controls):
                    param.data -= self.lr * (grad + s_c - c_c)
        local_params = {name: param.detach().cpu() for name, param in local_model.named_parameters()}
        global_params = {name: param.detach().cpu() for name, param in global_model.named_parameters()}
        delta_y = {name: local_params[name] - global_params[name] for name in global_params}
        steps_per_epoch = math.ceil(len(self.data.dataset) / self.data.batch_size) if self.data.batch_size else 0
        total_steps = max(steps_per_epoch * self.num_epochs, 1)
        scale = total_steps * self.lr
        new_client_c = {}
        delta_c = {}
        for name in global_params:
            new_client_c[name] = self.client_c[name] - self.server_c[name] - delta_y[name] / scale
            delta_c[name] = new_client_c[name] - self.client_c[name]
        self.client_c = new_client_c
        return local_params, delta_y, delta_c


In [None]:
class ScaffoldServer(BaseServer):
    def __init__(self, *args, **kwargs):
        self.c_init = (kwargs.get('optim_config') or {}).get('c_init', 0.0)
        super().__init__(*args, client_cls=BaseClient, **kwargs)
        self.server_c = self._zeros_like_parameters()
        # Replace base clients with Scaffold-aware clients
        self.clients = []
    def setup(self) -> None:
        super().setup()
        control_init = {name: torch.full_like(param, self.c_init) for name, param in self._parameter_state_dict().items()}
        self.clients = [
            ScaffoldClient(
                client_id=client.id,
                local_data=client.data,
                device=client.device,
                num_epochs=client.num_epochs,
                lr=client.lr,
                criterion=client.criterion,
                control_init=deepcopy(control_init),
            )
            for client in self.clients
        ]
    def collect_client_updates(self, client_ids):
        updates = []
        for idx in client_ids:
            client: ScaffoldClient = self.clients[idx]
            client.set_server_controls(self.server_c)
            local_params, delta_y, delta_c = client.train_with_controls(self.global_model)
            updates.append((idx, local_params, delta_y, delta_c))
        return updates
    def aggregate(self, updates):
        if not updates:
            return
        _, _, delta_y_list, delta_c_list = zip(*updates)
        mean_delta = self._average_state_dicts(delta_y_list)
        global_params = self._parameter_state_dict()
        # Update global model
        for name in global_params:
            global_params[name] = global_params[name] + self.global_lr * mean_delta[name]
        full_state = self._global_state_cpu()
        for name in global_params:
            full_state[name] = global_params[name]
        self.global_model.load_state_dict(full_state)
        # FIX: Update control variates by averaging delta_c over participating clients
        mean_delta_c = self._average_state_dicts(delta_c_list)
        for name in self.server_c:
            self.server_c[name] = self.server_c[name] + mean_delta_c[name]

In [None]:
ALGORITHM_MAP = {
    'FedAvg': FedAvgServer,
    'Scaffold': ScaffoldServer,
    'FedAdam': FedAdamServer,
    'FedAdagrad': FedAdagradServer,
    'FedYogi': FedYogiServer,
}


### Helper to run an algorithm

Instantiate the server, set up clients, train for the configured rounds, and return final metrics plus history.

In [None]:
def run_algorithm(name: str):
    alg_conf = deepcopy(alg_configs[name])
    fed_cfg = deepcopy(alg_conf['fed_config'])
    server_cls = ALGORITHM_MAP[name]
    server = server_cls(
        model_config=model_config,
        global_config=global_config,
        data_config=data_config,
        fed_config=fed_cfg,
        optim_config=alg_conf.get('optim_config', {}),
    )
    server.setup()
    server.train()
    loss, acc = evaluate_fn(server.test_loader, server.global_model, server.criterion, server.device)
    history = server.results
    return {'final_loss': loss, 'final_accuracy': acc, 'history': history}


### Execute all algorithms

Run each configured optimiser once and capture its training history.

In [None]:
# Edit this list to run one algorithm at a time if GPU memory is limited.
ALGORITHMS_TO_RUN = list(alg_configs)


In [None]:
results = {}
for name in ALGORITHMS_TO_RUN:
    print(f'Running {name} ...')
    results[name] = run_algorithm(name)
    torch.cuda.empty_cache()
results


### Accuracy over rounds

Visualise how each optimiser converges on the global test set.

In [None]:
def plot_histories(results):
    plt.figure(figsize=(8, 4))
    for name, summary in results.items():
        acc = summary['history'].get('accuracy')
        if acc:
            plt.plot(acc, label=name)
    plt.xlabel('Communication round')
    plt.ylabel('Test accuracy (%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.title('Algorithm convergence comparison')
plot_histories(results)


### Final metrics

Tabulate final loss and accuracy to compare end-of-training performance.

In [None]:
summary_df = pd.DataFrame({name: {'final_loss': res['final_loss'], 'final_accuracy': res['final_accuracy']} for name, res in results.items()}).T
summary_df


### Persist artefacts

Save the convergence history and summary table so instructors can share reference results without rerunning the lab.


In [None]:
history_export = {name: res['history'] for name, res in results.items()}
summary_path = ARTIFACT_DIR / 'module3_summary.csv'
history_path = ARTIFACT_DIR / 'module3_histories.json'
summary_df.to_csv(summary_path, float_format='%.4f')
with history_path.open('w') as f:
    json.dump(history_export, f, indent=2)
print(f'Saved summary to {summary_path.resolve()}')
print(f'Saved histories to {history_path.resolve()}')


### Validation checks

Ensure each algorithm logged the expected number of rounds and produced non-trivial accuracy.


In [None]:
def validate_histories(results, algorithm_config):
    issues = []
    for name, summary in results.items():
        expected = algorithm_config[name]['fed_config']['num_rounds']
        actual = len(summary['history'].get('accuracy', []))
        if actual < expected:
            issues.append(f"{name}: expected {expected} rounds, saw {actual}")
        if summary['final_accuracy'] <= 0:
            issues.append(f"{name}: non-positive final accuracy {summary['final_accuracy']}")
    if issues:
        raise ValueError('History validation failed:
' + '
'.join(issues))
    print('Validation passed for', ', '.join(sorted(results)))

validate_histories(results, alg_configs)


### Quick takeaway

Highlight the top-performing optimiser based on the final accuracy.


In [None]:
best_alg = summary_df['final_accuracy'].idxmax()
best_acc = summary_df.loc[best_alg, 'final_accuracy']
worst_acc = summary_df['final_accuracy'].min()
print(f'Best performer: {best_alg} with {best_acc:.2f}% accuracy.')
print(f'Accuracy gap vs. slowest optimiser: {best_acc - worst_acc:.2f} percentage points.')
