# Pet Detective — Model Training

Fine-tunes **MobileNetV2** (pretrained on ImageNet) on the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) to classify 37 cat and dog breeds.

## 1. Imports

In [None]:
from pathlib import Path
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchinfo import summary

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

## 2. Dataset Download

The Oxford-IIIT Pet Dataset contains 37 breeds (~200 images each).  
`download=True` fetches and extracts it automatically on first run (~800 MB).

In [26]:
DATA_DIR   = './data'
CACHE_DIR  = Path(DATA_DIR) / 'tensor_cache'
IMG_SIZE   = 224  # MobileNetV2 input size
CACHE_SIZE = 256  # cache at higher res so RandomResizedCrop has real spatial variety
BATCH_SIZE = 32
VAL_SPLIT  = 0.2

# Download raw data (runs once)
_raw      = OxfordIIITPet(root=DATA_DIR, split='trainval', target_types='category', download=True)
_raw_test = OxfordIIITPet(root=DATA_DIR, split='test',     target_types='category', download=True)
NUM_CLASSES = len(_raw.classes)
print(f'Classes: {NUM_CLASSES} | Train+val: {len(_raw)} | Test: {len(_raw_test)}')

# Cache stores values in [0, 1] — no normalization.
preprocess = transforms.Compose([
    transforms.Resize((CACHE_SIZE, CACHE_SIZE)),
    transforms.ToTensor(),
])

def build_cache(split):
    tensors_path = CACHE_DIR / f'{split}_tensors.pt'
    labels_path  = CACHE_DIR / f'{split}_labels.pt'
    if tensors_path.exists() and labels_path.exists():
        print(f'{split}: cache already exists, skipping.')
        return
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    ds = OxfordIIITPet(root=DATA_DIR, split=split, target_types='category',
                       download=False, transform=preprocess)
    loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4)
    imgs_list, labels_list = [], []
    for imgs, labels in loader:
        imgs_list.append(imgs)
        labels_list.append(labels)
    torch.save(torch.cat(imgs_list),   tensors_path)
    torch.save(torch.cat(labels_list), labels_path)
    print(f'{split}: cached {len(ds)} images → {tensors_path}')

build_cache('trainval')
build_cache('test')

Classes: 37 | Train+val: 3680 | Test: 3669
trainval: cached 3680 images → data/tensor_cache/trainval_tensors.pt
test: cached 3669 images → data/tensor_cache/test_tensors.pt


## 3. DataLoaders

In [27]:
class CachedPetDataset(Dataset):
    """Serves [0, 1] tensors from RAM; applies runtime transforms (augments + normalize)."""
    def __init__(self, tensors, labels, augments=None):
        self.tensors  = tensors
        self.labels   = labels
        self.augments = augments

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.tensors[idx]
        if self.augments:
            x = self.augments(x)
        return x, self.labels[idx]

_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])

# ColorJitter operates on [0, 1] values (as intended), then Normalize follows.
# RandomResizedCrop pulls from the 256×256 cache → 224px output; scale=(0.75, 1.0) gives
# crops ranging from ~192px to 256px, providing meaningful spatial diversity.
train_augments = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.75, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
    transforms.RandomRotation(15),
    _normalize,
])

# Val/test: center-crop to model input size, then normalize — no stochastic augments
val_augments = transforms.Compose([
    transforms.CenterCrop(IMG_SIZE),
    _normalize,
])

# Load pre-processed tensors into RAM once (replaces per-epoch JPEG decode + resize + normalize)
trainval_tensors = torch.load(CACHE_DIR / 'trainval_tensors.pt', weights_only=True)
trainval_labels  = torch.load(CACHE_DIR / 'trainval_labels.pt',  weights_only=True)
test_tensors     = torch.load(CACHE_DIR / 'test_tensors.pt',     weights_only=True)
test_labels      = torch.load(CACHE_DIR / 'test_labels.pt',      weights_only=True)

# Split train/val by index so each subset gets its own Dataset (and transform) cleanly
n        = len(trainval_labels)
perm     = torch.randperm(n)
val_idx  = perm[:int(n * VAL_SPLIT)]
train_idx = perm[int(n * VAL_SPLIT):]

train_dataset = CachedPetDataset(trainval_tensors[train_idx], trainval_labels[train_idx], augments=train_augments)
val_dataset   = CachedPetDataset(trainval_tensors[val_idx],   trainval_labels[val_idx],   augments=val_augments)
test_dataset  = CachedPetDataset(test_tensors, test_labels,                                augments=val_augments)

# num_workers=0: data is already in RAM, spawning workers only adds IPC overhead
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f'Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}')
print(f'Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}')

Train: 2944 | Val: 736 | Test: 3669
Train batches: 92 | Val batches: 23 | Test batches: 115


## 4. Model — MobileNetV2

We load the ImageNet-pretrained backbone, **freeze all layers**, then replace the final classifier with a new head sized to our 37 classes.  
Only the new head will be trained in the first phase.

In [34]:
model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)

# Freeze backbone
for param in model.parameters():
    param.requires_grad = False

# Replace classifier head (in_features=1280 for MobileNetV2)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(p=0.2),
    # nn.Dropout(p=0.4),  # swap in above if train/val gap suggests more regularization
    nn.Linear(in_features, NUM_CLASSES),
)

model = model.to(device)
summary(model, input_size=(1, 3, IMG_SIZE, IMG_SIZE))
model = model.to(device)  # summary() moves model to CPU internally; move it back

## 5. Phase 1 — Train Head

In [35]:
EPOCHS = 8
LR = 1e-3

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.classifier.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, correct = 0.0, 0
    ctx = torch.enable_grad() if train else torch.no_grad()
    with ctx:
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = criterion(preds, labels)
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            total_loss += loss.item() * len(images)
            correct += (preds.argmax(1) == labels).sum().item()
    n = len(loader.dataset)
    return total_loss / n, correct / n

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    val_loss,   val_acc   = run_epoch(val_loader,   train=False)
    scheduler.step()
    print(f'Epoch {epoch:02d} | '
          f'Train loss: {train_loss:.4f}  acc: {train_acc:.3f} | '
          f'Val loss: {val_loss:.4f}  acc: {val_acc:.3f}')

test_loss, test_acc = run_epoch(test_loader, train=False)
print(f'[Phase 1] Test loss: {test_loss:.4f}  |  Test accuracy: {test_acc:.3f}')

Epoch 01 | Train loss: 2.2578  acc: 0.530 | Val loss: 1.3669  acc: 0.823
Epoch 02 | Train loss: 1.4032  acc: 0.804 | Val loss: 1.2271  acc: 0.852
Epoch 03 | Train loss: 1.2996  acc: 0.833 | Val loss: 1.1798  acc: 0.860
Epoch 04 | Train loss: 1.2381  acc: 0.847 | Val loss: 1.1716  acc: 0.865
Epoch 05 | Train loss: 1.2015  acc: 0.857 | Val loss: 1.1721  acc: 0.856
Epoch 06 | Train loss: 1.1483  acc: 0.887 | Val loss: 1.1337  acc: 0.882
Epoch 07 | Train loss: 1.1375  acc: 0.896 | Val loss: 1.1277  acc: 0.885
Epoch 08 | Train loss: 1.1501  acc: 0.889 | Val loss: 1.1336  acc: 0.885
[Phase 1] Test loss: 1.1998  |  Test accuracy: 0.863


## 6. Phase 2 — Full Backbone Fine-tuning

Head has converged; now unfreeze the entire backbone with very low LRs. Early layers (edges, textures) are already well-suited to natural images and barely need touching — the cosine schedule anneals both param groups smoothly to near-zero.

- **Backbone**: 1e-6 — minimal nudge, preserves low-level pretrained features
- **Head**: 1e-5 — still updating faster than backbone

In [36]:
PHASE2_EPOCHS = 5

for param in model.features.parameters():
    param.requires_grad = True

optimizer = optim.Adam([
    {'params': model.features.parameters(),   'lr': 1e-6, 'weight_decay': 1e-4},
    {'params': model.classifier.parameters(), 'lr': 1e-5, 'weight_decay': 1e-4},
])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=PHASE2_EPOCHS)

for epoch in range(1, PHASE2_EPOCHS + 1):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    val_loss,   val_acc   = run_epoch(val_loader,   train=False)
    scheduler.step()
    print(f'P2 Epoch {epoch:02d} | '
          f'Train loss: {train_loss:.4f}  acc: {train_acc:.3f} | '
          f'Val loss: {val_loss:.4f}  acc: {val_acc:.3f}')

test_loss, test_acc = run_epoch(test_loader, train=False)
print(f'[Phase 2] Test loss: {test_loss:.4f}  |  Test accuracy: {test_acc:.3f}')

P2 Epoch 01 | Train loss: 1.1395  acc: 0.893 | Val loss: 1.1260  acc: 0.886
P2 Epoch 02 | Train loss: 1.1132  acc: 0.905 | Val loss: 1.1192  acc: 0.883
P2 Epoch 03 | Train loss: 1.1218  acc: 0.897 | Val loss: 1.1160  acc: 0.893
P2 Epoch 04 | Train loss: 1.1139  acc: 0.905 | Val loss: 1.1147  acc: 0.889
P2 Epoch 05 | Train loss: 1.1068  acc: 0.904 | Val loss: 1.1210  acc: 0.885
[Phase 2] Test loss: 1.1828  |  Test accuracy: 0.872


## 7. Save Model

In [37]:
torch.save(model.state_dict(), 'pet_detective_mobilenetv2.pth')
print('Model saved.')

Model saved.
