# Image Classification Demo

This notebook mirrors the CLI workflow for training RegNetY-016 on the Oxford-IIIT Pet dataset.

## 1. Environment Setup
Install dependencies (choose matching CUDA wheel if needed).

In [None]:
!pip install -r requirements.txt

## 2. Download and Prepare Data
This downloads the official dataset and prepares 224×224 splits.

In [None]:
!python scripts/get_data.py
!python scripts/split_dataset.py

Optional: create the small demo subset (first two classes).

In [None]:
!python scripts/create_demo_dataset.py

## 3. Train the Model
Launch training using the merged configs.

In [None]:
!python train.py --epochs 5 --plot

## 4. Evaluate
Evaluate the best checkpoint on the validation split.

In [None]:
!python validate.py --checkpoint checkpoints/regnety_016/best.pth --split val

## 5. Inspect Predictions
Example visualization of model outputs.

In [None]:
from IPython.display import Image, display
display(Image('checkpoints/regnety_016/training_curves.png'))

In [None]:
import torch
from pathlib import Path
from torchvision import transforms
from PIL import Image

from src.config import load_configs
from src.data import ImageClassificationDataset
from src.models import create_model
from src.utils import get_device

cfg_paths = [
    Path('configs/train.yaml'),
    Path('configs/data.yaml'),
    Path('configs/model/regnety_016.yaml'),
    Path('configs/aug.yaml'),
]
config = load_configs(cfg_paths)

model_cfg = config['model']
model = create_model(type('Cfg', (), model_cfg))
device = torch.device(get_device())
model.load_state_dict(torch.load('checkpoints/regnety_016/best.pth', map_location=device)['model_state'])
model.eval()
model.to(device)

dataset = ImageClassificationDataset('data/val', 'data/val_labels.csv', train=False)
idx = torch.randint(0, len(dataset), ()).item()
image, label = dataset[idx]
with torch.no_grad():
    logits = model(image.unsqueeze(0).to(device))
prob = torch.softmax(logits, dim=1).squeeze(0)
print('Sample index:', idx)
print('Ground truth label:', label)
print('Top-5 probabilities:', prob.topk(5))
