In [None]:
# 0) Environment setup (run on Colab)
# If you start a fresh Colab runtime you can uncomment below to install torch if necessary:
# !pip install torch torchvision --quiet
# --- Standard Library ---
import os
import time
import math
import random
from copy import deepcopy
from functools import partial
from collections import OrderedDict
from pathlib import Path
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple

# --- Third-Party Data & Plotting ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

# --- PyTorch Core ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.cuda.amp import autocast, GradScaler

# --- PyTorch Data Utilities ---
from torch.utils.data import Dataset, DataLoader, Subset, random_split

# --- Torchvision ---
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
from torchvision.models import resnet18, resnet50

# Note: 'CIFAR10' is available via 'datasets.CIFAR10'
from torchvision.datasets import CIFAR10

seed = 0\

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

In [None]:
def eval_global_on_test(model, test_loader, device):
    model = deepcopy(model).to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb).argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return correct / total if total > 0 else 0.0


In [None]:
# 1) Utilities: flattening, cosine similarity, pairwise stats
def flat_params_from_model(model):
  return parameters_to_vector([p for p in model.parameters() if p.requires_grad]).detach().cpu()

def set_flat_params_to_model(model, flat_vec):
  if isinstance(flat_vec, (list, np.ndarray)):
    flat_vec = torch.from_numpy(np.array(flat_vec))
  vector_to_parameters(flat_vec.to(next(model.parameters()).device), [p for p in model.parameters() if p.requires_grad])

def cosine_sim_np(a, b):
  a = np.asarray(a, dtype=float).ravel()
  b = np.asarray(b, dtype=float).ravel()
  denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
  return float(np.dot(a, b) / denom)

def pairwise_cosine_stats(list_of_flat_grads):
  n = len(list_of_flat_grads)
  sims = []
  for i in range(n):
    for j in range(i+1, n):
      sims.append(cosine_sim_np(list_of_flat_grads[i], list_of_flat_grads[j]))
  if len(sims) == 0:
    return 0.0, 0.0
  return float(np.min(sims)), float(np.mean(sims))

In [None]:
# 2) Fed-GGA utilities (seed-based perturbations & client scoring)
def sample_k_seeds(K, base_seed=None):
  rng = np.random.RandomState(base_seed)
  return [int(rng.randint(0, 2**31 - 1)) for _ in range(K)]


def get_heldout_split(domain_loaders, held_out):
  assert held_out in domain_loaders, f"{held_out} is not a valid domain!"

  test_loader = domain_loaders[held_out]
  train_loaders = [dl for name, dl in domain_loaders.items() if name != held_out]

  return train_loaders, test_loader, held_out

In [None]:
def set_global_seed(seed: int):
    seed = int(seed)

    # Python
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch (CPU)
    torch.manual_seed(seed)

    # PyTorch (CUDA)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # Ensure deterministic behavior (may slightly reduce performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
### Gradient Abstraction

@dataclass
class GradientDict:
    """Encapsulates gradients as a dictionary of tensors."""
    gradients: Dict[str, torch.Tensor]
    client_id: str
    
    def sign(self) -> Dict[str, torch.Tensor]:
        """Returns the sign of each gradient component."""
        return {name: torch.sign(grad) for name, grad in self.gradients.items()}
    
    def to_device(self, device):
        """Move all gradients to device."""
        self.gradients = {name: grad.to(device) for name, grad in self.gradients.items()}
        return self
    
    @staticmethod
    def from_model(model: nn.Module, client_id: str) -> 'GradientDict':
        """Extract gradients from model parameters."""
        gradients = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                gradients[name] = param.grad.clone().detach()
        return GradientDict(gradients=gradients, client_id=client_id)
    
    def apply_to_model(self, model: nn.Module, learning_rate: float):
        """Apply gradients to model parameters (gradient descent step)."""
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in self.gradients:
                    param.data -= learning_rate * self.gradients[name]


In [None]:
class AggregationStrategy(ABC):
    """Abstract base class for gradient aggregation strategies."""
    
    @abstractmethod
    def aggregate(self, gradients: List[GradientDict]) -> Dict[str, torch.Tensor]:
        """Aggregate gradients from multiple clients."""
        pass


class AgreementWeightedFedAvg(AggregationStrategy):
    """
    Agreement-Weighted Federated Averaging.
    
    Computes per-parameter agreement weights based on sign consensus:
    W_j = |Σ sign((g_i)_j)| / N
    
    Final update: g_avg ⊙ W (element-wise multiplication)
    """
    
    def __init__(self, verbose: bool = False):
        self.verbose = verbose
        self.last_agreement_weights = {}
    
    def compute_average_gradient(
        self, 
        gradients: List[GradientDict]
    ) -> Dict[str, torch.Tensor]:
        """Compute g_avg = (1/N) * Σ g_i"""
        n = len(gradients)
        averaged = {}
        
        param_names = gradients[0].gradients.keys()
        
        for name in param_names:
            grad_sum = sum(g.gradients[name] for g in gradients)
            averaged[name] = grad_sum / n
        
        return averaged
    
    def compute_agreement_weights(
        self, 
        gradients: List[GradientDict]
    ) -> Dict[str, torch.Tensor]:
        """
        Compute per-parameter agreement weights:
        W_j = |Σ sign((g_i)_j)| / N
        """
        n = len(gradients)
        weights = {}
        
        param_names = gradients[0].gradients.keys()
        
        for name in param_names:
            # Sum of signs for each parameter across all clients
            sign_sum = sum(torch.sign(g.gradients[name]) for g in gradients)
            
            # Absolute value normalized by number of clients
            weights[name] = torch.abs(sign_sum) / n
        
        return weights
    
    def aggregate(self, gradients: List[GradientDict]) -> Dict[str, torch.Tensor]:
        """
        Aggregate gradients using agreement weighting.
        
        Returns: g_avg ⊙ W
        """
        if not gradients:
            raise ValueError("Cannot aggregate empty gradient list")
        
        # Step 1: Compute average gradient
        g_avg = self.compute_average_gradient(gradients)
        
        # Step 2: Compute agreement weights
        weights = self.compute_agreement_weights(gradients)
        self.last_agreement_weights = weights
        
        # Step 3: Apply element-wise multiplication
        weighted_gradient = {
            name: g_avg[name] * weights[name]
            for name in g_avg.keys()
        }
        
        if self.verbose:
            for name in weights.keys():
                w = weights[name]
                print(f"  {name}: Agreement - Min: {w.min():.3f}, "
                      f"Max: {w.max():.3f}, Mean: {w.mean():.3f}")
        
        return weighted_gradient
    
    def get_agreement_statistics(self) -> Dict[str, Dict[str, float]]:
        """Get statistics about the last computed agreement weights."""
        stats = {}
        for name, weight in self.last_agreement_weights.items():
            stats[name] = {
                "min": float(weight.min()),
                "max": float(weight.max()),
                "mean": float(weight.mean()),
                "std": float(weight.std())
            }
        return stats

In [None]:
class PruningFedAvg(AggregationStrategy):
    def __init__(self, threshold: float = 0.3, patience: int = 5):
        self.threshold = threshold
        self.patience = patience
        self.pruning_rate = 0.0
        # Memory to store how many consecutive rounds a parameter has been in conflict
        self.consecutive_conflict_counts = {} 
        self.last_agreement_weights = {}

    def compute_agreement_weights(self, gradients: List[GradientDict]):
        n_clients = len(gradients)

        weights = {}
        for gradient in gradients:
            for name, grad in gradient.gradients.items():
                if name not in weights:
                    weights[name] = torch.zeros_like(grad)
                weights[name] += torch.sign(grad)
        for name in weights:
            weights[name] = torch.abs(weights[name]) / n_clients
        return weights

    def aggregate(self, gradients: List[GradientDict]):
        N = len(gradients)
        param_names = gradients[0].gradients.keys()
        
        # 1. Calculate standard average gradient
        g_avg = {name: sum(g.gradients[name] for g in gradients) / N for name in param_names}

        weights = self.compute_agreement_weights(gradients)

        self.last_agreement_weights = weights
        
        for name in g_avg.keys():
            g_avg[name] = g_avg[name] * weights[name]

        pruned_grads = {}
        total_params, pruned_params = 0, 0
        
        for name in param_names:
            # 2. Calculate Agreement W_j
            sign_sum = sum(torch.sign(g.gradients[name]) for g in gradients)
            agreement = torch.abs(sign_sum) / N
            
            # Initialize counter for this parameter layer if it doesn't exist
            if name not in self.consecutive_conflict_counts:
                self.consecutive_conflict_counts[name] = torch.zeros_like(agreement)

            # 3. Check for conflict: If agreement < threshold, increment count. Else, RESET to 0.
            # (agreement < self.threshold) creates a boolean mask
            has_conflict = (agreement < self.threshold)
            
            # Increment where there is conflict
            self.consecutive_conflict_counts[name] += has_conflict.float()
            
            # IMPORTANT: Reset count to 0 for any parameter that NOW reaches consensus
            self.consecutive_conflict_counts[name] *= has_conflict.float()

            # 4. Generate the Pruning Mask
            # Mask is 0 only if the conflict count has reached the 'patience' limit (e.g., 5)
            mask = (self.consecutive_conflict_counts[name] < self.patience).float()
            
            # 5. Apply Pruning
            pruned_grads[name] = g_avg[name] * mask
            
            # Track statistics
            total_params += mask.numel()
            pruned_params += (mask == 0).sum().item()
            
        pruning_rate = (pruned_params / total_params) * 100
        self.pruning_rate = pruning_rate
        return pruned_grads, pruning_rate


    def get_agreement_statistics(self) -> Dict[str, Dict[str, float]]:
        """Get statistics about the last computed agreement weights."""
        stats = {}
        for name, weight in self.last_agreement_weights.items():
            stats[name] = {
                "min": float(weight.min()),
                "max": float(weight.max()),
                "mean": float(weight.mean()),
                "std": float(weight.std())
            }
        return stats

In [None]:
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

train_transform_cifar = T.Compose([
    T.Resize(256),
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(imagenet_mean, imagenet_std)
])
test_transform_cifar = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(imagenet_mean, imagenet_std)
])


def dirichlet_partition_noniid(dataset_targets, num_clients, alpha, min_size=10, rng=None):
    """
    Partition indices of dataset_targets (list/np.array of class labels) into `num_clients` parts
    using Dirichlet concentration alpha (per-class distribution across clients).
    Returns a list of index lists: indices_per_client[k] = list of dataset indices for client k.
    """
    if rng is None:
        rng = np.random.RandomState(1234)
    labels = np.array(dataset_targets)
    num_classes = int(labels.max()) + 1
    idx_by_class = [np.where(labels == c)[0].tolist() for c in range(num_classes)]
    indices_per_client = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        idx_c = idx_by_class[c]
        if len(idx_c) == 0:
            continue
        # draw proportions
        proportions = rng.dirichlet([alpha] * num_clients)
        # Convert proportions to counts; ensure at least one sample per client sometimes
        # shuffle indices
        rng.shuffle(idx_c)
        # compute split sizes
        counts = (proportions * len(idx_c)).astype(int)
        # fix rounding errors: ensure sum(counts) == len(idx_c)
        diff = len(idx_c) - np.sum(counts)
        while diff > 0:
            # add one to the client with largest fractional remainder
            frac = proportions * len(idx_c) - counts
            idx = int(np.argmax(frac))
            counts[idx] += 1
            diff -= 1
        # assign
        pointer = 0
        for k in range(num_clients):
            cnt = counts[k]
            if cnt > 0:
                portion = idx_c[pointer:pointer+cnt]
                indices_per_client[k].extend(portion)
                pointer += cnt

    # Ensure minimum size per client (simple repair: move random samples if needed)
    for k in range(num_clients):
        if len(indices_per_client[k]) < min_size:
            # collect donors
            donors = [j for j in range(num_clients) if len(indices_per_client[j]) > min_size]
            for d in donors:
                if len(indices_per_client[k]) >= min_size:
                    break
                # move one sample from donor d
                indices_per_client[k].append(indices_per_client[d].pop())

    return indices_per_client


def build_cifar_clients(data_root, num_clients=10, alpha=0.1, batch_size=32, test_batch_size=256, num_workers=2, seed=0):
    """
    Loads CIFAR-10 and partitions training data among num_clients with Dirichlet(alpha).
    Returns:
      - clients: list of FedClient(name, train_loader, test_loader, device) instances (train loader per client)
      - global_test_loader: DataLoader on full CIFAR-10 test set
      - client_indices_list: list of lists of indices for reproducibility
    """
    set_global_seed(seed)
    # Download or use local copy
    cifar_train = CIFAR10(root=data_root, train=True, download=True, transform=train_transform_cifar)
    cifar_test = CIFAR10(root=data_root, train=False, download=True, transform=test_transform_cifar)
    
    # Get labels
    train_targets = [int(x) for x in cifar_train.targets]
    # partition
    indices_per_client = dirichlet_partition_noniid(train_targets, num_clients, alpha, rng=np.random.RandomState(seed))
    
    # Build DataLoaders per client
    clients = []
    for k in range(num_clients):
        idxs = indices_per_client[k]
        subset = Subset(cifar_train, idxs)
        loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        # name them client_0..client_{K-1}
        clients.append((f"client_{k}", loader))
    
    # global test loader
    test_loader = DataLoader(cifar_test, batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
    
    return clients, test_loader, indices_per_client

# Example usage (small example)
data_root = '/kaggle/working/cifar'
clients, test_loader, idxs = build_cifar_clients(data_root, num_clients=3, alpha=0.1, batch_size=32)

In [None]:
# ---- helper to select "head" params (works for your SqueezeNet variants) ----
def get_head_param_list_and_names(model):
    """
    Return (param_list, param_names) for the classifier/head parameters only.
    Works for:
      - model.fc (SqueezeNetLinear)
      - model.backbone.classifier[1] (SqueezeNetClassifier backbone conv)
      - last Linear layer named 'fc' anywhere else
    """
    # SqueezeNetLinear: model.fc
    if hasattr(model, "fc") and isinstance(model.fc, torch.nn.Linear):
        params = [p for p in model.fc.parameters() if p.requires_grad]
        names = ["fc." + n for n, _ in model.fc.named_parameters()]
        return params, names

    # SqueezeNetClassifier with backbone.classifier[1] conv
    if hasattr(model, "backbone") and hasattr(model.backbone, "classifier"):
        conv = model.backbone.classifier[1]
        params = [p for p in conv.parameters() if p.requires_grad]
        names = []
        # get full names relative to model
        for n, p in model.named_parameters():
            if n.startswith("backbone.classifier.1"):
                names.append(n)
        return params, names

    # fallback: last N parameters
    all_params = [p for p in model.parameters() if p.requires_grad]
    names = [n for n, _ in model.named_parameters() if _ is not None]
    if len(all_params) == 0:
        return [], []
    # choose last parameter group
    last = all_params[-1]
    return [last], [names[-1]]

# Uniform delta sampler (paper's U(-rho, rho) per-parameter, scaled by model norm)
def make_uniform_delta_from_seed(seed, prototype_vector, rho, device='cpu', scale_by_norm=False):
    # use a local generator to avoid global torch RNG side-effects
    gen = torch.Generator(device=device)
    gen.manual_seed(int(seed) & 0xffffffff)
    flat = prototype_vector.to(device)
    uni = (torch.rand(flat.shape, generator=gen, device=device) * 2.0 - 1.0) * rho
    if scale_by_norm:
        model_norm = torch.norm(flat) + 1e-12
        delta = uni * model_norm
    else:
        delta = uni
    return delta

In [None]:
def client_compute_scores_for_fedgga(model, loss_fn, data_loader, seeds, rho, device='cpu', 
                                     scale_by_norm=False, search_head_only=True):
    """
    Compute reference gradient on one small batch, then for each seed:
      - apply delta IN-PLACE to model parameters (fast)
      - forward/backward to get g_k (torch tensor on device)
      - compute cosine similarity on device between g_k and ref_grad
      - revert the delta IN-PLACE
    Returns:
      scores (list of float) length == len(seeds),
      ref_grad_numpy (np.array),
      loss_ref (float),
      losses_k (list of float) length == len(seeds)
    """
    local_model = deepcopy(model).to(device)
    local_model.train()

    # Get one small batch for scoring and move to device
    it = iter(data_loader)
    xs, ys = next(it)
    xs, ys = xs.to(device), ys.to(device)

    # Optionally micro-batch if you want cheaper scoring
    # micro_b = min(8, xs.shape[0])
    # xs, ys = xs[:micro_b], ys[:micro_b]

    # Select parameters that will be perturbed / measured
    # param_list = [p for p in local_model.parameters() if p.requires_grad]
    if search_head_only:
        param_list, param_names = get_head_param_list_and_names(local_model)
    else:
        param_list = [p for p in local_model.parameters() if p.requires_grad]
    if len(param_list) == 0:
        return [0.0] * len(seeds), np.zeros(1), 0.0, [0.0] * len(seeds)
        
    numels = [p.numel() for p in param_list]
    flat_theta = parameters_to_vector(param_list).detach().to(device)

    local_model.zero_grad()
    out = local_model(xs)
    loss_ref_tensor = loss_fn(out, ys)
    loss_ref = float(loss_ref_tensor.detach().cpu().item())
    loss_ref_tensor.backward()

    ref_grad_parts = []
    for p in param_list:
        g = p.grad
        if g is None:
            ref_grad_parts.append(torch.zeros(p.numel(), device=device))
        else:
            ref_grad_parts.append(g.detach().view(-1))
    ref_grad_t = torch.cat(ref_grad_parts)            # on device
    ref_grad_numpy = ref_grad_t.detach().cpu().numpy()

    scores = []
    losses_k = []

    # Helper: apply delta in-place and revert
    def apply_delta_inplace(delta_flat):
        offset = 0
        for p, n in zip(param_list, numels):
            seg = delta_flat[offset: offset + n].view_as(p.data)
            p.data.add_(seg)
            offset += n

    def revert_delta_inplace(delta_flat):
        offset = 0
        for p, n in zip(param_list, numels):
            seg = delta_flat[offset: offset + n].view_as(p.data)
            p.data.sub_(seg)
            offset += n

    # Loop over all candidate seeds -> produce one score per seed
    for seed in seeds:
        delta = make_uniform_delta_from_seed(seed, flat_theta, rho, device=device, scale_by_norm=scale_by_norm)
        apply_delta_inplace(delta)

        local_model.zero_grad()
        out_k = local_model(xs)
        loss_k_tensor = loss_fn(out_k, ys)
        loss_k = float(loss_k_tensor.detach().cpu().item())
        loss_k_tensor.backward()

        # collect gk as a single torch tensor (on device)
        gk_parts = []
        for p in param_list:
            g = p.grad
            if g is None:
                gk_parts.append(torch.zeros(p.numel(), device=device))
            else:
                gk_parts.append(g.detach().view(-1))
        gk_t = torch.cat(gk_parts)

        # cosine sim on device (use small eps)
        denom = (torch.norm(gk_t) * torch.norm(ref_grad_t) + 1e-12)
        sim_t = float(torch.dot(gk_t, ref_grad_t).item() / denom.item())

        scores.append(sim_t)
        losses_k.append(loss_k)

        # revert delta in-place
        revert_delta_inplace(delta)

    # Safety: ensure local_model params exactly restored (optional)
    set_flat_params_to_model(local_model, flat_theta)

    return scores, ref_grad_numpy, loss_ref, losses_k

def fedavg_from_state_dicts(state_dicts):
    """
    Given a list of PyTorch state_dict() objects (assumed identical keys),
    return a new state_dict that is the simple average of the tensors.
    """
    if len(state_dicts) == 0:
        raise ValueError("No state dicts provided to fedavg_from_state_dicts")
    n = len(state_dicts)
    keys = list(state_dicts[0].keys())
    new_sd = {}
    for k in keys:
        # sum up as float32 to avoid dtype issues
        accum = None
        for sd in state_dicts:
            v = sd[k].cpu().float()
            if accum is None:
                accum = v.clone()
            else:
                accum += v
        new_sd[k] = (accum / float(n))
    return new_sd


In [None]:
from torch import mode

class FedClient:
    def __init__(self, name, train_loader, test_loader, device):
        self.name = name
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

    def score_seeds(self, model, loss_fn, seeds, rho, scale_by_norm, search_head_only):
        return client_compute_scores_for_fedgga(model, loss_fn, self.train_loader, seeds, rho, device=self.device, 
                                                scale_by_norm=scale_by_norm, search_head_only=search_head_only)

    def local_update(self, global_model, local_epochs=1, lr=0.01, max_steps=None, use_amp=False):
        model = deepcopy(global_model).to(self.device)
        # opt = torch.optim.Adam(model.parameters(), lr=lr)
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        if len(trainable_params) == 0:
            return deepcopy(model.state_dict())
            
        opt = torch.optim.Adam(trainable_params, lr=lr)
            
        loss_fn = nn.CrossEntropyLoss()
        model.train()

        scaler = GradScaler() if (use_amp and self.device.startswith('cuda')) else None

        step = 0
        for _ in range(local_epochs):
            for xb, yb in self.train_loader:
                xb, yb = xb.to( self.device), yb.to( self.device)
                opt.zero_grad()
                if scaler is not None:
                    with autocast():
                        logits = model(xb)
                        loss = loss_fn(logits, yb)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
                else:
                    logits = model(xb)
                    loss = loss_fn(logits, yb)
                    loss.backward()
                    opt.step()
                step += 1
                if (max_steps is not None) and (step >= max_steps):
                    break
            if (max_steps is not None) and (step >= max_steps):
                break

        return deepcopy(model.state_dict())
    
    def compute_avg_gradient(self, global_model, local_epochs=1, max_batches=None, device=None):
        """
        Compute average gradients over local data WITHOUT applying optimizer steps.
        Use a deepcopy of `global_model` (server model) so we don't require client-local model attr.
        Returns: dict mapping parameter name -> gradient tensor (on CPU).
        """
        device = device or self.device
        # local copy of server/global model to avoid modifying server state
        model = deepcopy(global_model).to(device)
        model.train()

        loss_fn = nn.CrossEntropyLoss()

        accumulated = None
        batch_count = 0
        step = 0

        for ep in range(local_epochs):
            for xb, yb in self.train_loader:
                if (max_batches is not None) and (step >= max_batches):
                    break
                xb, yb = xb.to(device), yb.to(device)
                model.zero_grad()
                out = model(xb)
                loss = loss_fn(out, yb)
                loss.backward()

                # collect this batch grads (move to CPU to keep memory predictable)
                batch_grads = {}
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        batch_grads[name] = param.grad.detach().cpu().clone()

                if accumulated is None:
                    accumulated = {k: v.clone() for k, v in batch_grads.items()}
                else:
                    for k, v in batch_grads.items():
                        if k in accumulated:
                            accumulated[k] += v
                        else:
                            accumulated[k] = v.clone()

                batch_count += 1
                step += 1
            if (max_batches is not None) and (step >= max_batches):
                break

        if accumulated is None:
            return {}

        for k in accumulated:
            accumulated[k] = accumulated[k] / float(batch_count)

        return accumulated  


    def eval_on_test(self, model):
        model = deepcopy(model).to(self.device)
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x,y in self.test_loader:
                x,y = x.to(self.device), y.to(self.device)
                preds = model(x).argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        return correct / total if total > 0 else 0.0

class FedGGAServer:
    def __init__(self, server_model, clients, device, config, test_loader):
        self.model = deepcopy(server_model).to(device)
        self.clients = clients
        self.device = device
        self.config = config.copy()
        self.log = []
        self.W_history = []
        self.test_loader = test_loader


    def run(self):
        cfg = self.config
        loss_fn = nn.CrossEntropyLoss()
        extra_grad_evals = 0
        for rnd in range(cfg['rounds']):
            client_state_dicts = []
            client_flat_updates = []
            applied = False
            
            t0 = time.time()
            
            if cfg['R_start'] <= rnd <= cfg['R_end'] and cfg['enable_gga']:
                seeds = sample_k_seeds(cfg['K'], base_seed=cfg.get('base_seed', 1234) + rnd)
                # start_seed = cfg.get('base_seed', 1234)
                # current_round_seed = start_seed * 1000 + (rnd * cfg['K']) 
                # seeds = sample_k_seeds(cfg['K'], base_seed=current_round_seed)
                
                client_scores = []   # shape (n_clients, K)
                client_ref_grads = []
                client_ref_losses = []
                client_losses_k = [] # list of lists: for each client, list of K losses

                # print("clients starting")
                # Each client computes scores and returns ref_grad & losses
                for c in self.clients:
                    scores, ref_grad, ref_loss, losses_k = c.score_seeds(self.model, loss_fn, seeds, cfg['rho'], 
                                                             scale_by_norm=cfg.get('scale_by_norm', False),
                                                             search_head_only=cfg.get('search_last_layer_only', True))
                    client_scores.append(scores)
                    client_ref_grads.append(ref_grad)
                    client_ref_losses.append(ref_loss)
                    client_losses_k.append(losses_k)
                    extra_grad_evals += (1 + len(seeds)) * 1  # approx: 1 ref + K candidates per client

                # compute reference sim LB: min pairwise among client_ref_grads
                sim_ref_min, sim_ref_mean = pairwise_cosine_stats(client_ref_grads)
                LB = float(np.mean(client_ref_losses))  # average ref loss across clients

                # aggregate candidate scores/losses across clients
                arr_scores = np.stack(client_scores, axis=0)  # (n_clients, K)
                avg_scores = np.mean(arr_scores, axis=0)     # average of per-client sim proxies

                # --- START OF DIAGNOSTIC CHECK ---
                # agg_min = np.min(avg_scores)
                # agg_max = np.max(avg_scores)
                # agg_mean = np.mean(avg_scores)
                
                # print(f"[Diag Round {rnd}] Agreement (avg_scores): min={agg_min:.4e}, max={agg_max:.4e}, mean={agg_mean:.4e}")
                
                # if agg_mean < 0.1:
                #     print(f"DEBUG ALERT: Agreement weights are tiny (mean={agg_mean:.4f} << 0.1).")
                #     print("Suggestions: Relax 'loss_relax', use larger batches, or check if 'rho' is too large/small.")
                # --- END OF DIAGNOSTIC CHECK ---
                
                arr_losses_k = np.stack(client_losses_k, axis=0)  # (n_clients, K)
                avg_losses_k = np.mean(arr_losses_k, axis=0)

                # apply acceptance rule per candidate: avg_scores[k] > sim_ref_min AND avg_losses_k[k] - LB < loss_relax
                loss_relax = cfg.get('loss_relax', 0.1)
                accepted_indices = []
                for k_idx in range(len(seeds)):
                    if (avg_scores[k_idx] > sim_ref_min) and ((avg_losses_k[k_idx] - LB) < loss_relax):
                        accepted_indices.append(k_idx)

                # choose best among accepted by highest avg_score (if none accepted, choose best avg_score but only if not too much loss)
                if len(accepted_indices) > 0:
                    best_k = int(np.argmax(avg_scores[accepted_indices]))
                    best_k = accepted_indices[best_k]
                else:
                    # fallback: choose argmax avg_scores but require loss condition (if not satisfied, skip applying any delta)
                    best_k = int(np.argmax(avg_scores))
                    if not ((avg_losses_k[best_k] - LB) < loss_relax):
                        best_k = None

                # apply delta if best_k exists
                if best_k is not None:
                    best_seed = seeds[best_k]
                    if cfg.get('search_last_layer_only', True):
                        server_param_list, server_param_names = get_head_param_list_and_names(self.model)
                    else:
                        server_param_list = [p for p in self.model.parameters() if p.requires_grad]
                    
                    flat_theta = parameters_to_vector(server_param_list).detach().to(self.device)
                    delta = make_uniform_delta_from_seed(best_seed, flat_theta.cpu(), cfg['rho'], device=self.device,
                                                        scale_by_norm=cfg.get('scale_by_norm'))
                    beta = cfg.get('beta', 1.0)
                    new_flat = (flat_theta.to(self.device) + beta * delta.to(self.device)).clone()
                    
                    offset = 0
                    for p, n in zip(server_param_list, [p.numel() for p in server_param_list]):
                        seg = new_flat[offset: offset + n].view_as(p.data)
                        p.data.copy_(seg)
                        offset += n
                    # set_flat_params_to_model(self.model, new_flat)
                    applied = True
                else:
                    applied = False

                # diagnostics
                min_sim = sim_ref_min
                mean_sim = sim_ref_mean


            elif cfg['enable_dampening'] and (rnd >= cfg.get("D_start") and rnd < cfg.get("P_start")):
                aggregation = AgreementWeightedFedAvg(verbose=cfg.get('agg_verbose', False))
            
                # Use server model param order (only trainable params) to ensure consistent flattening
                server_param_names = [name for name, p in self.model.named_parameters() if p.requires_grad]
                param_numels = {name: int(p.numel()) for name, p in self.model.named_parameters() if p.requires_grad}
            
                # 1) compute avg gradients per client (CPU tensors) using the current server model snapshot
                gradient_dicts = []
                for c in self.clients:
                    grads_cpu = c.compute_avg_gradient(
                        global_model=self.model,
                        local_epochs=cfg['local_epochs'],
                        max_batches=cfg['max_client_steps'],
                        device=self.device
                    )
                    if not grads_cpu:
                        # client didn't return gradients (e.g. empty loader) -> skip
                        continue
            
                    # normalize/ensure grads_cpu has CPU torch.Tensor values for server_param_names
                    # fill missing params with zeros of correct size
                    grads_fixed = {}
                    for name in server_param_names:
                        if name in grads_cpu:
                            # ensure it's a torch.Tensor on CPU
                            t = grads_cpu[name]
                            if isinstance(t, np.ndarray):
                                t = torch.from_numpy(t)
                            grads_fixed[name] = t.detach().cpu().clone()
                        else:
                            grads_fixed[name] = torch.zeros(param_numels[name], dtype=torch.float32)
            
                    # minimal wrapper object expected by AgreementWeightedFedAvg (has .gradients and .client_id)
                    class _G:
                        def __init__(self, d, cid):
                            self.gradients = d
                            self.client_id = cid
                    gradient_dicts.append(_G(grads_fixed, c.name))
            
                # If no gradients, fallback to normal FedAvg local updates (keeps loop central)
                if len(gradient_dicts) == 0:
                    client_state_dicts = []
                    for c in self.clients:
                        sd = c.local_update(self.model, local_epochs=cfg['local_epochs'],
                                            lr=cfg['local_lr'], max_steps=cfg['max_client_steps'], use_amp=cfg['use_amp'])
                        client_state_dicts.append(sd)
                    new_sd = fedavg_from_state_dicts(client_state_dicts)
                    self.model.load_state_dict(new_sd)
            
                    pct_pruned = 0.0
                    damp_W_mean = None
                    min_sim, mean_sim = None, None
                else:
                    # 2) compute flattened numpy vectors (CPU) for pairwise similarity stats
                    flat_list = []
                    for g in gradient_dicts:
                        parts = []
                        for name in server_param_names:
                            t = g.gradients.get(name)
                            if t is None:
                                parts.append(np.zeros(param_numels[name], dtype=np.float32))
                            else:
                                parts.append(t.reshape(-1).numpy())
                        flat_vec = np.concatenate(parts).astype(np.float32)
                        flat_list.append(flat_vec)
            
                    # compute pairwise stats (min and mean) using your pairwise_cosine_stats utility
                    min_sim, mean_sim = pairwise_cosine_stats(flat_list)  # expects list of 1D numpy arrays
            
                    # 3) aggregate with AgreementWeightedFedAvg (works on CPU tensors)
                    weighted_grad = aggregation.aggregate(gradient_dicts)   # dict: param_name -> CPU tensor
            
                    # 4) apply aggregated gradient to server model using server_lr_damp (single server update)
                    server_lr_damp = cfg.get('server_lr') * 10
                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            if name in weighted_grad:
                                g_cpu = weighted_grad[name]  # CPU tensor
                                # if g_cpu is numpy, convert
                                if isinstance(g_cpu, np.ndarray):
                                    g_cpu = torch.from_numpy(g_cpu)
                                param.data.add_(-server_lr_damp * g_cpu.to(param.device))
            
                    # diagnostics: agreement stats (optional)
                    try:
                        damp_W_mean = float(np.mean([float(v.mean()) for v in aggregation.last_agreement_weights.values()])) \
                                     if hasattr(aggregation, 'last_agreement_weights') and aggregation.last_agreement_weights else None
                    except Exception:
                        damp_W_mean = None
            
                    pct_pruned = 0.0
            
            elif cfg['enable_pruning'] and (rnd >= cfg.get("P_start") and rnd <= cfg.get("rounds")):
                aggregation = PruningFedAvg(threshold=cfg.get('P_tolerance'), patience=cfg.get('P_patience'))
            
                # 1) collect avg gradients (CPU tensors) from each client using same centralized loop
                gradient_dicts = []
                param_order = None
                for c in self.clients:
                    grads_cpu = c.compute_avg_gradient(
                        global_model=self.model,
                        local_epochs=cfg.get('local_epochs', 1),
                        max_batches=cfg.get('max_client_steps', None),
                        device=self.device
                    )
                    if not grads_cpu:
                        continue
                    if param_order is None:
                        param_order = list(grads_cpu.keys())
                    class _G:
                        def __init__(self, d, cid):
                            self.gradients = d
                            self.client_id = cid
                    gradient_dicts.append(_G(grads_cpu, c.name))
            
                # If no gradients were collected, fallback to normal FedAvg round so loop remains centralized
                if len(gradient_dicts) == 0:
                    client_state_dicts = []
                    for c in self.clients:
                        sd = c.local_update(self.model, local_epochs=cfg['local_epochs'],
                                            lr=cfg['local_lr'], max_steps=cfg['max_client_steps'], use_amp=cfg['use_amp'])
                        client_state_dicts.append(sd)
                    new_sd = fedavg_from_state_dicts(client_state_dicts)
                    self.model.load_state_dict(new_sd)
                    pct_pruned = 0.0
                    damp_W_mean = None
                    min_sim, mean_sim = None, None
                else:
                    # 2) Flatten per-client average gradients (numpy) to compute pairwise similarity diagnostics
                    flat_list = []
                    for g in gradient_dicts:
                        parts = []
                        for name in param_order:
                            arr = g.gradients.get(name)
                            if arr is None:
                                parts.append(np.zeros(0, dtype=float))   # unlikely, fallback
                            else:
                                parts.append(arr.reshape(-1).numpy())
                        flat_list.append(np.concatenate(parts))
                    # compute pairwise stats (returns min_sim, mean_sim over clients)
                    min_sim, mean_sim = pairwise_cosine_stats(flat_list)
            
                    # 3) Aggregate with PruningFedAvg -> it returns (pruned_grads, pruning_rate)
                    pruned_grads, pruning_rate = aggregation.aggregate(gradient_dicts)
            
                    # 4) Apply pruned gradients to server model in-place using server_lr_prune
                    server_lr_prune = cfg.get('server_lr', 1e-3) * 10
                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            if name in pruned_grads:
                                g_cpu = pruned_grads[name]   # CPU tensor
                                param.data.add_(-server_lr_prune * g_cpu.to(self.device))
            
                    # diagnostics
                    pct_pruned = pruning_rate
                    try:
                        agree_stats = aggregation.get_agreement_statistics()
                        damp_W_mean = np.mean([s['mean'] for s in agree_stats.values()]) if len(agree_stats) > 0 else None
                    except Exception:
                        damp_W_mean = None
            
            
            
            else:
                flat_server = parameters_to_vector([p for p in self.model.parameters() if p.requires_grad]).detach().cpu().numpy()
                
                for c in self.clients:
                    sd = c.local_update(self.model, cfg['local_epochs'], cfg['local_lr'], cfg['max_client_steps'], cfg['use_amp'])
                    client_state_dicts.append(sd)
                    
                    # Compute update delta
                    client_model = deepcopy(self.model)
                    client_model.load_state_dict(sd)
                    flat_client = parameters_to_vector([p for p in client_model.parameters() if p.requires_grad]).detach().cpu().numpy()
                    client_flat_updates.append(flat_client - flat_server)
                
                min_sim, mean_sim = pairwise_cosine_stats(client_flat_updates)
                self.model.load_state_dict(fedavg_from_state_dicts(client_state_dicts))
                pct_pruned, damp_W_mean = 0.0, None

            # evaluate (clients' local test loaders might be used; for held-out evaluation evaluate separately)
            accs = [c.eval_on_test(self.model) for c in self.clients]
            avg_acc = float(np.mean(accs))

            # logging
            self.log.append({
                'round': rnd,
                'avg_client_acc': avg_acc,
                'min_pairwise_sim': min_sim,
                'mean_pairwise_sim': mean_sim,
                'applied_delta': bool(applied),
                'applied_delta': applied,
                'pct_pruned': pct_pruned,
                'damp_W_mean': damp_W_mean,
                'time': time.time() - t0,
                'extra_grad_evals_est': extra_grad_evals
            })

            if rnd%5==0:
              print(f"[R{rnd}] avg_acc={avg_acc:.4f} min_sim={min_sim} mean_sim={mean_sim} applied_delta={applied}")
              # print(f"[R{rnd}] avg_acc={avg_acc:.4f} min_sim={min_sim} mean_sim={mean_sim} applied_delta={applied} pct_pruned={pct_pruned:.3f}")

        return self.log

In [None]:
class SmallCNN(nn.Module):
  def __init__(self, num_classes=10):
      super().__init__()
      # input 3x224x224 (same preprocessing). You can shrink the input if desired.
      self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)   # out: 32x224x224
      self.bn1   = nn.BatchNorm2d(32)
      self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # out: 64x224x224
      self.bn2   = nn.BatchNorm2d(64)
      self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # out: 128x224x224
      self.bn3   = nn.BatchNorm2d(128)
      self.pool  = nn.MaxPool2d(2)  # halves spatial dims
      self.avgpool = nn.AdaptiveAvgPool2d((1,1))  # global pooling
      self.fc = nn.Linear(128, num_classes)

  def forward(self, x):
      x = F.relu(self.bn1(self.conv1(x)))
      x = self.pool(x)                    # 32 x 112 x 112
      x = F.relu(self.bn2(self.conv2(x)))
      x = self.pool(x)                    # 64 x 56 x 56
      x = F.relu(self.bn3(self.conv3(x)))
      x = self.pool(x)                    # 128 x 28 x 28
      x = self.avgpool(x)                 # 128 x 1 x 1
      x = x.view(x.size(0), -1)           # 128
      x = self.fc(x)                      # num_classes
      return x

In [None]:
import time
import pandas as pd
from copy import deepcopy

def eval_global_on_test(model, test_loader, device):
    model = deepcopy(model).to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb).argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return correct / total if total > 0 else 0.0

def run_fed_cifar_experiments(data_root,
                              num_clients=10,
                              seeds=None,
                              alpha=0.1,
                              optimizer='adam',
                              save_dir='/kaggle/working/fed_gga_cifar',
                              client_batch_size=32,
                              test_batch_size=256,
                              use_pretrained=True,
                              freeze_backbone=False):
    
    os.makedirs(save_dir, exist_ok=True)

    cfg = {}
    cfg.setdefault('rounds', 50)
    cfg.setdefault('R_start', 2)
    cfg.setdefault('R_end', 15)
    cfg.setdefault('D_start', 20)
    cfg.setdefault('P_start', 50)
    cfg.setdefault('P_tolerance', 0.2)
    cfg.setdefault('P_patience', 1)
    cfg.setdefault('K', 8)
    cfg.setdefault('rho', 1e-5)
    cfg.setdefault('beta', 0.3)
    cfg.setdefault('local_epochs', 2)
    cfg.setdefault('max_client_steps', 100)
    cfg.setdefault('local_lr',  1e-3)
    cfg.setdefault('server_lr', 1e-3)
    cfg.setdefault('enable_gga', True)
    cfg.setdefault('enable_dampening', True)
    cfg.setdefault('enable_pruning', False)
    cfg.setdefault('use_amp', True)
    cfg.setdefault('scale_by_norm', True)
    cfg.setdefault('search_last_layer_only', False)
    cfg.setdefault('loss_relax', 0.05)

    runs = []

    set_global_seed(seed)
    clients_list, global_test_loader, indices_per_client = build_cifar_clients(
        data_root,
        num_clients=num_clients,
        alpha=alpha,
        batch_size=client_batch_size,
        test_batch_size=test_batch_size,
        num_workers=2,
        seed=seed
    )
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    fed_clients = [FedClient(name, train_loader, global_test_loader, device) for name, train_loader in clients_list]
    
    model = SmallCNN().to(device)

    server = FedGGAServer(model, fed_clients, device, cfg.copy(), global_test_loader)

    t0 = time.time()
    run_log = server.run()
    run_time = time.time() - t0

    global_acc = eval_global_on_test(server.model, global_test_loader, device)
    print('evaluation done')

    client_sizes = [len(idxs) for idxs in indices_per_client]

    run_id = f"cifar_alpha{alpha}_clients{num_clients}_seed{seed}"
    pd.DataFrame(run_log).to_csv(os.path.join(save_dir, f"runlog_{run_id}.csv"), index=False)

    runs.append({
        'dataset': 'CIFAR10',
        'alpha': alpha,
        'num_clients': num_clients,
        'seed': seed,
        'optimizer': optimizer,
        'global_test_acc': float(global_acc),
        'time_s': float(run_time),
        'cfg_K': cfg['K'],
        'cfg_rho': cfg['rho'],
        'cfg_R_start': cfg['R_start'],
        'cfg_R_end': cfg['R_end'],
        'mean_client_size': float(np.mean(client_sizes)),
    })

    print(f"[seed {seed}] global_test_acc={global_acc:.4f} time={run_time:.1f}s mean_client_size={np.mean(client_sizes):.1f}")

    df = pd.DataFrame(runs)
    df.to_csv(os.path.join(save_dir, "per_run_results_cifar.csv"), index=False)

    summary = df.groupby(['alpha','num_clients'])['global_test_acc'].agg(['mean','std']).reset_index()

    table_rows = {}
    for _, row in summary.iterrows():
        k = f"alpha={row['alpha']}_clients={int(row['num_clients'])}"
        table_rows[k] = f"{float(row['mean']):.4f} ± {float(row['std']):.4f}"

    summary.to_csv(os.path.join(save_dir, "per_config_summary_cifar.csv"), index=False)
    print("\nSaved CIFAR run results to:", save_dir)

    return df, summary, table_rows, server.model

In [None]:
df, summary, table, trained_model = run_fed_cifar_experiments(
    data_root='/kaggle/working/cifar',
    seeds=[0],
    num_clients=3,
    alpha=0.1,
    save_dir='/kaggle/working/fed_gga_cifar'
)

In [None]:
# drop-in code: reads saved CSV runlog, plots acc vs mean_pairwise_sim, returns metrics
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def load_runlog_csv(path):
    """Load runlog CSV and normalize round column."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Runlog not found: {path}")
    df = pd.read_csv(path)
    if 'round' in df.columns:
        df['round'] = df['round'].astype(int)
        df = df.sort_values('round').reset_index(drop=True)
    return df

def select_accuracy_series(df, acc_priority=('global_accu','avg_client_acc')):
    """Return (acc_series (pd.Series indexed by round), acc_col_name)."""
    acc_col = None
    for c in acc_priority:
        if c in df.columns:
            acc_col = c
            break
    if acc_col is None:
        raise ValueError(f"No accuracy column found in runlog. Expected one of {acc_priority}")
    s = pd.to_numeric(df[acc_col], errors='coerce')
    # fill edge NaNs then interior via forward/backward fill
    s = s.fillna(method='ffill').fillna(method='bfill')
    if 'round' in df.columns:
        s.index = df['round'].values
    else:
        s.index = np.arange(len(s))
    return s, acc_col

def select_similarity_series(df, sim_col='mean_pairwise_sim'):
    """Return sim_series (pd.Series indexed by round). If not present, returns NaN series."""
    if sim_col in df.columns:
        s = pd.to_numeric(df[sim_col], errors='coerce')
        s = s.fillna(method='ffill').fillna(method='bfill')
        if 'round' in df.columns:
            s.index = df['round'].values
        else:
            s.index = np.arange(len(s))
    else:
        # create NaN series of same length/index as accuracy
        n = len(df)
        idx = df['round'].values if 'round' in df.columns else np.arange(n)
        s = pd.Series([np.nan]*n, index=idx)
    return s

def compute_metrics_from_series(acc_s, sim_s):
    """Compute final_acc, mean_pairwise_sim, auc (normalized by round span)."""
    if len(acc_s) == 0:
        raise ValueError("Empty accuracy series.")
    final_acc = float(acc_s.iloc[-1])
    # mean similarity ignoring NaNs
    mean_sim = float(sim_s.dropna().mean()) if sim_s.dropna().size > 0 else float('nan')
    # AUC (trapezoid) over rounds normalized by (last_round - first_round)
    rounds = np.asarray(acc_s.index, dtype=float)
    if len(rounds) >= 2:
        auc_raw = np.trapz(y=acc_s.values, x=rounds)
        denom = (rounds[-1] - rounds[0]) if (rounds[-1] - rounds[0]) > 0 else 1.0
        auc = float(auc_raw / denom)
    else:
        auc = float(acc_s.iloc[-1])
    return {'final_acc': final_acc, 'mean_pairwise_sim': mean_sim, 'auc': auc}

def plot_from_runlog(runlog_path, out_fig=None, title=None, show=True):
    """
    Read CSV at runlog_path, plot accuracy (left y) and mean_pairwise_sim (right y).
    Returns: metrics dict and (acc_series, sim_series, df) for further use.
    """
    df = load_runlog_csv(runlog_path)
    acc_s, acc_col = select_accuracy_series(df)
    sim_s = select_similarity_series(df, sim_col='mean_pairwise_sim')

    # metrics
    metrics = compute_metrics_from_series(acc_s, sim_s)

    # plotting
    fig, ax1 = plt.subplots(figsize=(9,5))
    rounds = acc_s.index.values

    # accuracy (left)
    ax1.plot(rounds, acc_s.values, marker='o', linewidth=2, label=f'Accuracy ({acc_col})')
    ax1.set_xlabel('round')
    ax1.set_ylabel('accuracy')
    ax1.grid(True)

    # similarity (right)
    ax2 = ax1.twinx()
    # if sim is all NaN, draw nothing and warn
    if np.all(np.isnan(sim_s.values)):
        ax2.text(0.5, 0.5, 'no similarity data', transform=ax2.transAxes, ha='center', va='center', alpha=0.7)
        ax2.set_ylabel('mean_pairwise_sim (N/A)')
    else:
        ax2.plot(rounds, sim_s.values, marker='x', linewidth=2, color='tab:orange', label='mean_pairwise_sim')
        ax2.set_ylabel('mean_pairwise_sim')

    # combined legend
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    print(lines_1)
    print(labels_1)
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    print(lines_2)
    print(labels_2)
    if lines_2:
        ax1.legend(lines_1 + lines_2, ['Accuracy (global model)'] + labels_2, loc='best')
    else:
        ax1.legend(lines_1, labels_1, loc='best')

    if title is None:
        title = os.path.basename(runlog_path)
    plt.title(title)
    plt.tight_layout()

    if out_fig is not None:
        os.makedirs(os.path.dirname(out_fig), exist_ok=True)
        plt.savefig(out_fig, dpi=200)
        print("Saved figure:", out_fig)

    if show:
        plt.show()
    else:
        plt.close(fig)

    # print metrics
    print("Metrics:")
    print(f"  final_acc = {metrics['final_acc']:.4f}")
    if not np.isnan(metrics['mean_pairwise_sim']):
        print(f"  mean_pairwise_sim = {metrics['mean_pairwise_sim']:.6f}")
    else:
        print("  mean_pairwise_sim = N/A")
    print(f"  auc (norm.) = {metrics['auc']:.6f}")

    return metrics, (acc_s, sim_s, df)


runlog_path = "/kaggle/input/comb1-exp/runlog_cifar_alpha0.1_clients3_seed0_comb1.csv"
metrics, (acc_s, sim_s, df) = plot_from_runlog(runlog_path,
                                              out_fig="/kaggle/working/fed_gga_cifar/plots/acc_vs_sim_seed0.png",
                                              show=True)

In [None]:
model_path = "/kaggle/working/fed_gga_cifar/final_server_model.pth"
torch.save(trained_model.state_dict(), model_path)
print("Saved final server model to:", model_path)