# 01 — Data Exploration

Explore the captioning dataset: visualise samples, inspect caption statistics,
and preview augmentation transforms.

In [None]:
import json
import sys
from pathlib import Path
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import yaml
from PIL import Image

# Make src/ importable
sys.path.insert(0, str(Path("../src").resolve()))

# Load config
with open("../configs/default.yaml") as f:
    cfg = yaml.safe_load(f)

IMAGE_ROOT = Path("..") / cfg["data"]["image_root"]
TRAIN_PATH = Path("..") / cfg["data"]["train_annotations"]

print(f"Image root : {IMAGE_ROOT}")
print(f"Train file : {TRAIN_PATH}")

## 1. Load annotations

In [None]:
def load_annotations(path):
    """Load from JSON or JSONL, flattening multi-caption entries."""
    path = Path(path)
    if path.suffix == ".jsonl":
        raw = [json.loads(l) for l in open(path) if l.strip()]
    else:
        raw = json.load(open(path))
    flat = []
    for e in raw:
        if "captions" in e:
            for c in e["captions"]:
                flat.append({"image": e["image"], "caption": c})
        else:
            flat.append(e)
    return flat

annotations = load_annotations(TRAIN_PATH)
print(f"Total caption entries: {len(annotations):,}")
unique_images = set(e["image"] for e in annotations)
print(f"Unique images:         {len(unique_images):,}")
print(f"Captions per image:    {len(annotations) / max(len(unique_images), 1):.1f}")
print()
print("Sample entry:")
annotations[0]

## 2. Sample images with captions

In [None]:
# Show 8 random samples
rng = np.random.default_rng(42)
indices = rng.choice(len(annotations), size=min(8, len(annotations)), replace=False)

fig, axes = plt.subplots(2, 4, figsize=(18, 9))
for ax, idx in zip(axes.flat, indices):
    entry = annotations[idx]
    img_path = IMAGE_ROOT / entry["image"]
    try:
        img = Image.open(img_path).convert("RGB")
    except Exception:
        img = Image.new("RGB", (224, 224), (200, 200, 200))
    ax.imshow(img)
    caption = entry["caption"]
    wrapped = "\n".join([caption[i:i+40] for i in range(0, len(caption), 40)])
    ax.set_title(wrapped, fontsize=8)
    ax.axis("off")
plt.suptitle("Random training samples", fontsize=14)
plt.tight_layout()
plt.show()

## 3. Caption length distribution

In [None]:
char_lengths = [len(e["caption"]) for e in annotations]
word_lengths = [len(e["caption"].split()) for e in annotations]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.hist(char_lengths, bins=60, edgecolor="black", alpha=0.7)
ax1.set_xlabel("Caption length (characters)")
ax1.set_ylabel("Count")
ax1.set_title("Character-length distribution")
ax1.axvline(np.mean(char_lengths), color="red", linestyle="--",
            label=f"mean = {np.mean(char_lengths):.0f}")
ax1.legend()

ax2.hist(word_lengths, bins=40, edgecolor="black", alpha=0.7, color="orange")
ax2.set_xlabel("Caption length (words)")
ax2.set_ylabel("Count")
ax2.set_title("Word-length distribution")
ax2.axvline(np.mean(word_lengths), color="red", linestyle="--",
            label=f"mean = {np.mean(word_lengths):.1f}")
ax2.legend()

plt.tight_layout()
plt.show()

print(f"Characters — min: {min(char_lengths)}, max: {max(char_lengths)}, "
      f"mean: {np.mean(char_lengths):.1f}, median: {np.median(char_lengths):.0f}")
print(f"Words      — min: {min(word_lengths)}, max: {max(word_lengths)}, "
      f"mean: {np.mean(word_lengths):.1f}, median: {np.median(word_lengths):.0f}")

## 4. Vocabulary statistics

In [None]:
# Build vocabulary from all captions
word_counter = Counter()
for e in annotations:
    tokens = e["caption"].lower().split()
    word_counter.update(tokens)

vocab_size = len(word_counter)
total_tokens = sum(word_counter.values())
print(f"Vocabulary size:  {vocab_size:,}")
print(f"Total tokens:     {total_tokens:,}")
print(f"Hapax legomena:   {sum(1 for c in word_counter.values() if c == 1):,} "
      f"({sum(1 for c in word_counter.values() if c == 1) / vocab_size * 100:.1f}%)")
print()

# Top 30 most common words
top30 = word_counter.most_common(30)
words, counts = zip(*top30)

fig, ax = plt.subplots(figsize=(14, 5))
ax.bar(words, counts, edgecolor="black", alpha=0.7)
ax.set_ylabel("Frequency")
ax.set_title("Top 30 most common words")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
# Word frequency vs rank (Zipf's law)
freqs = sorted(word_counter.values(), reverse=True)

fig, ax = plt.subplots(figsize=(8, 5))
ax.loglog(range(1, len(freqs) + 1), freqs, alpha=0.7)
ax.set_xlabel("Rank")
ax.set_ylabel("Frequency")
ax.set_title("Word frequency vs. rank (log-log)")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Augmented image samples

Preview the training-time augmentation pipeline defined in the config.

In [None]:
from dataset import build_train_transforms, build_eval_transforms

aug_cfg = cfg["data"].get("augmentation", {})
image_size = cfg["data"]["image_size"]

train_tfm = build_train_transforms(aug_cfg, image_size)
eval_tfm = build_eval_transforms(image_size)

print("Train transforms:", train_tfm)
print("Eval transforms: ", eval_tfm)

In [None]:
# Pick one image and show it with several augmentation rolls
sample = annotations[0]
img_path = IMAGE_ROOT / sample["image"]
try:
    original = Image.open(img_path).convert("RGB")
except Exception:
    original = Image.new("RGB", (320, 240), (180, 180, 180))

fig, axes = plt.subplots(2, 5, figsize=(18, 7))

# Row 1: original + eval transform
axes[0, 0].imshow(original)
axes[0, 0].set_title("Original", fontsize=9)
axes[0, 0].axis("off")

eval_img = eval_tfm(original)
axes[0, 1].imshow(eval_img)
axes[0, 1].set_title("Eval transform", fontsize=9)
axes[0, 1].axis("off")

for ax in axes[0, 2:]:
    ax.axis("off")

# Row 2: 5 different augmentation rolls
for i in range(5):
    aug_img = train_tfm(original)
    axes[1, i].imshow(aug_img)
    axes[1, i].set_title(f"Aug #{i+1}", fontsize=9)
    axes[1, i].axis("off")

plt.suptitle(f"Augmentation preview — {sample['image']}", fontsize=12)
plt.tight_layout()
plt.show()