## Usage
- Update the configuration cell with your COCO JSON and image directory paths.
- Run the loader cell to build a sample list (optionally capped by `SAMPLE_LIMIT`).
- Use the interactive viewer to step through each augmented image with overlayed category masks.
- Render the grid cell for a quick collage of random samples.

In [2]:
%matplotlib inline

from pathlib import Path
import random
import sys
from typing import Dict, List, Sequence, Tuple

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image

try:
    import ipywidgets as widgets
except Exception:
    widgets = None

# Ensure the repository root is on sys.path so diffusion_module can be imported when the notebook
# is opened from the subdirectory (e.g., VS Code's default working directory behavior).
project_root = Path.cwd()
possible_roots = {project_root, project_root.parent}
for root in possible_roots:
    candidate = root / "diffusion_module"
    if candidate.exists() and str(root) not in sys.path:
        sys.path.append(str(root))

from diffusion_module.coco_io import CocoDataset

plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["axes.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 120

In [None]:
# --- Configure these paths before running the rest of the notebook ---
COCO_JSON_PATH = Path("/path/to/annotations.coco.json")
IMAGE_ROOT = Path("/path/to/image/root")

SAMPLE_LIMIT = 24  # Set to None to use every sample
RANDOM_SEED = 42
DISPLAY_COLUMNS = 2
MASK_ALPHA = 0.45
OUTLINE_WIDTH = 2

In [None]:
if not COCO_JSON_PATH.exists():
    raise FileNotFoundError(f"COCO file not found: {COCO_JSON_PATH}")
if not IMAGE_ROOT.exists():
    raise FileNotFoundError(f"Image root not found: {IMAGE_ROOT}")

coco_ds = CocoDataset(COCO_JSON_PATH)
all_samples = list(coco_ds.iter_samples(IMAGE_ROOT))
if not all_samples:
    raise RuntimeError("No image entries were found for the provided COCO file")

if SAMPLE_LIMIT is not None:
    random.seed(RANDOM_SEED)
    samples = random.sample(all_samples, min(SAMPLE_LIMIT, len(all_samples)))
else:
    samples = all_samples

cat_lookup: Dict[int, str] = {cid: cat.get("name", f"category_{cid}") for cid, cat in coco_ds.categories.items()}
num_annotations = sum(len(sample.annotations) for sample in samples)
print(f"Loaded {len(samples)} samples with {num_annotations} annotations spanning {len(cat_lookup)} categories.")

In [None]:
def category_color(category_id: int) -> Tuple[float, float, float]:
    rng = np.random.default_rng(category_id)
    color = rng.random(3)
    return tuple(color.tolist())


def collect_annotation_stats(sample) -> List[Dict]:
    stats: List[Dict] = []
    total_pixels = sample.width * sample.height
    for ann in sample.annotations:
        mask = CocoDataset.annotation_mask(ann, sample.height, sample.width)
        coverage = float(mask.sum()) / total_pixels if total_pixels else 0.0
        stats.append(
            {
                "category_id": ann["category_id"],
                "category_name": cat_lookup.get(ann["category_id"], str(ann["category_id"])),
                "bbox": ann["bbox"],
                "coverage": coverage,
                "area": float(ann.get("area", coverage * total_pixels)),
            }
        )
    return stats


def _mask_to_rgba(mask: np.ndarray, color: Tuple[float, float, float], alpha: float) -> np.ndarray:
    overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.float32)
    overlay[..., :3] = color
    overlay[..., 3] = (mask > 0).astype(np.float32) * alpha
    return overlay


def render_sample(sample, ax=None, alpha: float = MASK_ALPHA):
    if sample.image_path.exists():
        image = Image.open(sample.image_path).convert("RGB")
    else:
        print(f"[WARN] Missing image file: {sample.image_path}")
        image = np.zeros((sample.height, sample.width, 3), dtype=np.uint8)

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
        created_fig = True
    ax.imshow(image)
    if not sample.annotations:
        ax.set_title(Path(sample.image_info["file_name"]).name)
        ax.axis("off")
        if created_fig:
            plt.show()
        return ax

    for ann in sample.annotations:
        mask = CocoDataset.annotation_mask(ann, sample.height, sample.width)
        if mask.sum() == 0:
            continue
        color = category_color(ann["category_id"])
        ax.imshow(_mask_to_rgba(mask, color, alpha))
        x, y, w, h = ann["bbox"]
        rect = patches.Rectangle((x, y), w, h, linewidth=OUTLINE_WIDTH, edgecolor=color, facecolor="none")
        ax.add_patch(rect)
        label = f"{cat_lookup.get(ann["category_id"], ann["category_id"])}"
        ax.text(
            x,
            max(0, y - 4),
            label,
            fontsize=9,
            color="white",
            bbox=dict(facecolor=color, alpha=0.7, pad=1),
        )
    ax.set_title(Path(sample.image_info["file_name"]).name)
    ax.axis("off")
    if created_fig:
        plt.show()
    return ax


def preview_sample(index: int) -> None:
    if not samples:
        raise RuntimeError("No samples loaded; run the loader cell first")
    index = max(0, min(index, len(samples) - 1))
    sample = samples[index]
    print(f"Sample {index + 1}/{len(samples)} | {sample.image_path}")
    for stat in collect_annotation_stats(sample):
        print(
            f" - {stat['category_name']} (id={stat['category_id']}): area={stat['area']:.1f}, coverage={stat['coverage']:.2%}"
        )
    render_sample(sample, alpha=MASK_ALPHA)

In [None]:
if widgets and len(samples) > 1:
    _slider = widgets.IntSlider(min=0, max=len(samples) - 1, step=1, description="index")
    widgets.interact(preview_sample, index=_slider)
else:
    preview_sample(0)

In [None]:
GRID_ROWS = 2
GRID_COLS = max(1, DISPLAY_COLUMNS)

def show_random_grid(rows: int = GRID_ROWS, cols: int = GRID_COLS, alpha: float = MASK_ALPHA):
    total = rows * cols
    picks = samples if len(samples) <= total else random.sample(samples, total)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = np.atleast_1d(axes).ravel()
    for ax, sample in zip(axes, picks):
        render_sample(sample, ax=ax, alpha=alpha)
        ax.set_title(Path(sample.image_info["file_name"]).stem, fontsize=9)
    for ax in axes[len(picks):]:
        ax.axis("off")
    fig.tight_layout()
    return fig

show_random_grid()