In [None]:
import time
import regex as re

import matplotlib.pyplot as plt
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import models
from torchvision import datasets
from torchvision.transforms import v2

from tempfile import TemporaryDirectory

In [None]:
device = (
    "cuda" if torch.cuda.is_available() else "cpu"
)

In [None]:
training_data = datasets.OxfordIIITPet(
    root = "data",
    split = "trainval",
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.RandomResizedCrop(size = (224, 224), antialias = True),
        v2.RandomHorizontalFlip(p = 0.5),
        v2.ToDtype(torch.float32, scale = True),
        v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
)

test_data = datasets.OxfordIIITPet(
    root = "data",
    split = "test",
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.Resize(256),
        v2.CenterCrop(224),
        v2.ToDtype(torch.float32, scale = True),
        v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
)

The normalize in tranform dismisses the color changes or small intensity changes of the same content in different images. This will enable the model to learn the real structures instead of dealing with the scale differences.

In [None]:
path = Path("data/oxford-iiit-pet")

df = pd.read_csv(
    path / "annotations/test.txt",
    sep = " ",
    names = ['Breed', 'Class ID', 'Species', 'Breed ID']
)

In [None]:
class_map = {class_id: "".join(re.findall(r"(.+)_\d+$", breed))
             for breed, class_id in zip(df['Breed'], df['Class ID'] - 1)}

In [None]:
datasets = {
    'train': training_data,
    'test': test_data
}

datasets_size = {
    'train': len(training_data),
    'test': len(test_data)
}

dataloaders = {x: DataLoader(
                    datasets[x],
                    shuffle = True,
                    batch_size = 64,
                    num_workers = 2,
                    persistent_workers = True,
                    pin_memory = True
                ) for x in ['train', 'test']}

In [2]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = Path(tempdir) / "best_models_params.pt"

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        history = {'train': [], 'test': []}

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'test']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / datasets_size[phase]
                epoch_acc = running_corrects.double() / datasets_size[phase]

                history[phase].append((epoch_loss, epoch_acc))

                print(f'{phase} Loss: {epoch_loss:.4f} Corrects: {running_corrects} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'test' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best test Acc: {best_acc:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))
    return model

Load the `resnet18` model and freeze it. Then replace the model head with a sequential layer and train it.

In [None]:
model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
  if isinstance(param, nn.Conv2d):
    param.requires_grad = False

num_features = model.fc.in_features

model.fc = nn.Sequential(
    nn.Linear(in_features = num_features, out_features = 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(in_features = 256, out_features = 37)
)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.1)

model = train_model(
    model,
    criterion,
    optimizer,
    exp_lr_scheduler,
    15
)