In [None]:
import time
import regex as re

import matplotlib.pyplot as plt
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import models
from torchvision import datasets
from torchvision.transforms import v2

from tempfile import TemporaryDirectory

In [None]:
device = (
    "cuda" if torch.cuda.is_available() else "cpu"
)

In [None]:
training_data = datasets.OxfordIIITPet(
    root = "data",
    split = "trainval",
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.RandomResizedCrop(size = (224, 224), antialias = True),
        v2.RandomHorizontalFlip(p = 0.5),
        v2.ToDtype(torch.float32, scale = True),
        v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
)

test_data = datasets.OxfordIIITPet(
    root = "data",
    split = "test",
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.Resize(256),
        v2.CenterCrop(224),
        v2.ToDtype(torch.float32, scale = True),
        v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
)

The normalize in tranform dismisses the color changes or small intensity changes of the same content in different images. This will enable the model to learn the real structures instead of dealing with the scale differences.

In [None]:
path = Path("data/oxford-iiit-pet")

df = pd.read_csv(
    path / "annotations/test.txt",
    sep = " ",
    names = ['Breed', 'Class ID', 'Species', 'Breed ID']
)

In [None]:
class_map = {class_id: "".join(re.findall(r"(.+)_\d+$", breed))
             for breed, class_id in zip(df['Breed'], df['Class ID'] - 1)}

In [None]:
datasets = {
    'train': training_data,
    'test': test_data
}

datasets_size = {
    'train': len(training_data),
    'test': len(test_data)
}

dataloaders = {x: DataLoader(
                    datasets[x],
                    shuffle = True,
                    batch_size = 64,
                    num_workers = 2,
                    persistent_workers = True,
                    pin_memory = True
                ) for x in ['train', 'test']}