### 1. Data preparation

##### 1.1. Dataset from Kaggle:

In [1]:
import kagglehub

# Download latest version
data_dir = kagglehub.dataset_download("andrewmvd/dog-and-cat-detection")

print("Path to dataset files:", data_dir)

Path to dataset files: C:\Users\admin\.cache\kagglehub\datasets\andrewmvd\dog-and-cat-detection\versions\1


##### 1.2. Create dataset:

In [4]:
from torch.utils.data import Dataset
import os
from xml.etree import ElementTree as ET
from PIL import Image
import torch

class ImageDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.image_files = self._filter_images_with_single_object()

    def _filter_images_with_single_object(self):
        valid_image_files = []
        for file in os.listdir(self.image_dir):
            if os.path.isfile(os.path.join(self.image_dir, file)):
                image_name = file
                annotation_name = os.path.splitext(image_name)[0] + '.xml'
                annotation_path = os.path.join(self.annotation_dir, annotation_name)

                if self._count_object_in_annotation(annotation_path) <= 1:
                    valid_image_files.append(image_name)
                else:
                    print(f'Image {image_name} has more than 1 object and will be excluded from dataset.')
        return valid_image_files
    
    def _count_object_in_annotation(self, annotation_path):
        try:
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            count = 0
            for obj in root.findall('object'):
                count += 1
            return count
        except FileNotFoundError:
            return 0

    def _parse_annotation(self, annotation_path):
        try:
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            image_width = int(root.find('size/width').text)
            image_height = int(root.find('size/height').text)
            label = None
            bbox = None
            for obj in root.findall('object'):
                name = obj.find('name').text
                # we only consider image with 1 object at the moment
                if label is None:
                    label = name
                    break

                xmin = int(obj.find('bndbox/xmin').text)
                xmax = int(obj.find('bndbox/xmax').text)
                ymin = int(obj.find('bndbox/ymin').text)
                ymax = int(obj.find('bndbox/ymax').text)

                bbox = [
                    xmin / image_width,
                    ymin / image_height,
                    xmax / image_width,
                    ymax / image_height
                ]

            # convert label to numerical representation
            label_num = 0 if label == 'cat' else 1 if label == 'dog' else -1
            return label_num, torch.tensor(bbox, dtype=torch.float32)
        except FileNotFoundError:
            print(f'File {annotation_path} not found.')

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_name)
        annotation_name = os.path.splitext(image_name)[0] + '.xml'
        annotation_path = os.path.join(self.annotation_dir, annotation_name)

        # load image
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        label, bbox = self._parse_annotation(annotation_path)
        return image, label, bbox

##### 1.3. Train test split and data loader:

In [5]:
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torch

annotation_dir = os.path.join(data_dir, 'annotations')
image_dir = os.path.join(data_dir, 'images')

data_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = ImageDataset(image_dir, annotation_dir, data_transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
generator = torch.Generator().manual_seed(42)
train_set, val_set = random_split(dataset=dataset, lengths=[train_size, val_size], generator=generator)

# data loader
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

Image Cats_Test736.png has more than 1 object and will be excluded from dataset.


### 2. Build a model with 2 heads (1 for classification, 1 for bbox regression)

In [6]:
import torch.nn as nn
import torchvision

class TwoHeadModel(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        # resnet18 backbone
        self.backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

        # remove fc of original resnet18
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # classification head
        self.classifier = nn.Linear(in_features=num_features, out_features=num_classes)

        # regression head
        self.bbox_regressor = nn.Linear(in_features=num_features, out_features=4)

    def forward(self, x):
        x = self.backbone(x)
        classifier_logits = self.classifier(x)
        bbox_coords = torch.sigmoid(self.bbox_regressor(x))
        return classifier_logits, bbox_coords