# OCT Dataset Inspection Notebook

**Purpose:** Quickly inspect and summarise a retinal OCT dataset for classification tasks (e.g., NORMAL / CNV / DME / DRUSEN).

**What this notebook does:**
- Detect class folders and count images
- Build a metadata table (path, class, width, height, file size)
- Plot class balance and image size distribution (matplotlib only)
- Create a simple montage image preview (single-plot)
- Export a `labels.csv` ready for PyTorch training (split by patient if available)

**Instructions:**
1. Set `DATA_ROOT` below to your dataset path (parent folder containing class subfolders).
2. (Optional) Adjust `CLASS_DIRS` if your class folders differ.
3. Run each cell top-to-bottom.

> Note: This notebook intentionally avoids seaborn and subplots. Each figure is a single-plot matplotlib figure.

In [None]:
import os, sys, csv, math, random, hashlib
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# === CONFIG ===
DATA_ROOT = "/path/to/OCT_Dataset"  # <-- CHANGE THIS to your dataset root
# If None, class directories will be auto-detected as immediate subfolders of DATA_ROOT.
CLASS_DIRS = None  # e.g., ["CNV", "DME", "DRUSEN", "NORMAL"]

# Allowed image extensions
IMG_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

# Set a random seed for reproducibility
RNG_SEED = 42
random.seed(RNG_SEED)
np.random.seed(RNG_SEED)

DATA_ROOT = Path(DATA_ROOT)
assert DATA_ROOT.exists(), f"DATA_ROOT not found: {DATA_ROOT}"

if CLASS_DIRS is None:
    CLASS_DIRS = [p.name for p in DATA_ROOT.iterdir() if p.is_dir()]
CLASS_DIRS = sorted(CLASS_DIRS)
print("Detected classes:", CLASS_DIRS)

In [None]:
# Scan files and build a metadata table
rows = []
for cls in CLASS_DIRS:
    cdir = DATA_ROOT / cls
    if not cdir.exists():
        print(f"Warning: class folder missing: {cdir}")
        continue
    for root, _, files in os.walk(cdir):
        for fn in files:
            ext = os.path.splitext(fn)[1].lower()
            if ext in IMG_EXTS:
                fp = Path(root) / fn
                try:
                    with Image.open(fp) as im:
                        w, h = im.size
                except Exception as e:
                    print(f"[SKIP] {fp}: {e}")
                    continue
                size_bytes = fp.stat().st_size
                rows.append({
                    "path": str(fp),
                    "label_name": cls,
                    "width": w,
                    "height": h,
                    "filesize": size_bytes,
                })
meta = pd.DataFrame(rows)
assert len(meta) > 0, "No images found. Check DATA_ROOT and CLASS_DIRS."
print("Total images:", len(meta))
meta.head()

In [None]:
# Class counts
class_counts = meta['label_name'].value_counts().sort_index()
print(class_counts)

# Plot class balance (single plot)
plt.figure()
class_counts.plot(kind='bar')
plt.title('Class Counts')
plt.xlabel('Class')
plt.ylabel('Images')
plt.tight_layout()
plt.show()

In [None]:
# Image area distribution
areas = (meta['width'] * meta['height']).astype(int)
plt.figure()
plt.hist(areas, bins=50)
plt.title('Image Area Distribution (pixels)')
plt.xlabel('Area (WÃ—H)')
plt.ylabel('Count')
plt.tight_layout()
plt.show()
print("Mean area:", int(areas.mean()))
print("Median area:", int(areas.median()))
print("Min/Max area:", int(areas.min()), int(areas.max()))

In [None]:
# Build a single-image montage (no subplots): stitch a grid using PIL, then show once.
def build_montage(paths, grid=(4, 6), thumb=(160, 160)):
    cols, rows = grid[1], grid[0]
    W, H = thumb
    canvas = Image.new('L', (cols*W, rows*H), color=0)
    for i, p in enumerate(paths[:rows*cols]):
        try:
            with Image.open(p).convert('L') as im:
                im = im.resize((W, H))
                x = (i % cols) * W
                y = (i // cols) * H
                canvas.paste(im, (x, y))
        except Exception:
            pass
    return canvas

sample_paths = []
per_class = 6
for cls in CLASS_DIRS:
    paths = meta[meta['label_name']==cls]['path'].tolist()
    random.shuffle(paths)
    sample_paths.extend(paths[:per_class])

mont = build_montage(sample_paths, grid=(4, 6), thumb=(160, 160))
plt.figure()
plt.imshow(mont, cmap='gray')
plt.axis('off')
plt.title('Sample Montage')
plt.tight_layout()
plt.show()

In [None]:
# Optional: naive duplicate detection via file hash on a subset
def file_hash(path, block_size=1<<20):
    h = hashlib.md5()
    with open(path, 'rb') as f:
        while True:
            b = f.read(block_size)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

subset = meta.sample(min(len(meta), 2000), random_state=RNG_SEED).copy()
subset['md5'] = subset['path'].apply(file_hash)
dups = subset[subset.duplicated('md5', keep=False)].sort_values('md5')
print(f"Checked {len(subset)} files; potential duplicates found:", dups['md5'].nunique())
dups.head(10)

In [None]:
# Export labels.csv for training (image-level). You can later aggregate per patient.
label_to_id = {cls:i for i, cls in enumerate(CLASS_DIRS)}
meta['label'] = meta['label_name'].map(label_to_id)

# Heuristic patient_id from path (customise for your dataset)
from pathlib import Path
def infer_patient_id(p):
    p = Path(p)
    parent = p.parent.name
    stem_digits = ''.join([c for c in p.stem if c.isdigit()]) or 'na'
    return f"{parent}_{stem_digits}"

meta['patient_id'] = meta['path'].apply(infer_patient_id)

labels = meta[['patient_id','path','label','label_name']].copy()
labels.to_csv('labels.csv', index=False)
print('Wrote labels.csv with', len(labels), 'rows')
labels.head()

## Next Steps
- Verify `labels.csv` and adjust `patient_id` parsing to your dataset naming.
- Create **patient-level** splits (train/val/test) to avoid leakage.
- Plug `labels.csv` into your PyTorch scaffold (`experiments/exp001_baseline.yaml`).
- Start with a small subset to prove the pipeline, then scale up.