# Image Classification - Deep Learning

This notebook mirrors the CLI workflow for training RegNetY-016 on the Oxford-IIIT Pet dataset.

## 1. Environment Setup
Install dependencies (choose matching CUDA wheel if needed).

In [None]:
!git clone https://github.com/HenryNVP/image-classification.git
%cd image-classification

In [None]:
!pip install -r requirements.txt

## 2. Download and Prepare Data
This downloads the official dataset and prepares 256x256 splits.

In [None]:
!python scripts/get_data.py
!python scripts/split_dataset.py

## 3. Train the Model
Launch training using the merged configs.

In [None]:
!python train.py --plot

## 4. Evaluate
Evaluate the best checkpoint on the validation split.

In [None]:
!python validate.py --checkpoint checkpoints/regnety_016/best.pth --split val

## 5. Inspect Predictions
Example visualization of model outputs.

In [None]:
import torch
from pathlib import Path
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

from src.config import load_configs
from src.data import ImageClassificationDataset
from src.models import create_model
from src.utils import get_device

cfg_paths = [
    Path('configs/train.yaml'),
    Path('configs/data.yaml'),
    Path('configs/model/regnety_016.yaml'),
    Path('configs/aug.yaml'),
]
config = load_configs(cfg_paths)

model_cfg = config['model']
model = create_model(type('Cfg', (), model_cfg))
device = torch.device(get_device())
model.load_state_dict(torch.load('checkpoints/regnety_016/best.pth', map_location=device)['model_state'])
model.eval()
model.to(device)

dataset = ImageClassificationDataset('data/val', 'data/val_labels.csv', train=False)
idx = torch.randint(0, len(dataset), ()).item()
image, label = dataset[idx]
with torch.no_grad():
    logits = model(image.unsqueeze(0).to(device))
prob = torch.softmax(logits, dim=1).squeeze(0)

# Get class names
class_names = dataset.classes

print('Sample index:', idx)
print('Ground truth label:', label, f'({class_names[label]})')
print('Top-5 probabilities:', prob.topk(5))
print('Top prediction:', prob.topk(1).indices.item(), f'({class_names[prob.topk(1).indices.item()]})')

# Display the image
plt.imshow(image.permute(1, 2, 0)) # Permute to HWC format for matplotlib
plt.title(f"Ground Truth: {class_names[label]}, Prediction: {class_names[prob.topk(1).indices.item()]}")
plt.axis('off')
plt.show()
