In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel

import os
import random
import requests
import shutil
import itertools
import numpy as np
from collections import Counter
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [2]:
class CFG:
    captions_path = '/kaggle/input/animal-captions/animal10'
    working_path = '/kaggle/working/'
    labels = [label.split('.')[0] for label in os.listdir(captions_path)]
    seen_classes = labels[:8]
    unseen_classes = labels[8:]
    batch_size = 32
    num_workers = 2
    epochs = 20
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Model
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 300
    
    # Training hyperparameters
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 3

In [3]:
#--- Text Model --- 
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, trainable=True):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)

        for param in self.model.parameters():
            param.requires_grad = trainable

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask)
        return outputs.last_hidden_state[:, 0]

In [4]:
class CaptionDataset(Dataset):
    def __init__(self, root_dir, tokenizer):
        self.root_dir = root_dir
        self.classes = sorted([animal_class.split('.')[0] for animal_class in os.listdir(self.root_dir)])
        self.class_to_index = {class_name: idx for idx, class_name in enumerate(CFG.labels)}
        
        self.caption = []
        self.label = []
        
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            caption_dir = os.path.join(CFG.captions_path, class_name+'.txt')
            
            with open(caption_dir, 'r') as file:
                for line in file:
                    parts = line.strip().split(',')
                    self.caption.append(parts[1])
                    self.label.append(class_name)
                    
        self.encoded_captions = tokenizer(
            self.caption, padding=True, truncation=True, max_length=CFG.max_length
        )  
        
    def __len__(self):
        return len(self.caption)
    
    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }
        
        item['caption'] = self.caption[idx]
        item['label'] = self.class_to_index[self.label[idx]]
        return item

In [5]:
def copy_description(source_dir, destination_dir, classes):
    os.makedirs(destination_dir, exist_ok=True)
    
    for class_animal in classes:
        source = f"{source_dir}/{class_animal}.txt"
        destination = f"{destination_dir}/{class_animal}.txt"
        try:
            shutil.copyfile(source, destination)
        except FileNotFoundError:
                    print(f"Caption file {source} not found for {class_animal}")

In [6]:
copy_description(CFG.captions_path, os.path.join(CFG.working_path, 'seen'), CFG.seen_classes)
copy_description(CFG.captions_path, os.path.join(CFG.working_path, 'unseen'), CFG.unseen_classes)

In [7]:
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [8]:
seen_dataset = CaptionDataset(os.path.join(CFG.working_path, 'seen'), tokenizer)
unseen_dataset = CaptionDataset(os.path.join(CFG.working_path, 'unseen'), tokenizer)

In [9]:
# Create DataLoader for training and validation from seen_dataset
num_train = int(len(seen_dataset) * 0.6)
num_val = int(len(seen_dataset) * 0.2)
num_test = len(seen_dataset) - num_train - num_val

train_dataset, val_dataset, seen_test_dataset = random_split(seen_dataset, [num_train, num_val, num_test])

train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

# Create DataLoader for testing from seen_test_dataset
seen_test_loader = DataLoader(seen_test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

# Create DataLoader for testing from unseen_dataset
test_loader = DataLoader(unseen_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

In [10]:
model = TextEncoder()
model.to(CFG.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.text_encoder_lr)
criterion = nn.CrossEntropyLoss()

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [11]:
# Training Loop
def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, patience):
    best_val_loss = float('inf')
    counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc='Training', total=len(train_loader)):
            input_ids = batch['input_ids'].to(CFG.device)
            attention_mask = batch['attention_mask'].to(CFG.device)
            labels = batch['label'].to(CFG.device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * input_ids.size(0)
        
        train_loss /= len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation', total=len(val_loader)):
                input_ids = batch['input_ids'].to(CFG.device)
                attention_mask = batch['attention_mask'].to(CFG.device)
                labels = batch['label'].to(CFG.device)
                
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * input_ids.size(0)
        
        val_loss /= len(val_loader.dataset)
        
        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Training Loss: {train_loss:.4f}, '
              f'Validation Loss: {val_loss:.4f}')
        
        # Check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping")
                break

In [12]:
def evaluate(model, loader, dataset_type):
    model.eval()
    test_loss = 0.0
    correct = 0
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc='Testing', total=len(loader)):
            input_ids = batch['input_ids'].to(CFG.device)
            attention_mask = batch['attention_mask'].to(CFG.device)
            labels = batch['label'].to(CFG.device)
            
            outputs = model(input_ids, attention_mask)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(predicted.cpu().numpy())

    test_accuracy = accuracy_score(true_labels, pred_labels)
    test_precision = precision_score(true_labels, pred_labels, average='macro')
    test_recall = recall_score(true_labels, pred_labels, average='macro')

    print(f'Accuracy on {dataset_type} dataset: {test_accuracy:.4f}')
    print(f'Precision on {dataset_type} dataset: {test_precision:.4f}')
    print(f'Recall on {dataset_type} dataset: {test_recall:.4f}')

In [13]:
# Train the model
train(model, train_loader, val_loader, optimizer, criterion, CFG.epochs, CFG.patience)

Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 1/20, Training Loss: 0.9998, Validation Loss: 0.1230


Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 2/20, Training Loss: 0.1107, Validation Loss: 0.0905


Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 3/20, Training Loss: 0.0924, Validation Loss: 0.0881


Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 4/20, Training Loss: 0.0752, Validation Loss: 0.0914


Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 5/20, Training Loss: 0.0658, Validation Loss: 0.0964


Training:   0%|          | 0/343 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

Epoch 6/20, Training Loss: 0.0569, Validation Loss: 0.1000
Early stopping


In [14]:
# Evaluate on seen test dataset
model.load_state_dict(torch.load('best_model.pth'))
evaluate(model, seen_test_loader, 'seen test')

Testing:   0%|          | 0/115 [00:00<?, ?it/s]

Accuracy on seen test dataset: 0.9828
Precision on seen test dataset: 0.9814
Recall on seen test dataset: 0.9821


In [15]:
# Evaluate on unseen dataset
evaluate(model, test_loader, 'unseen')

Testing:   0%|          | 0/186 [00:00<?, ?it/s]

Accuracy on unseen dataset: 0.0000
Precision on unseen dataset: 0.0000
Recall on unseen dataset: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
