In [4]:
# ===== Cell 1 – mount / install (solo Colab) =================================
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

!apt-get -qq update
!apt-get -qq install -y openslide-tools
!pip -q install openslide-python tqdm

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)


In [5]:
from pathlib import Path
import yaml, json, tarfile, random, io
from tqdm.auto import tqdm
import pandas as pd
from openslide import OpenSlide
from PIL import Image

In [6]:
# ===== Cell 2: Config & ambiente ============================================
# 1) YAML
yaml_path = Path('/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/config/preprocessing.yaml')
with open(yaml_path) as f:
    cfg = yaml.safe_load(f)

# 2) ambiente (Colab / local)
colab_root = Path(cfg['env_paths']['colab'])
local_root = Path(cfg['env_paths']['local'])
root       = colab_root if colab_root.exists() else local_root
if not root.exists():
    raise FileNotFoundError('Impossibile trovare project_root')

base_dir = root / 'data/RCC_WSIs'

# 3) stage
stage_cfg  = cfg['stages']['debug'] if cfg['stages']['debug']['downsample_patients']['enabled'] else cfg['stages']['training']
PATCH_SIZE = stage_cfg['patching']['patch_size']
SHARD_SIZE = 5_000                     # img per tar
RANDOM_SEED= stage_cfg['patching']['random_seed']
MAX_DBG    = 10                        # jpg per subtype per split
rng        = random.Random(RANDOM_SEED)

print('✅ root:', root)
print('✅ stage:', 'debug' if stage_cfg is cfg["stages"]["debug"] else 'training')
print('✅ patch:', PATCH_SIZE, 'px')
print('✅ shard size:', SHARD_SIZE)

✅ root: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project
✅ stage: training
✅ patch: 224 px
✅ shard size: 5000


In [7]:
# ===== Cell 3: Carica patch_df ==============================================
patch_df_path = root / 'data/processed/patch_df_5000.parquet'
patch_df      = pd.read_parquet(patch_df_path)
if 'split' not in patch_df.columns:
    raise RuntimeError("'split' column mancante nel parquet!")

print(f"✅ patch_df: {len(patch_df)} righe –  splits: {patch_df['split'].value_counts().to_dict()}")

✅ patch_df: 5000 righe –  splits: {'train': 3000, 'val': 1000, 'test': 1000}


In [8]:
# ===== Cell 4: Utility ======================================================
def extract_patch(row, cache):
    """Restituisce PIL.Image RGB della patch richiesta."""
    src = row['wsi_path'] if pd.notna(row['wsi_path']) else row['roi_file']
    if src not in cache:
        cache[src] = OpenSlide(src)
    slide = cache[src]
    region = slide.read_region(
        (int(row['x']), int(row['y'])), 0, (PATCH_SIZE, PATCH_SIZE)
    )
    return region.convert('RGB')

In [9]:
# ===== Cell 5 – Estrazione patch → WebDataset, suddivisa per split ===========
import tarfile, io, os
from tqdm.auto import tqdm

# --------------------------------------------------------------------------- #
# directory di output (.tar) e di debug (.jpg)                                #
# --------------------------------------------------------------------------- #
out_root = root / "data/processed/webdataset"
dbg_root = root / "data/visual_debug/extract_examples"
out_root.mkdir(parents=True, exist_ok=True)
dbg_root.mkdir(parents=True, exist_ok=True)

splits = ["train", "val", "test"]
for s in splits:
    (out_root / s).mkdir(parents=True, exist_ok=True)
    (dbg_root / s).mkdir(parents=True, exist_ok=True)

# --------------------------------------------------------------------------- #
# parametri globali (già definiti a monte, li ribadiamo qui per chiarezza)    #
# --------------------------------------------------------------------------- #
SHARD_SIZE      = 5_000        # nr. immagini per .tar
MAX_DEBUG_JPG   = 10           # max jpg di ispezione per (split, classe)

# --------------------------------------------------------------------------- #
# utilità: apertura/rotazione shard                                           #
# --------------------------------------------------------------------------- #
def _open_shard(split: str, idx: int):
    """Apre un nuovo tar in write-mode per lo split indicato."""
    path = out_root / split / f"patches-{idx:04d}.tar"
    return tarfile.open(path, "w"), path

# --------------------------------------------------------------------------- #
# stato per ciascuno split                                                    #
# --------------------------------------------------------------------------- #
state = {}
for s in splits:
    tar, _ = _open_shard(s, 0)
    state[s] = dict(
        tar         = tar,
        shard_idx   = 0,
        img_in_shard= 0,
        dbg_cnt     = {}        # subtype -> immagini salvate
    )

def _next_shard(split: str):
    """Chiude lo shard corrente e ne apre uno nuovo."""
    st = state[split]
    st["tar"].close()
    st["shard_idx"]   += 1
    st["img_in_shard"] = 0
    st["tar"], _       = _open_shard(split, st["shard_idx"])

# --------------------------------------------------------------------------- #
# estraiamo le patch – shuffle per distribuire meglio le classi               #
# --------------------------------------------------------------------------- #
slide_cache       = {}
patch_df_shuffled = patch_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

pbar = tqdm(total=len(patch_df_shuffled), desc="Extract", unit="patch")

for idx, row in patch_df_shuffled.iterrows():
    split = row["split"]
    st    = state[split]

    # 1) rotazione shard se necessario
    if st["img_in_shard"] >= SHARD_SIZE:
        _next_shard(split)

    # 2) estrai la patch
    try:
        img = extract_patch(row, slide_cache)
    except Exception as e:
        pbar.write(f"⚠️  skip {row['patient_id']} @({row['x']},{row['y']}): {e}")
        pbar.update(1)
        continue

    # 3) scrittura nel tar
    fname = f"{row['subtype']}_{row['patient_id']}_{idx:06d}.jpg"   # univoco nello split
    buf   = io.BytesIO()
    img.save(buf, format="JPEG", quality=90)
    ti = tarfile.TarInfo(fname)
    ti.size = buf.tell()
    buf.seek(0)
    st["tar"].addfile(ti, buf)
    st["img_in_shard"] += 1

    # 4) salvataggio debug (max N per classe e split)
    cnt = st["dbg_cnt"].get(row["subtype"], 0)
    if cnt < MAX_DEBUG_JPG:
        img.save(dbg_root / split / f"{row['subtype']}_{cnt}.jpg")
        st["dbg_cnt"][row["subtype"]] = cnt + 1

    pbar.update(1)

pbar.close()

# --------------------------------------------------------------------------- #
# chiusura definitiva di tutti gli shard                                      #
# --------------------------------------------------------------------------- #
for s in splits:
    state[s]["tar"].close()

print("\n✅ Estratti shard:")
for s in splits:
    print(f"  {s:<5}: {state[s]['shard_idx'] + 1} tar in {out_root / s}")

print(f"\n✅ Immagini di debug (max {MAX_DEBUG_JPG} / classe e split) in {dbg_root}")


Extract:   0%|          | 0/5000 [00:00<?, ?patch/s]

⚠️  skip HP15.12550 @(67110,98463): Cannot read raw tile
⚠️  skip HP14.5347 @(72656,81638): Cannot read raw tile
⚠️  skip HP19.3695 @(73551,145410): Cannot read raw tile
⚠️  skip HP17008718 @(16688,17218): Cannot read raw tile
⚠️  skip HP12.7601 @(26396,110327): Cannot read raw tile
⚠️  skip HP12.8793 @(60411,149999): Cannot read raw tile
⚠️  skip HP14.9097 @(65639,122052): Cannot read raw tile
⚠️  skip HP18005453 @(15985,7042): TIFFRGBAImageGet failed
⚠️  skip HP20001530 @(5132,22840): Cannot read raw tile
⚠️  skip HP14.1749 @(53876,72060): Unsupported or missing image file
⚠️  skip HP19.3695 @(75931,107526): Cannot read raw tile
⚠️  skip HP18009209 @(16169,6017): TIFFRGBAImageGet failed
⚠️  skip HP19.7715 @(83051,122227): Cannot read raw tile
⚠️  skip HP19.7840 @(27378,72573): Cannot read raw tile
⚠️  skip HP14.4279 @(63516,107374): Cannot read raw tile
⚠️  skip HP17.7980 @(54668,62174): Cannot read raw tile
⚠️  skip HP18009209 @(5525,7848): TIFFRGBAImageGet failed
⚠️  skip HP12.9282

KeyboardInterrupt: 



Extract:   5%
 246/5000 [06:56<1:07:08,  1.18patch/s]
⚠️  skip HP20.5602 @(920,2241): Unsupported or missing image file
⚠️  skip HP20.2506 @(49,11384): Cannot read raw tile
⚠️  skip HP18005453 @(12723,4433): Unsupported or missing image file
⚠️  skip HP11.6090 @(57962,143060): Cannot read raw tile
⚠️  skip HP20001530 @(36457,18206): Cannot read raw tile
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-12-de18dbf906fc> in <cell line: 0>()
     69     # 2) estrai la patch
     70     try:
---> 71         img = extract_patch(row, slide_cache)
     72     except Exception as e:
     73         pbar.write(f"⚠️  skip {row['patient_id']} @({row['x']},{row['y']}): {e}")

2 frames
/usr/local/lib/python3.11/dist-packages/openslide/lowlevel.py in _check_open(result, _func, _args)
    253
    254 # check for errors opening an image file and wrap the resulting handle
--> 255 def _check_open(result: int | None, _func: Any, _args: Any) -> _OpenSlide:
    256     if result is None:
    257         raise OpenSlideUnsupportedFormatError("Unsupported or missing image file")

KeyboardInterrupt:


In [None]:
Per ridurre i **tempi di estrazione delle patch**, ci sono diverse strategie che puoi adottare. Il collo di bottiglia principale in uno script come questo è spesso legato a:

* Estrazione delle immagini (`extract_patch`)
* Scrittura su disco (specie in `.tar`)
* Operazioni I/O in generale

Ecco **suggerimenti concreti** per ottimizzare:

---

### ✅ 1. **Parallelizzazione**

Usa il multiprocessing per processare righe in parallelo (soprattutto `extract_patch`, che è tipicamente costoso). Per esempio:

```python
from multiprocessing import Pool

def process_row(row_idx):
    row = patch_df_shuffled.iloc[row_idx]
    split = row["split"]
    # ...come sopra: shard, patch, save in memory...
    return (split, fname, buf.getvalue(), debug_data)
```

E poi usa:

```python
with Pool(processes=NUM_CORES) as pool:
    results = list(tqdm(pool.imap(process_row, range(len(patch_df_shuffled))), total=len(patch_df_shuffled)))
```

E successivamente salva i risultati nei tar, separatamente.

---

### ✅ 2. **Batching**

Se possibile, salva le immagini in batch invece che una per volta. Questo riduce il tempo speso nelle operazioni su disco.

> Alternativa: usa `wds.ShardWriter` da `webdataset` che gestisce efficientemente la scrittura a shard.

---

### ✅ 3. **Evita salvataggi JPEG per debug durante il loop principale**

Il salvataggio JPEG è lento. Puoi accodare le immagini da salvare per il debug in un buffer e salvarle a parte dopo il loop:

```python
debug_buffer = []

# In loop
if cnt < MAX_DEBUG_JPG:
    debug_buffer.append((split, row['subtype'], img.copy()))

# Dopo il loop
for split, subtype, img in debug_buffer:
    img.save(...)
```

---

### ✅ 4. **Disabilita temporaneamente la compressione tar**

Per test o per dataset di sviluppo, crea `.tar` non compressi (`tarfile.open(..., mode="w"`) invece di `w:gz` (che è più lento).

---

### ✅ 5. **Caching intelligente**

Se `extract_patch` lavora con grandi immagini (es. WSI), assicurati che `slide_cache` funzioni bene ed eviti ricaricamenti ridondanti da disco. Puoi anche usare `functools.lru_cache` se `extract_patch` è una funzione pura.

---

### ✅ 6. **Profilazione**

Usa `cProfile` o `%time` / `%timeit` per identificare il vero collo di bottiglia:

```python
import cProfile
cProfile.run("main_loop()")
```

---

### ✅ 7. **Salvataggio asincrono**

Con librerie come `aiofiles` (per file normali) o thread separati per `tarfile.addfile()`, puoi scrivere su disco in un thread separato rispetto all’elaborazione immagini.

---

### In sintesi – suggerimenti principali da provare subito:

* ✅ Multiprocessing per `extract_patch`
* ✅ Evita salvataggi JPEG in tempo reale
* ✅ Usa `ShardWriter` da `webdataset` se possibile
* ✅ Cache intelligente dei WSI
* ✅ Salvataggio asincrono o in batch

---

Se vuoi, posso aiutarti a **ristrutturare il codice con multiprocessing o `ShardWriter`**. Fammi sapere quale approccio preferisci.


In [None]:
# ===== Cell 5 – Estrazione patch → WebDataset, suddivisa per split ===========
import tarfile, io, os
from tqdm.auto import tqdm

# --------------------------------------------------------------------------- #
# directory di output (.tar) e di debug (.jpg)                                #
# --------------------------------------------------------------------------- #
out_root = root / "data/processed/webdataset"
dbg_root = root / "data/visual_debug/extract_examples"
out_root.mkdir(parents=True, exist_ok=True)
dbg_root.mkdir(parents=True, exist_ok=True)

splits = ["train", "val", "test"]
for s in splits:
    (out_root / s).mkdir(parents=True, exist_ok=True)
    (dbg_root / s).mkdir(parents=True, exist_ok=True)

# --------------------------------------------------------------------------- #
# parametri globali (già definiti a monte, li ribadiamo qui per chiarezza)    #
# --------------------------------------------------------------------------- #
SHARD_SIZE      = 5_000        # nr. immagini per .tar
MAX_DEBUG_JPG   = 10           # max jpg di ispezione per (split, classe)

# --------------------------------------------------------------------------- #
# utilità: apertura/rotazione shard                                           #
# --------------------------------------------------------------------------- #
def _open_shard(split: str, idx: int):
    """Apre un nuovo tar in write-mode per lo split indicato."""
    path = out_root / split / f"patches-{idx:04d}.tar"
    return tarfile.open(path, "w"), path

# --------------------------------------------------------------------------- #
# stato per ciascuno split                                                    #
# --------------------------------------------------------------------------- #
state = {}
for s in splits:
    tar, _ = _open_shard(s, 0)
    state[s] = dict(
        tar         = tar,
        shard_idx   = 0,
        img_in_shard= 0,
        dbg_cnt     = {}        # subtype -> immagini salvate
    )

def _next_shard(split: str):
    """Chiude lo shard corrente e ne apre uno nuovo."""
    st = state[split]
    st["tar"].close()
    st["shard_idx"]   += 1
    st["img_in_shard"] = 0
    st["tar"], _       = _open_shard(split, st["shard_idx"])

# --------------------------------------------------------------------------- #
# estraiamo le patch – shuffle per distribuire meglio le classi               #
# --------------------------------------------------------------------------- #
slide_cache       = {}
patch_df_shuffled = patch_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

pbar = tqdm(total=len(patch_df_shuffled), desc="Extract", unit="patch")

for idx, row in patch_df_shuffled.iterrows():
    split = row["split"]
    st    = state[split]

    # 1) rotazione shard se necessario
    if st["img_in_shard"] >= SHARD_SIZE:
        _next_shard(split)

    # 2) estrai la patch
    try:
        img = extract_patch(row, slide_cache)
    except Exception as e:
        pbar.write(f"⚠️  skip {row['patient_id']} @({row['x']},{row['y']}): {e}")
        pbar.update(1)
        continue

    # 3) scrittura nel tar
    fname = f"{row['subtype']}_{row['patient_id']}_{idx:06d}.jpg"   # univoco nello split
    buf   = io.BytesIO()
    img.save(buf, format="JPEG", quality=90)
    ti = tarfile.TarInfo(fname)
    ti.size = buf.tell()
    buf.seek(0)
    st["tar"].addfile(ti, buf)
    st["img_in_shard"] += 1

    # 4) salvataggio debug (max N per classe e split)
    cnt = st["dbg_cnt"].get(row["subtype"], 0)
    if cnt < MAX_DEBUG_JPG:
        img.save(dbg_root / split / f"{row['subtype']}_{cnt}.jpg")
        st["dbg_cnt"][row["subtype"]] = cnt + 1

    pbar.update(1)

pbar.close()

# --------------------------------------------------------------------------- #
# chiusura definitiva di tutti gli shard                                      #
# --------------------------------------------------------------------------- #
for s in splits:
    state[s]["tar"].close()

print("\n✅ Estratti shard:")
for s in splits:
    print(f"  {s:<5}: {state[s]['shard_idx'] + 1} tar in {out_root / s}")

print(f"\n✅ Immagini di debug (max {MAX_DEBUG_JPG} / classe e split) in {dbg_root}")


In [None]:
# ===== Cell 6: Visualizzazione esempi estratti ===============================
import matplotlib.pyplot as plt
from PIL import Image
import math, itertools

# Retrieve debug counts from the state variable defined in Cell 5
# We need to combine the counts from all splits (train, val, test)
debug_count = {}
for s in splits:
    for subtype, count in state[s]["dbg_cnt"].items():
        if subtype not in debug_count:
            debug_count[subtype] = 0
        debug_count[subtype] = max(debug_count[subtype], count) # Use the max count across splits

# Define jpg_debug_dir which was used but not defined in the original snippet
# This should point to the directory where the debug images were saved in Cell 5
jpg_debug_dir = dbg_root # Assuming dbg_root from Cell 5 is the correct directory

subtypes = sorted(debug_count.keys())
cols     = MAX_DEBUG_JPG
rows     = len(subtypes)

fig, axes = plt.subplots(rows, cols,
                         figsize=(cols*2.2, max(2, rows)*2.2),
                         squeeze=False)

for r, st in enumerate(subtypes):
    for c in range(cols):
        ax = axes[r][c]
        img_found = False
        # Iterate through splits to find the debug image
        for s in splits:
            img_p = jpg_debug_dir / s / f"{st}_{c}.jpg"
            if img_p.exists():
                ax.imshow(Image.open(img_p))
                img_found = True
                break # Stop searching once the image is found in one split

        if not img_found:
             # Optionally, display a placeholder or leave blank if no image is found
             ax.text(0.5, 0.5, 'N/A', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, color='gray', fontsize=10)


        ax.axis("off")
        if c == 0:
            ax.set_title(st, fontsize=12, loc="left")

plt.tight_layout()
plt.show()

In [None]:
# --------------------------------------------------------------------------- #
# Statistiche sui tar generati
# --------------------------------------------------------------------------- #
import glob

print("\n📦 Statistiche sui .tar generati per split:\n")
for s in splits:
    split_dir = out_root / s
    tar_paths = sorted(glob.glob(str(split_dir / "patches-*.tar")))
    print(f"Split '{s}': {len(tar_paths)} shard")
    for tp in tar_paths:
        with tarfile.open(tp, "r") as t:
            members = t.getmembers()
            n_items = len(members)
        size_mb = os.path.getsize(tp) / (1024**2)
        print(f"  • {os.path.basename(tp)} → {n_items:5d} immagini, {size_mb:6.2f} MB")
    print()
