In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [85]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import clip

import torchvision
from torch.nn.functional import cross_entropy

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm


In [319]:
train_ann_df = pd.read_csv('./data/train_data.csv')
novel_train_ann_df = pd.read_csv('./data/novel_train_data.csv')
combined_df = pd.concat([train_ann_df, novel_train_ann_df], ignore_index=True)

#test_ann_df = pd.read_csv('/content/drive/MyDrive/data/test_predictions.csv')
super_map_df = pd.read_csv('./data/superclass_mapping.csv')
sub_map_df = pd.read_csv('./data/subclass_mapping.csv')

train_img_dir = './data/train_images'
novel_img_dir = './data/synthetic_novel_resized'
test_img_dir = './data/test_images'
combined_dir = './data/combined_images'     # Folder contains both original training images and novel generated resized images

In [321]:
# Create Dataset class for multilabel classification

class MultiClassImageTrainDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        description = self.ann_df['description'][idx]

        if self.transform:
            image = self.transform(image)


        return image, super_idx, super_label, sub_idx, sub_label, description

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [322]:
# Create LightWeightAdapter class

class LightWeightAdapter(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=128):
        super().__init__()
        self.adapter = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.adapter(x) + x    # residual connection

In [421]:
# Create Dual HeadCLIP class

class DualHeadCLIP(nn.Module):
    def __init__(self, clip_model, adapter, super_head, sub_head, criterion, optimizer,
                 train_loader, val_loader, test_loader=None, device=None):
        super().__init__()
        self.device = device
        self.clip_model = clip_model.to(self.device)
        self.adapter = adapter.to(self.device)
        self.super_head = super_head.to(self.device)
        self.sub_head = sub_head.to(self.device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        self.adapter.train()
        self.super_head.train()
        self.sub_head.train()
        self.clip_model.eval() # Freeze CLIP

        running_loss = 0.0
        for i, data in enumerate(tqdm(self.train_loader, desc='Training')):
            inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

            with torch.no_grad():
                image_features = self.clip_model.encode_image(inputs).to(self.device).to(torch.float32)

            adapted_features = self.adapter(image_features)
            super_outputs = self.super_head(adapted_features)
            sub_outputs = self.sub_head(adapted_features)

            loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            #print(running_loss)

        print(f'Training loss: {running_loss / (i + 1):.3f}')

    def validate_epoch(self):
        self.adapter.eval()
        self.super_head.eval()
        self.sub_head.eval()
        self.clip_model.eval()

        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(tqdm(self.val_loader, desc='Validating')):
                inputs, super_labels, sub_labels = data[0].to(self.device), data[1].to(self.device), data[3].to(self.device)

                image_features = self.clip_model.encode_image(inputs).to(self.device).to(torch.float32)

                adapted_features = self.adapter(image_features)
                super_outputs = self.super_head(adapted_features)
                sub_outputs = self.sub_head(adapted_features)

                loss = self.criterion(super_outputs, super_labels) + self.criterion(sub_outputs, sub_labels)

                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()


        print(f'Validation loss: {running_loss / (i + 1):.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

In [422]:
def test_dualheadclip(model, test_loader, SUPER_THRESH, SUB_THRESH, save_to_csv=False, return_predictions=False):
    model.eval()
    model.adapter.eval()
    model.super_head.eval()
    model.sub_head.eval()
    model.clip_model.eval()

    test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}

    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader, desc='Testing')):
            inputs, img_name = data[0].to(model.device), data[1]

            image_features = model.clip_model.encode_image(inputs).to(model.device).to(torch.float32)
            adapted_features = model.adapter(image_features)

            super_outputs = model.super_head(adapted_features)
            sub_outputs = model.sub_head(adapted_features)

            super_probs = F.softmax(super_outputs, dim=1)
            sub_probs = F.softmax(sub_outputs, dim=1)

            super_conf, super_pred_head = torch.max(super_probs, dim=1)
            sub_conf, sub_pred_head = torch.max(sub_probs, dim=1)

            super_pred = torch.where(super_conf > SUPER_THRESH, super_pred_head, torch.tensor(3).to(model.device))
            sub_pred = torch.where(sub_conf > SUB_THRESH, sub_pred_head, torch.tensor(87).to(model.device))

            test_predictions['image'].append(img_name[0])
            test_predictions['superclass_index'].append(super_pred.item())
            test_predictions['subclass_index'].append(sub_pred.item())

    test_predictions = pd.DataFrame(data=test_predictions)

    if save_to_csv:
        test_predictions.to_csv('./test_predictions/DualHeadCLIP_test_predictions.csv', index=False)

    if return_predictions:
        return test_predictions

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [423]:
# Load clip model
clip_model, preprocess = clip.load("ViT-B/32", device=device)
image_dim = clip_model.visual.output_dim

# Define super and sub head
super_head = nn.Linear(image_dim, 4)
sub_head = nn.Linear(image_dim, 88)

In [424]:
generator = torch.Generator().manual_seed(42)

# Create train and val split
train_dataset = MultiClassImageTrainDataset(combined_df, super_map_df, sub_map_df, combined_dir, transform=preprocess)
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1], generator=generator)

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=preprocess)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
						  generator=generator)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True,
						generator=generator)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)

In [425]:
# Define loss and optimizer
adapter = LightWeightAdapter(input_dim=512, hidden_dim=128)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(
    list(adapter.parameters()) +
    list(super_head.parameters()) +
    list(sub_head.parameters()),
    lr=2e-4,
    weight_decay=0.01
)

trainer = DualHeadCLIP(
    clip_model=clip_model,
    adapter = adapter,
    super_head = super_head,
    sub_head = sub_head,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
	device=device
)

In [426]:
# Training loop
for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Epoch 1


Training: 100%|██████████| 194/194 [00:25<00:00,  7.59it/s]


Training loss: 4.170


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.41it/s]


Validation loss: 3.274
Validation superclass acc: 99.27 %
Validation subclass acc: 35.17 %

Epoch 2


Training: 100%|██████████| 194/194 [00:24<00:00,  7.91it/s]


Training loss: 2.764


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.46it/s]


Validation loss: 2.330
Validation superclass acc: 98.84 %
Validation subclass acc: 74.27 %

Epoch 3


Training: 100%|██████████| 194/194 [00:26<00:00,  7.30it/s]


Training loss: 2.098


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.36it/s]


Validation loss: 1.939
Validation superclass acc: 98.98 %
Validation subclass acc: 86.19 %

Epoch 4


Training: 100%|██████████| 194/194 [00:23<00:00,  8.25it/s]


Training loss: 1.834


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.94it/s]


Validation loss: 1.775
Validation superclass acc: 98.98 %
Validation subclass acc: 88.08 %

Epoch 5


Training: 100%|██████████| 194/194 [00:25<00:00,  7.69it/s]


Training loss: 1.719


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.19it/s]


Validation loss: 1.707
Validation superclass acc: 98.98 %
Validation subclass acc: 89.68 %

Epoch 6


Training: 100%|██████████| 194/194 [00:25<00:00,  7.66it/s]


Training loss: 1.648


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.10it/s]


Validation loss: 1.645
Validation superclass acc: 98.98 %
Validation subclass acc: 90.84 %

Epoch 7


Training: 100%|██████████| 194/194 [00:24<00:00,  7.77it/s]


Training loss: 1.601


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.78it/s]


Validation loss: 1.614
Validation superclass acc: 99.13 %
Validation subclass acc: 90.99 %

Epoch 8


Training: 100%|██████████| 194/194 [00:24<00:00,  7.88it/s]


Training loss: 1.564


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.96it/s]


Validation loss: 1.588
Validation superclass acc: 99.13 %
Validation subclass acc: 90.99 %

Epoch 9


Training: 100%|██████████| 194/194 [00:25<00:00,  7.73it/s]


Training loss: 1.537


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.07it/s]


Validation loss: 1.565
Validation superclass acc: 99.13 %
Validation subclass acc: 91.72 %

Epoch 10


Training: 100%|██████████| 194/194 [00:24<00:00,  7.83it/s]


Training loss: 1.513


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.11it/s]


Validation loss: 1.555
Validation superclass acc: 99.13 %
Validation subclass acc: 91.86 %

Epoch 11


Training: 100%|██████████| 194/194 [00:24<00:00,  7.87it/s]


Training loss: 1.492


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.37it/s]


Validation loss: 1.536
Validation superclass acc: 99.27 %
Validation subclass acc: 92.15 %

Epoch 12


Training: 100%|██████████| 194/194 [00:24<00:00,  7.93it/s]


Training loss: 1.475


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.95it/s]


Validation loss: 1.518
Validation superclass acc: 99.13 %
Validation subclass acc: 92.30 %

Epoch 13


Training: 100%|██████████| 194/194 [00:24<00:00,  7.99it/s]


Training loss: 1.459


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.14it/s]


Validation loss: 1.511
Validation superclass acc: 99.27 %
Validation subclass acc: 92.73 %

Epoch 14


Training: 100%|██████████| 194/194 [00:24<00:00,  7.98it/s]


Training loss: 1.444


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.31it/s]


Validation loss: 1.497
Validation superclass acc: 99.13 %
Validation subclass acc: 92.30 %

Epoch 15


Training: 100%|██████████| 194/194 [00:23<00:00,  8.16it/s]


Training loss: 1.431


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.45it/s]


Validation loss: 1.487
Validation superclass acc: 99.27 %
Validation subclass acc: 93.46 %

Epoch 16


Training: 100%|██████████| 194/194 [00:23<00:00,  8.19it/s]


Training loss: 1.418


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.32it/s]


Validation loss: 1.479
Validation superclass acc: 99.42 %
Validation subclass acc: 92.59 %

Epoch 17


Training: 100%|██████████| 194/194 [00:23<00:00,  8.20it/s]


Training loss: 1.408


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.65it/s]


Validation loss: 1.467
Validation superclass acc: 99.13 %
Validation subclass acc: 93.17 %

Epoch 18


Training: 100%|██████████| 194/194 [00:24<00:00,  7.92it/s]


Training loss: 1.398


Validating: 100%|██████████| 22/22 [00:02<00:00,  7.92it/s]


Validation loss: 1.472
Validation superclass acc: 99.27 %
Validation subclass acc: 92.88 %

Epoch 19


Training: 100%|██████████| 194/194 [00:24<00:00,  7.83it/s]


Training loss: 1.386


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.39it/s]


Validation loss: 1.458
Validation superclass acc: 99.27 %
Validation subclass acc: 93.46 %

Epoch 20


Training: 100%|██████████| 194/194 [00:23<00:00,  8.24it/s]


Training loss: 1.377


Validating: 100%|██████████| 22/22 [00:02<00:00,  8.16it/s]

Validation loss: 1.455
Validation superclass acc: 99.42 %
Validation subclass acc: 93.75 %

Finished Training





In [442]:
test_predictions = test_dualheadclip(
    model=trainer,
    test_loader=test_loader,
    SUPER_THRESH=0.84,
    SUB_THRESH=0.47,
    save_to_csv=True,
    return_predictions=True
)

Testing: 100%|██████████| 11180/11180 [02:46<00:00, 67.10it/s]


In [444]:
test_ann_df = pd.read_csv('./test_predictions/DualHeadCLIP_test_predictions.csv')

sup_matches = test_ann_df[test_ann_df['superclass_index'] == 3]
sub_matches = test_ann_df[test_ann_df['subclass_index'] == 87]

print("Matching indices:", len(sup_matches.index.tolist()))
print("Matching indices:", len(sub_matches.index.tolist()))


Matching indices: 3106
Matching indices: 7385
