In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random
import math
from pathlib import Path
from PIL import Image

%matplotlib inline

## Data Loader

In [None]:
DATASET_ROOTS = {"val_blurred": Path("./val_blurred")}

def load_data(dataset, transformation=None, n_train=None, n_test=None, seed=None):
    """
    Loads the data from the dataset, applies given transformation and splits the data to the given split.

    Args: 
        dataset: name of the dataset
        transformation: transformations to be applied to the data
        n_train: number of train samples
        n_test: number of test samples
        seed: for reproducible shuffling

    Returns:
        sample_train: lazy (sample, label) generator for training
        sample_test: lazy (sample, label) generator for testing
    """
    dataset_root = DATASET_ROOTS.get(dataset)

    if dataset_root is None:
        raise ValueError(f"Unknown dataset: {dataset}")
    dataset_root = dataset_root.expanduser().resolve() # normalize so the loader works across diff machines

    class_dirs = sorted([d for d in dataset_root.iterdir() if d.is_dir()])
    class_dirs_indexed = {cls.name: idx for idx, cls in enumerate(class_dirs)}
    
    samples = []
    for cls_dir in class_dirs:
        label = class_dirs_indexed[cls_dir.name]
        for img_path in cls_dir.glob("*.jpg"):
            samples.append((img_path, label))

    random.shuffle(samples)

    total = len(samples)
    if n_train is None and n_test is None:
        n_train = int(0.8 * total)
        n_test = total - n_train
    elif n_train is None:
        n_train = total - n_test
    elif n_test is None:
        n_test = total - n_train
    elif n_train + n_test > total:
        raise ValueError('Sample sizes combined exceed the total data size')

    train_samples = samples[:n_train]
    test_samples = samples[n_train:n_train + n_test]

    def generator(items):
        for img_path, label in items:
            with Image.open(img_path) as img:
                img = img.convert("RGB")
                data = transformation(img) if transformation else np.array(img)
            yield data, label

    return generator(train_samples), generator(test_samples)


## Sample the data

In [None]:
def to_array_64(img):
    return np.array(img.resize((64, 64)))

tr = 15
ts = 15
train_gen, test_gen = load_data("val_blurred", transformation=to_array_64, n_train = tr, n_test = ts)
train_batch = [next(train_gen) for _ in range(tr)]
test_batch = [next(test_gen) for _ in range(ts)]

# for idx, (img, label) in enumerate(train_batch):
#     print(f"train[{idx}] -> shape {img.shape}, dtype {img.dtype}, label {label}")
# for idx, (img, label) in enumerate(test_batch):
#     print(f"test[{idx}]  -> shape {img.shape}, dtype {img.dtype}, label {label}")

## Visualize the data


In [None]:
def show(samples, rows=None, cols=None, outfile=None):
    """
    Shows given samples in the wanted format (rows x columns) and saves the output to the specifed file.

    Args: 
        samples: data to be represented (plain images or labled images)
        rows: number of rows
        cols: number of columns
        outfile: file in which the figure will be saved (if None shows the figure directly)
    """

    n = len(samples)

    if cols is None:
        cols = 5
    if rows is None:
        rows = math.ceil(n / cols)

    fig, axes = plt.subplots(rows, cols, figsize=(cols*2, rows*2))
    axes = np.atleast_1d(axes).flatten()

    for ax, sample in zip(axes, samples):
        if isinstance(sample, tuple) and len(sample) == 2:  # in case we only imput images without labels
            img, label = sample
        else:
            img, label = sample, None

        ax.imshow(img)
        ax.axis("off")

        if label is not None:
            ax.set_title(str(label))

    for ax in axes[len(samples):]:
        ax.axis("off")

    fig.tight_layout()
    if outfile is not None:
        fig.savefig(outfile, dpi = 300)
    else:
        plt.show()
    plt.close(fig)