# Multitask - the OxfordIIITPet dataset for segmentation and classification

In [1]:
from PIL import Image

In [2]:
import torch
import torch.nn as nn
from torchvision.transforms import v2
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader, Dataset
import timm
import torchvision
from torchvision.transforms.v2.functional import hflip
from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
img_size = 224
num_classes = 37

In [4]:
class ClassificationDataset(Dataset):
    def __init__(self, split, image_transforms):
        self.data = OxfordIIITPet(
            root="../data",
            download=True,
            split=split,
            target_types=("category",),
        )
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        image, class_label = self.data[i]
        image = self.image_transforms(image)
        return image, class_label


image_transforms_train = v2.Compose([
    v2.ToImage(),
    v2.Resize([img_size, img_size]),
    v2.RandomCrop([img_size, img_size], padding=12),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.AutoAugment(),
])

image_transforms_test = v2.Compose([
    v2.ToImage(),
    v2.Resize([img_size, img_size]),
    v2.ToDtype(torch.float32, scale=True),
])

cutmix_or_mixup = v2.RandomChoice([
    v2.CutMix(num_classes=num_classes),
    v2.MixUp(num_classes=num_classes),
])

train_dataset = ClassificationDataset("trainval", image_transforms_train)
test_dataset = ClassificationDataset("test", image_transforms_test)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=32, drop_last=False)

In [14]:
class ClassificationModel(nn.Module):
    def __init__(self, backbone_name='resnet18', num_classes=num_classes):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True)
        if not hasattr(self.backbone, "fc") and not hasattr(self.backbone, "head"):
            raise RuntimeError("Backbone not implemented: " + backbone_name)
        if not hasattr(self.backbone, "fc"):
            self.backbone.head = nn.Linear(self.backbone.head.weight.size(1), num_classes)
        else:
            self.backbone.fc = nn.Linear(self.backbone.fc.weight.size(1), num_classes)

    def forward(self, x):
        return self.backbone(x)

    def freeze_backbone(self):
        self.backbone.requires_grad_(False)
        if not hasattr(self.backbone, "fc"):
            self.backbone.head.requires_grad_(True)
        else:
            self.backbone.fc.requires_grad_(True)

In [6]:
def f1_score(x, y):
    x_sum = x.sum().item()
    y_sum = y.sum().item()

    if x_sum == y_sum == 0:
        return 1.0
    elif x_sum == 0 or y_sum == 0:
        return 0.0

    return 2.0 * (x & y).sum().item() / (x_sum + y_sum)


def f1_macro(predicted, targets, num_classes):
    f1s = []
    for cls in range(num_classes):
        f1s.append(f1_score(predicted == cls, targets == cls))
    return sum(f1s) / num_classes


In [7]:
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
print("Using device", device)
# model = ClassificationModel().to(device)
model = ClassificationModel("hf_hub:timm/resnest14d.gluon_in1k").to(device)
classification_criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)


Using device cuda


In [8]:
def train():
    loss_sum = 0.0
    num_batches = 0

    model.train()
    pbar = tqdm(train_loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        images, labels = cutmix_or_mixup(images, labels)

        with torch.autocast(device.type, enabled=device.type == 'cuda'):
            class_logits = model(images)
            loss = classification_criterion(class_logits, labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_sum += loss.item()
        num_batches += 1

        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
        })

    return loss_sum / num_batches


@torch.inference_mode()
def val():
    cls_f1_sum = 0.0
    cls_acc_sum = 0.0
    loss_sum = 0.0
    num_batches = 0

    model.eval()
    pbar = tqdm(test_loader, desc="Evaluating")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        with torch.autocast(device.type, enabled=device.type == 'cuda'):
            class_logits = model(images)
            loss = classification_criterion(class_logits, labels)

            pred_classes = class_logits.argmax(dim=1)

            cls_f1 = f1_macro(pred_classes, labels, num_classes=num_classes)
            cls_acc = (pred_classes == labels).float().mean().item()

        cls_f1_sum += cls_f1
        cls_acc_sum += cls_acc
        loss_sum += loss
        num_batches += 1

        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'ClsF1': f'{cls_f1:.4f}',
            'ClsAcc': f'{cls_acc:.4f}',
        })

    return (
        cls_f1_sum / num_batches,
        cls_acc_sum / num_batches,
        loss_sum / num_batches,
    )


@torch.inference_mode()
def val_tta(tta_type):
    cls_f1_sum = 0.0
    cls_acc_sum = 0.0
    num_batches = 0

    model.eval()
    pbar = tqdm(test_loader, desc=f"Evaluating with TTA level {level}")

    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        with torch.autocast(device.type, enabled=device.type == 'cuda'):
            combined = [images]
            if tta_type == "mirroring":
                combined.append(hflip(images))
            elif tta_type == "translate":  # left
                padding_size = 2
                padded = v2.functional.pad(images, [padding_size])
                for i in [-2, 0, 2]:
                    for j in [-2, 0, 2]:
                        if i == 0 and j == 0:
                            continue
                        x = padding_size + i
                        y = padding_size + j
                        combined.append(padded[:, :, x:x + img_size, y:y + img_size])
            elif tta_type == "mirroring_and_translate":
                combined.append(hflip(images))
                padding_size = 2
                padded = v2.functional.pad(images, [padding_size])
                for i in [-2, 0, 2]:
                    for j in [-2, 0, 2]:
                        if i == 0 and j == 0:
                            continue
                        x = padding_size + i
                        y = padding_size + j
                        aux = padded[:, :, x:x + img_size, y:y + img_size]
                        combined.append(aux)
                        combined.append(hflip(aux))
            elif tta_type == "translate_aggressive":
                padding_size = 4
                padded = v2.functional.pad(images, [padding_size])
                for i in [-4, -2, 0, 2, 4]:
                    for j in [-4, -2, 0, 2, 4]:
                        if i == 0 and j == 0:
                            continue
                        x = padding_size + i
                        y = padding_size + j
                        combined.append(padded[:, :, x:x + img_size, y:y + img_size])

            outputs = sum(model(x) for x in combined)
        outputs = outputs.argmax(dim=1)

        cls_f1 = f1_macro(outputs, labels, num_classes=num_classes)
        cls_acc = (outputs == labels).float().mean().item()

        cls_f1_sum += cls_f1
        cls_acc_sum += cls_acc
        num_batches += 1

        pbar.set_postfix({
            'ClsF1': f'{cls_f1:.4f}',
            'ClsAcc': f'{cls_acc:.4f}',
        })

    return (
        cls_f1_sum / num_batches,
        cls_acc_sum / num_batches,
    )


In [None]:
models = [
    "hf_hub:timm/resnet18.a1_in1k",
    "hf_hub:timm/resnet50.a1h_in1k",
    "hf_hub:timm/resnest14d.gluon_in1k",
    "hf_hub:timm/resnest26d.gluon_in1k",
    "hf_hub:timm/vit_base_patch16_clip_224.openai",
    "hf_hub:timm/vit_small_patch16_224.augreg_in21k",
    "hf_hub:timm/maxvit_tiny_tf_224.in1k",
    "hf_hub:timm/maxvit_small_tf_224.in1k",
]
# Loading all models first
for model_name in models:
    model = ClassificationModel(model_name)

In [17]:
results = []

for model_name in models:
    model = ClassificationModel(model_name).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    model.freeze_backbone()
    num_epochs = 10
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        tr_loss = train()
        vl_cls_f1, vl_cls_acc, vl_loss = val()
        print(f"[Train]                                | Loss: {tr_loss:.4f}")
        print(f"[Val]   ClsF1: {vl_cls_f1:.4f} | ClsAcc: {vl_cls_acc:.4f} | Loss: {vl_loss:.4f}")

    for level in ["no_tta", "mirroring", "translate"]:
        val_f1_tta, val_acc_tta = val_tta(level)
        results.append((model_name, level, val_f1_tta, val_acc_tta))


pytorch_model.bin:   0%|          | 0.00/599M [00:00<?, ?B/s]


Epoch 1/10


Training: 100%|██████████| 115/115 [01:46<00:00,  1.08it/s, Loss=1.9354]
Evaluating:  81%|████████  | 93/115 [01:26<00:20,  1.08it/s, Loss=0.9837, ClsF1=0.9437, ClsAcc=0.8438]


KeyboardInterrupt: 

In [None]:
for model_name, level, val_f1_tta, val_acc_tta in results:
    print(f"{model_name.split('/')[-1]: <30} | {level: <10} | ClsF1: {val_f1_tta:.4f} | ClsAcc: {val_acc_tta:.4f}")