In [1]:
import os
import kagglehub
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from PIL import Image

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import ResNet18_Weights

In [2]:
DATA_DIR = kagglehub.dataset_download("andrewmvd/dog-and-cat-detection")
print("Path to dataset files: ", DATA_DIR)

Downloading from https://www.kaggle.com/api/v1/datasets/download/andrewmvd/dog-and-cat-detection?dataset_version_number=1...


100%|██████████| 1.03G/1.03G [00:06<00:00, 166MB/s]

Extracting files...





Path to dataset files:  /root/.cache/kagglehub/datasets/andrewmvd/dog-and-cat-detection/versions/1


In [3]:
class ImageDataset(Dataset):
    def __init__(self, annotations_dir, image_dir, transform=None):
        self.annotations_dir = annotations_dir
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = self.filter_images_with_multiple_objects()

    def filter_images_with_multiple_objects(self):
        valid_images_files = []
        for f in os.listdir(self.image_dir):
            if os.path.isfile(os.path.join(self.image_dir, f)):
                img_name = f
                annotation_name = os.path.splitext(img_name)[0] + ".xml"
                annotation_path = os.path.join(self.annotations_dir, annotation_name)

                if self.count_objects_in_annotation(annotation_path) <= 1:
                    valid_images_files.append(img_name)
                else:
                    print(
                        f"Image {img_name} has multiple objects and will be excluded from the dataset"
                    )
        return valid_images_files

    def count_objects_in_annotation(self, annotation_path):
        try:
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            count = 0
            for obj in root.findall('object'):
                count += 1
            return count
        except FileNotFoundError:
            return 0

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)

        image = Image.open(img_path).convert('RGB')

        annotation_name = os.path.splitext(img_name)[0] + '.xml'
        annotation_path = os.path.join(self.annotations_dir, annotation_name)

        label, bbox = self.parse_annotation(annotation_path)

        if self.transform:
            image = self.transform(image)

        return image, label, bbox

    def parse_annotation(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        img_width = int(root.find('size').find('width').text)
        img_height = int(root.find('size').find('height').text)

        label = None
        for obj in root.findall('object'):
            name = obj.find('name').text
            if label is None:
                label = name

                xmin = int(obj.find('bndbox').find('xmin').text)
                ymin = int(obj.find('bndbox').find('ymin').text)
                xmax = int(obj.find('bndbox').find('xmax').text)
                ymax = int(obj.find('bndbox').find('ymax').text)

                bbox = [
                    xmin / img_width,
                    ymin / img_height,
                    xmax / img_width,
                    ymax / img_height,
                ]

        label_num = 0 if label == 'cat' else 1 if label == 'dog' else -1

        return label_num, torch.tensor(bbox, dtype=torch.float32)

In [4]:
annotations_dir = os.path.join(DATA_DIR, 'annotations')
image_dir = os.path.join(DATA_DIR, 'images')

image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
df = pd.DataFrame({'image_name': image_files})

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.486, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = ImageDataset(annotations_dir, image_dir, transform=transform)
val_dataset = ImageDataset(annotations_dir, image_dir, transform=transform)

train_dataset.image_files = [f for f in train_dataset.image_files if f in train_df['image_name'].values]
val_dataset.image_files = [f for f in val_dataset.image_files if f in val_df['image_name'].values]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Image Cats_Test736.png has multiple objects and will be excluded from the dataset
Image Cats_Test736.png has multiple objects and will be excluded from the dataset


In [5]:
class TwoHeadedModel(nn.Module):
    def __init__(self, num_classes=2):
        super(TwoHeadedModel, self).__init__()

        self.base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.num_ftrs = self.base_model.fc.in_features

        self.base_model.fc = nn.Identity()

        self.classifier = nn.Linear(self.num_ftrs, num_classes)

        self.regressor = nn.Linear(self.num_ftrs, 4)

    def forward(self, x):
        x = self.base_model(x)
        class_logits = self.classifier(x)
        bbox_cords = torch.sigmoid(self.regressor(x))

        return class_logits, bbox_cords

In [6]:
model = TwoHeadedModel()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion_class = nn.CrossEntropyLoss()
criterion_bbox = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]


In [7]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, targets, bboxs) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        bboxs = bboxs.to(device)

        score, bboxs_pred = model(data)
        loss_class = criterion_class(score, targets)
        loss_bbox = criterion_bbox(bboxs_pred, bboxs)
        loss = loss_class + loss_bbox

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        total_loss_bbox = 0
        total_samples = 0
        for data, targets, bboxs in val_loader:
            data = data.to(device)
            targets = targets.to(device)
            bboxs = bboxs.to(device)

            scores, bboxs_pred = model(data)
            _, predictions = scores.max(1)
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            total_loss_bbox += criterion_bbox(bboxs_pred, bboxs).item() * data.size(0)
            total_samples += data.size(0)

        avg_loss_bbox = total_loss_bbox / total_samples

        print(f'Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {correct/total*100:.2f}%, Avg. Bbox Loss: {avg_loss_bbox:.4f}')

Epoch 1/10, Validation Accuracy: 93.50%, Avg. Bbox Loss: 0.0180
Epoch 2/10, Validation Accuracy: 96.07%, Avg. Bbox Loss: 0.0135
Epoch 3/10, Validation Accuracy: 85.64%, Avg. Bbox Loss: 0.0113
Epoch 4/10, Validation Accuracy: 96.88%, Avg. Bbox Loss: 0.0095
Epoch 5/10, Validation Accuracy: 93.36%, Avg. Bbox Loss: 0.0112
Epoch 6/10, Validation Accuracy: 95.66%, Avg. Bbox Loss: 0.0125
Epoch 7/10, Validation Accuracy: 96.88%, Avg. Bbox Loss: 0.0108
Epoch 8/10, Validation Accuracy: 93.90%, Avg. Bbox Loss: 0.0091
Epoch 9/10, Validation Accuracy: 96.61%, Avg. Bbox Loss: 0.0078
Epoch 10/10, Validation Accuracy: 80.89%, Avg. Bbox Loss: 0.0111
