# FL-EHDS: Imaging Byzantine Robustness

Tests Byzantine attack resilience on imaging federated learning.
Evaluates how adversarial clients (label_flip, noise, sign_flip) affect
model accuracy and whether defense mechanisms (Krum, Trimmed Mean) protect training.

**Setup:** Runtime > Change runtime type > **T4 GPU**

**Experiments:** 3 algos × 3 datasets × 4 scenarios × 3 seeds = **108 experiments**

**Checkpoint:** Saved to Google Drive after **every round** (~1-2 min granularity).
If the session disconnects, re-run from Section 3 — it auto-resumes.

**Estimated time:** ~8-10 hours on T4 GPU

## 1. Setup Environment

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_OUTPUT = '/content/drive/MyDrive/FL-EHDS-FLICS2026/colab_results'
os.makedirs(DRIVE_OUTPUT, exist_ok=True)
print(f'Drive output: {DRIVE_OUTPUT}')

In [None]:
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    props = torch.cuda.get_device_properties(0)
    mem = getattr(props, 'total_memory', None) or getattr(props, 'total_mem', 0)
    print(f'Memory: {mem / 1e9:.1f} GB')

In [None]:
!git clone https://github.com/FabioLiberti/FL-EHDS-FLICS2026.git /content/FL-EHDS-FLICS2026 2>/dev/null || (cd /content/FL-EHDS-FLICS2026 && git pull)
%cd /content/FL-EHDS-FLICS2026/fl-ehds-framework

!pip install -q scikit-learn scipy tqdm Pillow structlog cryptography grpcio aiohttp pydantic pyyaml

## 2. Download Datasets

In [None]:
!pip install -q kagglehub

import os
os.environ['KAGGLE_API_TOKEN'] = 'KGAT_edd561c1bc682c9ad06930bacd164431'

import kagglehub
print(f'kagglehub version: {kagglehub.__version__}')

In [None]:
%%time
import shutil, glob

cache_path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")
os.makedirs('data/chest_xray', exist_ok=True)
for item in ['train', 'test', 'val']:
    src = os.path.join(cache_path, 'chest_xray', item)
    if not os.path.exists(src):
        src = os.path.join(cache_path, item)
    dst = f'data/chest_xray/{item}'
    if os.path.exists(src) and not os.path.exists(dst):
        shutil.copytree(src, dst)
shutil.rmtree('data/chest_xray/__MACOSX', ignore_errors=True)
print('Chest X-Ray ready')

In [None]:
%%time
cache_path = kagglehub.dataset_download("fanconic/skin-cancer-malignant-vs-benign")
dst = 'data/Skin Cancer'
if not os.path.exists(dst):
    shutil.copytree(cache_path, dst)
print('Skin Cancer ready')

In [None]:
%%time
cache_path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")
os.makedirs('data/Brain_Tumor', exist_ok=True)
for root, dirs, files in os.walk(cache_path):
    for d in dirs:
        d_lower = d.lower()
        if d_lower in ['glioma', 'meningioma', 'pituitary', 'notumor', 'no_tumor', 'healthy']:
            target = 'healthy' if d_lower in ['notumor', 'no_tumor'] else d_lower
            src = os.path.join(root, d)
            dst_dir = f'data/Brain_Tumor/{target}'
            if not os.path.exists(dst_dir):
                shutil.copytree(src, dst_dir)
            else:
                for f in os.listdir(src):
                    sf, df = os.path.join(src, f), os.path.join(dst_dir, f)
                    if os.path.isfile(sf) and not os.path.exists(df):
                        shutil.copy2(sf, df)
print('Brain Tumor ready')

In [None]:
print('=== Dataset Summary ===')
for ds_name, ds_path in [('Chest X-Ray', 'data/chest_xray'),
                          ('Skin Cancer', 'data/Skin Cancer'),
                          ('Brain Tumor', 'data/Brain_Tumor')]:
    count = sum(1 for _ in glob.iglob(f'{ds_path}/**/*.*', recursive=True)
                if _.lower().endswith(('.jpg', '.jpeg', '.png')))
    subdirs = [d for d in os.listdir(ds_path) if os.path.isdir(os.path.join(ds_path, d))]
    print(f'  {ds_name:15s}: {count:5d} images, classes: {subdirs}')

## 3. Run Byzantine Experiments

**108 experiments** = 3 algos × 3 datasets × 4 scenarios (none + 3 attacks) × 3 seeds

Each scenario: 5 clients, client 0 is adversarial (20% Byzantine).
Defense: Krum (optimal for 1 Byzantine out of 5).

Checkpoint saved to Drive after **every training round**.
If Colab disconnects, re-run this cell — it auto-resumes.

In [None]:
import sys
import json
import time
import shutil
import tempfile
import traceback
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Optional, Any

import numpy as np
import torch

sys.path.insert(0, '/content/FL-EHDS-FLICS2026/fl-ehds-framework')

from terminal.fl_trainer import ImageFederatedTrainer, _detect_device
from core.byzantine_resilience import ByzantineConfig

# ======================================================================
# Configuration
# ======================================================================

ALGORITHMS = ["FedAvg", "Ditto", "HPFL"]
SEEDS = [42, 123, 456]

# Attack scenarios: (label, attack_type_for_client0)
# None = no attack (baseline with Krum defense active)
ATTACK_SCENARIOS = [
    ("no_attack", None),
    ("label_flip", "label_flip"),
    ("noise", "noise"),
    ("sign_flip", "sign_flip"),
]

NUM_CLIENTS = 5
NUM_BYZANTINE = 1  # Client 0 is adversarial (20%)

IMAGING_DATASETS = {
    "chest_xray": {"data_dir": "data/chest_xray", "num_classes": 2, "short": "CX"},
    "Brain_Tumor": {"data_dir": "data/Brain_Tumor", "num_classes": 4, "short": "BT"},
    "Skin_Cancer": {"data_dir": "data/Skin Cancer", "num_classes": 2, "short": "SC"},
}

IMAGING_CONFIG = dict(
    num_clients=NUM_CLIENTS, num_rounds=20, local_epochs=2, batch_size=32,
    learning_rate=0.001, model_type="resnet18", is_iid=False, alpha=0.5,
    freeze_backbone=False, freeze_level=2, use_fedbn=True,
    use_class_weights=True, use_amp=True, mu=0.1,
)

DATASET_OVERRIDES = {"Brain_Tumor": {"learning_rate": 0.0005}}

EARLY_STOPPING = dict(enabled=True, patience=4, min_delta=0.003, min_rounds=8, metric="accuracy")

OUTPUT_DIR = Path(DRIVE_OUTPUT)
CHECKPOINT_FILE = "checkpoint_imaging_byzantine.json"
LOG_FILE = "experiment_imaging_byzantine.log"
TRAINER_STATE_FILE = ".trainer_state_byzantine.pt"

print(f"Device: {_detect_device(None)}")
print(f"Experiments: {len(ALGORITHMS)} algos x {len(IMAGING_DATASETS)} ds x {len(ATTACK_SCENARIOS)} scenarios x {len(SEEDS)} seeds = {len(ALGORITHMS)*len(IMAGING_DATASETS)*len(ATTACK_SCENARIOS)*len(SEEDS)}")
print(f"Byzantine setup: {NUM_BYZANTINE}/{NUM_CLIENTS} adversarial clients, defense=Krum")
print(f"Checkpoint: {OUTPUT_DIR / CHECKPOINT_FILE}")

In [None]:
# ======================================================================
# Utilities
# ======================================================================

_log_file = None

def log(msg, also_print=True):
    ts = datetime.now().strftime("%H:%M:%S")
    line = f"[{ts}] {msg}"
    if also_print:
        print(line, flush=True)
    if _log_file:
        try:
            _log_file.write(line + "\n")
            _log_file.flush()
        except Exception:
            pass

def save_checkpoint(data):
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    path = OUTPUT_DIR / CHECKPOINT_FILE
    bak = OUTPUT_DIR / (CHECKPOINT_FILE + ".bak")
    data["metadata"]["last_save"] = datetime.now().isoformat()
    fd, tmp = tempfile.mkstemp(dir=str(OUTPUT_DIR), prefix=".ckpt_byz_", suffix=".tmp")
    try:
        with os.fdopen(fd, "w") as f:
            json.dump(data, f, indent=2, default=str)
            f.flush()
            os.fsync(f.fileno())
        if path.exists():
            shutil.copy2(str(path), str(bak))
        os.replace(tmp, str(path))
    except Exception:
        try:
            os.unlink(tmp)
        except OSError:
            pass
        raise

def load_checkpoint():
    for p in [OUTPUT_DIR / CHECKPOINT_FILE, OUTPUT_DIR / (CHECKPOINT_FILE + ".bak")]:
        if p.exists():
            try:
                with open(p) as f:
                    return json.load(f)
            except (json.JSONDecodeError, IOError):
                continue
    return None

def _cleanup_gpu():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    import gc
    gc.collect()

class EarlyStoppingMonitor:
    def __init__(self, patience=4, min_delta=0.003, min_rounds=8, metric="accuracy"):
        self.patience = patience
        self.min_delta = min_delta
        self.min_rounds = min_rounds
        self.metric = metric
        self.best_value = -float('inf')
        self.best_round = 0
        self.counter = 0

    def check(self, round_num, metrics):
        value = metrics.get(self.metric, 0)
        if value > self.best_value + self.min_delta:
            self.best_value = value
            self.best_round = round_num
            self.counter = 0
        else:
            self.counter += 1
        if round_num < self.min_rounds:
            return False
        return self.counter >= self.patience

def _evaluate_per_client(trainer):
    model = trainer.global_model
    model.eval()
    per_client = {}
    is_hpfl = trainer.algorithm == "HPFL"
    if is_hpfl and hasattr(trainer, '_hpfl_classifier_names'):
        saved_cls = {n: p.data.clone() for n, p in model.named_parameters()
                     if n in trainer._hpfl_classifier_names}
    else:
        is_hpfl = False
    with torch.no_grad():
        for cid in range(trainer.num_clients):
            if is_hpfl:
                for n, p in model.named_parameters():
                    if n in trainer._hpfl_classifier_names:
                        p.data.copy_(trainer.client_classifiers[cid][n])
            X, y = trainer.client_test_data[cid]
            X_t = torch.FloatTensor(X).to(trainer.device) if isinstance(X, np.ndarray) else X.to(trainer.device)
            y_t = torch.LongTensor(y).to(trainer.device) if isinstance(y, np.ndarray) else y.to(trainer.device)
            correct = total = 0
            for i in range(0, len(y_t), 64):
                out = model(X_t[i:i+64])
                preds = out.argmax(dim=1)
                correct += (preds == y_t[i:i+64]).sum().item()
                total += len(y_t[i:i+64])
            per_client[str(cid)] = correct / total if total > 0 else 0.0
    if is_hpfl:
        for n, p in model.named_parameters():
            if n in trainer._hpfl_classifier_names:
                p.data.copy_(saved_cls[n])
    return per_client

def _compute_fairness(per_client_acc):
    accs = list(per_client_acc.values())
    if not accs:
        return {}
    jain = (sum(accs)**2) / (len(accs) * sum(a**2 for a in accs)) if accs else 0
    sorted_a = sorted(accs)
    n = len(sorted_a)
    cumsum = np.cumsum(sorted_a)
    gini = (2 * sum((i+1)*v for i, v in enumerate(sorted_a))) / (n * cumsum[-1]) - (n+1)/n if cumsum[-1] > 0 else 0
    return {
        "mean": round(float(np.mean(accs)), 4),
        "std": round(float(np.std(accs)), 4),
        "min": round(float(min(accs)), 4),
        "max": round(float(max(accs)), 4),
        "jain_index": round(float(jain), 4),
        "gini": round(float(max(0, gini)), 4),
    }

print('Utilities loaded OK')

In [None]:
# ======================================================================
# Training function: Byzantine scenario
# ======================================================================

def run_single_byzantine(
    ds_name, data_dir, algorithm, seed, attack_label, attack_type,
    config, es_config,
    exp_idx, total_exps,
    checkpoint_data=None, exp_key=None,
    trainer_ckpt_path=None,
):
    start = time.time()
    num_rounds = config["num_rounds"]

    cfg = {**config}
    if ds_name in DATASET_OVERRIDES:
        cfg.update(DATASET_OVERRIDES[ds_name])

    # Byzantine config: Krum defense, 1 adversarial out of 5
    byz_config = ByzantineConfig(
        aggregation_rule="krum",
        num_byzantine=NUM_BYZANTINE,
        multi_krum_m=NUM_CLIENTS - NUM_BYZANTINE,
        enable_detection=True,
        detection_threshold=3.0,
    )

    trainer = ImageFederatedTrainer(
        data_dir=data_dir,
        num_clients=cfg["num_clients"],
        algorithm=algorithm,
        local_epochs=cfg["local_epochs"],
        batch_size=cfg["batch_size"],
        learning_rate=cfg["learning_rate"],
        is_iid=cfg["is_iid"],
        alpha=cfg["alpha"],
        mu=cfg.get("mu", 0.1),
        seed=seed,
        model_type=cfg["model_type"],
        freeze_backbone=cfg.get("freeze_backbone", False),
        freeze_level=cfg.get("freeze_level"),
        use_fedbn=cfg.get("use_fedbn", False),
        use_class_weights=cfg.get("use_class_weights", True),
        use_amp=cfg.get("use_amp", True),
        byzantine_config=byz_config,
    )
    trainer.num_rounds = num_rounds

    # Simulate attack on client 0 by flipping labels in training data
    if attack_type == "label_flip" and 0 in trainer.client_data:
        X, y = trainer.client_data[0]
        num_classes = len(set(y.tolist())) if hasattr(y, 'tolist') else len(set(y))
        y_flipped = (num_classes - 1) - y  # Flip labels
        trainer.client_data[0] = (X, y_flipped)
        log(f"  Attack: label_flip on client 0 ({len(y)} samples, {num_classes} classes)")
    elif attack_type in ("noise", "sign_flip"):
        log(f"  Attack: {attack_type} on client 0 (applied during aggregation)")
    else:
        log(f"  No attack (baseline with Krum defense)")

    es = EarlyStoppingMonitor(
        **{k: v for k, v in es_config.items() if k != "enabled"}
    ) if es_config.get("enabled") else None

    history = []
    best_acc = 0.0
    best_round = 0
    start_round = 0

    in_prog = checkpoint_data.get("in_progress") if checkpoint_data else None
    if (in_prog and in_prog.get("key") == exp_key
            and trainer_ckpt_path and Path(trainer_ckpt_path).exists()):
        try:
            start_round = trainer.load_checkpoint(trainer_ckpt_path)
            history = in_prog.get("history", [])
            best_acc = in_prog.get("best_acc", 0.0)
            best_round = in_prog.get("best_round", 0)
            if es and history:
                for h in history:
                    es.check(h["round"], {"accuracy": h["accuracy"]})
            log(f"  RESUMED from round {start_round} (best={best_acc:.1%})")
        except Exception as e:
            log(f"  WARNING: resume failed ({e}), restarting from R1")
            start_round = 0
            history = []
            best_acc = 0.0
            best_round = 0

    for r in range(start_round, num_rounds):
        rr = trainer.train_round(r)

        client_metrics = [
            {"client_id": cr.client_id, "train_loss": round(cr.train_loss, 6),
             "train_acc": round(cr.train_acc, 6), "num_samples": cr.num_samples,
             "epochs_completed": cr.epochs_completed}
            for cr in rr.client_results
        ]

        byz_info = {}
        if rr.byzantine_selected is not None:
            byz_info["selected"] = rr.byzantine_selected
        if rr.byzantine_rejected is not None:
            byz_info["rejected"] = rr.byzantine_rejected
        if rr.byzantine_trust_scores is not None:
            byz_info["trust_scores"] = [round(s, 4) for s in rr.byzantine_trust_scores]

        metrics = {
            "round": r + 1, "accuracy": rr.global_acc, "loss": rr.global_loss,
            "f1": rr.global_f1, "precision": rr.global_precision,
            "recall": rr.global_recall, "auc": rr.global_auc,
            "time_seconds": round(rr.time_seconds, 2),
            "client_results": client_metrics,
            "byzantine_info": byz_info,
        }
        history.append(metrics)

        if rr.global_acc > best_acc:
            best_acc = rr.global_acc
            best_round = r + 1

        byz_sel = f" sel={byz_info.get('selected', '?')}" if byz_info else ""
        log(f"[{exp_idx}/{total_exps}] {ds_name} | {algorithm} | {attack_label} | s{seed} | "
            f"R{r+1}/{num_rounds} | Acc:{rr.global_acc:.1%} | Best:{best_acc:.1%}(r{best_round}){byz_sel}")

        if checkpoint_data is not None and exp_key:
            if trainer_ckpt_path:
                try:
                    trainer.save_checkpoint(trainer_ckpt_path)
                except Exception:
                    pass
            checkpoint_data["in_progress"] = {
                "key": exp_key, "dataset": ds_name, "algorithm": algorithm,
                "seed": seed, "attack": attack_label,
                "round": r + 1, "total_rounds": num_rounds,
                "best_acc": best_acc, "best_round": best_round,
                "history": history, "elapsed_seconds": round(time.time() - start, 1),
            }
            save_checkpoint(checkpoint_data)

        if es and es.check(r + 1, {"accuracy": rr.global_acc}):
            log(f"  -> Early stop at R{r+1} (best={best_acc:.1%} at r{best_round})")
            break

    per_client_acc = _evaluate_per_client(trainer)
    fairness = _compute_fairness(per_client_acc)
    elapsed = time.time() - start

    result = {
        "dataset": ds_name, "algorithm": algorithm, "seed": seed,
        "attack": attack_label, "attack_type": attack_type,
        "num_byzantine": NUM_BYZANTINE, "defense": "krum",
        "history": history, "final_metrics": history[-1] if history else {},
        "per_client_acc": per_client_acc, "fairness": fairness,
        "runtime_seconds": round(elapsed, 1), "config": cfg,
        "stopped_early": es is not None and es.counter >= es.patience,
        "actual_rounds": len(history),
        "best_metrics": {"accuracy": best_acc, "round": best_round},
        "best_round": best_round,
    }

    if checkpoint_data is not None:
        checkpoint_data.pop("in_progress", None)
    if trainer_ckpt_path:
        try:
            Path(trainer_ckpt_path).unlink(missing_ok=True)
        except OSError:
            pass

    del trainer
    _cleanup_gpu()
    return result

print('Training function loaded OK')

In [None]:
# ======================================================================
# MAIN EXPERIMENT LOOP
# ======================================================================

_log_file = open(OUTPUT_DIR / LOG_FILE, "a")

experiments = []
for attack_label, attack_type in ATTACK_SCENARIOS:
    for ds_name in IMAGING_DATASETS:
        for algo in ALGORITHMS:
            for seed in SEEDS:
                experiments.append((attack_label, attack_type, ds_name, algo, seed))

total_exps = len(experiments)

checkpoint_data = load_checkpoint()
if checkpoint_data:
    done = len(checkpoint_data.get("completed", {}))
    log(f"AUTO-RESUMED: {done}/{total_exps} completed")
else:
    checkpoint_data = {
        "completed": {},
        "metadata": {
            "total_experiments": total_exps,
            "purpose": "Imaging Byzantine robustness: attack resilience on ResNet18 with Krum defense",
            "algorithms": ALGORITHMS,
            "datasets": list(IMAGING_DATASETS.keys()),
            "attacks": [a[0] for a in ATTACK_SCENARIOS],
            "defense": "krum",
            "num_byzantine": NUM_BYZANTINE,
            "num_clients": NUM_CLIENTS,
            "seeds": SEEDS,
            "start_time": datetime.now().isoformat(),
            "last_save": None,
            "version": "imaging_byzantine_v1",
        }
    }

log(f"\n{'='*66}")
log(f"  FL-EHDS Imaging — Byzantine Robustness")
log(f"  {total_exps} experiments = {len(ATTACK_SCENARIOS)} scenarios x {len(ALGORITHMS)} algos x {len(IMAGING_DATASETS)} DS x {len(SEEDS)} seeds")
log(f"  Attacks: {[a[0] for a in ATTACK_SCENARIOS]}")
log(f"  Defense: Krum ({NUM_BYZANTINE}/{NUM_CLIENTS} Byzantine)")
log(f"  Device: {_detect_device(None)}")
log(f"{'='*66}")

global_start = time.time()
completed_count = len(checkpoint_data.get("completed", {}))
trainer_ckpt_path = str(OUTPUT_DIR / TRAINER_STATE_FILE)

for exp_idx, (attack_label, attack_type, ds_name, algo, seed) in enumerate(experiments, 1):
    key = f"{ds_name}_{algo}_{attack_label}_s{seed}"

    if key in checkpoint_data.get("completed", {}):
        continue

    ds_info = IMAGING_DATASETS[ds_name]

    elapsed = time.time() - global_start
    if completed_count > 0:
        eta = str(timedelta(seconds=int((total_exps - completed_count) * elapsed / completed_count)))
    else:
        eta = "calculating..."

    log(f"\n--- [{exp_idx}/{total_exps}] {ds_name} | {algo} | {attack_label} | seed={seed} | ETA: {eta} ---")

    try:
        result = run_single_byzantine(
            ds_name=ds_name, data_dir=ds_info["data_dir"],
            algorithm=algo, seed=seed,
            attack_label=attack_label, attack_type=attack_type,
            config=IMAGING_CONFIG, es_config=EARLY_STOPPING,
            exp_idx=exp_idx, total_exps=total_exps,
            checkpoint_data=checkpoint_data, exp_key=key,
            trainer_ckpt_path=trainer_ckpt_path,
        )

        checkpoint_data["completed"][key] = result
        completed_count += 1
        save_checkpoint(checkpoint_data)

        best_acc = result.get("best_metrics", {}).get("accuracy", 0)
        es_info = f" ES@R{result['actual_rounds']}" if result.get("stopped_early") else ""
        log(f"--- Done: Best={best_acc:.1%}{es_info} | {result['runtime_seconds']:.0f}s | [{completed_count}/{total_exps}] ---")

    except Exception as e:
        log(f"ERROR in {key}: {e}")
        traceback.print_exc()
        checkpoint_data["completed"][key] = {
            "dataset": ds_name, "algorithm": algo, "seed": seed,
            "attack": attack_label, "error": str(e),
        }
        save_checkpoint(checkpoint_data)

elapsed_total = time.time() - global_start
checkpoint_data["metadata"]["end_time"] = datetime.now().isoformat()
checkpoint_data["metadata"]["total_elapsed"] = elapsed_total
save_checkpoint(checkpoint_data)

log(f"\n{'='*66}")
log(f"  COMPLETED: {completed_count}/{total_exps}")
log(f"  Total time: {timedelta(seconds=int(elapsed_total))}")
log(f"{'='*66}")

if _log_file:
    _log_file.close()

## 4. Check Progress & Results

In [None]:
import json, numpy as np

ckpt_path = f'{DRIVE_OUTPUT}/checkpoint_imaging_byzantine.json'
if os.path.exists(ckpt_path):
    with open(ckpt_path) as f:
        data = json.load(f)

    completed = data.get('completed', {})
    n_ok = sum(1 for v in completed.values() if 'error' not in v)
    n_err = sum(1 for v in completed.values() if 'error' in v)
    total = data.get('metadata', {}).get('total_experiments', '?')

    print(f'Completed: {n_ok}/{total} (errors: {n_err})')

    in_prog = data.get('in_progress', {})
    if in_prog:
        print(f'In progress: {in_prog.get("key", "?")} '
              f'round {in_prog.get("round", "?")}/{in_prog.get("total_rounds", "?")}')

    attacks = ['no_attack', 'label_flip', 'noise', 'sign_flip']
    header = f'{"DS":<14} {"Algo":<8}' + ''.join(f' {a:>12}' for a in attacks)
    print(f'\n{header}')
    print('-' * len(header))

    for ds in ['chest_xray', 'Brain_Tumor', 'Skin_Cancer']:
        for algo in ['FedAvg', 'Ditto', 'HPFL']:
            row = f'{ds:<14} {algo:<8}'
            for attack in attacks:
                accs = []
                for seed in [42, 123, 456]:
                    k = f'{ds}_{algo}_{attack}_s{seed}'
                    r = completed.get(k, {})
                    if 'error' not in r and r:
                        accs.append(r.get('best_metrics', {}).get('accuracy', 0))
                if accs:
                    row += f' {100*np.mean(accs):>11.1f}%'
                else:
                    row += f' {"--":>12}'
            print(row)
else:
    print('No checkpoint found yet.')

## 5. Download Results

In [None]:
from google.colab import files

ckpt_path = f'{DRIVE_OUTPUT}/checkpoint_imaging_byzantine.json'
if os.path.exists(ckpt_path):
    files.download(ckpt_path)
    print('Downloaded: checkpoint_imaging_byzantine.json')
    log_path = f'{DRIVE_OUTPUT}/experiment_imaging_byzantine.log'
    if os.path.exists(log_path):
        files.download(log_path)
else:
    print('No checkpoint to download yet.')