In [None]:
# ⚙️  Environment bootstrap ---------------------------------------------------------
import subprocess, sys, json, time, shutil, random, io, os, glob, tarfile
from pathlib import Path
from typing import Dict, List
from datetime import datetime

# Mount Google Drive when in Colab
try:
    from google.colab import drive  # type: ignore
    if not Path("/content/drive").exists():
        drive.mount("/content/drive", force_remount=True)
except ModuleNotFoundError:
    print("Not running on Colab – skipping drive.mount()")

# System packages
print("[SETUP] Installing OpenSlide system libs…")
subprocess.run("apt-get -qq update", shell=True, check=True)
subprocess.run("apt-get -qq install -y openslide-tools libopenslide-dev", shell=True, check=True)

# Python wheels
print("[SETUP] Installing Python wheels…")
subprocess.run(
    "pip install -qq --upgrade openslide-python openslide-bin "
    "webdataset tqdm matplotlib zarr 'numcodecs<0.8.0'",
    shell=True,
    check=True,
)

# Clean stray ~orch dirs (PyTorch leftovers if present)
subprocess.run("rm -rf /usr/local/lib/python*/dist-packages/~orch*", shell=True)

# Verify imports
try:
    import openslide, webdataset, numpy as np
    from PIL import Image
    from tqdm.auto import tqdm
    print(f"[OK] OpenSlide-Python {openslide.__version__} · NumPy {np.__version__}")
except Exception as e:
    sys.exit(f"[FATAL] Import failed → {e}")


Mounted at /content/drive
[SETUP] Installing OpenSlide system libs…
[SETUP] Installing Python wheels…
[OK] OpenSlide-Python 1.4.2 · NumPy 2.0.2


In [None]:
# 📑 Load YAML config & define paths -----------------------------------------------
import yaml, pandas as pd, random

CFG_PATH = Path("/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/config/preprocessing.yaml")
if not CFG_PATH.exists():
    sys.exit(f"[FATAL] YAML config not found → {CFG_PATH}")

cfg = yaml.safe_load(CFG_PATH.read_text())

COLAB_ROOT = Path(cfg["env_paths"]["colab"])
LOCAL_ROOT  = Path(cfg["env_paths"]["local"])
ROOT = COLAB_ROOT if COLAB_ROOT.exists() else LOCAL_ROOT
if not ROOT.exists():
    sys.exit(f"[FATAL] Project root not found → {ROOT}")

stage_cfg   = cfg["stages"]["debug"] if cfg["stages"]["debug"]["downsample_patients"]["enabled"] else cfg["stages"]["training"]
PATCH_SIZE  = stage_cfg["patching"]["patch_size"]
RANDOM_SEED = stage_cfg["patching"]["random_seed"]
SHARD_SIZE  = 5_000
MAX_DEBUG_JPG = 10
random.seed(RANDOM_SEED)

print("[CFG] Stage       :", "debug" if stage_cfg is cfg["stages"]["debug"] else "training")
print("[CFG] Patch size  :", PATCH_SIZE)
print("[CFG] Root        :", ROOT)

WSI_LOCAL_DIR = Path("/content/WSI_cache")
WSI_LOCAL_DIR.mkdir(parents=True, exist_ok=True)

DAILY_QUOTA_GB = 400        # stop before Drive read-quota
BYTES_IN_GB    = 1024 ** 3

PATCH_DF_PATH  = ROOT / "data/processed/patch_df_5000.parquet"
WDATASET_DIR   = ROOT / "data/processed/webdataset"
RESUME_PATH    = ROOT / "data/processed/resume_state.json"
ERROR_LOG_PATH = ROOT / "copy_errors.json"
DEBUG_DIR      = ROOT / "data/visual_debug/extract_examples"

WDATASET_DIR.mkdir(parents=True, exist_ok=True)
DEBUG_DIR.mkdir(parents=True, exist_ok=True)


[CFG] Stage       : training
[CFG] Patch size  : 224
[CFG] Root        : /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project


In [None]:
# 🔄 Resume helpers ---------------------------------------------------------------
def load_resume() -> Dict[str, List[str]]:
    if RESUME_PATH.exists():
        return json.loads(RESUME_PATH.read_text())
    return {"processed_wsis": []}

def save_resume(state: Dict[str, List[str]]):
    RESUME_PATH.write_text(json.dumps(state, indent=2))

resume_state   = load_resume()
processed_wsis = set(resume_state["processed_wsis"])


In [None]:
# 🐼 Load patch dataframe ----------------------------------------------------------
if not PATCH_DF_PATH.exists():
    sys.exit(f"[FATAL] patch_df not found → {PATCH_DF_PATH}")

patch_df = pd.read_parquet(PATCH_DF_PATH)
assert "split" in patch_df.columns
print(f"[DATA] patch_df rows  : {len(patch_df):,}")

wsi_paths   = pd.concat([patch_df["wsi_path"].dropna(),
                         patch_df["roi_file"].dropna()]).unique()
unique_wsis = sorted(set(wsi_paths))
print(f"[DATA] unique WSIs    : {len(unique_wsis):,}")


[DATA] patch_df rows  : 5,000
[DATA] unique WSIs    : 205


In [None]:
# 🔨 Shard writer ------------------------------------------------------------------
import webdataset as wds

# Prepara le cartelle train/val/test
splits = ["train", "val", "test"]
for split in splits:
    (WDATASET_DIR / split).mkdir(exist_ok=True, parents=True)

# Crea uno ShardWriter per ciascuno split, pattern patches-0000.tar (4 cifre)
shard_writers = {
    split: wds.ShardWriter(
        str(WDATASET_DIR / split / "patches-%04d.tar"),
        maxcount=SHARD_SIZE
    )
    for split in splits
}

# writing /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset/patches-000000.tar 0 0.0 GB 0


In [None]:
# 🏃 Main extraction loop ----------------------------------------------------------
import openslide, io

bytes_copied_today      = 0
total_patches_this_run  = 0
error_log               = []

start_time = datetime.now()

for wsi_path in tqdm(unique_wsis, desc="Processing WSI"):
    if wsi_path in processed_wsis:
        continue

    src = Path(wsi_path)
    if not src.exists():
        error_log.append({"wsi": wsi_path, "error": "not_found"})
        continue

    size_b = src.stat().st_size
    # — Check quota BEFORE starting a new WSI —
    if (bytes_copied_today + size_b) / BYTES_IN_GB >= DAILY_QUOTA_GB:
        print(f"\n[STOP] Daily quota {DAILY_QUOTA_GB} GB reached. Halting before new WSI.\n")
        break

    # Copy slide to local SSD -------------------------------------------------------
    dst = WSI_LOCAL_DIR / src.name
    if not dst.exists():
        ok = False
        for _ in range(3):
            try:
                shutil.copy(src, dst)
                bytes_copied_today += size_b
                ok = True
                break
            except Exception:
                time.sleep(1)
        if not ok:
            error_log.append({"wsi": wsi_path, "error": "copy_failed"})
            continue

    # Extract patches --------------------------------------------------------------
    try:
        slide  = openslide.OpenSlide(str(dst))
        sub_df = patch_df[(patch_df["wsi_path"] == wsi_path) |
                          (patch_df["roi_file"] == wsi_path)]
        sub_df = sub_df.sample(frac=1, random_state=RANDOM_SEED)

        patch_cnt = 0
        for _, row in sub_df.iterrows():
            x, y = int(row["x"]), int(row["y"])
            img  = slide.read_region((x, y), 0, (PATCH_SIZE, PATCH_SIZE)).convert("RGB")

            buf = io.BytesIO()
            img.save(buf, format="JPEG", quality=90)
            key = f"{row['split']}/{row['subtype']}_{row['patient_id']}_{x}_{y}"

            shard_writers[row["split"]].write({
                "__key__": key,
                "jpg": buf.getvalue(),
            })
            patch_cnt += 1
            total_patches_this_run += 1
        slide.close()
    except Exception as e:
        error_log.append({"wsi": wsi_path, "error": str(e)})
        patch_cnt = 0

    # Remove local copy to free space ----------------------------------------------
    try:
        dst.unlink()
    except Exception:
        pass

    # Progress message -------------------------------------------------------------
    tqdm.write(f"[WSI] {src.name}: {patch_cnt} patches "
               f"(cumulative: {total_patches_this_run})")

    # Update resume *after* finishing the whole slide ------------------------------
    processed_wsis.add(wsi_path)
    save_resume({"processed_wsis": sorted(processed_wsis)})


Processing WSI:   0%|          | 0/205 [00:00<?, ?it/s]

[WSI] 1.svs: 84 patches (cumulative: 84)
[WSI] 10.svs: 63 patches (cumulative: 147)
[WSI] 11.svs: 77 patches (cumulative: 224)
[WSI] 12.svs: 103 patches (cumulative: 327)
[WSI] 13.tif: 88 patches (cumulative: 415)
[WSI] 2.svs: 110 patches (cumulative: 525)
[WSI] 3.svs: 43 patches (cumulative: 568)
[WSI] 4.svs: 55 patches (cumulative: 623)
[WSI] 5.svs: 39 patches (cumulative: 662)
[WSI] 6.svs: 41 patches (cumulative: 703)
[WSI] 7.svs: 37 patches (cumulative: 740)
[WSI] 8.svs: 200 patches (cumulative: 940)
[WSI] 9.svs: 60 patches (cumulative: 1000)
[WSI] 1.svs: 67 patches (cumulative: 1067)
[WSI] 10.tif: 34 patches (cumulative: 1101)
[WSI] 11.tif: 37 patches (cumulative: 1138)
[WSI] 12.tif: 36 patches (cumulative: 1174)
[WSI] 13.tif: 36 patches (cumulative: 1210)
[WSI] 14.tif: 27 patches (cumulative: 1237)
[WSI] 15.tif: 28 patches (cumulative: 1265)
[WSI] 16.tif: 27 patches (cumulative: 1292)
[WSI] 17.tif: 29 patches (cumulative: 1321)
[WSI] 18.tif: 20 patches (cumulative: 1341)
[WSI] 19

In [None]:
# 📝 Run summary -------------------------------------------------------------------
shard_writer.close()
runtime_min = (datetime.now() - start_time).total_seconds() / 60

print("\n================= SUMMARY =================")
print(f"Runtime        : {runtime_min:.1f} min")
print(f"WSIs completed : {len(processed_wsis):,}/{len(unique_wsis):,}")
print(f"Patches saved  : {total_patches_this_run:,}  (this run)")
print(f"GB copied      : {bytes_copied_today / BYTES_IN_GB:.2f} GB")
print(f"Shards dir     : {WDATASET_DIR}")

if error_log:
    ERROR_LOG_PATH.write_text(json.dumps(error_log, indent=2))
    print(f"Errors logged  : {len(error_log)} → {ERROR_LOG_PATH}")
else:
    print("Errors logged  : none 🎉")



Runtime        : 76.4 min
WSIs completed : 205/205
Patches saved  : 5,000  (this run)
GB copied      : 178.75 GB
Shards dir     : /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset
Errors logged  : none 🎉


In [None]:
# ---------------------------------- CELL 06b (FIXED ITEM COUNT) ----------------------
# 📊 WebDataset statistics (robust, count real items) -----------------------------
import tarfile

print("\n[STATS] WebDataset overview (all shards on disk)")

# Trova tutti i file .tar nel WebDataset directory
tar_paths = sorted(glob.glob(str(WDATASET_DIR / "patches-*.tar")))
print(f"• Shard files found  : {len(tar_paths)}")

total_samples = 0
valid_shards = 0

# Per ogni .tar, conta i membri validi e ricava i sample
for tp in tar_paths:
    size_mb = os.path.getsize(tp) / (1024**2)
    if size_mb == 0:
        # Shard vuoto
        print(f"  - {Path(tp).name:<20} → EMPTY (0 samples, {size_mb:.2f} MB)")
        continue

    try:
        with tarfile.open(tp, "r") as tar:
            members = tar.getmembers()
            # Ogni patch genera due file nel tar: .jpg e .json
            n_members = len(members)
            n_samples = n_members // 2
            valid_shards += 1
            total_samples += n_samples
            print(f"  - {Path(tp).name:<20} → {n_samples:5d} samples, {size_mb:6.2f} MB")
    except Exception as e:
        # Shard corrotto o non apribile
        print(f"  - {Path(tp).name:<20} → ERROR ({e})")
        continue

print(f"\n• Valid shards       : {valid_shards}/{len(tar_paths)}")
print(f"• Total samples (patches): {total_samples:,}\n")



[STATS] WebDataset overview (all shards on disk)
• Shard files found  : 1
  - patches-000000.tar   →  5000 samples, 102.85 MB

• Valid shards       : 1/1
• Total samples (patches): 5,000



In [None]:
# 🔍 Save ≥5 thumbnails per split × class ------------------------------------------
from PIL import Image
from tqdm.auto import tqdm as tq
DEBUG_DIR.mkdir(parents=True, exist_ok=True)

thumb_paths = []  # (split, subtype, path)

sample_df = (
    patch_df.groupby(["split", "subtype"], as_index=False, group_keys=False)
    .apply(lambda d: d.sample(n=min(5, len(d)), random_state=RANDOM_SEED))
)

for _, row in tq(sample_df.iterrows(), total=len(sample_df), desc="Saving thumbs"):
    wsi_path = row["wsi_path"] if pd.notna(row["wsi_path"]) else row["roi_file"]
    src      = Path(wsi_path)
    if not src.exists():
        continue

    # cache slide if not already copied
    dst = WSI_LOCAL_DIR / src.name
    if not dst.exists():
        try:
            shutil.copy(src, dst)
        except Exception:
            continue

    try:
        slide = openslide.OpenSlide(str(dst))
        img   = slide.read_region((int(row["x"]), int(row["y"])), 0,
                                  (PATCH_SIZE, PATCH_SIZE)).convert("RGB")
        slide.close()
    except Exception:
        continue

    out_dir  = DEBUG_DIR / row["split"] / row["subtype"]
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"{row['patient_id']}_{int(row['x'])}_{int(row['y'])}.jpg"
    img.save(out_path, quality=85)
    thumb_paths.append((row["split"], row["subtype"], out_path))

In [None]:
# 🖼️  Display grid per split × class ----------------------------------------------
import matplotlib.pyplot as plt
from itertools import groupby

if not thumb_paths:
    # Rebuild list from disk in case of notebook restart
    for split_dir in DEBUG_DIR.iterdir():
        if not split_dir.is_dir():
            continue
        for subtype_dir in split_dir.iterdir():
            if not subtype_dir.is_dir():
                continue
            for img_file in subtype_dir.glob("*.jpg"):
                thumb_paths.append((split_dir.name, subtype_dir.name, img_file))

if not thumb_paths:
    print("[DEBUG] No thumbnails found to display.")
else:
    for split_name, items in groupby(sorted(thumb_paths, key=lambda t: t[0]),
                                     key=lambda t: t[0]):
        items   = list(items)
        classes = sorted({st for _, st, _ in items})
        n_rows  = len(classes)
        n_cols  = 5

        fig, axes = plt.subplots(n_rows, n_cols,
                                 figsize=(n_cols*3, n_rows*3))
        axes = axes if n_rows > 1 else [axes]

        for r, subtype in enumerate(classes):
            patches = [p for _, st, p in items if st == subtype][:n_cols]
            for c in range(n_cols):
                ax = axes[r][c]
                if c < len(patches) and patches[c].exists():
                    ax.imshow(Image.open(patches[c]))
                    ax.set_title(subtype, fontsize=8)
                else:
                    ax.text(0.5, 0.5, "Missing",
                            ha="center", va="center", fontsize=8, color="red")
                ax.axis("off")

        fig.suptitle(f"Split: {split_name}", fontsize=12, y=0.92)
        plt.tight_layout()
        plt.show()
