# Image Model Training

This notebook trains two image classification models on the CUB-200-2011 dataset:
1. **ResNet-18** - CNN-based architecture with residual connections
2. **ViT-B/16** - Vision Transformer with self-attention

Both models are trained on 90 species from the intersection of Xeno-Canto and CUB datasets.

In [None]:
import sys
from pathlib import Path
import pandas as pd
import json
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

sys.path.insert(0, str(Path('..').resolve()))

from src.models.image_resnet import ImageResNet
from src.models.image_vit import ImageViT
from src.datasets.image import ImageDataset, get_image_transforms
from src.training.trainer import Trainer

ARTIFACTS = Path('../artifacts')
MODELS_DIR = ARTIFACTS / 'models'
MODELS_DIR.mkdir(exist_ok=True)

device_obj = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = str(device_obj)
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Load Data and Create Splits

In [None]:
# Load filtered CUB data
cub_df = pd.read_parquet(ARTIFACTS / 'cub_filtered.parquet')

# Filter to species with >=2 samples (needed for stratification)
cub_counts = cub_df['species_normalized'].value_counts()
species_to_keep = cub_counts[cub_counts >= 2].index
cub_df = cub_df[cub_df['species_normalized'].isin(species_to_keep)].copy()

# Load splits
with open(ARTIFACTS / 'splits' / 'cub_image_splits.json', 'r') as f:
    splits = json.load(f)

# Create species to label mapping
species_list = sorted(cub_df['species_normalized'].unique())
species_to_idx = {sp: i for i, sp in enumerate(species_list)}
num_classes = len(species_list)

print(f"Dataset: {len(cub_df)} images, {num_classes} species")
print(f"Train: {len(splits['train'])} samples")
print(f"Val: {len(splits['val'])} samples")
print(f"Test: {len(splits['test'])} samples")

# Show species distribution
print(f"\nSample species:")
for sp in species_list[:5]:
    count = (cub_df['species_normalized'] == sp).sum()
    print(f"  {sp}: {count} images")

## Create Datasets and DataLoaders

In [None]:
# Create datasets with transforms
train_dataset = ImageDataset(
    df=cub_df,
    indices=splits['train'],
    species_to_idx=species_to_idx,
    transform=get_image_transforms(train=True, image_size=224)
)

val_dataset = ImageDataset(
    df=cub_df,
    indices=splits['val'],
    species_to_idx=species_to_idx,
    transform=get_image_transforms(train=False, image_size=224)
)

test_dataset = ImageDataset(
    df=cub_df,
    indices=splits['test'],
    species_to_idx=species_to_idx,
    transform=get_image_transforms(train=False, image_size=224)
)

print(f"Dataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

# Create dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, 
    num_workers=4, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=32, shuffle=False,
    num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=32, shuffle=False,
    num_workers=4, pin_memory=True
)

print(f"\nDataloader batches:")
print(f"  Train: {len(train_loader)} batches")
print(f"  Val: {len(val_loader)} batches")
print(f"  Test: {len(test_loader)} batches")

## Train ResNet-18

ResNet-18 is a convolutional neural network with 18 layers and skip connections. We use transfer learning with ImageNet pre-trained weights.

In [None]:
# Initialize ResNet-18
model = ImageResNet(num_classes=num_classes, pretrained=True).to(device_obj)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

# Setup optimizer and scheduler
optimizer = torch.optim.SGD(
    model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir=MODELS_DIR / 'image_resnet18',
    experiment_name='ImageResNet18',
    use_amp=True,
    gradient_clip=1.0,
    early_stopping_patience=10
)

print("\nðŸš€ Starting ResNet-18 training...")
print("This may take 30-60 minutes depending on your GPU.\n")

In [None]:
# Train the model
history = trainer.train(num_epochs=50)

print(f"\nâœ“ ResNet-18 training complete")
print(f"âœ“ Best val accuracy: {max(history['val_acc']):.4f}")
print(f"âœ“ Final train loss: {history['train_loss'][-1]:.4f}")
print(f"âœ“ Final val loss: {history['val_loss'][-1]:.4f}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['val_loss'], label='Val', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('ResNet-18 - Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', linewidth=2)
axes[1].plot(history['val_acc'], label='Val', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('ResNet-18 - Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(MODELS_DIR / 'image_resnet18' / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Save history
with open(MODELS_DIR / 'image_resnet18' / 'history.json', 'w') as f:
    json.dump({k: [float(v) for v in vals] for k, vals in history.items()}, f, indent=2)
    
print(f"âœ“ Saved training curves and history to {MODELS_DIR / 'image_resnet18'}")

## Train Vision Transformer (ViT-B/16)

Vision Transformer applies self-attention mechanisms to image patches. We use the base model with 16x16 patches.

In [None]:
# Initialize ViT-B/16
model = ImageViT(num_classes=num_classes, pretrained='google/vit-base-patch16-224').to(device_obj)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

# Setup optimizer and scheduler (AdamW + Cosine annealing for ViT)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=50, eta_min=1e-6
)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir=MODELS_DIR / 'image_vit',
    experiment_name='ImageViT',
    use_amp=True,
    gradient_clip=1.0,
    early_stopping_patience=10
)

print("\nðŸš€ Starting ViT-B/16 training...")
print("This may take 45-90 minutes depending on your GPU.\n")

In [None]:
# Train the model
vit_history = trainer.train(num_epochs=50)

print(f"\nâœ“ ViT-B/16 training complete")
print(f"âœ“ Best val accuracy: {max(vit_history['val_acc']):.4f}")
print(f"âœ“ Final train loss: {vit_history['train_loss'][-1]:.4f}")
print(f"âœ“ Final val loss: {vit_history['val_loss'][-1]:.4f}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(vit_history['train_loss'], label='Train', linewidth=2)
axes[0].plot(vit_history['val_loss'], label='Val', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('ViT-B/16 - Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(vit_history['train_acc'], label='Train', linewidth=2)
axes[1].plot(vit_history['val_acc'], label='Val', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('ViT-B/16 - Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(MODELS_DIR / 'image_vit' / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Save history
with open(MODELS_DIR / 'image_vit' / 'history.json', 'w') as f:
    json.dump({k: [float(v) for v in vals] for k, vals in vit_history.items()}, f, indent=2)
    
print(f"âœ“ Saved training curves and history to {MODELS_DIR / 'image_vit'}")

## Summary

Both image models have been trained and their checkpoints saved. The models can now be evaluated on the test set.