## Utils

In [None]:
# Custom Dataset Normalization
import torch
from torch.utils.data import DataLoader

def compute_mean_std(dataset):
    # Calculate mean and std across entire training set
    means = torch.zeros(3)
    stds = torch.zeros(3)
    for image, _ in dataset:
        for c in range(3):
            means[c] += image[c].mean()
            stds[c] += image[c].std()
    
    means /= len(dataset)
    stds /= len(dataset)
    return means, stds



def compute_mean_std2(dataset):
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
    
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_samples = 0
    
    for data, _ in dataloader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        total_samples += batch_samples
    
    mean /= total_samples
    std /= total_samples
    return mean, std


## Build the Dataset and DataLoader

In [None]:
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torchvision.io import decode_image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.optim import SGD
from torchvision.transforms import v2

In [None]:
class PokemonDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, dtype={'label': np.int64})
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = decode_image(img_path, mode='RGB')
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
def setup_training_dataset(base_dataset, output_size):
    size = len(base_dataset)
    iters = round(output_size/size)
    concat_dataset = [base_dataset]
    augment_transform = v2.Compose([
        # 1. GEOMETRIC AUGMENTATIONS
        v2.Resize(size=(280, 280)),  # Slightly larger for random crop
        v2.RandomCrop(size=(256, 256)),  # Random crop instead of direct resize
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomRotation(degrees=15),  # Pokemon can be rotated slightly
        v2.RandomAffine(
            degrees=0,  # Already handled by RandomRotation
            translate=(0.1, 0.1),  # Small translation
            scale=(0.9, 1.1),  # Slight scaling
            shear=5  # Small shear
        ),

        # 2. COLOR AUGMENTATIONS (Important for Pokemon variety)
        v2.ColorJitter(
            brightness=0.3,  # Pokemon have varied brightness
            contrast=0.3,    # Different contrast levels
            saturation=0.4,  # Pokemon have vibrant colors
            hue=0.1         # Slight hue variation
        ),
        v2.RandomGrayscale(p=0.1),  # Occasionally convert to grayscale

        # 3. CONVERT AND NORMALIZE
        v2.ToDtype(torch.float32),
        v2.Normalize(
            mean=[85.6542, 80.4268, 72.8841],
            std=[93.3963, 88.0354, 82.0991]
        ),

        # 4. CUTOUT/RANDOM ERASING (Applied after normalization)
        v2.RandomErasing(
            p=0.3,  # 30% chance
            scale=(0.02, 0.15),  # Small to medium patches
            ratio=(0.3, 3.3),
            value=0  # Erase with normalized mean (0)
        )
    ])

    for _ in range(iters):
        concat_dataset.append(PokemonDataset('data/train_labels.csv', 'data/train', augment_transform))

    return ConcatDataset(concat_dataset)

In [None]:
train_transform = v2.Compose([
    v2.Resize(size=(256,256)),
    v2.ToDtype(torch.float32),
    v2.Normalize(
        mean=[85.6542, 80.4268, 72.8841],
        std=[93.3963, 88.0354, 82.0991]
    )
])
test_transform = v2.Compose([
    v2.Resize(size=(256,256)),
    v2.ToDtype(torch.float32),
    v2.Normalize(
        mean=[231.0843, 225.8617, 219.1473],
        std=[43.8241, 48.9563, 60.1995]
    )
])

train_data = setup_training_dataset(PokemonDataset('data/train_labels.csv', 'data/train', train_transform), 10000)
test_data = PokemonDataset('data/test_labels.csv', 'data/test', test_transform)

In [None]:
batch_size = 8
train_dataloader = DataLoader(train_data, batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size, shuffle=True)

### Build model

In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [None]:
# Build model
class NeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.flatten = nn.Flatten()
        self.sequential_stack = nn.Sequential(
            nn.Linear(input_size, 6144),
            nn.ReLU(),
            nn.Linear(6144, 3072),
            nn.ReLU(),
            nn.Linear(3072, 1536),
            nn.ReLU(),
            nn.Linear(1536, 768),
            nn.ReLU(),
            nn.Linear(768, 384),
            nn.ReLU(),
            nn.Linear(384, 151)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.sequential_stack(x)
        return logits

In [None]:
model = NeuralNetwork(196608).to(device)
model

## Optimize the Model Paremeters

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Resets gradients
        optimizer.zero_grad()

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()

        if batch % 15 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>7f} \n")

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=1e-3)

In [None]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")