In [14]:
import os
from datetime import datetime
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

from torchvision import transforms
from torchvision.transforms import InterpolationMode

In [5]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df 
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]
        
        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]
        
        if self.transform:
            image = self.transform(image)  
            
        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)  
            
        return image, img_name

In [6]:
train_ann_df = pd.read_csv('/Users/liu/Desktop/NNDL_final_proj/Released_Data_NNDL_2025/train_data_novel.csv')
super_map_df = pd.read_csv('/Users/liu/Desktop/NNDL_final_proj/Released_Data_NNDL_2025/superclass_mapping.csv')
sub_map_df = pd.read_csv('/Users/liu/Desktop/NNDL_final_proj/Released_Data_NNDL_2025/subclass_mapping.csv')

train_img_dir = '/Users/liu/Desktop/NNDL_final_proj/Released_Data_NNDL_2025/train_images_with_novel'
test_img_dir = '/Users/liu/Desktop/NNDL_final_proj/Released_Data_NNDL_2025/test_images'

# image_preprocessing = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0), std=(1)),
# ])

image_preprocessing = transforms.Compose([
    transforms.Resize((64, 64), interpolation=InterpolationMode.LANCZOS),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=image_preprocessing)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1]) 

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=image_preprocessing)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

val_loader = DataLoader(val_dataset, 
                        batch_size=batch_size, 
                        shuffle=True)

test_loader = DataLoader(test_dataset, 
                         batch_size=1, 
                         shuffle=False)

In [7]:
# Model Definition
class CNN(nn.Module):
    def __init__(self, input_size=64):
        super().__init__()
        
        self.feature_size = input_size // (2**3)
        
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, 3, padding='same'), 
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3, padding='same'), 
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc1 = nn.Linear(self.feature_size * self.feature_size * 128, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3a = nn.Linear(128, 4)
        self.fc3b = nn.Linear(128, 88)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        super_out = self.fc3a(x)
        sub_out = self.fc3b(x)
        return super_out, sub_out

In [None]:
# Trainer
class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct_all = 0
        sub_correct_all = 0
        super_correct_seen = 0
        super_correct_novel = 0
        sub_correct_seen = 0
        sub_correct_novel = 0
        total = 0
        seen_super_total = 0
        novel_super_total = 0
        seen_sub_total = 0
        novel_sub_total = 0
        running_loss = 0.0
        ce_super_total = 0.0
        ce_sub_total = 0.0

        # Define novel class indices
        novel_super_indices = [3] 
        novel_sub_indices = [87]   

        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, _, sub_labels, _ = data
                inputs = inputs.to(device)
                super_labels = super_labels.to(device)
                sub_labels = sub_labels.to(device)

                super_outputs, sub_outputs = self.model(inputs)

                # Separate CE losses
                ce_super = self.criterion(super_outputs, super_labels)
                ce_sub = self.criterion(sub_outputs, sub_labels)
                loss = ce_super + ce_sub

                running_loss += loss.item()
                ce_super_total += ce_super.item()
                ce_sub_total += ce_sub.item()

                _, super_preds = torch.max(super_outputs, 1)
                _, sub_preds = torch.max(sub_outputs, 1)

                total += super_labels.size(0)
                super_correct_all += (super_preds == super_labels).sum().item()
                sub_correct_all += (sub_preds == sub_labels).sum().item()

                # Superclass: Seen vs Novel
                for j in range(super_labels.size(0)):
                    label = super_labels[j].item()
                    if label in novel_super_indices:
                        novel_super_total += 1
                        if super_preds[j] == super_labels[j]:
                            super_correct_novel += 1
                    else:
                        seen_super_total += 1
                        if super_preds[j] == super_labels[j]:
                            super_correct_seen += 1

                # Subclass: Seen vs Novel
                for j in range(sub_labels.size(0)):
                    label = sub_labels[j].item()
                    if label in novel_sub_indices:
                        novel_sub_total += 1
                        if sub_preds[j] == sub_labels[j]:
                            sub_correct_novel += 1
                    else:
                        seen_sub_total += 1
                        if sub_preds[j] == sub_labels[j]:
                            sub_correct_seen += 1

        # Avoid division by zero
        seen_super_acc = 100 * super_correct_seen / seen_super_total if seen_super_total > 0 else 0
        novel_super_acc = 100 * super_correct_novel / novel_super_total if novel_super_total > 0 else 0
        seen_sub_acc = 100 * sub_correct_seen / seen_sub_total if seen_sub_total > 0 else 0
        novel_sub_acc = 100 * sub_correct_novel / novel_sub_total if novel_sub_total > 0 else 0

        # Final Output
        overall_cross_entropy = running_loss / len(self.val_loader)
        print(f'Cross-Entropy: Superclass={ce_super_total / len(self.val_loader):.4f} | Subclass={ce_sub_total / len(self.val_loader):.4f}')
        print(f'Overall Cross-Entropy Loss: {overall_cross_entropy:.4f}')
        print(f'Superclass Acc: Overall={100*super_correct_all/total:.2f}% | Seen={seen_super_acc:.2f}% | Novel={novel_super_acc:.2f}%')
        print(f'Subclass  Acc: Overall={100*sub_correct_all/total:.2f}% | Seen={seen_sub_acc:.2f}% | Novel={novel_sub_acc:.2f}%')

        return overall_cross_entropy

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(device), data[1]
        
                super_outputs, sub_outputs = self.model(inputs)
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)
                
                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())
                
        test_predictions = pd.DataFrame(data=test_predictions)
        
        if save_to_csv:
            test_predictions.to_csv('test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [9]:
# Training Setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
model = CNN(input_size=64).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader)

Using device: cpu


In [None]:
# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0
best_epoch = 0
best_model_path = "best_model.pth"

num_epochs = 30
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()

    val_loss = trainer.validate_epoch()

    # Early stopping based on val_loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {epoch+1} with val loss {best_val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

print(f'Finished Training. Best model at epoch {best_epoch} with val loss {best_val_loss:.4f}')

Epoch 1
Training loss: 1.089
Cross-Entropy: Superclass=0.3869 | Subclass=0.9251
Overall Cross-Entropy Loss: 1.3121
Superclass Acc: Overall=85.96% | Seen=90.00% | Novel=41.10%
Subclass  Acc: Overall=70.78% | Seen=62.75% | Novel=91.80%
New best model saved at epoch 1 with val loss 1.3121%
Epoch 2
Training loss: 0.780
Cross-Entropy: Superclass=0.4357 | Subclass=0.7847
Overall Cross-Entropy Loss: 1.2204
Superclass Acc: Overall=83.92% | Seen=87.41% | Novel=45.21%
Subclass  Acc: Overall=75.08% | Seen=68.86% | Novel=91.39%
New best model saved at epoch 2 with val loss 1.2204%
Epoch 3
Training loss: 0.569
Cross-Entropy: Superclass=0.3562 | Subclass=0.8248
Overall Cross-Entropy Loss: 1.1809
Superclass Acc: Overall=87.09% | Seen=90.49% | Novel=49.32%
Subclass  Acc: Overall=76.44% | Seen=68.54% | Novel=97.13%
New best model saved at epoch 3 with val loss 1.1809%
Epoch 4
Training loss: 0.443
Cross-Entropy: Superclass=0.4185 | Subclass=0.7203
Overall Cross-Entropy Loss: 1.1388
Superclass Acc: Overa

In [13]:
# Test and Save Prediction 
# test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

model.load_state_dict(torch.load(best_model_path))
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

In [15]:
model.load_state_dict(torch.load(best_model_path))
test_predictions = trainer.test(save_to_csv=False, return_predictions=True)

notebook_name = "NNDL_CNN_early_stopping.ipynb"

output_dir = "test_predictions"
os.makedirs(output_dir, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f'test_predictions_{timestamp}'
csv_filename = os.path.join(output_dir, base_name + ".csv")
meta_filename = os.path.join(output_dir, base_name + "_info.txt")

test_predictions.to_csv(csv_filename, index=False)

with open(meta_filename, 'w') as f:
    f.write(f"Best Epoch: {best_epoch}\n")
    f.write(f"Validation Loss: {best_val_loss:.4f}\n")
    f.write(f"Notebook Source: {notebook_name}\n")