# Hugging Face Image Dataset Browser

This notebook demonstrates how to load an image dataset from the Hugging Face Hub and visualize samples. The dataset is defined in one place at the top via `DATASET_ID`, currently set to `friedrice231/SG_Memes`.

References:
- Hugging Face Datasets: Load a dataset from the Hub ([docs](https://huggingface.co/docs/datasets/load_hub))
- Dataset: `friedrice231/SG_Memes` ([dataset card](https://huggingface.co/datasets/friedrice231/SG_Memes))

> Tip: Use streaming to avoid downloading the entire dataset if you only need to preview samples.


In [None]:
# Optional: install dependencies in the notebook environment
# You can skip this if already installed in your env
# %pip install -q datasets pillow matplotlib

from typing import List, Optional
from dataclasses import dataclass

import os
from datasets import load_dataset, get_dataset_split_names, load_dataset_builder
import matplotlib.pyplot as plt
from PIL import Image

# For better inline figure display
%matplotlib inline


ModuleNotFoundError: No module named 'datasets'

In [None]:
# Configuration (single source of truth)
DATASET_ID = "friedrice231/SG_Memes"  # change this to switch datasets

# If set, load a specific split (e.g., "train"). If None, the loader will
# automatically choose a sensible default ("train" if present, else the first split).
BROWSE_SPLIT = None

# If True, load via streaming to avoid full download; if False, download/cache the full dataset
USE_STREAMING = False

# Number of samples to show/save in some utilities
NUM_SAMPLES_DEFAULT = 4


In [None]:
# Inspect dataset builder info and available splits
builder = load_dataset_builder(DATASET_ID)
print("Description:\n", (builder.info.description or "").strip()[:600], "...\n")
print("Features:", builder.info.features)

try:
    splits = get_dataset_split_names(DATASET_ID)
    print("Available splits:", splits)
except Exception as e:
    print("Could not list splits:", e)


In [None]:
# Load dataset (streaming vs full download)
if USE_STREAMING:
    ds = load_dataset(DATASET_ID, split=SPLIT, streaming=True)
    print("Loaded streaming dataset.")
else:
    ds = load_dataset(DATASET_ID, split=SPLIT)
    print("Loaded dataset:", ds)

# Peek first element safely (works for both streaming and non-streaming)
first_row = next(iter(ds)) if USE_STREAMING else ds[0]
print("Columns:", list(first_row.keys()))
print("Example filename:", first_row.get("filename"))
print("Example image id:", first_row.get("img_id"))
print("Number of captions:", len(first_row.get("caption", [])))


In [None]:
def show_sample(row: dict, title: Optional[str] = None):
    """Display a single sample with image and up to 2 captions."""
    image = row.get("image")
    captions = row.get("caption", [])

    # Hugging Face Datasets with Image feature typically returns PIL.Image.Image
    if isinstance(image, Image.Image):
        plt.figure(figsize=(6, 6))
        plt.imshow(image)
        plt.axis("off")
        title_text = title or (captions[0] if captions else "(no caption)")
        plt.title(title_text)
        plt.show()
    else:
        print("Image not decoded as PIL.Image; got type:", type(image))
        print("Row keys:", list(row.keys()))


def show_grid(rows: List[dict], cols: int = 2):
    """Display a grid of samples (images + first caption)."""
    if not rows:
        print("No rows to display.")
        return
    rows_count = len(rows)
    col_count = max(1, cols)
    row_count = max(1, (rows_count + col_count - 1) // col_count)

    plt.figure(figsize=(5 * col_count, 5 * row_count))
    for i, row in enumerate(rows):
        image = row.get("image")
        captions = row.get("caption", [])
        ax = plt.subplot(row_count, col_count, i + 1)
        if isinstance(image, Image.Image):
            ax.imshow(image)
            ax.axis("off")
            ax.set_title(captions[0] if captions else "(no caption)")
        else:
            ax.text(0.5, 0.5, f"Non-image type: {type(image)}", ha="center")
            ax.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
# Utilities: sample selection and keyword search
import random

random.seed(7)

def get_n_samples(n: int = NUM_SAMPLES_DEFAULT):
    if USE_STREAMING:
        # For streaming, materialize the first n items
        return [row for _, row in zip(range(n), ds)]
    else:
        indices = random.sample(range(len(ds)), k=min(n, len(ds)))
        return [ds[i] for i in indices]


def search_by_keyword(keyword: str, limit: int = 12):
    """Return up to `limit` rows where any caption contains the keyword (case-insensitive)."""
    keyword_lower = keyword.lower()
    results = []
    if USE_STREAMING:
        for row in ds:
            captions = row.get("caption", [])
            if any(keyword_lower in c.lower() for c in captions):
                results.append(row)
                if len(results) >= limit:
                    break
    else:
        for i in range(len(ds)):
            row = ds[i]
            captions = row.get("caption", [])
            if any(keyword_lower in c.lower() for c in captions):
                results.append(row)
                if len(results) >= limit:
                    break
    return results


In [None]:
# Quick test: show one random sample and a small grid
samples = get_n_samples(4)

# Show a single sample
show_sample(samples[0])

# Show a grid
show_grid(samples, cols=2)


In [None]:
# Optional: save a few samples to disk
from pathlib import Path

def save_samples(rows: list, out_dir: str = "./data/flickr30k-samples"):
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    for idx, row in enumerate(rows):
        image = row.get("image")
        filename = row.get("filename") or f"sample_{idx:04d}.jpg"
        if isinstance(image, Image.Image):
            target = out_path / filename
            # Ensure JPEG extension
            if target.suffix.lower() not in {".jpg", ".jpeg", ".png"}:
                target = target.with_suffix(".jpg")
            image.save(target)
            # Also save a tiny text file with the first caption
            captions = row.get("caption", [])
            if captions:
                (out_path / f"{target.stem}.txt").write_text(captions[0])
        else:
            print(f"Skipping non-image sample at index {idx}.")

# Example: save 6 random samples
samples_to_save = get_n_samples(6)
save_samples(samples_to_save)
print("Saved", len(samples_to_save), "samples.")
