In [33]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
from pathlib import Path
from PIL import Image
import random
import itertools

In [34]:
def load_data(
    dataset='tiny-imagenet-200',
    transformation=None,
    n_test=None,
    n_train=None,
    data_root='./dataset',
    words_file='words.txt',
    shuffle=True,
):
    train_dir = Path(data_root) / dataset / dataset / 'train'
    words_path = Path(data_root) / dataset / dataset / words_file

    # -----------------------------------------------------
    # 1) dictanary mapping class IDs(0,1,2...) to names
    # -----------------------------------------------------

    #load raw mapping: WordNet ID (nXXXX) -> readable name
    wnid_to_name = {}
    with open(words_path, 'r') as f:
        for line in f:
            parts = line.strip().split(None, 1)
            if len(parts) == 2:
                cls_id, cls_name = parts
                wnid_to_name[cls_id] = cls_name

    class_ids = sorted([d for d in os.listdir(train_dir) if os.path.isdir(train_dir / d)])
    class_to_idx = {cid: i for i, cid in enumerate(class_ids)}
    id_to_name = {}
    for i, cid in enumerate(class_ids):
        # Uses .get() to return the ID itself if the name isn't found in words.txt
        id_to_name[i] = wnid_to_name.get(cid, cid)

    # -----------------------------------------------------
    # 2) split train/test per class
    # -----------------------------------------------------

    # collect image paths for each class
    class_files = {}
    for cid in class_ids:
        cls_path = train_dir / cid / 'images'
        imgs = [
            cls_path / f for f in os.listdir(cls_path) if f.lower().endswith(('jpg', 'jpeg', 'png'))
        ]
        if shuffle:
            random.shuffle(imgs)
        class_files[cid] = imgs
    
    train_list = []
    test_list = []

    for cid in class_ids:
        imgs = class_files[cid]
        label = class_to_idx[cid] 

        # take fixed number for test set
        n_test = n_test if n_test is not None else 0
        
        test_imgs = imgs[:n_test]
        train_imgs = imgs[n_test:]
        
        # optional global cap on training data
        if n_train is not None:
            limit = n_train // len(class_ids)
            train_imgs = train_imgs[:limit]

        for img in train_imgs:
            train_list.append((img, label))
        for img in test_imgs:
            test_list.append((img, label))
    
    if shuffle:
        random.shuffle(train_list)
        random.shuffle(test_list)

    # -----------------------------------------------------
    # 3) generators
    # -----------------------------------------------------
    def train_generator():
        for path, label in train_list:
            img = Image.open(path).convert('RGB')
            if transformation:
                img = transformation(img)
            yield img, label

    def test_generator():
        for path, label in test_list:
            img = Image.open(path).convert('RGB')
            if transformation:
                img = transformation(img)
            yield img, label

    # Returns generator, generator, and the {int -> name} dictionary
    return train_generator(), test_generator(), id_to_name


In [35]:
def show(x, id_to_name, outfile=None, title=None):
    """
    Consumes 16 images from the generators and plots them in a 4x4 grid.

    """

    def show_grid(generator, title_text):
        # Create a figure with 4x4 subplots
        fig, axes = plt.subplots(4, 4, figsize=(10, 10))
        fig.suptitle(title_text, fontsize=20)
        
        # Fetch the first 16 items from the generator
        batch = list(itertools.islice(generator, 16))
        
        for i, ax in enumerate(axes.flat):
            # Check if we have enough data (in case dataset < 16 images)
            if i < len(batch):
                img, label_idx = batch[i]
                
                # Get readable name from the dictionary using the integer index
                class_name = id_to_name.get(label_idx, "Unknown")
                
                # Plot
                ax.imshow(img)
                ax.set_title(f"{class_name.split(',')[0]}\n(ID: {label_idx})", fontsize=9)
                ax.axis('off')
            else:
                ax.axis('off') # Hide empty subplots
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.92) # Adjust so main title fits
        # 4. Save or Show
        if outfile:
            plt.savefig(outfile)
            print(f"Plot saved successfully to: {outfile}")
            plt.close(fig) # Important to close figure when saving
        else:
            plt.show()

    # -------------------------------------------
    # 1. Visualize Data
    # -------------------------------------------
    title = title if title else "Dataset Samples"
    show_grid(x, title)



In [36]:
# --- Usage Example ---
# Assuming load_data is defined as in your previous step
train_g, test_g, id_map = load_data(
    n_test=50,
    transformation=lambda img: img.resize((64, 64))
)

show(train_g, id_map, outfile="train_samples.png", title="Training Set Samples")
show(test_g, id_map, outfile="test_samples.png", title="Testing Set Samples")

Plot saved successfully to: train_samples.png
Plot saved successfully to: test_samples.png
