import os
import io
import gzip
import csv
import shutil
import tempfile
import zipfile
import glob
import numpy as np

In [None]:

import subprocess
import sys
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])
for p in ["nibabel", "scipy", "requests"]:
    pip_install(p)
print("OK: nibabel, scipy, requests")

In [None]:

import nibabel as nib
import requests
import matplotlib
matplotlib.use("Agg")

In [None]:

import matplotlib.pyplot as plt

# مسیر ریشه: در Colab معمولاً /content؛ برای اجرای محلی مسیر پروژه را بگذارید
ROOT = os.path.abspath("/content" if os.path.isdir("/content") else os.getcwd())
# پوشهٔ داده: هر case باید شامل *_brain_{flair,t1ce,t1,t2}.nii.zip و *_final_seg.nii باشد
DATA_DIR = os.path.join(ROOT, "data")
# خروجی یکجا
OUT_ROOT = os.path.join(ROOT, "ablation_out")
FIG_DIR = os.path.join(OUT_ROOT, "figures")
os.makedirs(OUT_ROOT, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)

# API مدل Swin UNETR
API_BASE = "http://216.126.237.218:8086"
PREDICT_URL = f"{API_BASE}/predict"

# شناسه نمونه‌ها: اگر داده در DATA_DIR است، همه caseهایی که هر چهار مدالیته را دارند پیدا می‌شوند
# یا به‌صورت دستی لیست کنید:
CASE_IDS = ["00000042", "00000057"]

def find_cases(data_dir):
    zips = glob.glob(os.path.join(data_dir, "*_brain_flair.nii.zip"))
    out = []
    for z in zips:
        base = os.path.basename(z)
        case_id = base.replace("_brain_flair.nii.zip", "")
        req = [f"{case_id}_brain_{m}.nii.zip" for m in ["flair", "t1ce", "t1", "t2"]]
        if all(os.path.isfile(os.path.join(data_dir, n)) for n in req):
            if os.path.isfile(os.path.join(data_dir, f"{case_id}_final_seg.nii")):
                out.append(case_id)
    return sorted(out)

cases = find_cases(DATA_DIR) if os.path.isdir(DATA_DIR) else []
if not cases:
    cases = list(CASE_IDS)
print("Data dir:", DATA_DIR)
print("Output root:", OUT_ROOT)
print("Cases to process:", cases if cases else "(none found — put data in DATA_DIR)")

In [None]:
# ========== توابع بارگذاری، جایگزینی مدالیتی و API ==========
MODALITY_ORDER = ["flair", "t1ce", "t1", "t2"]
LABELS = [1, 2, 4]
LABEL_RGB = {0: (0, 0, 0), 1: (1, 0, 0), 2: (0, 1, 0), 4: (0, 0, 1)}

def unzip_nii(zip_path, out_dir):
    with zipfile.ZipFile(zip_path, "r") as z:
        names = [n for n in z.namelist() if n.endswith(".nii") and not n.endswith(".nii.gz")]
        if not names:
            raise FileNotFoundError(f"No .nii in {zip_path}")
        z.extract(names[0], out_dir)
    return os.path.join(out_dir, os.path.basename(names[0]))

def load_case_modalities(data_dir, case_id, tmp_dir):
    mods, affine = {}, None
    for key in MODALITY_ORDER:
        zip_path = os.path.join(data_dir, f"{case_id}_brain_{key}.nii.zip")
        nii_path = unzip_nii(zip_path, tmp_dir)
        img = nib.load(nii_path)
        mods[key] = np.asarray(img.dataobj, dtype=np.float32)
        if affine is None:
            affine = img.affine
    return mods, affine

def replace_black(vol):
    return np.zeros_like(vol, dtype=vol.dtype)
def replace_gaussian(vol, ref_vol=None):
    ref = ref_vol if ref_vol is not None else vol
    r = np.nanmax(ref) - np.nanmin(ref)
    if r <= 0: r = 1.0
    std = 0.2 * r
    noise = np.random.default_rng(42).normal(0, std, size=vol.shape).astype(vol.dtype)
    return np.clip(vol + noise, ref.min(), ref.max())
def replace_mean(volumes):
    return np.mean(volumes, axis=0).astype(volumes[0].dtype)
def replace_copy(other):
    return np.asarray(other, dtype=other.dtype).copy()
def replace_interpolation(removed_key, mods):
    idx = MODALITY_ORDER.index(removed_key)
    if idx == 0: return replace_copy(mods["t1ce"])
    if idx == len(MODALITY_ORDER) - 1: return replace_copy(mods[MODALITY_ORDER[idx - 1]])
    left, right = mods[MODALITY_ORDER[idx - 1]], mods[MODALITY_ORDER[idx + 1]]
    return (0.5 * left + 0.5 * right).astype(left.dtype)

def build_single_scenarios(mods):
    scenarios = []
    for removed in MODALITY_ORDER:
        remaining = [k for k in MODALITY_ORDER if k != removed]
        others = [mods[k] for k in remaining]
        for name, repl in [
            ("black", replace_black(mods[removed])),
            ("gaussian", replace_gaussian(mods[removed])),
            ("mean", replace_mean(others)),
            (f"copy_{remaining[0]}", replace_copy(mods[remaining[0]])),
            ("interp", replace_interpolation(removed, mods)),
        ]:
            out = {k: replace_copy(mods[k]) for k in MODALITY_ORDER}
            out[removed] = repl
            scenarios.append((f"single_{removed}_{name}", out))
    return scenarios

def build_double_scenarios(mods):
    pairs = [("flair", "t1ce"), ("flair", "t1"), ("flair", "t2"), ("t1ce", "t1"), ("t1ce", "t2"), ("t1", "t2")]
    scenarios = []
    for r1, r2 in pairs:
        remaining = [k for k in MODALITY_ORDER if k not in (r1, r2)]
        others = [mods[k] for k in remaining]
        mean_rest = replace_mean(others)
        for rep_name, v1, v2 in [
            ("black", replace_black(mods[r1]), replace_black(mods[r2])),
            ("gaussian", replace_gaussian(mods[r1]), replace_gaussian(mods[r2])),
            ("mean", mean_rest, replace_copy(mean_rest)),
            (f"copy_{remaining[0]}", replace_copy(mods[remaining[0]]), replace_copy(mods[remaining[0]])),
            (f"copy_{remaining[1]}", replace_copy(mods[remaining[1]]), replace_copy(mods[remaining[1]])),
        ]:
            out = {k: replace_copy(mods[k]) for k in MODALITY_ORDER}
            out[r1], out[r2] = v1, v2
            scenarios.append((f"double_{r1}_{r2}_{rep_name}", out))
    return scenarios

def nii_to_gz(nii_path):
    gz_path = nii_path + ".gz"
    with open(nii_path, "rb") as f_in:
        with gzip.open(gz_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
    return gz_path

def run_api(mods, affine, tmp_dir):
    files_to_send = {}
    try:
        for key in MODALITY_ORDER:
            nii_path = os.path.join(tmp_dir, f"{key}.nii")
            nib.save(nib.Nifti1Image(mods[key], affine), nii_path)
            gz_path = nii_to_gz(nii_path)
            files_to_send[key] = (key, open(gz_path, "rb"), "application/octet-stream")
        r = requests.post(PREDICT_URL, files={k: (os.path.basename(v[1].name), v[1], v[2]) for k, v in files_to_send.items()}, timeout=300)
        if r.status_code != 200:
            print(f"    API {r.status_code}")
            return None
        content = r.content
        if len(content) >= 2 and content[0] == 0x1f and content[1] == 0x8b:
            with gzip.GzipFile(fileobj=io.BytesIO(content), mode="rb") as gz_in:
                content = gz_in.read()
        return content
    finally:
        for v in files_to_send.values():
            v[1].close()
print("Ablation helpers loaded.")

In [None]:
# ========== متریک‌ها (Dice, Sensitivity, HD95)، seg_to_png، و اجرای ablation برای هر case ==========
from scipy import ndimage
from scipy.spatial.distance import cdist

def dice_per_label(pred, gt, labels):
    pred, gt = np.round(pred).astype(np.int32), np.round(gt).astype(np.int32)
    return {L: (2.0 * np.sum((pred == L) & (gt == L)) / (np.sum(pred == L) + np.sum(gt == L))) if (np.sum(pred == L) + np.sum(gt == L)) > 0 else 1.0 for L in labels}
def sensitivity_per_label(pred, gt, labels):
    pred, gt = np.round(pred).astype(np.int32), np.round(gt).astype(np.int32)
    out = {}
    for L in labels:
        tp = np.sum((pred == L) & (gt == L))
        fn = np.sum((pred != L) & (gt == L))
        out[L] = (tp / (tp + fn)) if (tp + fn) > 0 else 1.0
    return out
def surface_voxels(mask):
    eroded = ndimage.binary_erosion(mask, structure=ndimage.generate_binary_structure(3, 1))
    return np.argwhere(mask & (~eroded))
def hausdorff_95(pred, gt, label):
    pred, gt = np.round(pred).astype(np.int32), np.round(gt).astype(np.int32)
    pa, pb = (pred == label), (gt == label)
    if not np.any(pa) or not np.any(pb): return 0.0
    sa, sb = surface_voxels(pa), surface_voxels(pb)
    if sa.size == 0 or sb.size == 0: return 0.0
    if len(sa) > 2000:
        sa = sa[np.random.default_rng(42).choice(len(sa), 2000, replace=False)]
    if len(sb) > 2000:
        sb = sb[np.random.default_rng(43).choice(len(sb), 2000, replace=False)]
    d = cdist(sa, sb, metric="euclidean")
    return float(max(np.percentile(np.min(d, axis=1), 95), np.percentile(np.min(d, axis=0), 95)))

def compute_metrics(pred_path, gt_path):
    pred = np.asarray(nib.load(pred_path).dataobj)
    gt = np.asarray(nib.load(gt_path).dataobj)
    pred = np.round(pred).astype(np.int32)
    gt = np.round(gt).astype(np.int32)
    pred[~np.isin(pred, [0,1,2,4])] = 0
    gt[~np.isin(gt, [0,1,2,4])] = 0
    if pred.shape != gt.shape: return {"error": True}
    dice = dice_per_label(pred, gt, LABELS)
    sens = sensitivity_per_label(pred, gt, LABELS)
    hd95 = {L: hausdorff_95(pred, gt, L) for L in LABELS}
    return {"dice_1": dice[1], "dice_2": dice[2], "dice_4": dice[4], "sensitivity_1": sens[1], "sensitivity_2": sens[2], "sensitivity_4": sens[4], "hd95_1": hd95[1], "hd95_2": hd95[2], "hd95_4": hd95[4]}

def seg_to_png(nii_path, png_path, title=""):
    data = np.asarray(nib.load(nii_path).dataobj)
    data = np.round(data).astype(np.int32)
    data[~np.isin(data, [0,1,2,4])] = 0
    depth = data.shape[2]
    indices = [depth//4, depth//2-1, depth//2, 3*depth//4]
    indices = [max(0, min(i, depth-1)) for i in indices]
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for ax, idx in zip(axes.flat, indices):
        sl = data[:,:,idx]
        rgb = np.zeros((*sl.shape, 3), dtype=np.float32)
        for label, c in LABEL_RGB.items(): rgb[sl == label] = c
        ax.imshow(rgb); ax.set_title(f"Slice {idx}"); ax.axis("off")
    plt.suptitle(title or os.path.basename(nii_path).replace(".nii", ""), fontsize=12)
    plt.tight_layout(); plt.savefig(png_path, dpi=150, bbox_inches="tight"); plt.close()

def run_ablation_for_case(data_dir, case_id, out_dir, run_baseline=True, run_single=True, run_double=True):
    gt_path = os.path.join(data_dir, f"{case_id}_final_seg.nii")
    if not os.path.isfile(gt_path):
        print(f"  Skip {case_id}: no GT"); return []
    os.makedirs(out_dir, exist_ok=True)
    rows = []
    with tempfile.TemporaryDirectory(prefix="ablation_") as tmp:
        mods, affine = load_case_modalities(data_dir, case_id, tmp)
        def run_scenario(name, mod_dict):
            content = run_api(mod_dict, affine, tmp)
            if content is None: return None
            pred_path = os.path.join(out_dir, f"{case_id}_{name}_pred.nii")
            with open(pred_path, "wb") as f: f.write(content)
            seg_to_png(pred_path, os.path.join(out_dir, f"{case_id}_{name}_pred.png"), f"{case_id} {name}")
            met = compute_metrics(pred_path, gt_path)
            if met.get("error"): return None
            return {"case_id": case_id, "scenario": name, **met}
        if run_baseline:
            row = run_scenario("baseline", {k: replace_copy(mods[k]) for k in MODALITY_ORDER})
            if row: rows.append(row)
        if run_single:
            for name, mod_dict in build_single_scenarios(mods):
                row = run_scenario(name, mod_dict)
                if row: rows.append(row)
        if run_double:
            for name, mod_dict in build_double_scenarios(mods):
                row = run_scenario(name, mod_dict)
                if row: rows.append(row)
    return rows
print("Metrics and run_ablation loaded.")

In [None]:
# ========== نوشتن CSVها و توابع نمودار ==========
def scenario_order(s):
    if s == "baseline": return (0, s)
    if s.startswith("single_"): return (1, s)
    return (2, s)

def parse_scenario(s):
    if s == "baseline": return "none", "baseline"
    if s.startswith("single_"):
        rest = s[7:]
        for mod in MODALITY_ORDER:
            if rest.startswith(mod + "_"):
                return mod, rest[len(mod)+1:]
        return "", s
    if s.startswith("double_"):
        parts = s[7:].split("_")
        return "_".join(parts[:2]), "_".join(parts[2:]) if len(parts) >= 3 else "double"
    return "", s

def write_csvs(all_rows, out_root):
    os.makedirs(out_root, exist_ok=True)
    sorted_rows = sorted(all_rows, key=lambda r: (r.get("case_id",""), scenario_order(r.get("scenario",""))))
    # metrics.csv
    with open(os.path.join(out_root, "metrics.csv"), "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=["case_id","scenario","dice_TC","dice_WT","dice_ET","sensitivity_TC","sensitivity_WT","sensitivity_ET","hd95_TC","hd95_WT","hd95_ET"])
        w.writeheader()
        for r in sorted_rows:
            w.writerow({"case_id": r["case_id"], "scenario": r["scenario"],
                "dice_TC": round(r.get("dice_1",0),6), "dice_WT": round(r.get("dice_2",0),6), "dice_ET": round(r.get("dice_4",0),6),
                "sensitivity_TC": round(r.get("sensitivity_1",0),6), "sensitivity_WT": round(r.get("sensitivity_2",0),6), "sensitivity_ET": round(r.get("sensitivity_4",0),6),
                "hd95_TC": round(r.get("hd95_1",0),6), "hd95_WT": round(r.get("hd95_2",0),6), "hd95_ET": round(r.get("hd95_4",0),6)})
    # metrics_summary.csv + build summary_rows for plotting
    summary_rows = []
    with open(os.path.join(out_root, "metrics_summary.csv"), "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=["case_id","scenario","removed_modality","replacement_type","dice_TC","dice_WT","dice_ET","sensitivity_TC","sensitivity_WT","sensitivity_ET","hd95_TC","hd95_WT","hd95_ET"])
        w.writeheader()
        for r in sorted_rows:
            rem, rep = parse_scenario(r.get("scenario",""))
            row = {"case_id": r["case_id"], "scenario": r["scenario"], "removed_modality": rem, "replacement_type": rep,
                "dice_TC": round(r.get("dice_1",0),6), "dice_WT": round(r.get("dice_2",0),6), "dice_ET": round(r.get("dice_4",0),6),
                "sensitivity_TC": round(r.get("sensitivity_1",0),6), "sensitivity_WT": round(r.get("sensitivity_2",0),6), "sensitivity_ET": round(r.get("sensitivity_4",0),6),
                "hd95_TC": round(r.get("hd95_1",0),6), "hd95_WT": round(r.get("hd95_2",0),6), "hd95_ET": round(r.get("hd95_4",0),6)}
            summary_rows.append(row)
            w.writerow(row)
    return summary_rows

MOD_ORDER = ["flair", "t1ce", "t1", "t2"]
FIXED_REP = ["black", "gaussian", "mean", "interp"]
def rep_order(single_rows):
    seen = set(r["replacement_type"] for r in single_rows)
    return FIXED_REP + sorted(x for x in seen if x not in FIXED_REP and x.startswith("copy_"))
def get_val(single_rows, mod, rep, key):
    for r in single_rows:
        if r["removed_modality"] == mod and r["replacement_type"] == rep:
            return r[key]
    return np.nan

def plot_all(summary_rows, fig_dir):
    single = [r for r in summary_rows if r["removed_modality"] in MOD_ORDER]
    baseline = [r for r in summary_rows if r["scenario"] == "baseline"]
    if not single or not baseline:
        print("Not enough rows to plot."); return
    b = baseline[0]
    baseline_val = (b["dice_TC"]+b["dice_WT"]+b["dice_ET"])/3
    rep_ord = rep_order(single)
    # 1) Dice by modality and replacement
    fig, ax = plt.subplots(figsize=(12, 6))
    x = np.arange(len(MOD_ORDER))
    width = 0.12
    for i, rep in enumerate(rep_ord):
        vals = [(get_val(single, mod, rep, "dice_TC")+get_val(single, mod, rep, "dice_WT")+get_val(single, mod, rep, "dice_ET"))/3 for mod in MOD_ORDER]
        vals = [v if not np.isnan(v) else np.nan for v in vals]
        ax.bar(x + (i - len(rep_ord)/2)*width + width/2, vals, width, label=rep.replace("copy_","C."))
    ax.axhline(y=baseline_val, color="gray", linestyle="--", label="Baseline")
    ax.set_xticks(x); ax.set_xticklabels(["FLAIR","T1ce","T1","T2"])
    ax.set_ylabel("Dice (mean TC, WT, ET)"); ax.set_title("Dice by removed modality and replacement")
    ax.legend(loc="lower right", fontsize=8); ax.set_ylim(0, 1.05)
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "dice_by_modality_replacement.png"), dpi=150, bbox_inches="tight"); plt.close()
    # 2) Heatmap
    mat = np.zeros((len(MOD_ORDER), len(rep_ord)))
    for i, mod in enumerate(MOD_ORDER):
        for j, rep in enumerate(rep_ord):
            v = (get_val(single, mod, rep, "dice_TC")+get_val(single, mod, rep, "dice_WT")+get_val(single, mod, rep, "dice_ET"))/3
            mat[i,j] = v if not np.isnan(v) else 0
    fig, ax = plt.subplots(figsize=(8, 5))
    im = ax.imshow(mat, aspect="auto", vmin=0, vmax=1, cmap="RdYlGn")
    ax.set_xticks(np.arange(len(rep_ord))); ax.set_yticks(np.arange(len(MOD_ORDER)))
    ax.set_xticklabels([r.replace("copy_","C.") for r in rep_ord]); ax.set_yticklabels(["FLAIR","T1ce","T1","T2"])
    for i in range(len(MOD_ORDER)):
        for j in range(len(rep_ord)):
            ax.text(j, i, f"{mat[i,j]:.2f}", ha="center", va="center", fontsize=9)
    plt.colorbar(im, ax=ax, label="Dice (mean)"); ax.set_title("Dice — Removed modality vs replacement")
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "dice_heatmap.png"), dpi=150, bbox_inches="tight"); plt.close()
    # 3) Dice per region (first 20 scenarios)
    plot_rows = [r for r in summary_rows if r["scenario"] != "baseline"][:20]
    if plot_rows:
        fig, ax = plt.subplots(figsize=(14, 5))
        xx = np.arange(len(plot_rows)); w = 0.25
        ax.bar(xx - w, [r["dice_TC"] for r in plot_rows], w, label="Dice TC", color="C0")
        ax.bar(xx, [r["dice_WT"] for r in plot_rows], w, label="Dice WT", color="C1")
        ax.bar(xx + w, [r["dice_ET"] for r in plot_rows], w, label="Dice ET", color="C2")
        ax.set_xticks(xx); ax.set_xticklabels([r["removed_modality"]+"\n"+r["replacement_type"] for r in plot_rows], rotation=45, ha="right", fontsize=7)
        ax.set_ylabel("Dice"); ax.legend(); ax.set_ylim(0, 1.05); ax.set_title("Dice per region by scenario")
        fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "dice_per_region.png"), dpi=150, bbox_inches="tight"); plt.close()
    # 4) Sensitivity by modality
    means = [np.mean([(r["sensitivity_TC"]+r["sensitivity_WT"]+r["sensitivity_ET"])/3 for r in single if r["removed_modality"]==m]) for m in MOD_ORDER]
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.bar(np.arange(4), means, color="steelblue"); ax.axhline(y=(b["sensitivity_TC"]+b["sensitivity_WT"]+b["sensitivity_ET"])/3, color="gray", linestyle="--", label="Baseline")
    ax.set_xticks(np.arange(4)); ax.set_xticklabels(["FLAIR","T1ce","T1","T2"])
    ax.set_ylabel("Sensitivity (mean)"); ax.set_title("Mean sensitivity when each modality removed"); ax.legend(); ax.set_ylim(0, 1.05)
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "sensitivity_by_modality.png"), dpi=150, bbox_inches="tight"); plt.close()
    # 5) HD95 by modality
    means_h = [np.mean([(r["hd95_TC"]+r["hd95_WT"]+r["hd95_ET"])/3 for r in single if r["removed_modality"]==m]) for m in MOD_ORDER]
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.bar(np.arange(4), means_h, color="coral"); ax.axhline(y=(b["hd95_TC"]+b["hd95_WT"]+b["hd95_ET"])/3, color="gray", linestyle="--", label="Baseline")
    ax.set_xticks(np.arange(4)); ax.set_xticklabels(["FLAIR","T1ce","T1","T2"])
    ax.set_ylabel("HD95 (mean)"); ax.set_title("Mean HD95 when each modality removed"); ax.legend()
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "hd95_by_modality.png"), dpi=150, bbox_inches="tight"); plt.close()
    # 6) Dice by replacement type
    means_rep = [np.mean([(r["dice_TC"]+r["dice_WT"]+r["dice_ET"])/3 for r in single if r["replacement_type"]==rep]) for rep in rep_ord]
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.bar(np.arange(len(rep_ord)), means_rep, color="teal", alpha=0.8)
    ax.axhline(y=baseline_val, color="gray", linestyle="--", label="Baseline")
    ax.set_xticks(np.arange(len(rep_ord))); ax.set_xticklabels([r.replace("copy_","Copy ") for r in rep_ord], rotation=30, ha="right")
    ax.set_ylabel("Dice (mean)"); ax.set_title("Dice by replacement type"); ax.legend(); ax.set_ylim(0, 1.05)
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "dice_by_replacement_type.png"), dpi=150, bbox_inches="tight"); plt.close()
    print("  All 6 figures saved in", fig_dir)
print("CSV and plot helpers loaded.")

In [None]:
# ========== اجرای یکپارچه: ablation برای همه caseها، ذخیره CSV و نمودارها ==========
all_rows = []
for case_id in cases:
    case_out = os.path.join(OUT_ROOT, case_id)
    print(f"--- {case_id} ---")
    rows = run_ablation_for_case(DATA_DIR, case_id, case_out, run_baseline=True, run_single=True, run_double=True)
    all_rows.extend(rows)
    print(f"  {len(rows)} scenarios -> {case_out}")

if not all_rows:
    print("No results. Check DATA_DIR and case data (zip + final_seg.nii).")
else:
    summary_rows = write_csvs(all_rows, OUT_ROOT)
    os.makedirs(FIG_DIR, exist_ok=True)
    plot_all(summary_rows, FIG_DIR)
    print(f"\nTotal: {len(all_rows)} rows. CSVs and figures in {OUT_ROOT}")

In [None]:
# ========== خلاصه خروجی: ساختار پوشه، جدول متریک، نمایش نمودارها، دانلود یکجا ==========
from IPython.display import display, Image, HTML
import base64

print("خروجی‌ها در:", OUT_ROOT)
for root, dirs, files in os.walk(OUT_ROOT):
    level = root.replace(OUT_ROOT, "").count(os.sep)
    indent = "  " * level
    print(f"{indent}{os.path.basename(root) or 'ablation_out'}/")
    for f in sorted(files)[:15]:
        print(f"{indent}  {f}")
    if len(files) > 15:
        print(f"{indent}  ... و {len(files)-15} فایل دیگر")

# جدول متریک (نمایش چند سطر)
csv_path = os.path.join(OUT_ROOT, "metrics.csv")
if os.path.isfile(csv_path):
    with open(csv_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
    print(f"\nجدول متریک (تعداد سطر: {len(rows)}). نمونه:")
    for r in rows[:8]:
        print(r)

# نمایش نمودارها در نوت‌بوک
for name in ["dice_by_modality_replacement.png", "dice_heatmap.png", "sensitivity_by_modality.png", "hd95_by_modality.png", "dice_by_replacement_type.png", "dice_per_region.png"]:
    p = os.path.join(FIG_DIR, name)
    if os.path.isfile(p):
        display(Image(filename=p, width=500))
        print(name)

# بسته‌سازی برای دانلود (در Colab)
zip_path = os.path.join(ROOT, "ablation_out.zip")
if os.path.isdir(OUT_ROOT):
    shutil.make_archive(os.path.join(ROOT, "ablation_out"), "zip", ROOT, "ablation_out")
    print(f"\nفایل zip: {zip_path}")
    try:
        from google.colab import files
        files.download(zip_path)
    except Exception:
        pass