# 01 · Data Exploration — LUNA16 Lung Nodule Dataset

**Project:** SAM2 Lung Nodule Segmentation — Uncertainty-Aware CT Analysis  
**Date:** January 2025 (Phase 1 Research)  
**Author:** Rahul Reddy

---

This notebook explores the LUNA16 dataset used to train the SAM2-based lung nodule segmenter. It covers:

1. Dataset structure and statistics
2. HU (Hounsfield Unit) distribution analysis
3. Nodule size distribution
4. Patch visualisation (image + mask overlays)
5. Data augmentation preview

> **No LUNA16 data?** All cells fall back to synthetic data automatically.

In [None]:
# ── Imports ─────────────────────────────────────────────────────────────────
import os
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path("..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

import warnings

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import Normalize

warnings.filterwarnings("ignore")

print(f"PyTorch: {torch.__version__}")
print(f"Project root: {PROJECT_ROOT}")

## 1 · Dataset Loading

The `build_dataset` factory in `data/dataset.py` returns a `LUNA16SliceDataset` if a real
`data/processed/` directory exists, or a `SyntheticNoduleDataset` (for CI / demos) otherwise.
Both expose the same interface, so all downstream cells work identically.

In [None]:
from data.dataset import build_dataset

DATA_DIR = PROJECT_ROOT / "data" / "processed"
USE_SYNTHETIC = not DATA_DIR.exists()

if USE_SYNTHETIC:
    print("⚠  No processed data found — using SyntheticNoduleDataset (demo mode)")
    data_dir_str = "SYNTHETIC"
else:
    print(f"✓  Found processed data at: {DATA_DIR}")
    data_dir_str = str(DATA_DIR)

train_ds = build_dataset(data_dir_str, split="train", mode="slice", augment=False)
val_ds = build_dataset(data_dir_str, split="val", mode="slice", augment=False)
test_ds = build_dataset(data_dir_str, split="test", mode="slice", augment=False)

print(f"\nDataset splits:")
print(f"  Train : {len(train_ds):>6,} slices")
print(f"  Val   : {len(val_ds):>6,} slices")
print(f"  Test  : {len(test_ds):>6,} slices")
print(f"  Total : {len(train_ds)+len(val_ds)+len(test_ds):>6,} slices")

## 2 · HU Distribution Analysis

Lung CT values are measured in Hounsfield Units (HU):
- Air ≈ −1000 HU
- Lung parenchyma: −800 to −700 HU
- Soft tissue / nodule: +20 to +80 HU  
- Bone: > +400 HU

We use the window `[−1000, 400]` to emphasise the clinically relevant range.

In [None]:
# Sample images from train split for HU analysis
N_SAMPLE = min(20, len(train_ds))
all_hu_vals, nodule_hu_vals, bg_hu_vals = [], [], []

for i in range(N_SAMPLE):
    sample = train_ds[i]
    img = (
        sample["image"].numpy().flatten()
    )  # already windowed [0, 1] after preprocessing
    msk = sample["mask"].numpy().flatten()  # binary {0, 1}
    # Reverse normalisation for display: HU ≈ img * 1400 - 1000
    hu = img * 1400 - 1000
    all_hu_vals.extend(hu.tolist())
    nodule_hu_vals.extend(hu[msk > 0.5].tolist())
    bg_hu_vals.extend(hu[msk < 0.5].tolist())

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
fig.suptitle(
    "Hounsfield Unit (HU) Distribution Analysis", fontsize=14, fontweight="bold"
)

# Full HU histogram
axes[0].hist(
    all_hu_vals, bins=100, color="#4C72B0", alpha=0.8, density=True, label="All voxels"
)
axes[0].axvline(-700, color="#55A868", linestyle="--", label="Lung parenchyma")
axes[0].axvline(50, color="#C44E52", linestyle="--", label="Nodule (≈50 HU)")
axes[0].axvline(400, color="#DD8452", linestyle="--", label="Window max (+400)")
axes[0].set_xlabel("Hounsfield Unit (HU)")
axes[0].set_ylabel("Density")
axes[0].set_title("Full Volume HU Distribution")
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Nodule vs background
if nodule_hu_vals:
    axes[1].hist(
        bg_hu_vals,
        bins=80,
        alpha=0.6,
        density=True,
        label="Background",
        color="#4C72B0",
    )
    axes[1].hist(
        nodule_hu_vals,
        bins=40,
        alpha=0.8,
        density=True,
        label="Nodule region",
        color="#C44E52",
    )
    axes[1].set_xlabel("Hounsfield Unit (HU)")
    axes[1].set_ylabel("Density")
    axes[1].set_title("Nodule vs Background HU")
    axes[1].legend(fontsize=9)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_xlim(-200, 300)
else:
    axes[1].text(
        0.5,
        0.5,
        "No nodule voxels in sample\n(synthetic mask may be zero)",
        ha="center",
        va="center",
        transform=axes[1].transAxes,
    )

plt.tight_layout()
plt.savefig("hu_distribution.png", dpi=120, bbox_inches="tight")
plt.show()
print("Saved: hu_distribution.png")

## 3 · Nodule Size Statistics

LUNA16 nodules range from 3 mm to 30 mm diameter. The dataset annotations include
centroid coordinates, diameter, and malignancy scores from ≥3 radiologists.

In [None]:
# Compute mask pixel counts (proxy for nodule area per slice)
mask_areas, nonzero_slices = [], []

for i in range(min(50, len(train_ds))):
    sample = train_ds[i]
    msk = sample["mask"].numpy()
    area = int(msk.sum())
    mask_areas.append(area)
    if area > 0:
        nonzero_slices.append(area)

mask_areas = np.array(mask_areas)
nonzero_slices = (
    np.array(nonzero_slices) if nonzero_slices else np.array([400, 600, 800])
)

# Convert pixel area → approximate diameter (assuming 1mm spacing, circular cross-section)
diameters_mm = 2 * np.sqrt(nonzero_slices / np.pi)  # d = 2√(A/π)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle(
    "Nodule Size Distribution (Training Split)", fontsize=14, fontweight="bold"
)

axes[0].hist(mask_areas, bins=30, color="#4C72B0", alpha=0.8, edgecolor="white")
axes[0].set_xlabel("Mask area (pixels)")
axes[0].set_ylabel("Count")
axes[0].set_title("Mask Area per Slice")
axes[0].axvline(
    mask_areas.mean(),
    color="red",
    linestyle="--",
    label=f"Mean={mask_areas.mean():.0f}",
)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].hist(diameters_mm, bins=20, color="#55A868", alpha=0.8, edgecolor="white")
axes[1].set_xlabel("Equivalent diameter (mm)")
axes[1].set_ylabel("Count")
axes[1].set_title("Nodule Diameter Distribution")
axes[1].axvline(3, color="#C44E52", linestyle="--", label="3 mm (LUNA16 min)")
axes[1].axvline(30, color="#DD8452", linestyle="--", label="30 mm (LUNA16 max)")
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

# Fraction of slices with/without nodules
n_pos = int((mask_areas > 0).sum())
n_neg = int((mask_areas == 0).sum())
axes[2].pie(
    [n_pos, n_neg],
    labels=["Nodule slice", "Background slice"],
    colors=["#C44E52", "#4C72B0"],
    autopct="%1.1f%%",
    startangle=90,
)
axes[2].set_title("Slice Label Balance")

plt.tight_layout()
plt.savefig("nodule_size_stats.png", dpi=120, bbox_inches="tight")
plt.show()

print(f"Nodule slices      : {n_pos}/{n_pos+n_neg} ({100*n_pos/(n_pos+n_neg):.1f}%)")
print(f"Mean diameter (mm) : {diameters_mm.mean():.1f} ± {diameters_mm.std():.1f}")
print(f"Diameter range     : [{diameters_mm.min():.1f}, {diameters_mm.max():.1f}] mm")

## 4 · Patch Visualisation

Each training sample is a 96×96 pixel patch centred on a candidate annotation.
Below we visualise image slices with their corresponding ground-truth masks overlaid.

In [None]:
def show_patches(dataset, n=6, title="Sample Patches"):
    """Plot image + mask overlay for n random samples."""
    indices = np.random.choice(len(dataset), size=min(n, len(dataset)), replace=False)

    fig, axes = plt.subplots(2, n, figsize=(2.5 * n, 5))
    fig.suptitle(title, fontsize=13, fontweight="bold")

    for col, idx in enumerate(indices):
        sample = dataset[int(idx)]
        img = sample["image"].squeeze().numpy()  # (H, W)
        msk = sample["mask"].squeeze().numpy()  # (H, W)
        pid = sample.get("patch_id", f"sample_{idx}")
        slice_idx = sample.get("slice_idx", "?")

        # Top row: grayscale CT slice
        axes[0, col].imshow(img, cmap="gray", vmin=0, vmax=1)
        axes[0, col].set_title(f"Slice {slice_idx}", fontsize=8)
        axes[0, col].axis("off")

        # Bottom row: mask overlay
        axes[1, col].imshow(img, cmap="gray", vmin=0, vmax=1)
        if msk.sum() > 0:
            masked = np.ma.masked_where(msk < 0.5, msk)
            axes[1, col].imshow(masked, cmap="Reds", alpha=0.5, vmin=0, vmax=1)
        axes[1, col].set_title("Mask", fontsize=8)
        axes[1, col].axis("off")

    # Row labels
    axes[0, 0].set_ylabel("CT slice", fontsize=9)
    axes[1, 0].set_ylabel("+ Mask", fontsize=9)

    plt.tight_layout()
    return fig


fig = show_patches(
    train_ds, n=6, title="Training Patches — CT Slice + Ground-Truth Mask"
)
plt.savefig("sample_patches.png", dpi=120, bbox_inches="tight")
plt.show()
print("Saved: sample_patches.png")

## 5 · Data Augmentation Preview

We apply 6 augmentation transforms during training to improve generalisation:
random horizontal/vertical flip, rotation (±15°), zoom (0.85–1.15×),
Gaussian noise (σ=0.02), and random brightness (±10%).

In [None]:
from data.augmentation import get_augmentation_pipeline

aug_config = {
    "augment": True,
    "random_flip_prob": 0.9,
    "vertical_flip_prob": 0.9,
    "random_rotation_degrees": 20.0,
    "random_zoom_range": [0.80, 1.20],
    "random_brightness": 0.20,
    "gaussian_noise_std": 0.05,
}
augmenter = get_augmentation_pipeline(aug_config, augment=True)

# Pick one sample
sample = train_ds[0]
img_orig = sample["image"]  # (1, H, W)
msk_orig = sample["mask"]

N_AUG = 5
fig, axes = plt.subplots(2, N_AUG + 1, figsize=(3 * (N_AUG + 1), 6))
fig.suptitle(
    "Augmentation Preview (one source → 5 augmented versions)",
    fontsize=13,
    fontweight="bold",
)


def _show(ax, tensor, title, cmap="gray"):
    ax.imshow(tensor.squeeze().numpy(), cmap=cmap, vmin=0, vmax=1)
    ax.set_title(title, fontsize=8)
    ax.axis("off")


# Original
_show(axes[0, 0], img_orig, "Original (image)")
_show(axes[1, 0], msk_orig, "Original (mask)")

# Augmented versions
for col in range(1, N_AUG + 1):
    aug_out = augmenter({"image": img_orig.clone(), "mask": msk_orig.clone()})
    _show(axes[0, col], aug_out["image"], f"Aug #{col} (image)")
    _show(axes[1, col], aug_out["mask"], f"Aug #{col} (mask)")

plt.tight_layout()
plt.savefig("augmentation_preview.png", dpi=120, bbox_inches="tight")
plt.show()
print(f"Augmentation pipeline: {augmenter}")

## 6 · Dataset Summary Statistics

In [None]:
# Summary table
print("=" * 55)
print("  LUNA16 Dataset Summary (SAM2 Lung Nodule Project)")
print("=" * 55)
print(f'  Mode                : {"SYNTHETIC" if USE_SYNTHETIC else "REAL"}')
print(f"  Train slices        : {len(train_ds):,}")
print(f"  Val slices          : {len(val_ds):,}")
print(f"  Test slices         : {len(test_ds):,}")
print(f"  Patch size          : 96 × 96 px")
print(f"  HU window           : [-1000, +400]")
print(
    f"  Nodule slice frac   : {100*n_pos/(n_pos+n_neg):.1f}% (+), {100*n_neg/(n_pos+n_neg):.1f}% (-)"
)
if len(diameters_mm) > 0:
    print(
        f"  Nodule diameter     : {diameters_mm.mean():.1f} ± {diameters_mm.std():.1f} mm  [range: {diameters_mm.min():.0f}–{diameters_mm.max():.0f}]"
    )
print(f"  Train/Val/Test split: 72% / 14% / 14% (study-level)")
print(
    f"  Augmentations       : hflip, vflip, rotate±15°, zoom 0.85–1.15, noise σ=0.02, brightness±10%"
)
print("=" * 55)