In [1]:
import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

def compute_fashion_mnist_mean_std(root="./data"):
    """
    Load raw FashionMNIST training data, convert to float in [0,1],
    and compute global mean and std over all pixels, as recommended in CS231n:
    center data to mean 0 and normalize its scale.
    """
    # Load once without transforms to access raw uint8 data
    raw_train = datasets.FashionMNIST(
        root=root,
        train=True,
        download=True,
        transform=None
    )

    # raw_train.data: shape [60000, 28, 28], dtype uint8 in [0, 255]
    train_data = raw_train.data.float() / 255.0 # match ToTensor scaling

    mean = train_data.mean().item()
    std = train_data.std().item()
    return mean, std


def get_fashion_mnist_datasets(root="./data", val_ratio=0.2, seed=551):
    """
    Acquire FashionMNIST, compute normalization statistics on training set,
    and return normalized train, validation, and test datasets.

    - Uses the default 28x28 version.
    - Uses the 60k official training split for train + validation.
    - Uses the 10k official test split as test.
    """
    mean, std = compute_fashion_mnist_mean_std(root)

    train_transform = transforms.Compose([
        transforms.ToTensor(), # [0, 255] -> [0, 1]
        transforms.Normalize((mean,), (std,)) # zero mean, unit-ish variance
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std,))
    ])

    full_train_dataset = datasets.FashionMNIST(
        root=root,
        train=True,
        download=True,
        transform=train_transform
    )

    test_dataset = datasets.FashionMNIST(
        root=root,
        train=False,
        download=True,
        transform=test_transform
    )

    # Split 60k training samples into train and validation
    total_train = len(full_train_dataset) # should be 60000
    val_size = int(val_ratio * total_train)
    train_size = total_train - val_size

    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = random_split(
        full_train_dataset,
        [train_size, val_size],
        generator=generator
    )

    return train_dataset, val_dataset, test_dataset, mean, std


def get_fashion_mnist_loaders(
    root="./data",
    val_ratio=0.2,
    batch_size=128,
    num_workers=2,
    seed=551
):
    """
    Convenience function that wraps dataset acquisition and returns
    DataLoaders for train, validation, and test sets.
    """
    train_dataset, val_dataset, test_dataset, mean, std = get_fashion_mnist_datasets(
        root=root,
        val_ratio=val_ratio,
        seed=seed
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader, mean, std


if __name__ == "__main__":
    train_loader, val_loader, test_loader, mean, std = get_fashion_mnist_loaders()

    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    print(f"Computed mean: {mean:.4f}, std: {std:.4f}")

    # Inspect one batch shape
    images, labels = next(iter(train_loader))
    # images shape: [batch_size, 1, 28, 28]
    print(f"Batch image tensor shape: {images.shape}")
    print(f"Batch labels tensor shape: {labels.shape}")


100%|██████████| 26.4M/26.4M [00:01<00:00, 17.2MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 278kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.05MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 9.43MB/s]


Train batches: 375
Validation batches: 94
Test batches: 79
Computed mean: 0.2860, std: 0.3530




Batch image tensor shape: torch.Size([128, 1, 28, 28])
Batch labels tensor shape: torch.Size([128])
