In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [2]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df 
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]
        
        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]
        
        if self.transform:
            image = self.transform(image)  
            
        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)  
            
        return image, img_name

In [4]:
train_ann_df = pd.read_csv('./data/train_data.csv')
#test_ann_df = pd.read_csv('./data/example_test_predictions.csv')
super_map_df = pd.read_csv('./data/superclass_mapping.csv')
sub_map_df = pd.read_csv('./data/subclass_mapping.csv')

train_img_dir = './data/train_images'
test_img_dir = './data/test_images'

image_preprocessing = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=image_preprocessing)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1]) 

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=image_preprocessing)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

val_loader = DataLoader(val_dataset, 
                        batch_size=batch_size, 
                        shuffle=True)

test_loader = DataLoader(test_dataset, 
                         batch_size=1, 
                         shuffle=False)

In [10]:
class LeNet5MultiOutput(nn.Module):
    def __init__(self):
        super(LeNet5MultiOutput, self).__init__()

        # Block 1: Conv1 + ReLU + AvgPool
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=2),  # (B, 3, 64, 64) → (B, 6, 64, 64)
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)                 # → (B, 6, 32, 32)
        )

        # Block 2: Conv2 + ReLU + AvgPool
        self.block2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1),            # (B, 6, 32, 32) → (B, 16, 28, 28)
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)                 # → (B, 16, 14, 14)
        )

        # Block 3: Fully connected + 2 classification heads
        self.fc1 = nn.Linear(16 * 14 * 14, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3a = nn.Linear(84, 4)   # Super-class output (bird, dog, reptile, novel)
        self.fc3b = nn.Linear(84, 88)  # Sub-class output (87 known + 1 novel)

    def forward(self, x):
        x = self.block1(x)            # → (B, 6, 32, 32)
        x = self.block2(x)            # → (B, 16, 14, 14)
        x = torch.flatten(x, 1)       # → (B, 3136)
        x = F.relu(self.fc1(x))       # → (B, 120)
        x = F.relu(self.fc2(x))       # → (B, 84)
        super_out = self.fc3a(x)      # → (B, 4)
        sub_out = self.fc3b(x)        # → (B, 88)
        return super_out, sub_out

class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device=None):
        self.device = device
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

                super_outputs, sub_outputs = self.model(inputs)
                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(self.device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            test_predictions.to_csv('./test_predictions/LeNet5_test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [11]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [12]:
# Init model and trainer
model = LeNet5MultiOutput().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader, device)

In [13]:
# Training loop
for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1
Training loss: 4.685
Validation loss: 4.062
Validation superclass acc: 82.80 %
Validation subclass acc: 15.45 %

Epoch 2
Training loss: 3.005
Validation loss: 3.145
Validation superclass acc: 85.99 %
Validation subclass acc: 31.37 %

Epoch 3
Training loss: 2.122
Validation loss: 2.301
Validation superclass acc: 92.20 %
Validation subclass acc: 47.77 %

Epoch 4
Training loss: 1.583
Validation loss: 1.995
Validation superclass acc: 92.99 %
Validation subclass acc: 53.82 %

Epoch 5
Training loss: 1.238
Validation loss: 1.806
Validation superclass acc: 94.75 %
Validation subclass acc: 56.53 %

Epoch 6
Training loss: 0.991
Validation loss: 1.802
Validation superclass acc: 93.31 %
Validation subclass acc: 60.67 %

Epoch 7
Training loss: 0.788
Validation loss: 1.764
Validation superclass acc: 94.11 %
Validation subclass acc: 60.03 %

Epoch 8
Training loss: 0.632
Validation loss: 1.802
Validation superclass acc: 94.43 %
Validation subclass acc: 59.87 %

Epoch 9
Training loss: 0.482
Val

In [14]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)