# FedAsync

This notebook is organized with one cell per original source file,
and is optimized for running on **Google Colab** with all results
stored on **Google Drive**.

**How to use on Colab:**

1. Upload this notebook to Colab.
2. Run the first cell to mount Google Drive and set your base path.
3. All subsequent paths (data, logs, models, results) will be created
   relative to that base directory.


In [None]:
! pip install pytorch_lightning

In [None]:

# @title Mount Google Drive and set base project directory
from google.colab import drive
import os
from pathlib import Path

# Mount your Google Drive
drive.mount('/content/drive')

# Base directory INSIDE your Drive where everything will be stored.
# You can change this to any folder you like.
BASE_DIR = "/content/drive/MyDrive/Arya/FedAsync"  # @param {type:"string"}

BASE_PATH = Path(BASE_DIR).expanduser()
BASE_PATH.mkdir(parents=True, exist_ok=True)
os.chdir(BASE_PATH)
os.environ["FEDASYNC_BASE_DIR"] = str(BASE_PATH)
print("Working directory set to:", os.getcwd())


## config.yaml

Below is the YAML configuration used by the code.


In [None]:

from pathlib import Path
import textwrap

config_path = Path(BASE_DIR) / "config.yaml"
config_text = f"""

# ---- Data ----
data:
  dataset: cifar10
  data_dir: ./data
  num_classes: 10

# ---- Clients ----
clients:
  total: 20
  concurrent: 10
  local_epochs: 2
  batch_size: 128
  lr: 0.005

  # --- heterogeneity controls ---
  struggle_percent: 20          # % of all clients that are slow
  delay_slow_range: [0.8, 2.0]  # seconds, Uniform[a, b] for slow clients
  delay_fast_range: [0.0, 0.2]  # seconds, Uniform[a, b] for fast/normal clients
  jitter_per_round: 0.1         # extra +/- seconds each local fit; 0 disables
  fix_delays_per_client: true   # true: sample once per client; false: resample every fit

# ---- Async FedAsync ----
async:
  alpha: 0.5

# ---- Evaluation / stopping ----
eval:
  # Aggregation-based eval cadence
  eval_every_aggs: 5
  target_accuracy: 0.8     # stop when global test_acc >= this value

# ---- Safety cap on merges (optional) ----
train:
  max_rounds: 1000

# ---- Partitioning ----
partition_alpha: 0.1

# ---- Reproducibility ----
seed: 42

# ---- Runtime / I/O ----
server_runtime:
  client_delay: 0.0   # extra global delay added on every client before training

io:
  checkpoints_dir: "{(Path(BASE_DIR) / 'checkpoints' / 'FedAsync').as_posix()}"
  logs_dir: "{(Path(BASE_DIR) / 'logs').as_posix()}"
  results_dir: "{(Path(BASE_DIR) / 'results').as_posix()}"
  global_log_csv: "{(Path(BASE_DIR) / 'logs' / 'FedAsync.csv').as_posix()}"
  client_participation_csv: "{(Path(BASE_DIR) / 'logs' / 'FedAsyncClientParticipation.csv').as_posix()}"
  final_model_path: "{(Path(BASE_DIR) / 'results' / 'FedAsyncModel.pt').as_posix()}"

"""

config_path.write_text(textwrap.dedent(config_text).lstrip())
print(f"config.yaml written to {config_path}")



In [None]:

from pathlib import Path
import textwrap

config_path = Path(BASE_DIR) / "config_run.yaml"
config_text = f"""

# ---- Data ----
data:
  dataset: cifar10
  data_dir: ./data
  num_classes: 10

# ---- Clients ----
clients:
  total: 20
  concurrent: 10
  local_epochs: 2
  batch_size: 128
  lr: 0.005

  # --- heterogeneity controls ---
  struggle_percent: 20          # % of all clients that are slow
  delay_slow_range: [0.8, 2.0]  # seconds, Uniform[a, b] for slow clients
  delay_fast_range: [0.0, 0.2]  # seconds, Uniform[a, b] for fast/normal clients
  jitter_per_round: 0.1         # extra +/- seconds each local fit; 0 disables
  fix_delays_per_client: true   # true: sample once per client; false: resample every fit

# ---- Async FedAsync ----
async:
  alpha: 0.5

# ---- Evaluation / stopping ----
eval:
  # Aggregation-based eval cadence
  eval_every_aggs: 5
  target_accuracy: 0.8     # stop when global test_acc >= this value

# ---- Safety cap on merges (optional) ----
train:
  max_rounds: 1000

# ---- Partitioning ----
partition_alpha: 1000

# ---- Reproducibility ----
seed: 42

# ---- Runtime / I/O ----
server_runtime:
  client_delay: 0.0   # extra global delay added on every client before training

io:
  checkpoints_dir: "{(Path(BASE_DIR) / 'checkpoints' / 'FedAsync').as_posix()}"
  logs_dir: "{(Path(BASE_DIR) / 'logs').as_posix()}"
  results_dir: "{(Path(BASE_DIR) / 'results').as_posix()}"
  global_log_csv: "{(Path(BASE_DIR) / 'logs' / 'FedAsync.csv').as_posix()}"
  client_participation_csv: "{(Path(BASE_DIR) / 'logs' / 'FedAsyncClientParticipation.csv').as_posix()}"
  final_model_path: "{(Path(BASE_DIR) / 'results' / 'FedAsyncModel.pt').as_posix()}"

"""

config_path.write_text(textwrap.dedent(config_text).lstrip())
print(f"config.yaml written to {config_path}")



In [None]:
# ===== helper.py =====

from __future__ import annotations

import random
import numpy as np
import torch


def set_seed(seed: int = 42) -> None:
    """Seed all RNGs used in this project.

    Setting a global seed helps to ensure reproducible results.  This
    function touches Python's builtâ€‘in random module, NumPy, and
    PyTorch's CPU and GPU RNGs.  Deterministic behaviour in cuDNN
    kernels is also enabled.

    Parameters
    ----------
    seed:
        The random seed to use.  Defaults to ``42``.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # When running on GPUs you may have more than one device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # Deterministic behaviour comes at a performance cost but
    # reproducibility is more important for experimentation.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device() -> torch.device:
    """Return the first available computation device.

    Tries to use CUDA if available, otherwise falls back to MPS
    (Apple Silicon) and finally the CPU.

    Returns
    -------
    device:
        A PyTorch ``torch.device`` object indicating where tensors
        should be allocated.
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    # MPS stands for Metal Performance Shaders.  It is the backend
    # available on Apple Silicon systems for GPU acceleration.
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


In [None]:
# ===== model.py =====

# Utilities for building models and converting parameters
from typing import Dict, List
import torch
import torch.nn as nn
from torchvision import models


def build_resnet18(num_classes: int = 10, pretrained: bool = False) -> nn.Module:
    """Create a ResNet-18 tailored for CIFAR-size inputs."""
    m = models.resnet18(weights=None)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()

    in_features = m.fc.in_features
    m.fc = nn.Linear(in_features, num_classes)
    m.num_classes = num_classes
    return m


def state_to_list(state: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
    """Flatten a state_dict to a list of tensors on CPU."""
    return [t.detach().cpu().clone() for _, t in state.items()]


def list_to_state(template: Dict[str, torch.Tensor], arrs: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Rebuild a state_dict from a list of tensors using a template for keys/dtypes/devices."""
    out: Dict[str, torch.Tensor] = {}
    for (k, v), a in zip(template.items(), arrs):
        out[k] = a.to(v.device).type_as(v)
    return out


In [None]:

# ===== partitioning.py =====

import os
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Any


BASE_PATH = Path(os.environ.get("FEDASYNC_BASE_DIR", ".")).expanduser()
DEFAULT_DATA_DIR = str(BASE_PATH / "data")
DEFAULT_RESULTS_DIR = BASE_PATH / "results"


class DataDistributor:
    def __init__(self, dataset_name: str, data_dir: str = DEFAULT_DATA_DIR):
        """
        Flexible data distributor for federated learning experiments.

        Args:
            dataset_name (str): Name of dataset ('CIFAR10', 'MNIST', etc.)
            data_dir (str): Directory to store data.
        """
        self.dataset_name = dataset_name.lower()
        data_root = Path(data_dir)
        if not data_root.is_absolute():
            data_root = BASE_PATH / data_root
        self.data_dir = str(data_root)
        self.train_dataset, self.test_dataset, self.num_classes = self._load_dataset()
        self.partitions = None

    def _load_dataset(self) -> Tuple[Any, Any, int]:
        """Load supported torchvision datasets."""
        transform = transforms.Compose([transforms.ToTensor()])

        if self.dataset_name == "cifar10":
            train = datasets.CIFAR10(self.data_dir, train=True, download=True, transform=transform)
            test = datasets.CIFAR10(self.data_dir, train=False, download=True, transform=transform)
            num_classes = 10

        elif self.dataset_name == "mnist":
            train = datasets.MNIST(self.data_dir, train=True, download=True, transform=transform)
            test = datasets.MNIST(self.data_dir, train=False, download=True, transform=transform)
            num_classes = 10

        elif self.dataset_name == "fashionmnist":
            train = datasets.FashionMNIST(self.data_dir, train=True, download=True, transform=transform)
            test = datasets.FashionMNIST(self.data_dir, train=False, download=True, transform=transform)
            num_classes = 10

        else:
            raise ValueError(f"Dataset '{self.dataset_name}' is not supported yet.")

        return train, test, num_classes

    def distribute_data(self, num_clients: int, alpha: float = 0.5, seed: int = 42) -> Dict[int, List[int]]:
        """
        Perform Dirichlet-based data partitioning across clients (Non-IID).

        Args:
            num_clients (int): Number of clients.
            alpha (float): Dirichlet distribution parameter (smaller = more non-IID).
            seed (int): Random seed for reproducibility.
        """
        np.random.seed(seed)
        targets = np.array(self.train_dataset.targets)
        self.partitions = {i: [] for i in range(num_clients)}

        for cls in range(self.num_classes):
            idxs = np.where(targets == cls)[0]
            # Shuffle indices for this class
            np.random.shuffle(idxs)
            # Sample proportions from a Dirichlet distribution
            proportions = np.random.dirichlet(alpha=np.repeat(alpha, num_clients))
            # Convert proportions to integer counts (floor) for each client
            int_props = np.floor(proportions * len(idxs)).astype(int)
            # Assign counts to clients
            start = 0
            for client_id, size in enumerate(int_props):
                self.partitions[client_id].extend(idxs[start:start + size])
                start += size
            # If any samples are left over due to floor truncation, assign them
            # to clients with the largest initial share (or random if equal).  This
            # ensures that the union of partitions covers the full dataset.
            remaining = len(idxs) - start
            if remaining > 0:
                # Rank clients by proportion (descending); break ties randomly
                ranked_clients = np.argsort(-proportions)
                # Distribute leftover samples in roundâ€‘robin order among ranked clients
                for i in range(remaining):
                    cid = ranked_clients[i % len(ranked_clients)]
                    self.partitions[int(cid)].append(idxs[start + i])

        for cid in self.partitions:
            np.random.shuffle(self.partitions[cid])

        return self.partitions

    # ... rest of partitioning.py remains unchanged ...


    def get_client_data(self, client_id: int) -> Subset:
        """
        Retrieve dataset subset for a specific client.

        Args:
            client_id (int): Client identifier.
        """
        if self.partitions is None:
            raise ValueError("Please run distribute_data() before accessing client data.")
        indices = self.partitions[client_id]
        return Subset(self.train_dataset, indices)

    def visualize_distribution(self, save_path: str | None = None) -> None:
        """
        Create IEEE-style stacked bar chart of sample counts per client.

        Args:
            save_path (str | None): File path to save the visualization.
        """
        if self.partitions is None:
            raise ValueError("Run distribute_data() before visualization.")

        path_obj = Path(save_path) if save_path is not None else DEFAULT_RESULTS_DIR / "data_distribution_ieee.png"
        if not path_obj.is_absolute():
            path_obj = BASE_PATH / path_obj
        path_obj.parent.mkdir(parents=True, exist_ok=True)

        targets = np.array(self.train_dataset.targets)
        client_counts = np.zeros((len(self.partitions), self.num_classes), dtype=int)

        for cid, idxs in self.partitions.items():
            class_counts = np.bincount(targets[idxs], minlength=self.num_classes)
            client_counts[cid, :] = class_counts

        # IEEE single-column figure size (~3.5in wide)
        fig, ax = plt.subplots(figsize=(1.8, 1.2), dpi=300)
        bottom = np.zeros(len(self.partitions))
        colors = plt.get_cmap("tab20").colors

        for cls in range(self.num_classes):
            ax.bar(
                x=np.arange(len(self.partitions)),
                height=client_counts[:, cls],
                bottom=bottom,
                color=colors[cls % len(colors)],
                linewidth=0.1,
                edgecolor="white",
            )
            bottom += client_counts[:, cls]

        ax.set_xlabel("Client ID", fontsize=8)
        ax.set_ylabel("Samples", fontsize=8)
        ax.set_title(f"{self.dataset_name.upper()} Data Distribution Among Clients", fontsize=9)
        ax.tick_params(axis="both", which="major", labelsize=7)

        fig.savefig(path_obj, dpi=300, bbox_inches="tight")
        plt.close(fig)
        print(f"ðŸ“Š Data distribution plot saved at: {path_obj}")


# -------------------------------
# Example Usage (for testing)
# -------------------------------
if __name__ == "__main__":
    distributor = DataDistributor(dataset_name="CIFAR10", data_dir=DEFAULT_DATA_DIR)
    distributor.distribute_data(num_clients=5, alpha=0.3, seed=42)
    distributor.visualize_distribution(DEFAULT_RESULTS_DIR / "cifar10_distribution_ieee.png")

    # Retrieve client dataset subset
    client_data = distributor.get_client_data(0)
    print(f"âœ… Client 0 has {len(client_data)} samples.")


In [None]:
# ==== server.py ====

# Async FedAsync server with periodic evaluation/logging and accuracy-based stopping
import csv
import time
import threading
from pathlib import Path
from typing import List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, Future

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from contextlib import redirect_stdout, redirect_stderr
import io
from utils.model import state_to_list, list_to_state, build_resnet18
from utils.helper import get_device



def _testloader(root: str, batch_size: int = 256):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    buf = io.StringIO()
    # Silence torchvision download/cache prints
    with redirect_stdout(buf), redirect_stderr(buf):
        ds = datasets.CIFAR10(root=root, train=False, download=True, transform=tfm)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)


def _evaluate(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            loss_sum += float(loss.item()) * y.size(0)
            total += y.size(0)
            correct += (logits.argmax(1) == y).sum().item()
    return loss_sum / max(1, total), correct / max(1, total)


def _async_eval_worker(
    state_dict: dict,
    data_dir: str,
    num_classes: int,
    log_path: str,
    agg_id: int,
    avg_train_loss: float,
    avg_train_acc: float,
    device_str: str,
) -> Tuple[float, float]:
    """Evaluate a model copy in a background thread to avoid blocking the server."""
    device = torch.device(device_str)
    model = build_resnet18(num_classes=num_classes)
    model.load_state_dict(state_dict)
    model = model.to(device)
    loader = _testloader(root=data_dir, batch_size=256)
    test_loss, test_acc = _evaluate(model, loader, device)

    log_path = Path(log_path)
    log_path.parent.mkdir(parents=True, exist_ok=True)
    write_header = not log_path.exists()
    with log_path.open("a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(
                ["total_agg", "avg_train_loss", "avg_train_acc", "test_loss", "test_acc", "time"]
            )
        writer.writerow(
            [
                int(agg_id),
                float(avg_train_loss),
                float(avg_train_acc),
                float(test_loss),
                float(test_acc),
                time.time(),
            ]
        )
    return test_loss, test_acc


class AsyncFedServer:
    """FedAsync: fixed alpha mix (1 - alpha) * w_global + alpha * w_client. Logs every `eval_every_aggs` aggregations."""
    def __init__(
        self,
        global_model: torch.nn.Module,
        alpha: float = 0.5,
        target_accuracy: float = 0.70,
        max_rounds: Optional[int] = None,
        eval_every_aggs: int = 5,
        data_dir: str = "./data",
        logs_dir: str = "./logs",
        global_log_csv: Optional[str] = None,
        client_participation_csv: Optional[str] = None,
        final_model_path: Optional[str] = None,
        num_classes: int = 10,
        device: Optional[torch.device] = None,
    ):
        self.model = global_model
        self.template = {k: v.detach().clone() for k, v in self.model.state_dict().items()}
        self.device = device or get_device()
        self.model.to(self.device)
        self.num_classes = int(num_classes)

        self.alpha = float(alpha)

        self.eval_every_aggs = int(eval_every_aggs)
        self.target_accuracy = float(target_accuracy)
        self.max_rounds = int(max_rounds) if max_rounds is not None else None

        # I/O
        self.data_dir = data_dir
        self.log_dir = Path(logs_dir); self.log_dir.mkdir(parents=True, exist_ok=True)

        # Paths supplied by config (with defaults)
        self.csv_path = Path(global_log_csv) if global_log_csv else (self.log_dir / "FedAsync.csv")
        self.participation_csv = Path(client_participation_csv) if client_participation_csv else (self.log_dir / "FedAsyncClientParticipation.csv")
        self.final_model_path = Path(final_model_path) if final_model_path else Path("./results/FedAsyncModel.pt")
        self.final_model_path.parent.mkdir(parents=True, exist_ok=True)

        # Init CSV headers if files don't exist
        if not self.csv_path.exists():
            self.csv_path.parent.mkdir(parents=True, exist_ok=True)
            with self.csv_path.open("w", newline="") as f:
                csv.writer(f).writerow(
                    ["total_agg", "avg_train_loss", "avg_train_acc", "test_loss", "test_acc", "time"]
                )

        if not self.participation_csv.exists():
            self.participation_csv.parent.mkdir(parents=True, exist_ok=True)
            with self.participation_csv.open("w", newline="") as f:
                csv.writer(f).writerow(
                    [
                        "client_id",
                        "local_train_loss",
                        "local_train_acc",
                        "local_test_loss",
                        "local_test_acc",
                        "total_agg",
                        "staleness",
                    ]
                )

        self._lock = threading.Lock()
        self._stop = False
        self.t_round = 0  # increments on every merge
        self.testloader = _testloader(self.data_dir)
        self._train_loss_acc_accum: List[Tuple[float, float, int]] = []  # (loss, acc, n) since last eval
        # Async evaluation state
        self._eval_executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(max_workers=1)
        self._async_eval_future: Optional[Future] = None
        self._init_async_eval_log()
        self._last_eval: Tuple[float, float] = (0.0, 0.0)

    def _save_final_model(self) -> None:
        torch.save(self.model.state_dict(), self.final_model_path)

    def _shutdown_eval_executor(self) -> None:
        """Release async eval thread quickly when stopping."""
        if self._eval_executor is None:
            return
        try:
            self._eval_executor.shutdown(wait=False, cancel_futures=True)
        except Exception:
            pass
        self._eval_executor = None

    # ---------- client/server API ----------
    def get_global(self):
        with self._lock:
            return state_to_list(self.model.state_dict()), self.t_round

    def submit_update(
        self,
        client_id: int,
        base_version: int,
        new_params: List[torch.Tensor],
        num_samples: int,
        train_time_s: float,
        train_loss: float,
        train_acc: float,
        test_loss: float,
        test_acc: float,
    ) -> None:
        cleanup_requested = False
        with self._lock:
            if self._stop:
                return
            if self.max_rounds is not None and self.t_round >= self.max_rounds:
                self._stop = True
                cleanup_requested = True
            else:
                # FedAsync merge (fixed alpha, still logs staleness for visibility)
                staleness = max(0, self.t_round - base_version)
                eff = self.alpha

                g = state_to_list(self.model.state_dict())
                merged = [(1.0 - eff) * gi + eff * ci for gi, ci in zip(g, new_params)]
                new_state = list_to_state(self.template, merged)
                self.model.load_state_dict(new_state, strict=True)

                self.t_round += 1

                # Client participation CSV (append staleness like TrustWeight)
                with self.participation_csv.open("a", newline="") as f:
                    csv.writer(f).writerow(
                        [
                            client_id,
                            f"{train_loss:.6f}",
                            f"{train_acc:.6f}",
                            f"{test_loss:.6f}",
                            f"{test_acc:.6f}",
                            self.t_round,
                            float(staleness),
                        ]
                    )

                # accumulate metrics since last eval tick (used by optional timer)
                self._train_loss_acc_accum.append((float(train_loss), float(train_acc), int(num_samples)))
                # Kick off async global eval every eval_every_aggs
                if self.t_round % self.eval_every_aggs == 0:
                    avg_loss, avg_acc = self._compute_avg_train()
                    self._train_loss_acc_accum.clear()
                    self._launch_async_eval_if_needed(self.t_round, avg_loss, avg_acc)

                # only print aggregation number to console
                print(self.t_round)
        if cleanup_requested:
            self._save_final_model()
            self._shutdown_eval_executor()

    def should_stop(self) -> bool:
        with self._lock:
            return self._stop

    def mark_stop(self) -> None:
        with self._lock:
            self._stop = True
            # store final model when marking stop
            self._save_final_model()
        self._shutdown_eval_executor()

    # ---------- evaluation / logging ----------
    def _compute_avg_train(self) -> Tuple[float, float]:
        if not self._train_loss_acc_accum:
            return 0.0, 0.0
        loss_sum, acc_sum, n_sum = 0.0, 0.0, 0
        for l, a, n in self._train_loss_acc_accum:
            loss_sum += l * n
            acc_sum += a * n
            n_sum += n
        return loss_sum / max(1, n_sum), acc_sum / max(1, n_sum)

    def _init_async_eval_log(self) -> None:
        """Ensure the async evaluation CSV exists with header."""
        if self.csv_path.exists():
            return
        self.csv_path.parent.mkdir(parents=True, exist_ok=True)
        with self.csv_path.open("w", newline="") as f:
            csv.writer(f).writerow(
                ["total_agg", "avg_train_loss", "avg_train_acc", "test_loss", "test_acc", "time"]
            )

    def _launch_async_eval_if_needed(self, total_agg: int, avg_train_loss: float, avg_train_acc: float) -> None:
        """Schedule a non-blocking global eval every eval_every_aggs aggregations."""
        if self._eval_executor is None:
            return
        if total_agg % self.eval_every_aggs != 0:
            return
        if self._async_eval_future is not None and not self._async_eval_future.done():
            return

        # snapshot state on current device
        state_copy = {k: v.clone() for k, v in self.model.state_dict().items()}
        self._async_eval_future = self._eval_executor.submit(
            _async_eval_worker,
            state_copy,
            self.data_dir,
            self.num_classes,
            str(self.csv_path),
            total_agg,
            avg_train_loss,
            avg_train_acc,
            str(self.device),
        )
        self._async_eval_future.add_done_callback(self._handle_eval_result)

    def _handle_eval_result(self, fut: Future) -> None:
        try:
            test_loss, test_acc = fut.result()
        except Exception:
            return
        with self._lock:
            self._last_eval = (float(test_loss), float(test_acc))
            if test_acc >= self.target_accuracy:
                self._stop = True
                self._save_final_model()
                self._shutdown_eval_executor()

    def wait(self):
        try:
            while not self.should_stop():
                time.sleep(0.2)
        finally:
            self.mark_stop()


In [None]:
# ===== client.py =====

# Lightning-based local client, no Flower deps
import time
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from contextlib import redirect_stdout, redirect_stderr
import io
import pytorch_lightning as pl
from utils.model import build_resnet18, state_to_list, list_to_state
from utils.helper import get_device
import random


def _device_to_accelerator(device: torch.device) -> str:
    if device.type == "cuda":
        return "gpu"
    if device.type == "mps":
        return "mps"
    return "cpu"


def _testloader(root: str, batch_size: int = 256):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    buf = io.StringIO()
    # Silence torchvision download/cache prints
    with redirect_stdout(buf), redirect_stderr(buf):
        ds = datasets.CIFAR10(root=root, train=False, download=True, transform=tfm)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2)


def _evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    crit = nn.CrossEntropyLoss()
    model = model.to(device)
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            loss_sum += float(loss.item()) * y.size(0)
            total += y.size(0)
            correct += (logits.argmax(1) == y).sum().item()
    return loss_sum / max(1, total), correct / max(1, total)


class LitCifar(pl.LightningModule):
    def __init__(self, base_model: nn.Module, lr: float = 1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=["base_model"])
        self.model = base_model
        self.criterion = nn.CrossEntropyLoss()
        self._train_loss_sum = 0.0
        self._train_correct = 0
        self._train_total = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, _batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        pred = logits.argmax(1)
        self._train_loss_sum += float(loss.item()) * y.size(0)
        self._train_correct += (pred == y).sum().item()
        self._train_total += y.size(0)
        return loss

    def on_train_epoch_start(self):
        self._train_loss_sum = 0.0
        self._train_correct = 0
        self._train_total = 0

    def get_epoch_metrics(self) -> Tuple[float, float]:
        if self._train_total == 0:
            return 0.0, 0.0
        return self._train_loss_sum / self._train_total, self._train_correct / self._train_total

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


class LocalAsyncClient:
    """Pull global -> Lightning local fit -> push update with averaged metrics.
       Supports per-client slow/fast delays with optional per-round jitter."""
    def __init__(
        self,
        cid: int,
        cfg: dict,
        subset: Subset,
        base_delay: float = 0.0,
        slow: bool = False,
        delay_ranges: Optional[tuple] = None,   # ((a_s, b_s), (a_f, b_f))
        jitter: float = 0.0,
        fix_delay: bool = True,
    ):
        self.cid = cid
        self.cfg = cfg
        self.device = get_device()

        base = build_resnet18(num_classes=cfg["data"]["num_classes"], pretrained=False)
        self.lit = LitCifar(base, lr=float(cfg["clients"]["lr"]))

        # Rebuild a training subset with CIFAR-style augmentation (keeps partition indices, avoids touching partitioner)
        indices = subset.indices if hasattr(subset, "indices") else list(range(len(subset)))
        train_tfm = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
        ])
        train_ds = datasets.CIFAR10(cfg["data"]["data_dir"], train=True, download=False, transform=train_tfm)
        aug_subset = Subset(train_ds, indices)
        self.loader = DataLoader(aug_subset, batch_size=int(cfg["clients"]["batch_size"]),
                                 shuffle=True, num_workers=0)

        # delay controls
        self.base_delay = float(base_delay)
        self.slow = bool(slow)
        self.delay_ranges = delay_ranges
        self.jitter = float(jitter)
        self.fix_delay = bool(fix_delay)

        # pre-sample fixed delay if requested
        if self.fix_delay and self.delay_ranges is not None:
            (a_s, b_s), (a_f, b_f) = self.delay_ranges
            if self.slow:
                self.base_delay = random.uniform(float(a_s), float(b_s))
            else:
                self.base_delay = random.uniform(float(a_f), float(b_f))

        self.accelerator = _device_to_accelerator(self.device)

        # local test loader to compute per-client test metrics
        self.testloader = _testloader(cfg["data"]["data_dir"])

    def _to_list(self) -> List[torch.Tensor]:
        return state_to_list(self.lit.model.state_dict())

    def _from_list(self, arrs: List[torch.Tensor]) -> None:
        sd = self.lit.model.state_dict()
        new_sd = list_to_state(sd, arrs)
        self.lit.model.load_state_dict(new_sd, strict=True)
        self.lit.to(self.device)

    def _sleep_delay(self):
        # global delay from config (kept for backward compat)
        global_d = float(self.cfg.get("server_runtime", {}).get("client_delay", 0.0))

        # per-client base delay
        base = self.base_delay

        # if not fixed, resample each fit
        if not self.fix_delay and self.delay_ranges is not None:
            (a_s, b_s), (a_f, b_f) = self.delay_ranges
            if self.slow:
                base = random.uniform(float(a_s), float(b_s))
            else:
                base = random.uniform(float(a_f), float(b_f))

        # add +/- jitter
        jit = random.uniform(-self.jitter, self.jitter) if self.jitter > 0.0 else 0.0

        delay = max(0.0, global_d + base + jit)
        if delay > 0.0:
            time.sleep(delay)

    def fit_once(self, server) -> bool:
        # pull global
        params, version = server.get_global()
        self._from_list(params)

        # emulate heterogeneous device speed
        self._sleep_delay()

        # train for local_epochs; checkpoints disabled for async runs
        epochs = int(self.cfg["clients"]["local_epochs"])
        trainer = pl.Trainer(
            max_epochs=epochs,
            accelerator=self.accelerator,
            devices=1,
            enable_checkpointing=False,
            logger=False,
            enable_model_summary=False,
            num_sanity_val_steps=0,
            enable_progress_bar=False,
            callbacks=[],
        )
        start = time.time()
        trainer.fit(self.lit, train_dataloaders=self.loader)
        duration = time.time() - start

        # local metrics
        train_loss, train_acc = self.lit.get_epoch_metrics()
        test_loss, test_acc = _evaluate(self.lit.model, self.testloader, self.device)

        new_params = self._to_list()
        num_examples = len(self.loader.dataset)

        server.submit_update(
            client_id=self.cid,
            base_version=version,
            new_params=new_params,
            num_samples=num_examples,
            train_time_s=duration,
            train_loss=train_loss,
            train_acc=train_acc,
            test_loss=test_loss,
            test_acc=test_acc,
        )
        return not server.should_stop()


In [None]:
# ==== run.py ====

# Orchestrator: partitions data, starts server, runs Lightning clients
import os
# ---- silence libraries before anything else ----
import logging, warnings
os.environ["TQDM_DISABLE"] = "1"
os.environ["PYTHONWARNINGS"] = "ignore"
os.environ["LIGHTNING_DISABLE_RICH"] = "1"
for name in [
    "pytorch_lightning", "lightning", "lightning.pytorch",
    "lightning_fabric", "lightning_utilities", "torch", "torchvision",
]:
    logging.getLogger(name).setLevel(logging.ERROR)
    logging.getLogger(name).propagate = False
logging.getLogger().setLevel(logging.WARNING)
warnings.filterwarnings("ignore")

import time
from typing import Dict, Any, List
import random
from concurrent.futures import ThreadPoolExecutor

import yaml


CFG_PATH = os.environ.get("FEDASYNC_CONFIG", str(Path(BASE_DIR) / "config_run.yaml"))


def load_cfg(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def main():
    cfg = load_cfg(CFG_PATH)

    # Reproducibility
    seed = int(cfg.get("seed", 42))
    set_seed(seed)
    random.seed(seed)

    # Partition dataset
    dd = DataDistributor(dataset_name=cfg["data"]["dataset"], data_dir=cfg["data"]["data_dir"])
    dd.distribute_data(
        num_clients=int(cfg["clients"]["total"]),
        alpha=float(cfg.get("partition_alpha", 0.5)),
        seed=seed,
    )

    # Build server with periodic eval/log and accuracy-based stopping
    global_model = build_resnet18(num_classes=cfg["data"]["num_classes"], pretrained=False)
    server = AsyncFedServer(
        global_model=global_model,
        alpha=float(cfg["async"]["alpha"]),
        target_accuracy=float(cfg["eval"]["target_accuracy"]),
        max_rounds=int(cfg["train"]["max_rounds"]) if "max_rounds" in cfg["train"] else None,
        eval_every_aggs=int(cfg["eval"].get("eval_every_aggs", 5)),
        data_dir=cfg["data"]["data_dir"],
        logs_dir=cfg["io"]["logs_dir"],
        global_log_csv=cfg["io"].get("global_log_csv"),
        client_participation_csv=cfg["io"].get("client_participation_csv"),
        final_model_path=cfg["io"].get("final_model_path"),
        num_classes=int(cfg["data"]["num_classes"]),
        device=get_device(),
    )

    # ---- derive per-client delays to simulate heterogeneity ----
    n = int(cfg["clients"]["total"])
    pct = max(0, min(100, int(cfg["clients"].get("struggle_percent", 0))))
    k_slow = (n * pct) // 100
    slow_ids = set(random.sample(range(n), k_slow)) if k_slow > 0 else set()

    a_s, b_s = cfg["clients"].get("delay_slow_range", [0.8, 2.0])
    a_f, b_f = cfg["clients"].get("delay_fast_range", [0.0, 0.2])
    fix_delays = bool(cfg["clients"].get("fix_delays_per_client", True))
    jitter = float(cfg["clients"].get("jitter_per_round", 0.0))

    per_client_base_delay: Dict[int, float] = {}
    if fix_delays:
        for cid in range(n):
            if cid in slow_ids:
                per_client_base_delay[cid] = random.uniform(float(a_s), float(b_s))
            else:
                per_client_base_delay[cid] = random.uniform(float(a_f), float(b_f))

    # Create clients
    clients: List[LocalAsyncClient] = []
    for cid in range(n):
        subset = dd.get_client_data(cid)
        base_delay = per_client_base_delay.get(cid, 0.0)
        is_slow = cid in slow_ids
        clients.append(LocalAsyncClient(
            cid=cid,
            cfg=cfg,
            subset=subset,
            base_delay=base_delay,
            slow=is_slow,
            delay_ranges=((float(a_s), float(b_s)), (float(a_f), float(b_f))),
            jitter=jitter,
            fix_delay=fix_delays,
        ))

    # Concurrency gate via thread pool
    def client_loop(client: LocalAsyncClient):
        try:
            while not server.should_stop():
                cont = client.fit_once(server)
                if not cont:
                    break
                time.sleep(0.05)
        except Exception:
            # ensure the orchestrator stops if any client fails
            server.mark_stop()
            raise

    with ThreadPoolExecutor(max_workers=int(cfg["clients"]["concurrent"])) as executor:
        futures = [executor.submit(client_loop, cl) for cl in clients]
        try:
            while not server.should_stop():
                if all(f.done() for f in futures):
                    # all clients exited (success or error) -> stop server loop
                    server.mark_stop()
                    break
                time.sleep(0.2)
        finally:
            server.mark_stop()
            for f in futures:
                f.result()



if __name__ == "__main__":
    main()


In [None]:
# ===== experiment.py =====

# Straggler sweep runner for FedAsync
import os
import time
import random
import logging
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterable, List
from concurrent.futures import ThreadPoolExecutor

import yaml


def load_cfg(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def _override_io(cfg: Dict[str, Any], exp_dir: Path) -> Dict[str, Any]:
    cfg = deepcopy(cfg)
    exp_dir = exp_dir.resolve()
    cfg["io"]["logs_dir"] = str(exp_dir)
    cfg["io"]["checkpoints_dir"] = str(exp_dir / "checkpoints")
    cfg["io"]["results_dir"] = str(exp_dir / "results")
    cfg["io"]["global_log_csv"] = str(exp_dir / "FedAsync.csv")
    cfg["io"]["client_participation_csv"] = str(exp_dir / "FedAsyncClientParticipation.csv")
    cfg["io"]["final_model_path"] = str(exp_dir / "results" / "FedAsyncModel.pt")
    return cfg


def run_once(cfg: Dict[str, Any]) -> None:
    # Silence noisy logs
    os.environ["TQDM_DISABLE"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"
    os.environ["LIGHTNING_DISABLE_RICH"] = "1"
    for name in [
        "pytorch_lightning", "lightning", "lightning.pytorch",
        "lightning_fabric", "lightning_utilities", "torch", "torchvision",
    ]:
        logging.getLogger(name).setLevel(logging.ERROR)
        logging.getLogger(name).propagate = False
    logging.getLogger().setLevel(logging.WARNING)
    warnings.filterwarnings("ignore")

    # Reproducibility
    seed = int(cfg.get("seed", 42))
    set_seed(seed)
    random.seed(seed)

    # Partition dataset
    dd = DataDistributor(dataset_name=cfg["data"]["dataset"], data_dir=cfg["data"]["data_dir"])
    dd.distribute_data(
        num_clients=int(cfg["clients"]["total"]),
        alpha=float(cfg.get("partition_alpha", 0.5)),
        seed=seed,
    )

    # Build server
    global_model = build_resnet18(num_classes=cfg["data"]["num_classes"], pretrained=False)
    server = AsyncFedServer(
        global_model=global_model,
        alpha=float(cfg["async"]["alpha"]),
        target_accuracy=float(cfg["eval"]["target_accuracy"]),
        max_rounds=int(cfg["train"]["max_rounds"]) if "max_rounds" in cfg["train"] else None,
        eval_every_aggs=int(cfg["eval"].get("eval_every_aggs", 5)),
        data_dir=cfg["data"]["data_dir"],
        logs_dir=cfg["io"]["logs_dir"],
        global_log_csv=cfg["io"].get("global_log_csv"),
        client_participation_csv=cfg["io"].get("client_participation_csv"),
        final_model_path=cfg["io"].get("final_model_path"),
        num_classes=int(cfg["data"]["num_classes"]),
        device=get_device(),
    )

    # Straggler sampling
    n = int(cfg["clients"]["total"])
    pct = max(0, min(100, int(cfg["clients"].get("struggle_percent", 0))))
    k_slow = (n * pct) // 100
    slow_ids = set(random.sample(range(n), k_slow)) if k_slow > 0 else set()

    a_s, b_s = cfg["clients"].get("delay_slow_range", [0.8, 2.0])
    a_f, b_f = cfg["clients"].get("delay_fast_range", [0.0, 0.2])
    fix_delays = bool(cfg["clients"].get("fix_delays_per_client", True))
    jitter = float(cfg["clients"].get("jitter_per_round", 0.0))

    per_client_base_delay: Dict[int, float] = {}
    if fix_delays:
        for cid in range(n):
            if cid in slow_ids:
                per_client_base_delay[cid] = random.uniform(float(a_s), float(b_s))
            else:
                per_client_base_delay[cid] = random.uniform(float(a_f), float(b_f))

    # Clients
    clients: List[LocalAsyncClient] = []
    for cid in range(n):
        subset = dd.get_client_data(cid)
        base_delay = per_client_base_delay.get(cid, 0.0)
        is_slow = cid in slow_ids
        clients.append(LocalAsyncClient(
            cid=cid,
            cfg=cfg,
            subset=subset,
            base_delay=base_delay,
            slow=is_slow,
            delay_ranges=((float(a_s), float(b_s)), (float(a_f), float(b_f))),
            jitter=jitter,
            fix_delay=fix_delays,
        ))

    def client_loop(client: LocalAsyncClient):
        try:
            while not server.should_stop():
                cont = client.fit_once(server)
                if not cont:
                    break
                time.sleep(0.05)
        except Exception:
            server.mark_stop()
            raise

    with ThreadPoolExecutor(max_workers=int(cfg["clients"]["concurrent"])) as executor:
        futures = [executor.submit(client_loop, cl) for cl in clients]
        try:
            while not server.should_stop():
                if all(f.done() for f in futures):
                    server.mark_stop()
                    break
                time.sleep(0.2)
        finally:
            server.mark_stop()
            for f in futures:
                f.result()


def straggler_sweep(
    base_cfg: Dict[str, Any],
    percents: Iterable[int],
    out_root: Path,
) -> None:
    for pct in percents:
        exp_dir = out_root / f"straggle_{pct}pct"
        cfg = deepcopy(base_cfg)
        cfg["clients"]["struggle_percent"] = int(pct)
        cfg = _override_io(cfg, exp_dir)
        exp_dir.mkdir(parents=True, exist_ok=True)
        print(f"[straggler_sweep] percent={pct}% -> logs at {exp_dir}")
        run_once(cfg)


if __name__ == "__main__":
    cfg_path = os.environ.get("FEDASYNC_CONFIG", str(Path(BASE_DIR) / "config.yaml"))
    base = load_cfg(cfg_path)
    out_root = Path(base["io"]["logs_dir"]) / "FedAsyncStragglerExp"
    out_root.mkdir(parents=True, exist_ok=True)
    straggler_sweep(base, percents=[10, 20, 30, 40, 50], out_root=out_root)
