In [74]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import clip

import torchvision
from torch.nn.functional import cross_entropy

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [75]:
# Create Dataset class for multilabel classification
class CLIPFineTuneDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df 
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        description = self.ann_df['description'][idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label, description

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)  
            
        return image, img_name

In [76]:
train_ann_df = pd.read_csv('data/train_data.csv')
test_ann_df = pd.read_csv('data/example_test_predictions.csv')
super_map_df = pd.read_csv('data/superclass_mapping.csv')
sub_map_df = pd.read_csv('data/subclass_mapping.csv')

train_img_dir = 'data/train_images'
test_img_dir = 'data/test_images'

image_preprocessing = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match CLIP expectations
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))  # CLIP's normalization
])

# Create train and val split
train_dataset = CLIPFineTuneDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=image_preprocessing)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1]) 

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=image_preprocessing)

# Create dataloaders
batch_size = 64
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

val_loader = DataLoader(val_dataset, 
                        batch_size=batch_size, 
                        shuffle=True)

test_loader = DataLoader(test_dataset, 
                         batch_size=1, 
                         shuffle=False)

In [77]:
class CLIPClassifierTrainer():
    def __init__(self, clip_model, criterion, optimizer, train_loader, val_loader, test_loader=None, device=None):
        self.clip_model = clip_model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device or (torch.device('mps' if torch.backends.mps.is_available() else 'cpu'))

        for p in self.clip_model.parameters():
            p.requires_grad = True

        self.super_head = nn.Linear(512, 4).to(self.device)
        self.sub_head = nn.Linear(512, 88).to(self.device)

    def train_epoch(self):
        self.super_head.train()
        self.sub_head.train()
        self.clip_model.eval()

        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels, texts = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device), data[5]

            text_tokens = clip.tokenize(list(texts)).to(self.device)

            image_features = self.clip_model.encode_image(inputs)
            text_features = self.clip_model.encode_text(text_tokens)

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            assert image_features.size(0) == text_features.size(0), "Batch size mismatch between images and texts"

            logits_per_image = 100.0 * image_features @ text_features.T
            logits_per_text = logits_per_image.T

            ground_truth = torch.arange(len(inputs), device=self.device)
            contrastive_loss = (self.criterion(logits_per_image, ground_truth) + self.criterion(logits_per_text, ground_truth)) / 2

            super_outputs = self.super_head(image_features)
            sub_outputs = self.sub_head(image_features)
            classification_loss = F.cross_entropy(super_outputs, super_labels) + F.cross_entropy(sub_outputs, sub_labels)

            total_loss = contrastive_loss + classification_loss

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            running_loss += total_loss.item()

        print(f'Training loss: {running_loss / (i + 1):.3f}')

    def validate_epoch(self):
        self.super_head.eval()
        self.sub_head.eval()
        self.clip_model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, sub_labels, texts = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device), data[5]

                text_tokens = clip.tokenize(list(texts)).squeeze(1).to(self.device)

                image_features = self.clip_model.encode_image(inputs)
                text_features = self.clip_model.encode_text(text_tokens)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                assert image_features.size(0) == text_features.size(0), "Batch size mismatch between images and texts"

                logits_per_image = 100.0 * image_features @ text_features.T
                logits_per_text = logits_per_image.T

                ground_truth = torch.arange(len(inputs), device=self.device)
                contrastive_loss = (self.criterion(logits_per_image, ground_truth) + self.criterion(logits_per_text, ground_truth)) / 2

                super_outputs = self.super_head(image_features)
                sub_outputs = self.sub_head(image_features)
                classification_loss = F.cross_entropy(super_outputs, super_labels) + F.cross_entropy(sub_outputs, sub_labels)

                total_loss = contrastive_loss + classification_loss
                running_loss += total_loss.item()

                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()

        print(f'Validation loss: {running_loss / (i + 1):.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(self.device), data[1]
                image_features = self.clip_model.encode_image(inputs)
                super_outputs = self.super_head(image_features)
                sub_outputs = self.sub_head(image_features)

                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            test_predictions.to_csv('example_test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [78]:
# Init model and trainer
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
clip_model, preprocess = clip.load("ViT-B/32", device=device)
super_head = nn.Linear(512, 4).to(device)
sub_head = nn.Linear(512, 88).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    list(clip_model.parameters()) + list(super_head.parameters()) + list(sub_head.parameters()), lr=1e-5
)
trainer = CLIPClassifierTrainer(clip_model, criterion, optimizer, train_loader, val_loader, test_loader)
trainer.super_head = super_head
trainer.sub_head = sub_head

In [None]:
# Training loop
for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1


In [7]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

In [8]:
# Quick script for evaluating generated csv files with ground truth

super_correct = 0
sub_correct = 0
seen_super_correct = 0
seen_sub_correct = 0
unseen_super_correct = 0
unseen_sub_correct = 0

total = 0
seen_super_total = 0
unseen_super_total = 0
seen_sub_total = 0
unseen_sub_total = 0

for i in range(len(test_predictions)):
    super_pred = test_predictions['superclass_index'][i]
    sub_pred = test_predictions['subclass_index'][i]

    super_gt = test_ann_df['superclass_index'][i]
    sub_gt = test_ann_df['subclass_index'][i]

    # Total setting
    if super_pred == super_gt:
        super_correct += 1
    if sub_pred == sub_gt:
        sub_correct += 1
    total += 1

    # Unseen superclass setting
    if super_gt == 3:
        if super_pred == super_gt:
            unseen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        unseen_super_total += 1
        unseen_sub_total += 1
    
    # Seen superclass, unseen subclass setting
    if super_gt != 3 and sub_gt == 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        seen_super_total += 1
        unseen_sub_total += 1

    # Seen superclass and subclass setting
    if super_gt != 3 and sub_gt != 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            seen_sub_correct += 1
        seen_super_total += 1
        seen_sub_total += 1

print('Superclass Accuracy')
print(f'Overall: {100*super_correct/total:.2f} %')
print(f'Seen: {100*seen_super_correct/seen_super_total:.2f} %')
print(f'Unseen: {100*unseen_super_correct/unseen_super_total:.2f} %')

print('\nSubclass Accuracy')
print(f'Overall: {100*sub_correct/total:.2f} %')
print(f'Seen: {100*seen_sub_correct/seen_sub_total:.2f} %')
print(f'Unseen: {100*unseen_sub_correct/unseen_sub_total:.2f} %')

Superclass Accuracy
Overall: 43.83 %
Seen: 61.11 %
Unseen: 0.00 %

Subclass Accuracy
Overall: 2.03 %
Seen: 9.56 %
Unseen: 0.00 %
