In [2]:
# %% [markdown]
# Stage 3 - Federated Learning for Cross-Country Iris Recognition (SwinV2 Tiny, CPU-only)

# %%
import os, json, random, time, pathlib
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import copy


import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from sklearn.metrics import roc_auc_score, roc_curve

try:
    import timm
except Exception as e:
    timm = None
    print("⚠️ timm is not installed. Install with: pip install timm")

# ---------------------- Global Config ----------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

INPUT_SIZE   = (64, 512)    # (H, W)
BATCH_SIZE   = 32
EMBED_DIM    = 256
MARGIN       = 0.3
TEMPERATURE  = 0.07
CLIP_NORM    = 5.0

# PK
P_IDENTITIES = 16
K_IMAGES     = 2

# Paths 
RESULTS_ROOT = r"C:\Users\awais\OneDrive\Desktop\Thesis\Eye_Results"
SPLITS_DIR   = r"C:\Users\awais\OneDrive\Desktop\Thesis\splits\Eyes\FIHR"
MODELS_ROOT  = r"C:\Users\awais\OneDrive\Desktop\Thesis\Eye_Results"
FL_MODELS_ROOT = os.path.join(MODELS_ROOT, "FL_Stage3")
os.makedirs(FL_MODELS_ROOT, exist_ok=True)

DATASETS = {
    "Pakistan": r"C:\Users\awais\OneDrive\Desktop\Thesis\Pakistan\PakIris_Normalized_Aug",
    "China":    r"C:\Users\awais\OneDrive\Desktop\Thesis\China\CASIA-Iris-Interval_Normalized_Aug",
    "Czech":    r"C:\Users\awais\OneDrive\Desktop\Thesis\Czech\CzechIris_Normalized_Enhanced_FIHR",
    "India":    r"C:\Users\awais\OneDrive\Desktop\Thesis\India\IITD_Normalized",
    "Iraq":     r"C:\Users\awais\OneDrive\Desktop\Thesis\Iraq\AMF_Normalized_Enhanced_FIHR",
    "Malaysia": r"C:\Users\awais\OneDrive\Desktop\Thesis\Malaysia\MMU_Normalized_Enhanced_FIHR",
    "Iran":    r"C:\Users\awais\OneDrive\Desktop\Thesis\Iran\Iris_Normalized_Enhanced_FIHR"
}

def log(msg, level="INFO", dataset=None, model=None):
    prefix = f"[{level}]"
    if dataset and model: prefix += f" [{dataset} | {model}]"
    elif dataset:         prefix += f" [{dataset}]"
    elif model:           prefix += f" [{model}]"
    print(f"{prefix} {msg}")


Using device: cuda


In [3]:
# %% [markdown]
# Cell 2 - JSON train/val/test splits & mean/std

# %%
def load_json_split(dataset_name: str, dataset_root: str):
    """
    Loads train/val/test split from JSON in SPLITS_DIR.
    Keeps your original structure.
    """
    fp = os.path.join(SPLITS_DIR, f"{dataset_name.lower()}_split.json")
    if not os.path.exists(fp):
        raise FileNotFoundError(fp)
    with open(fp, 'r', encoding='utf-8') as f:
        data = json.load(f)

    def finalize(lst):
        items = []
        for itm in lst:
            if isinstance(itm, (list, tuple)):
                path, lbl = itm[0], itm[1] if len(itm) > 1 else None
            elif isinstance(itm, dict):
                path = itm.get('path') or itm.get('image') or next(iter(itm.values()))
                lbl  = itm.get('label')
            else:
                path, lbl = itm, None
            abs_path = path if os.path.isabs(path) else os.path.join(dataset_root, path)
            if lbl is None:
                lbl = pathlib.Path(abs_path).parent.name
            items.append((abs_path, lbl))
        return items

    train = finalize(data.get('train', []))
    val   = finalize(data.get('val', []))
    test  = finalize(data.get('test', []))

    classes = sorted({l for _, l in train + val + test})
    log(f"Loaded {dataset_name}: train={len(train)}, val={len(val)}, test={len(test)}, classes={len(classes)}",
        "SPLIT", dataset_name)
    return {'train': train, 'val': val, 'test': test, 'classes': classes}

def estimate_mean_std(items: List[Tuple[str,str]], max_samples: int = 512):
    sel = items if len(items) <= max_samples else random.sample(items, max_samples)
    m = 0.0; s2 = 0.0; n = 0
    for p,_ in sel:
        with Image.open(p) as im:
            arr = np.asarray(im.convert('L'), dtype=np.float32) / 255.0
        m_batch = arr.mean(); s2_batch = arr.var(); k = arr.size
        new_n = n + k
        delta = m_batch - m
        m = m + delta * (k / new_n)
        s2 = (n*s2 + k*s2_batch + (delta**2)*n*k/new_n) / new_n
        n = new_n
    return float(m), float(np.sqrt(max(s2, 1e-8)))


In [4]:
# %% [markdown]
# Cell 3 - Siamese Train Dataset, PKSampler, IndexDataset

# %%
from torchvision import transforms

class SiameseTrainSet(Dataset):
    """Grayscale 64×512 with augmentations for metric learning."""
    def __init__(self, items, label_to_idx, mean, std):
        self.items = items
        self.label_to_idx = label_to_idx
        self.mean, self.std = mean, std

        self.by_lbl: Dict[str, List[int]] = {}
        for i, (_, l) in enumerate(items):
            self.by_lbl.setdefault(l, []).append(i)

        self.tx = transforms.Compose([
            transforms.Grayscale(1),
            transforms.Resize(INPUT_SIZE),
            transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.9, 1.1)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([self.mean], [self.std]),
            transforms.RandomErasing(p=0.3, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
        ])

    def __len__(self): return len(self.items)

    def load(self, idx: int):
        p, l = self.items[idx]
        img = Image.open(p).convert('L')
        x = self.tx(img)  # (1,64,512)
        return x, self.label_to_idx[l]

class PKSampler(Sampler[List[int]]):
    """Yield batches as P identities × K images indices."""
    def __init__(self, items, P=16, K=2, shuffle=True):
        self.items = items
        self.P, self.K = P, K
        self.shuffle = shuffle
        self.by_lbl: Dict[str, List[int]] = {}
        for i, (_, l) in enumerate(items):
            self.by_lbl.setdefault(l, []).append(i)
        self.labels = [l for l, idxs in self.by_lbl.items() if len(idxs) >= K]

    def __iter__(self):
        labels = self.labels.copy()
        if self.shuffle: random.shuffle(labels)
        for i in range(0, len(labels), self.P):
            grp = labels[i:i+self.P]
            if len(grp) < self.P: break
            idxs = []
            for l in grp:
                pool = self.by_lbl[l]
                sel = random.sample(pool, self.K) if self.shuffle else pool[:self.K]
                idxs.extend(sel)
            yield idxs

    def __len__(self):
        return max(0, len(self.labels)//self.P)

class IndexDataset(Dataset):
    """Returns just an index i; used with PKSampler."""
    def __init__(self, n): self.n = n
    def __len__(self): return self.n
    def __getitem__(self, i): return int(i)


In [5]:
# %% [markdown]
# Cell 4 - Metric-learning losses + ROC/EER evaluation (with interpolation)

# %%
def pairwise_sim(z):
    z = F.normalize(z, dim=1)
    return z @ z.t()

def batch_hard_triplet_loss(z, y, margin=MARGIN):
    y = y.view(-1,1)
    S = pairwise_sim(z)
    D = torch.clamp(2 - 2*S, min=0.0)
    M = (y==y.t())
    pos = (D * M).where(M, torch.tensor(-1., device=D.device)).max(1).values
    neg = (D + 1e6*M.float()).min(1).values
    return F.relu(pos - neg + margin).mean()

def supervised_contrastive_loss(z, y, temperature=TEMPERATURE):
    z = F.normalize(z, dim=1)
    S = z @ z.t() / temperature
    y = y.view(-1,1)
    mask = (y==y.t()).float()
    logits_mask = torch.ones_like(mask) - torch.eye(mask.shape[0], device=mask.device)
    mask = mask * logits_mask
    log_prob = S - torch.logsumexp(S * logits_mask - 1e9*(1-logits_mask), dim=1, keepdim=True)
    pos_count = mask.sum(1).clamp_min(1.0)
    return (-(mask * log_prob).sum(1) / pos_count).mean()


# ---------- Interpolated EER and TAR@FAR helpers ----------

def _eer_from_fpr_tpr_interp(fpr, tpr):
    """
    EER via interpolation between the two points where fpr-fnr changes sign.
    Falls back to closest point if no sign change.
    """
    fnr = 1.0 - tpr
    diff = fpr - fnr

    # If never crosses zero, fall back to nearest point
    if np.all(diff >= 0) or np.all(diff <= 0):
        i = int(np.nanargmin(np.abs(diff)))
        return float(max(fpr[i], fnr[i]))

    # indices where sign changes between i and i+1
    idx = np.where(np.sign(diff[:-1]) != np.sign(diff[1:]))[0]
    i0 = int(idx[0])
    i1 = i0 + 1

    # linear interpolation on diff -> 0
    d0, d1 = diff[i0], diff[i1]
    t = d0 / (d0 - d1 + 1e-12)

    x0, x1 = fpr[i0], fpr[i1]
    eer = x0 + t * (x1 - x0)
    return float(eer)


def _tar_at_far_interp(fpr, tpr, far):
    """
    TAR at given FAR using linear interpolation on (fpr, tpr).
    """
    fpr = np.asarray(fpr)
    tpr = np.asarray(tpr)

    if fpr.size == 0:
        return float("nan")

    if far <= fpr[0]:
        return float(tpr[0])
    if far >= fpr[-1]:
        return float(tpr[-1])

    idx1 = int(np.searchsorted(fpr, far))
    idx0 = idx1 - 1

    x0, x1 = fpr[idx0], fpr[idx1]
    y0, y1 = tpr[idx0], tpr[idx1]

    if x1 == x0:
        return float(y0)

    alpha = (far - x0) / (x1 - x0)
    tar = y0 + alpha * (y1 - y0)
    return float(tar)


@torch.no_grad()
def compute_embeddings(model, items, mean, std, batch_size=BATCH_SIZE):
    """
    Evaluation pipeline:
    - grayscale
    - resize to (64, 512)
    - normalize with dataset-specific mean/std
    - L2-normalized embeddings from the FL model
    """
    tx = transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize(INPUT_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([mean], [std]),
    ])

    class _EvalSet(Dataset):
        def __init__(self, items):
            self.items = items
        def __len__(self):
            return len(self.items)
        def __getitem__(self, i):
            p, l = self.items[i]
            img = Image.open(p).convert("L")
            x = tx(img)
            return x, l

    pin = (DEVICE.type == "cuda")
    loader = DataLoader(
        _EvalSet(items),
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=pin,
    )

    model.eval()
    feats, labels = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=pin)
        z = model(xb).cpu()
        feats.append(z)
        labels.extend(yb)

    return torch.cat(feats).numpy(), np.array(labels)


def evaluate_embeddings(feats, labels):
    """
    Given embeddings and identity labels, computes:
    - ROC-AUC
    - EER (interpolated)
    - TAR@FAR=0.01 and 0.001 (interpolated)
    """
    if len(labels) < 2:
        return {
            "roc_auc": float("nan"),
            "eer": float("nan"),
            "tar1": float("nan"),
            "tar01": float("nan"),
        }

    # cosine similarities of L2-normalized embeddings
    Z = torch.tensor(feats, dtype=torch.float32)
    Z = F.normalize(Z, dim=1)
    S = (Z @ Z.t()).numpy()

    m = labels.reshape(-1, 1) == labels.reshape(1, -1)
    iu = np.triu_indices_from(S, 1)
    y_true = m[iu].astype(np.uint8)
    y_score = S[iu]

    if len(np.unique(y_true)) < 2:
        return {
            "roc_auc": float("nan"),
            "eer": float("nan"),
            "tar1": float("nan"),
            "tar01": float("nan"),
        }

    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc  = float(roc_auc_score(y_true, y_score))
    eer  = _eer_from_fpr_tpr_interp(fpr, tpr)
    tar1 = _tar_at_far_interp(fpr, tpr, 0.01)
    tar01= _tar_at_far_interp(fpr, tpr, 0.001)

    return {
        "roc_auc": auc,
        "eer": eer,
        "tar1": tar1,
        "tar01": tar01,
    }


In [6]:
# %% [markdown]
# Cell 5 - SwinV2 Tiny Siamese (timm)

# %%
class SwinV2_Tiny_Siamese(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super().__init__()
        if timm is None:
            raise RuntimeError("timm is required for SwinV2_Tiny_Siamese. Install via pip.")
        m = timm.create_model(
            'swinv2_tiny_window8_256',  # as used in centralized training
            pretrained=True,
            num_classes=0,
            in_chans=1,
            img_size=INPUT_SIZE,
        )
        feat_dim = m.num_features
        self.backbone = m
        self.proj = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.GELU(),
            nn.LayerNorm(feat_dim),
            nn.Linear(feat_dim, embed_dim),
        )
    def forward(self, x):
        f = self.backbone(x)
        z = self.proj(f)
        return F.normalize(z, p=2, dim=1)

def swin_model_fn():
    return SwinV2_Tiny_Siamese(embed_dim=EMBED_DIM)


In [7]:
# %% [markdown]
# Cell 6 - FL configs (ONLY FHIR with polynomial-learned weights)

from dataclasses import dataclass, field
import random
import numpy as np
import torch

# ---------------- Training ----------------
@dataclass
class FLTrainConfig:
    rounds: int = 200
    local_epochs: int = 1
    eval_every: int = 5
    patience: int = 20
    fp16_comms: bool = True

# ---------------- Optimizers (client) ----------------
@dataclass
class FLOptimConfig:
    lr: float = 1e-4
    weight_decay: float = 1e-4
    betas: tuple = (0.9, 0.999)

# ---------------- Loss ----------------
@dataclass
class FLLossConfig:
    supcon_weight: float = 0.5
    margin: float = MARGIN
    temperature: float = TEMPERATURE

# ---------------- FHIR + POLY settings ----------------
@dataclass
class FHIRPolyConfig:
    # Warmup: run OG-FHIR for 30-50 rounds to collect training data
    warmup_rounds: int = 50

    # Polynomial regression (degree 2)
    poly_degree: int = 2
    ridge_alpha: float = 1e-2  # regularization

    # how many (round,client) samples before first fit
    min_fit_samples: int = 80   # 7 clients * ~12 eval steps ≈ 84

    # score stability
    score_floor: float = 1e-6

# ---------------- OG-FHIR weights (warmup only) ----------------
@dataclass
class FHIRWeightConfig:
    alpha_size: float    = 0.35   # α log(1 + N_k)
    beta_div: float      = 0.25   # β (q_k / N_k)
    gamma_quality: float = 0.25   # γ Q_k
    delta_rarity: float  = 0.15   # δ (1 / sqrt(q_k))

@dataclass
class FLExperimentConfig:
    train: FLTrainConfig = field(default_factory=FLTrainConfig)
    optim: FLOptimConfig = field(default_factory=FLOptimConfig)
    loss: FLLossConfig   = field(default_factory=FLLossConfig)

    # Only one strategy now
    strategy: str = "fhir_poly"   # <- ONLY strategy used

    # Warmup OG-FHIR weights + polynomial learner config
    fhir: FHIRWeightConfig = field(default_factory=FHIRWeightConfig)
    poly: FHIRPolyConfig   = field(default_factory=FHIRPolyConfig)

def set_global_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


In [8]:
# %% [markdown]
# Cell 7 - Minimal server state (ONLY FHIR)

from dataclasses import dataclass
import copy
import torch

@dataclass
class ServerState:
    global_weights: dict
    round: int = 0

class Strategy:
    def init_state(self, model_fn):
        model = model_fn()
        return ServerState(global_weights=copy.deepcopy(model.state_dict()))

def build_strategy(name: str) -> Strategy:
    # only one supported strategy now
    if name != "fhir_poly":
        raise ValueError(f"Unsupported strategy: {name}. Use 'fhir_poly' only.")
    return Strategy()


In [9]:
# %% [markdown]
# Cell 8 - Prepare FL client data (splits, label maps, mean/std)

# %%
from collections import namedtuple
import copy

FLClientData = namedtuple("FLClientData", ["name", "train_items", "val_items", "test_items",
                                           "label_to_idx", "mean", "std"])

def prepare_fl_clients():
    splits = {d: load_json_split(d, r) for d, r in DATASETS.items()}
    mean_std_cache = {}
    clients_data = {}

    for ds_name, split in splits.items():
        label_to_idx = {l: i for i, l in enumerate(split['classes'])}

        if ds_name not in mean_std_cache:
            m, s = estimate_mean_std(split['train'])
            mean_std_cache[ds_name] = (m, s)
            log(f"[FL] mean/std for {ds_name} = ({m:.3f}, {s:.3f})",
                "FL-NORM", ds_name, "SwinV2_Tiny_Siamese")
        mean, std = mean_std_cache[ds_name]

        clients_data[ds_name] = FLClientData(
            name=ds_name,
            train_items=split['train'],
            val_items=split['val'],
            test_items=split['test'],
            label_to_idx=label_to_idx,
            mean=mean,
            std=std
        )
    return clients_data

fl_clients_data = prepare_fl_clients()


[SPLIT] [Pakistan] Loaded Pakistan: train=2304, val=288, test=288, classes=720
[SPLIT] [China] Loaded China: train=2106, val=256, test=277, classes=395
[SPLIT] [Czech] Loaded Czech: train=312, val=36, test=36, classes=128
[SPLIT] [India] Loaded India: train=1756, val=209, test=215, classes=425
[SPLIT] [Iraq] Loaded Iraq: train=440, val=50, test=50, classes=108
[SPLIT] [Malaysia] Loaded Malaysia: train=365, val=45, test=40, classes=90
[SPLIT] [Iran] Loaded Iran: train=633, val=80, test=79, classes=158
[FL-NORM] [Pakistan | SwinV2_Tiny_Siamese] [FL] mean/std for Pakistan = (0.361, 0.125)
[FL-NORM] [China | SwinV2_Tiny_Siamese] [FL] mean/std for China = (0.505, 0.267)
[FL-NORM] [Czech | SwinV2_Tiny_Siamese] [FL] mean/std for Czech = (0.366, 0.175)
[FL-NORM] [India | SwinV2_Tiny_Siamese] [FL] mean/std for India = (0.264, 0.195)
[FL-NORM] [Iraq | SwinV2_Tiny_Siamese] [FL] mean/std for Iraq = (0.367, 0.168)
[FL-NORM] [Malaysia | SwinV2_Tiny_Siamese] [FL] mean/std for Malaysia = (0.324, 0.185

In [9]:
# %% [markdown]
# Cell 9 - FLClient class (local training loop)

# %%
class FLClient:
    def __init__(self, client_data: FLClientData,
                 cfg: FLExperimentConfig,
                 model_fn,
                 prox_mu: float = 0.0):
        self.data = client_data
        self.cfg = cfg
        self.model_fn = model_fn
        self.prox_mu = prox_mu

        self.train_set = SiameseTrainSet(
            items=self.data.train_items,
            label_to_idx=self.data.label_to_idx,
            mean=self.data.mean,
            std=self.data.std
        )

    def _prox_term(self, model, global_weights):
        """FedProx / FHIR-Prox proximal regularization term."""
        if self.prox_mu <= 0.0:
            return torch.tensor(0., device=DEVICE)

        prox = torch.tensor(0., device=DEVICE)
        for (n, p) in model.named_parameters():
            if p.requires_grad:
                prox = prox + torch.sum((p - global_weights[n].to(p.device)) ** 2)
        return 0.5 * self.prox_mu * prox

    def train_one_round(self, global_weights: dict):
        device = DEVICE
        model = self.model_fn().to(device)
        model.load_state_dict(copy.deepcopy(global_weights), strict=True)
        model.train()

        opt_cfg = self.cfg.optim
        opt = torch.optim.AdamW(
            model.parameters(),
            lr=opt_cfg.lr,
            weight_decay=opt_cfg.weight_decay,
            betas=opt_cfg.betas
        )

        total_loss = 0.0
        total_batches = 0

        for _ in range(self.cfg.train.local_epochs):
            pk_sampler = PKSampler(
                self.data.train_items,
                P=P_IDENTITIES,
                K=K_IMAGES,
                shuffle=True,
            )

            # IMPORTANT: keep this single-process (num_workers=0)
            # to avoid DataLoader worker crashes on Windows / notebooks.
            train_loader = DataLoader(
                IndexDataset(len(self.data.train_items)),
                batch_sampler=pk_sampler,
                num_workers=0,      # <- no multiprocessing workers
                pin_memory=False,   # GPU still works; this only affects host→device copies
            )

            for batch_idx in train_loader:
                xb, yb = [], []
                for idx in batch_idx.tolist():
                    x, y = self.train_set.load(int(idx))
                    xb.append(x)
                    yb.append(y)

                x = torch.stack(xb).to(device)
                y = torch.tensor(yb, dtype=torch.long).to(device)

                opt.zero_grad(set_to_none=True)
                z = model(x)

                # Metric-learning losses
                l_trip = batch_hard_triplet_loss(
                    z, y,
                    margin=self.cfg.loss.margin
                )
                l_sup = supervised_contrastive_loss(
                    z, y,
                    temperature=self.cfg.loss.temperature
                )

                loss = l_trip + self.cfg.loss.supcon_weight * l_sup

                # FedProx / FHIR-Prox proximal term (0 if mu == 0)
                loss = loss + self._prox_term(model, global_weights)

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                opt.step()

                total_loss += float(loss)
                total_batches += 1

        avg_loss = total_loss / max(1, total_batches)
        updated = copy.deepcopy(model.state_dict())

        # Compress updates to fp16 for communication if enabled
        if self.cfg.train.fp16_comms:
            for k in updated:
                updated[k] = updated[k].to(torch.float16).cpu()

        return updated, avg_loss


In [10]:
# %% [markdown]
# Cell 10 - Global evaluation on val/test of each client

# %%
def evaluate_global_on_clients(model_fn, global_state_dict, fl_clients_data: dict, split_type="val"):
    results = {}
    for name, cdata in fl_clients_data.items():
        items = cdata.val_items if split_type == "val" else cdata.test_items
        model = model_fn().to(DEVICE)
        model.load_state_dict(global_state_dict, strict=True)
        feats, labels = compute_embeddings(model, items, cdata.mean, cdata.std)
        metrics = evaluate_embeddings(feats, labels)
        results[name] = metrics
    return results


In [11]:
# ============================================================
# Cell 11 - FL server training loop (ONLY FHIR warmup + POLY learned weights)
#   Improvements:
#   1) StandardScaler before poly expansion
#   2) Softplus positivity instead of clipping
#   3) Explicit weights_next_round (computed at eval, applied next round) -> no leakage
# ============================================================

import math
import copy
import numpy as np
import torch

from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.linear_model import Ridge

def _fmt_named(vals_by_name, order, nd=3):
    return ", ".join([f"{n}:{vals_by_name[n]:.{nd}f}" for n in order])

@torch.no_grad()
def _eval_one_client(model_fn, state_dict, cdata):
    """Evaluate ONLY on a single client's val split and return metrics dict."""
    model = model_fn().to(DEVICE)
    model.load_state_dict(state_dict, strict=True)
    feats, labels = compute_embeddings(model, cdata.val_items, cdata.mean, cdata.std)
    return evaluate_embeddings(feats, labels)

def _softplus_np(x):
    # numerically stable softplus: log(1 + exp(x))
    # for large x, softplus(x) ~ x
    x = np.asarray(x, dtype=np.float32)
    return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)

def train_federated(model_fn,
                    fl_clients_data: dict,
                    cfg: FLExperimentConfig):

    set_global_seed(SEED)
    strategy = build_strategy(cfg.strategy)
    state = strategy.init_state(model_fn)

    # ---------------- Build Clients ----------------
    clients = []
    for name, cdata in fl_clients_data.items():
        clients.append(FLClient(cdata, cfg, model_fn, prox_mu=0.0))

    client_names = [cl.data.name for cl in clients]
    num_clients = len(client_names)

    # ---------------- Precompute N_k, q_k and OG terms (for warmup only) ----------------
    stats_by_name = {}
    for cl in clients:
        n_k = len(cl.data.train_items)
        q_k = len(cl.data.label_to_idx)
        stats_by_name[cl.data.name] = {
            "n": n_k,
            "q": q_k,
            "size_score": math.log(1.0 + float(n_k)),
            "id_score": float(q_k) / max(float(n_k), 1.0),
            "rarity_score": 1.0 / math.sqrt(max(float(q_k), 1.0)),
        }

    # Warmup quality scores Q_k for OG-FHIR (only used during warmup)
    quality_scores = {n: 1.0 for n in client_names}

    # ---------------- Polynomial regression buffers ----------------
    # X_k = [EER, TAR1, TAR01, log1p(Nk), log1p(qk)]
    # y_k = utility = max(0, EER_global_on_client - EER_clientupd_on_client)
    X_buf = []
    y_buf = []

    # poly learner objects
    poly = PolynomialFeatures(degree=cfg.poly.poly_degree, include_bias=False)
    scaler = StandardScaler(with_mean=True, with_std=True)
    poly_model = None

    # cache weights to use in next round (NO leakage)
    # Round 1 uses uniform by default
    weights_next_round = {n: 1.0 / num_clients for n in client_names}

    def _fit_poly_if_ready():
        nonlocal poly_model, scaler
        if len(X_buf) < int(cfg.poly.min_fit_samples):
            return False

        X = np.asarray(X_buf, dtype=np.float32)
        y = np.asarray(y_buf, dtype=np.float32)

        # scale -> poly -> ridge
        scaler.fit(X)
        Xs = scaler.transform(X)
        Xp = poly.fit_transform(Xs)

        reg = Ridge(alpha=float(cfg.poly.ridge_alpha), fit_intercept=True, random_state=SEED)
        reg.fit(Xp, y)
        poly_model = reg
        return True

    def _predict_scores(metrics_by_client):
        """
        metrics_by_client: dict[name] = {eer, tar1, tar01}
        returns score_by_client (positive scores)
        """
        if poly_model is None:
            # uniform until fitted
            return {n: 1.0 for n in client_names}

        X = []
        for n in client_names:
            m = metrics_by_client[n]
            st = stats_by_name[n]
            X.append([
                float(m["eer"]),
                float(m["tar1"]),
                float(m["tar01"]),
                float(np.log1p(st["n"])),
                float(np.log1p(st["q"])),
            ])
        X = np.asarray(X, dtype=np.float32)

        # IMPORTANT: use same scaler as training
        Xs = scaler.transform(X)
        Xp = poly.transform(Xs)  # poly already fitted in _fit_poly_if_ready

        raw = poly_model.predict(Xp).astype(np.float32)

        # Softplus -> strictly positive, smooth
        pos = _softplus_np(raw) + float(cfg.poly.score_floor)

        return {n: float(v) for n, v in zip(client_names, pos.tolist())}

    # ---------------------------------------------------------------
    best_state = copy.deepcopy(state)
    best_round = 0
    best_macro_eer = float("inf")
    bad_rounds = 0

    # ---------------- Federated Rounds ----------------
    for r in range(cfg.train.rounds):
        round_id = r + 1
        log(f"==== [FL] Round {round_id}/{cfg.train.rounds} | strategy={cfg.strategy} ====", "FL")

        # snapshot of global before local training (used for utility labels)
        global_pre = {k: v.to(torch.float32) for k, v in state.global_weights.items()}

        # ---- Local training ----
        client_updates = []
        for client in clients:
            upd, avg_loss = client.train_one_round(global_pre)
            log(f"Client {client.data.name}: avg_loss={avg_loss:.4f}", "FL-CLIENT")
            for k in upd:
                upd[k] = upd[k].to(torch.float32)
            client_updates.append(upd)

        # ---- Aggregation weights (NO leakage) ----
        in_warmup = (round_id <= int(cfg.poly.warmup_rounds))

        if in_warmup:
            # OG-FHIR warmup weights (computed from current Q_k)
            alpha = cfg.fhir.alpha_size
            beta  = cfg.fhir.beta_div
            gamma = cfg.fhir.gamma_quality
            delta = cfg.fhir.delta_rarity

            raw = {}
            for n in client_names:
                st = stats_by_name[n]
                w_k = (
                    alpha * st["size_score"] +
                    beta  * st["id_score"] +
                    gamma * float(quality_scores.get(n, 1.0)) +
                    delta * st["rarity_score"]
                )
                raw[n] = float(max(w_k, 1e-6))

            w_sum = sum(raw.values()) + 1e-12
            alphas = {n: raw[n] / w_sum for n in client_names}
            log(f"FHIR weights (WARMUP-OG): { _fmt_named(alphas, client_names, nd=4) }", "FL-WEIGHTS")

        else:
            # Poly weights for THIS round come from weights_next_round
            alphas = dict(weights_next_round)
            log(f"FHIR weights (POLY, from prev eval): { _fmt_named(alphas, client_names, nd=4) }", "FL-WEIGHTS")

        # ---- Apply aggregation ----
        new_global = {k: torch.zeros_like(v) for k, v in global_pre.items()}
        for n, upd in zip(client_names, client_updates):
            a_k = float(alphas[n])
            for k in new_global:
                new_global[k] += a_k * upd[k]

        state.global_weights = new_global
        state.round += 1

        # ---- Periodic Evaluation ----
        if (round_id) % cfg.train.eval_every == 0:
            log("Evaluating global model on validation splits...", "FL-EVAL")

            metrics = evaluate_global_on_clients(
                model_fn,
                state.global_weights,
                fl_clients_data,
                split_type="val"
            )

            for cname, m in metrics.items():
                log(f"  {cname}: EER={m['eer']:.4f}, "
                    f"TAR@1%={m['tar1']:.4f}, TAR@0.1%={m['tar01']:.4f}",
                    "FL-METRIC")

            # ---- During warmup: update Q_k (normalized EER) ----
            if in_warmup:
                eers = [m["eer"] for m in metrics.values() if np.isfinite(m["eer"])]
                if len(eers) > 0:
                    e_min = float(min(eers))
                    e_max = float(max(eers))
                    span = max(e_max - e_min, 1e-6)
                    for cname, m in metrics.items():
                        e = m["eer"]
                        if np.isfinite(e):
                            q_norm = 1.0 - (float(e) - e_min) / span
                            quality_scores[cname] = float(q_norm)
                    log(f"Quality Q_k (warmup OG): { _fmt_named(quality_scores, client_names) }", "FL-Q")

            # ---- Collect regression samples (X,y) using EER utility ----
            log("Collecting POLY regression samples (EER utility)...", "FL-POLY")

            for n, upd in zip(client_names, client_updates):
                cdata = fl_clients_data[n]

                base_m = _eval_one_client(model_fn, global_pre, cdata)
                upd_m  = _eval_one_client(model_fn, upd, cdata)

                if not (np.isfinite(base_m["eer"]) and np.isfinite(upd_m["eer"])):
                    continue

                st = stats_by_name[n]
                Xk = [
                    float(metrics[n]["eer"]),
                    float(metrics[n]["tar1"]),
                    float(metrics[n]["tar01"]),
                    float(np.log1p(st["n"])),
                    float(np.log1p(st["q"])),
                ]
                yk = max(0.0, float(base_m["eer"]) - float(upd_m["eer"]))

                X_buf.append(Xk)
                y_buf.append(yk)

            log(f"POLY buffer size: {len(X_buf)} samples", "FL-POLY")

            # ---- Fit poly model once warmup is done and enough samples exist ----
            fitted_now = False
            if (round_id >= int(cfg.poly.warmup_rounds)) and (len(X_buf) >= int(cfg.poly.min_fit_samples)):
                fitted_now = _fit_poly_if_ready()
                if fitted_now:
                    log("POLY model: fitted/updated (scaled -> degree-2 poly -> ridge).", "FL-POLY")

            # ---- Compute weights for NEXT round (no leakage) ----
            if (round_id >= int(cfg.poly.warmup_rounds)) and (poly_model is not None):
                score_by = _predict_scores(metrics)
                s_sum = sum(score_by.values()) + 1e-12
                weights_next_round = {n: float(score_by[n] / s_sum) for n in client_names}
                log(f"POLY score s_k: { _fmt_named(score_by, client_names, nd=4) }", "FL-POLY")
                log(f"weights_next_round: { _fmt_named(weights_next_round, client_names, nd=4) }", "FL-POLY")
            elif (round_id >= int(cfg.poly.warmup_rounds)):
                # still not fitted -> keep uniform next round
                weights_next_round = {n: 1.0 / num_clients for n in client_names}

            # ---- Early stopping based on macro EER (global) ----
            valid_eers = [m["eer"] for m in metrics.values() if np.isfinite(m["eer"])]
            macro_eer = float(np.mean(valid_eers)) if len(valid_eers) else float("inf")
            log(f"  Macro-EER: {macro_eer:.4f}", "FL-MACRO")

            if np.isfinite(macro_eer) and macro_eer + 1e-5 < best_macro_eer:
                best_macro_eer = macro_eer
                best_round = round_id
                best_state = copy.deepcopy(state)
                bad_rounds = 0
                log(f"  ✅ New best macro-EER {best_macro_eer:.4f} at round {best_round}", "FL-BEST")
            else:
                bad_rounds += 1
                log(f"  No improvement. bad_rounds={bad_rounds}/{cfg.train.patience}", "FL-PATIENCE")
                if bad_rounds >= cfg.train.patience:
                    log(f"Early stopping at round {round_id}. "
                        f"Best macro-EER={best_macro_eer:.4f} @ round {best_round}",
                        "FL-STOP")
                    return best_state, {
                        "best_round": best_round,
                        "best_macro_eer": best_macro_eer,
                        "poly_samples": len(X_buf),
                        "poly_fitted": (poly_model is not None),
                    }

    return best_state, {
        "best_round": best_round,
        "best_macro_eer": best_macro_eer,
        "poly_samples": len(X_buf),
        "poly_fitted": (poly_model is not None),
    }


In [13]:
# ============================================================
# Cell 12 - FHIR (Warmup OG) -> POLY learned weights (EER utility): train + save best
# ============================================================

import torch
import os

fl_cfg = FLExperimentConfig()
fl_cfg.strategy = "fhir_poly"   # ONLY strategy now

# schedule
fl_cfg.train.rounds = 250
fl_cfg.train.local_epochs = 1
fl_cfg.train.eval_every = 1
fl_cfg.train.patience = 150
fl_cfg.loss.supcon_weight = 0.5

# warmup + poly settings
fl_cfg.poly.warmup_rounds = 50        
fl_cfg.poly.poly_degree = 2
fl_cfg.poly.ridge_alpha = 1e-2
fl_cfg.poly.min_fit_samples = 80

# OG-FHIR weights used ONLY during warmup
fl_cfg.fhir.alpha_size    = 0.35
fl_cfg.fhir.beta_div      = 0.25
fl_cfg.fhir.gamma_quality = 0.25
fl_cfg.fhir.delta_rarity  = 0.15

best_state, info = train_federated(
    model_fn=swin_model_fn,
    fl_clients_data=fl_clients_data,
    cfg=fl_cfg
)

print("✅ FHIR-POLY training finished.")
print(f"Best round: {info.get('best_round', None)}")
print(f"Best macro-EER: {info.get('best_macro_eer', None):.6f}")
print(f"Poly samples: {info.get('poly_samples', None)} | poly_fitted: {info.get('poly_fitted', None)}")

SAVE_PATH = r"C:\Users\awais\OneDrive\Desktop\Thesis\Eye_Results\FL_Stage3\FHIR_POLY_EER_best.pt"

global_weights = best_state.global_weights
global_weights_cpu = {k: v.cpu() for k, v in global_weights.items()}

torch.save({
    "global_model": global_weights_cpu,
    "best_round": info.get("best_round", None),
    "best_macro_eer": info.get("best_macro_eer", None),
    "strategy": "FHIR-POLY-EER",
    "poly_degree": fl_cfg.poly.poly_degree,
    "warmup_rounds": fl_cfg.poly.warmup_rounds,
}, SAVE_PATH)

print(f"💾 Model saved successfully at:\n{SAVE_PATH}")


[FL] ==== [FL] Round 1/250 | strategy=fhir_poly ====
[FL-CLIENT] Client Pakistan: avg_loss=1.0478
[FL-CLIENT] Client China: avg_loss=1.6263
[FL-CLIENT] Client Czech: avg_loss=1.5531
[FL-CLIENT] Client India: avg_loss=0.9323
[FL-CLIENT] Client Iraq: avg_loss=1.6854
[FL-CLIENT] Client Malaysia: avg_loss=1.9171
[FL-CLIENT] Client Iran: avg_loss=1.7054
[FL-WEIGHTS] FHIR weights (WARMUP-OG): Pakistan:0.1622, China:0.1589, Czech:0.1266, India:0.1563, Iraq:0.1309, Malaysia:0.1275, Iran:0.1376
[FL-EVAL] Evaluating global model on validation splits...
[FL-METRIC]   Pakistan: EER=0.0748, TAR@1%=0.7384, TAR@0.1%=0.5301
[FL-METRIC]   China: EER=0.0847, TAR@1%=0.7170, TAR@0.1%=0.4721
[FL-METRIC]   Czech: EER=0.0589, TAR@1%=0.6944, TAR@0.1%=0.5833
[FL-METRIC]   India: EER=0.0348, TAR@1%=0.9351, TAR@0.1%=0.8798
[FL-METRIC]   Iraq: EER=0.0791, TAR@1%=0.6900, TAR@0.1%=0.5000
[FL-METRIC]   Malaysia: EER=0.0444, TAR@1%=0.7889, TAR@0.1%=0.4556
[FL-METRIC]   Iran: EER=0.1242, TAR@1%=0.7205, TAR@0.1%=0.3292