# FL-EHDS: Scalability Gap + Imaging 10-Seed Statistical Validation

Two experiment batches in one notebook:

**Part A — Scalability gap (9 experiments)**
Completes the 9 missing experiments in `checkpoint_imaging_scalability.json`
(Skin_Cancer K10/K20 s42 + chest_xray K20 s42, all 3 algos).

**Part B — 10-Seed Validation (90 experiments)**
FedAvg/Ditto/HPFL × 3 datasets × 10 seeds for Wilcoxon statistical test
on imaging (same as tabular Phase 4).

**Setup:** Runtime > Change runtime type > **T4 GPU**

**Total:** 9 + 90 = **99 experiments**

**Checkpoint:** Saved to Google Drive after **every round**, per-client, per-epoch.
If the session disconnects, re-run from Section 3 — it auto-resumes.

**Estimated time:** ~5 hours on T4 GPU

## 1. Setup Environment

In [None]:
# Mount Google Drive for persistent checkpoint storage
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_OUTPUT = '/content/drive/MyDrive/FL-EHDS-FLICS2026/colab_results'
os.makedirs(DRIVE_OUTPUT, exist_ok=True)
print(f'Drive output: {DRIVE_OUTPUT}')

In [None]:
# Check GPU
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    props = torch.cuda.get_device_properties(0)
    mem = getattr(props, 'total_memory', None) or getattr(props, 'total_mem', 0)
    print(f'Memory: {mem / 1e9:.1f} GB')

In [None]:
# Clone repository
!git clone https://github.com/FabioLiberti/FL-EHDS-FLICS2026.git /content/FL-EHDS-FLICS2026 2>/dev/null || (cd /content/FL-EHDS-FLICS2026 && git pull)
%cd /content/FL-EHDS-FLICS2026/fl-ehds-framework

# Install dependencies
!pip install -q scikit-learn scipy tqdm Pillow structlog cryptography grpcio aiohttp pydantic pyyaml

## 2. Download Datasets

In [None]:
!pip install -q kagglehub

import os
os.environ['KAGGLE_API_TOKEN'] = 'KGAT_edd561c1bc682c9ad06930bacd164431'

import kagglehub
print(f'kagglehub version: {kagglehub.__version__}')

In [None]:
%%time
import shutil, glob

cache_path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")
os.makedirs('data/chest_xray', exist_ok=True)
for item in ['train', 'test', 'val']:
    src = os.path.join(cache_path, 'chest_xray', item)
    if not os.path.exists(src):
        src = os.path.join(cache_path, item)
    dst = f'data/chest_xray/{item}'
    if os.path.exists(src) and not os.path.exists(dst):
        shutil.copytree(src, dst)
shutil.rmtree('data/chest_xray/__MACOSX', ignore_errors=True)
print('Chest X-Ray ready')

In [None]:
%%time
cache_path = kagglehub.dataset_download("fanconic/skin-cancer-malignant-vs-benign")
dst = 'data/Skin Cancer'
if not os.path.exists(dst):
    shutil.copytree(cache_path, dst)
print('Skin Cancer ready')

In [None]:
%%time
cache_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
os.makedirs('data/Brain_Tumor', exist_ok=True)
for root, dirs, files in os.walk(cache_path):
    for d in dirs:
        d_lower = d.lower()
        if d_lower in ['glioma', 'meningioma', 'pituitary', 'notumor', 'no_tumor', 'healthy']:
            target = 'healthy' if d_lower in ['notumor', 'no_tumor'] else d_lower
            src = os.path.join(root, d)
            dst_dir = f'data/Brain_Tumor/{target}'
            if not os.path.exists(dst_dir):
                shutil.copytree(src, dst_dir)
            else:
                for f in os.listdir(src):
                    sf, df = os.path.join(src, f), os.path.join(dst_dir, f)
                    if os.path.isfile(sf) and not os.path.exists(df):
                        shutil.copy2(sf, df)
print('Brain Tumor ready')

In [None]:
print('=== Dataset Summary ===')
for ds_name, ds_path in [('Chest X-Ray', 'data/chest_xray'),
                          ('Skin Cancer', 'data/Skin Cancer'),
                          ('Brain Tumor', 'data/Brain_Tumor')]:
    count = sum(1 for _ in glob.iglob(f'{ds_path}/**/*.*', recursive=True)
                if _.lower().endswith(('.jpg', '.jpeg', '.png')))
    subdirs = [d for d in os.listdir(ds_path) if os.path.isdir(os.path.join(ds_path, d))]
    print(f'  {ds_name:15s}: {count:5d} images, classes: {subdirs}')

## 3. Shared Utilities & Training Function

Run this cell once — used by both Part A and Part B.

In [None]:
import sys
import json
import time
import shutil
import tempfile
import traceback
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Optional, Any

import numpy as np
import torch

sys.path.insert(0, '/content/FL-EHDS-FLICS2026/fl-ehds-framework')

from terminal.fl_trainer import ImageFederatedTrainer, _detect_device

# ======================================================================
# Shared configuration
# ======================================================================

IMAGING_DATASETS = {
    "chest_xray": {"data_dir": "data/chest_xray", "num_classes": 2, "short": "CX"},
    "Brain_Tumor": {"data_dir": "data/Brain_Tumor", "num_classes": 4, "short": "BT"},
    "Skin_Cancer": {"data_dir": "data/Skin Cancer", "num_classes": 2, "short": "SC"},
}

IMAGING_CONFIG = dict(
    num_clients=5, num_rounds=20, local_epochs=2, batch_size=32,
    learning_rate=0.001, model_type="resnet18", is_iid=False, alpha=0.5,
    freeze_backbone=False, freeze_level=2, use_fedbn=True,
    use_class_weights=True, use_amp=True, mu=0.1,
)

DATASET_OVERRIDES = {"Brain_Tumor": {"learning_rate": 0.0005}}

EARLY_STOPPING = dict(enabled=True, patience=4, min_delta=0.003, min_rounds=8, metric="accuracy")

OUTPUT_DIR = Path(DRIVE_OUTPUT)

# ======================================================================
# Utilities
# ======================================================================

_log_file = None

def log(msg, also_print=True):
    ts = datetime.now().strftime("%H:%M:%S")
    line = f"[{ts}] {msg}"
    if also_print:
        print(line, flush=True)
    if _log_file:
        try:
            _log_file.write(line + "\n")
            _log_file.flush()
        except Exception:
            pass

def save_checkpoint(data, checkpoint_file):
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    path = OUTPUT_DIR / checkpoint_file
    bak = OUTPUT_DIR / (checkpoint_file + ".bak")
    data["metadata"]["last_save"] = datetime.now().isoformat()
    fd, tmp = tempfile.mkstemp(dir=str(OUTPUT_DIR), prefix=".ckpt_", suffix=".tmp")
    try:
        with os.fdopen(fd, "w") as f:
            json.dump(data, f, indent=2, default=str)
            f.flush()
            os.fsync(f.fileno())
        if path.exists():
            shutil.copy2(str(path), str(bak))
        os.replace(tmp, str(path))
    except Exception:
        try:
            os.unlink(tmp)
        except OSError:
            pass
        raise

def load_checkpoint(checkpoint_file):
    for p in [OUTPUT_DIR / checkpoint_file, OUTPUT_DIR / (checkpoint_file + ".bak")]:
        if p.exists():
            try:
                with open(p) as f:
                    return json.load(f)
            except (json.JSONDecodeError, IOError):
                continue
    return None

def _cleanup_gpu():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    import gc
    gc.collect()

class EarlyStoppingMonitor:
    def __init__(self, patience=4, min_delta=0.003, min_rounds=8, metric="accuracy"):
        self.patience = patience
        self.min_delta = min_delta
        self.min_rounds = min_rounds
        self.metric = metric
        self.best_value = -float('inf')
        self.best_round = 0
        self.counter = 0

    def check(self, round_num, metrics):
        value = metrics.get(self.metric, 0)
        if value > self.best_value + self.min_delta:
            self.best_value = value
            self.best_round = round_num
            self.counter = 0
        else:
            self.counter += 1
        if round_num < self.min_rounds:
            return False
        return self.counter >= self.patience

def _evaluate_per_client(trainer):
    model = trainer.global_model
    model.eval()
    per_client = {}
    with torch.no_grad():
        for cid in range(trainer.num_clients):
            X, y = trainer.client_test_data[cid]
            X_t = torch.FloatTensor(X).to(trainer.device) if isinstance(X, np.ndarray) else X.to(trainer.device)
            y_t = torch.LongTensor(y).to(trainer.device) if isinstance(y, np.ndarray) else y.to(trainer.device)
            correct = total = 0
            for i in range(0, len(y_t), 64):
                out = model(X_t[i:i+64])
                preds = out.argmax(dim=1)
                correct += (preds == y_t[i:i+64]).sum().item()
                total += len(y_t[i:i+64])
            per_client[str(cid)] = correct / total if total > 0 else 0.0
    return per_client

def _compute_fairness(per_client_acc):
    accs = list(per_client_acc.values())
    if not accs:
        return {}
    jain = (sum(accs)**2) / (len(accs) * sum(a**2 for a in accs)) if accs else 0
    sorted_a = sorted(accs)
    n = len(sorted_a)
    cumsum = np.cumsum(sorted_a)
    gini = (2 * sum((i+1)*v for i, v in enumerate(sorted_a))) / (n * cumsum[-1]) - (n+1)/n if cumsum[-1] > 0 else 0
    return {
        "mean": round(float(np.mean(accs)), 4),
        "std": round(float(np.std(accs)), 4),
        "min": round(float(min(accs)), 4),
        "max": round(float(max(accs)), 4),
        "jain_index": round(float(jain), 4),
        "gini": round(float(max(0, gini)), 4),
    }

# ======================================================================
# Training function (generic — works for both Part A and Part B)
# ======================================================================

def run_single_experiment(
    ds_name, data_dir, algorithm, seed, num_clients,
    config, es_config,
    exp_idx, total_exps,
    checkpoint_data, checkpoint_file, exp_key,
    trainer_ckpt_path,
):
    start = time.time()
    num_rounds = config["num_rounds"]

    cfg = {**config, "num_clients": num_clients}
    if ds_name in DATASET_OVERRIDES:
        cfg.update(DATASET_OVERRIDES[ds_name])

    trainer = ImageFederatedTrainer(
        data_dir=data_dir,
        num_clients=num_clients,
        algorithm=algorithm,
        local_epochs=cfg["local_epochs"],
        batch_size=cfg["batch_size"],
        learning_rate=cfg["learning_rate"],
        is_iid=cfg["is_iid"],
        alpha=cfg["alpha"],
        mu=cfg.get("mu", 0.1),
        seed=seed,
        model_type=cfg["model_type"],
        freeze_backbone=cfg.get("freeze_backbone", False),
        freeze_level=cfg.get("freeze_level"),
        use_fedbn=cfg.get("use_fedbn", False),
        use_class_weights=cfg.get("use_class_weights", True),
        use_amp=cfg.get("use_amp", True),
    )
    trainer.num_rounds = num_rounds

    es = EarlyStoppingMonitor(
        **{k: v for k, v in es_config.items() if k != "enabled"}
    ) if es_config.get("enabled") else None

    history = []
    best_acc = 0.0
    best_round = 0
    start_round = 0

    # Resume mid-experiment
    in_prog = checkpoint_data.get("in_progress") if checkpoint_data else None
    if (in_prog and in_prog.get("key") == exp_key
            and trainer_ckpt_path and Path(trainer_ckpt_path).exists()):
        try:
            start_round = trainer.load_checkpoint(trainer_ckpt_path)
            history = in_prog.get("history", [])
            best_acc = in_prog.get("best_acc", 0.0)
            best_round = in_prog.get("best_round", 0)
            if es and history:
                for h in history:
                    es.check(h["round"], {"accuracy": h["accuracy"]})
            log(f"  RESUMED from round {start_round} (best={best_acc:.1%})")
        except Exception as e:
            log(f"  WARNING: resume failed ({e}), restarting from R1")
            start_round = 0
            history = []
            best_acc = 0.0
            best_round = 0

    for r in range(start_round, num_rounds):
        rr = trainer.train_round(r)

        client_metrics = [
            {
                "client_id": cr.client_id,
                "train_loss": round(cr.train_loss, 6),
                "train_acc": round(cr.train_acc, 6),
                "num_samples": cr.num_samples,
                "epochs_completed": cr.epochs_completed,
                "epoch_metrics": cr.epoch_metrics or [],
            }
            for cr in rr.client_results
        ]

        metrics = {
            "round": r + 1,
            "accuracy": rr.global_acc,
            "loss": rr.global_loss,
            "f1": rr.global_f1,
            "precision": rr.global_precision,
            "recall": rr.global_recall,
            "auc": rr.global_auc,
            "time_seconds": round(rr.time_seconds, 2),
            "client_results": client_metrics,
        }
        history.append(metrics)

        if rr.global_acc > best_acc:
            best_acc = rr.global_acc
            best_round = r + 1

        log(f"[{exp_idx}/{total_exps}] {ds_name} | {algorithm} | K={num_clients} | s{seed} | "
            f"R{r+1}/{num_rounds} | Acc:{rr.global_acc:.1%} | Best:{best_acc:.1%}(r{best_round})")

        # Save checkpoint after EVERY round
        if trainer_ckpt_path:
            try:
                trainer.save_checkpoint(trainer_ckpt_path)
            except Exception:
                pass
        checkpoint_data["in_progress"] = {
            "key": exp_key,
            "dataset": ds_name,
            "algorithm": algorithm,
            "seed": seed,
            "num_clients": num_clients,
            "round": r + 1,
            "total_rounds": num_rounds,
            "best_acc": best_acc,
            "best_round": best_round,
            "history": history,
            "elapsed_seconds": round(time.time() - start, 1),
        }
        save_checkpoint(checkpoint_data, checkpoint_file)

        if es and es.check(r + 1, {"accuracy": rr.global_acc}):
            log(f"  -> Early stop at R{r+1} (best={best_acc:.1%} at r{best_round})")
            break

    per_client_acc = _evaluate_per_client(trainer)
    fairness = _compute_fairness(per_client_acc)
    elapsed = time.time() - start

    result = {
        "dataset": ds_name,
        "algorithm": algorithm,
        "seed": seed,
        "num_clients": num_clients,
        "history": history,
        "final_metrics": history[-1] if history else {},
        "per_client_acc": per_client_acc,
        "fairness": fairness,
        "runtime_seconds": round(elapsed, 1),
        "config": cfg,
        "stopped_early": es is not None and es.counter >= es.patience,
        "actual_rounds": len(history),
        "best_metrics": {"accuracy": best_acc, "round": best_round},
        "best_round": best_round,
    }

    checkpoint_data.pop("in_progress", None)
    if trainer_ckpt_path:
        try:
            Path(trainer_ckpt_path).unlink(missing_ok=True)
        except OSError:
            pass

    del trainer
    _cleanup_gpu()
    return result

print(f'Device: {_detect_device(None)}')
print('Shared utilities & training function loaded OK')

## 4. Part A — Complete Scalability Gap (9 experiments)

**9 missing experiments** from `checkpoint_imaging_scalability.json`:
- Skin_Cancer × FedAvg/Ditto/HPFL × K10 × seed=42 (3)
- Skin_Cancer × FedAvg/Ditto/HPFL × K20 × seed=42 (3)
- chest_xray × FedAvg/Ditto/HPFL × K20 × seed=42 (3)

Loads existing checkpoint from Drive (63 completed), adds the 9 missing, saves back.
If already complete, this cell finishes instantly.

In [None]:
# ======================================================================
# PART A: Scalability gap — 9 missing experiments
# ======================================================================

SCALE_CHECKPOINT = "checkpoint_imaging_scalability.json"
SCALE_LOG = "experiment_imaging_scalability.log"
SCALE_TRAINER_STATE = ".trainer_state_scalability.pt"

SCALE_ALGORITHMS = ["FedAvg", "Ditto", "HPFL"]
SCALE_K_VALUES = [10, 20]
SCALE_SEEDS = [42, 123, 456, 789]

_log_file = open(OUTPUT_DIR / SCALE_LOG, "a")

# Build the FULL scalability experiment list (same grid as original script)
scale_experiments = []
for k_val in SCALE_K_VALUES:
    for ds_name in IMAGING_DATASETS:
        for algo in SCALE_ALGORITHMS:
            for seed in SCALE_SEEDS:
                scale_experiments.append((ds_name, algo, seed, k_val))

total_scale = len(scale_experiments)

# Load existing checkpoint (should have 63 completed)
scale_data = load_checkpoint(SCALE_CHECKPOINT)
if scale_data:
    done = len(scale_data.get("completed", {}))
    log(f"PART A — Scalability: loaded {done}/{total_scale} completed")
else:
    scale_data = {
        "completed": {},
        "metadata": {
            "total_experiments": total_scale,
            "purpose": "Imaging scalability: K=10,20 for FedAvg/Ditto/HPFL (gap completion)",
            "algorithms": SCALE_ALGORITHMS,
            "k_values": SCALE_K_VALUES,
            "datasets": list(IMAGING_DATASETS.keys()),
            "seeds": SCALE_SEEDS,
            "start_time": datetime.now().isoformat(),
            "last_save": None,
            "version": "imaging_scalability_v2_gap",
        }
    }

# Count how many are actually missing
missing = []
for ds_name, algo, seed, k_val in scale_experiments:
    key = f"{ds_name}_{algo}_K{k_val}_s{seed}"
    if key not in scale_data.get("completed", {}):
        missing.append((ds_name, algo, seed, k_val, key))

log(f"PART A — {len(missing)} experiments to run (out of {total_scale} total)")

if missing:
    log(f"\n{'='*66}")
    log(f"  Part A: Scalability Gap Completion")
    log(f"  {len(missing)} missing experiments")
    log(f"{'='*66}")

    global_start_a = time.time()
    completed_a = 0
    trainer_ckpt_a = str(OUTPUT_DIR / SCALE_TRAINER_STATE)

    for idx, (ds_name, algo, seed, k_val, key) in enumerate(missing, 1):
        ds_info = IMAGING_DATASETS[ds_name]

        elapsed = time.time() - global_start_a
        eta = str(timedelta(seconds=int((len(missing) - completed_a) * elapsed / completed_a))) if completed_a > 0 else "calculating..."
        log(f"\n--- [A:{idx}/{len(missing)}] {ds_name} | {algo} | K={k_val} | s{seed} | ETA: {eta} ---")

        try:
            result = run_single_experiment(
                ds_name=ds_name, data_dir=ds_info["data_dir"],
                algorithm=algo, seed=seed, num_clients=k_val,
                config=IMAGING_CONFIG, es_config=EARLY_STOPPING,
                exp_idx=idx, total_exps=len(missing),
                checkpoint_data=scale_data, checkpoint_file=SCALE_CHECKPOINT,
                exp_key=key, trainer_ckpt_path=trainer_ckpt_a,
            )

            scale_data["completed"][key] = result
            completed_a += 1
            save_checkpoint(scale_data, SCALE_CHECKPOINT)

            best_acc = result.get("best_metrics", {}).get("accuracy", 0)
            es_info = f" ES@R{result['actual_rounds']}" if result.get("stopped_early") else ""
            log(f"--- Done: Best={best_acc:.1%}{es_info} | {result['runtime_seconds']:.0f}s | [{completed_a}/{len(missing)}] ---")

        except Exception as e:
            log(f"ERROR in {key}: {e}")
            traceback.print_exc()
            scale_data["completed"][key] = {
                "dataset": ds_name, "algorithm": algo, "seed": seed,
                "num_clients": k_val, "error": str(e),
            }
            save_checkpoint(scale_data, SCALE_CHECKPOINT)

    elapsed_a = time.time() - global_start_a
    log(f"\nPart A COMPLETED: {completed_a}/{len(missing)} in {timedelta(seconds=int(elapsed_a))}")
else:
    log("Part A: All scalability experiments already complete!")

total_scale_done = len(scale_data.get("completed", {}))
log(f"Scalability checkpoint: {total_scale_done}/{total_scale} total")

if _log_file:
    _log_file.close()
    _log_file = None

## 5. Part B — 10-Seed Statistical Validation (90 experiments)

**90 experiments** = 3 algos (FedAvg, Ditto, HPFL) × 3 datasets × 10 seeds

Enables Wilcoxon signed-rank test on imaging results (same as tabular Phase 4).
10 seeds provide sufficient statistical power for non-parametric testing.

Seeds: 42, 123, 456, 789, 999, 7, 13, 31, 67, 101

Checkpoint saved to Drive after **every training round**, with per-client and per-epoch metrics.
If Colab disconnects, re-run this cell — it auto-resumes.

In [None]:
# ======================================================================
# PART B: 10-Seed Statistical Validation — 90 experiments
# ======================================================================

SEEDS10_CHECKPOINT = "checkpoint_imaging_seeds10.json"
SEEDS10_LOG = "experiment_imaging_seeds10.log"
SEEDS10_TRAINER_STATE = ".trainer_state_seeds10.pt"

SEEDS10_ALGORITHMS = ["FedAvg", "Ditto", "HPFL"]
SEEDS10 = [42, 123, 456, 789, 999, 7, 13, 31, 67, 101]

_log_file = open(OUTPUT_DIR / SEEDS10_LOG, "a")

seeds10_experiments = []
for ds_name in IMAGING_DATASETS:
    for algo in SEEDS10_ALGORITHMS:
        for seed in SEEDS10:
            seeds10_experiments.append((ds_name, algo, seed))

total_seeds10 = len(seeds10_experiments)

seeds10_data = load_checkpoint(SEEDS10_CHECKPOINT)
if seeds10_data:
    done = len(seeds10_data.get("completed", {}))
    log(f"PART B — Seeds10: AUTO-RESUMED {done}/{total_seeds10} completed")
else:
    seeds10_data = {
        "completed": {},
        "metadata": {
            "total_experiments": total_seeds10,
            "purpose": "Imaging 10-seed statistical validation for Wilcoxon signed-rank test",
            "algorithms": SEEDS10_ALGORITHMS,
            "datasets": list(IMAGING_DATASETS.keys()),
            "seeds": SEEDS10,
            "num_clients": 5,
            "start_time": datetime.now().isoformat(),
            "last_save": None,
            "version": "imaging_seeds10_v1",
        }
    }

log(f"\n{'='*66}")
log(f"  Part B: Imaging 10-Seed Statistical Validation")
log(f"  {total_seeds10} experiments = {len(SEEDS10_ALGORITHMS)} algos x {len(IMAGING_DATASETS)} DS x {len(SEEDS10)} seeds")
log(f"  Algorithms: {SEEDS10_ALGORITHMS}")
log(f"  Seeds: {SEEDS10}")
log(f"  Device: {_detect_device(None)}")
log(f"{'='*66}")

global_start_b = time.time()
completed_b = len(seeds10_data.get("completed", {}))
trainer_ckpt_b = str(OUTPUT_DIR / SEEDS10_TRAINER_STATE)

for exp_idx, (ds_name, algo, seed) in enumerate(seeds10_experiments, 1):
    key = f"{ds_name}_{algo}_s{seed}"

    if key in seeds10_data.get("completed", {}):
        continue

    ds_info = IMAGING_DATASETS[ds_name]

    elapsed = time.time() - global_start_b
    remaining = total_seeds10 - completed_b
    if completed_b > 0:
        eta = str(timedelta(seconds=int(remaining * elapsed / completed_b)))
    else:
        eta = "calculating..."

    log(f"\n--- [B:{exp_idx}/{total_seeds10}] {ds_name} | {algo} | s{seed} | ETA: {eta} ---")

    try:
        result = run_single_experiment(
            ds_name=ds_name, data_dir=ds_info["data_dir"],
            algorithm=algo, seed=seed, num_clients=5,
            config=IMAGING_CONFIG, es_config=EARLY_STOPPING,
            exp_idx=exp_idx, total_exps=total_seeds10,
            checkpoint_data=seeds10_data, checkpoint_file=SEEDS10_CHECKPOINT,
            exp_key=key, trainer_ckpt_path=trainer_ckpt_b,
        )

        seeds10_data["completed"][key] = result
        completed_b += 1
        save_checkpoint(seeds10_data, SEEDS10_CHECKPOINT)

        best_acc = result.get("best_metrics", {}).get("accuracy", 0)
        es_info = f" ES@R{result['actual_rounds']}" if result.get("stopped_early") else ""
        log(f"--- Done: Best={best_acc:.1%}{es_info} | {result['runtime_seconds']:.0f}s | [{completed_b}/{total_seeds10}] ---")

    except Exception as e:
        log(f"ERROR in {key}: {e}")
        traceback.print_exc()
        seeds10_data["completed"][key] = {
            "dataset": ds_name, "algorithm": algo, "seed": seed,
            "error": str(e),
        }
        save_checkpoint(seeds10_data, SEEDS10_CHECKPOINT)

elapsed_b = time.time() - global_start_b
seeds10_data["metadata"]["end_time"] = datetime.now().isoformat()
seeds10_data["metadata"]["total_elapsed"] = elapsed_b
save_checkpoint(seeds10_data, SEEDS10_CHECKPOINT)

log(f"\n{'='*66}")
log(f"  PART B COMPLETED: {completed_b}/{total_seeds10}")
log(f"  Total time: {timedelta(seconds=int(elapsed_b))}")
log(f"{'='*66}")

if _log_file:
    _log_file.close()
    _log_file = None

## 6. Results Summary

In [None]:
import json, numpy as np

# --- Part A: Scalability ---
print('=' * 66)
print('  PART A: Scalability Results')
print('=' * 66)

ckpt_a = f'{DRIVE_OUTPUT}/checkpoint_imaging_scalability.json'
if os.path.exists(ckpt_a):
    with open(ckpt_a) as f:
        data_a = json.load(f)
    completed_a = data_a.get('completed', {})
    n_ok = sum(1 for v in completed_a.values() if 'error' not in v)
    n_err = sum(1 for v in completed_a.values() if 'error' in v)
    print(f'Completed: {n_ok}/{72} (errors: {n_err})')

    header = f'{"DS":<14} {"Algo":<10} {"K":>4} {"Best Acc":>10} {"Rounds":>8}'
    print(f'\n{header}')
    print('-' * len(header))
    for ds in ['chest_xray', 'Brain_Tumor', 'Skin_Cancer']:
        for algo in ['FedAvg', 'Ditto', 'HPFL']:
            for k_val in [10, 20]:
                accs = []
                for seed in [42, 123, 456, 789]:
                    key = f'{ds}_{algo}_K{k_val}_s{seed}'
                    r = completed_a.get(key, {})
                    if 'error' not in r and r:
                        accs.append(r.get('best_metrics', {}).get('accuracy', 0))
                if accs:
                    print(f'{ds:<14} {algo:<10} {k_val:>4} {100*np.mean(accs):>9.1f}% {"":>8}')
else:
    print('No scalability checkpoint found.')

# --- Part B: 10-Seed ---
print(f'\n{"=" * 66}')
print('  PART B: 10-Seed Statistical Validation Results')
print('=' * 66)

ckpt_b = f'{DRIVE_OUTPUT}/checkpoint_imaging_seeds10.json'
if os.path.exists(ckpt_b):
    with open(ckpt_b) as f:
        data_b = json.load(f)
    completed_b = data_b.get('completed', {})
    n_ok = sum(1 for v in completed_b.values() if 'error' not in v)
    n_err = sum(1 for v in completed_b.values() if 'error' in v)
    print(f'Completed: {n_ok}/{90} (errors: {n_err})')

    header = f'{"DS":<14} {"Algo":<10} {"Mean Acc":>10} {"Std":>8} {"Min":>8} {"Max":>8} {"ES":>4}'
    print(f'\n{header}')
    print('-' * len(header))
    for ds in ['chest_xray', 'Brain_Tumor', 'Skin_Cancer']:
        for algo in ['FedAvg', 'Ditto', 'HPFL']:
            accs = []
            es_count = 0
            for seed in [42, 123, 456, 789, 999, 7, 13, 31, 67, 101]:
                key = f'{ds}_{algo}_s{seed}'
                r = completed_b.get(key, {})
                if 'error' not in r and r:
                    accs.append(r.get('best_metrics', {}).get('accuracy', 0))
                    if r.get('stopped_early'):
                        es_count += 1
            if accs:
                print(f'{ds:<14} {algo:<10} {100*np.mean(accs):>9.1f}% {100*np.std(accs):>7.1f}% {100*min(accs):>7.1f}% {100*max(accs):>7.1f}% {es_count:>3}x')
            else:
                print(f'{ds:<14} {algo:<10} {"--":>10} {"--":>8} {"--":>8} {"--":>8} {"--":>4}')

    # Wilcoxon signed-rank test: HPFL vs FedAvg
    print(f'\n--- Wilcoxon Signed-Rank Test: HPFL vs FedAvg ---')
    try:
        from scipy.stats import wilcoxon
        for ds in ['chest_xray', 'Brain_Tumor', 'Skin_Cancer']:
            hpfl_accs = []
            fedavg_accs = []
            for seed in [42, 123, 456, 789, 999, 7, 13, 31, 67, 101]:
                h = completed_b.get(f'{ds}_HPFL_s{seed}', {})
                f = completed_b.get(f'{ds}_FedAvg_s{seed}', {})
                if h and f and 'error' not in h and 'error' not in f:
                    hpfl_accs.append(h.get('best_metrics', {}).get('accuracy', 0))
                    fedavg_accs.append(f.get('best_metrics', {}).get('accuracy', 0))
            if len(hpfl_accs) >= 5:
                stat, p = wilcoxon(hpfl_accs, fedavg_accs, alternative='greater')
                sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'ns'
                diff = 100 * (np.mean(hpfl_accs) - np.mean(fedavg_accs))
                print(f'  {ds:<14} HPFL-FedAvg = {diff:+.1f}pp  p={p:.4f}  {sig}')
            else:
                print(f'  {ds:<14} insufficient data ({len(hpfl_accs)} pairs)')
    except ImportError:
        print('  scipy not available for Wilcoxon test')
else:
    print('No seeds10 checkpoint found yet.')

## 7. Download Results

In [None]:
from google.colab import files

for fname in ['checkpoint_imaging_scalability.json',
              'checkpoint_imaging_seeds10.json',
              'experiment_imaging_scalability.log',
              'experiment_imaging_seeds10.log']:
    fpath = f'{DRIVE_OUTPUT}/{fname}'
    if os.path.exists(fpath):
        files.download(fpath)
        print(f'Downloaded: {fname}')
    else:
        print(f'Not found: {fname}')