In [27]:
import os
import string
import time
import re
import random
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pickle import dump, load

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
import torchvision.models as models
from torchvision.models import ResNet50_Weights
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

from tqdm import tqdm

In [28]:
class Vocabulary:
    def __init__(self, min_freq=5):
        self.min_freq = min_freq
        self.itos = {0:"pad", 1:"startofseq", 2:"endofseq", 3:"unk"}
        self.stoi = {v:k for k,v in self.itos.items()}
        self.index = 4
    
    def __len__(self):
        return len(self.itos)
    
    def tokenizer(self, text):
        text = text.lower()
        tokens = re.findall(r"\w+", text) #splits the text into tokens based on punctuation
        return tokens

    def build_vocab(self, sentence_list):
        count = {}
        for sentence in sentence_list:
            tokens = self.tokenizer(sentence)
            for token in tokens:
                if token not in count:
                    count[token] = 1
                else:
                    count[token] += 1
        
        for token, freq in count.items():
            # print(f"{token}:{freq}")
            if freq>=self.min_freq:
                #if freq>=min_freq then add it to vocab
                self.itos[self.index] = token
                self.stoi[token] = self.index
                self.index += 1
    
    def change_to_nums(self, text):
        tokens = self.tokenizer(text)
        nums = []
        for token in tokens:
            if token in self.stoi:
                nums.append(self.stoi[token])
            else:
                nums.append(self.stoi["unk"])
        return nums

In [29]:
def parse_tokens(filepath):
    descriptions = {}
    with open(filepath, "r") as file:
        lines = file.readlines()
        for line in lines[1:]:
            line = line.strip()
            words = line.split(".jpg")
            img_id = words[0] + ".jpg"
            caption = words[1:]
            cap = " ".join(caption)

            if img_id not in descriptions:
                descriptions[img_id] = []
            descriptions[img_id].append(cap.strip(',"'))
    
    return descriptions

# len(parse_tokens(TOKENS_FILE))

In [30]:
class MyDataset(Dataset):
    def __init__(self, descriptions, vocab, transform=None, train=True):
        self.desc = []
        self.vocab = vocab
        self.transform = transform

        imgs=[]
        caps=[]

        for img_id, captions in descriptions.items():
            img_path = os.path.join(IMAGES_DIR, img_id)
            if not os.path.isfile(img_path):
                continue
            else:
                for caption in captions:
                    imgs.append(img_path)
                    caps.append(caption)

        itrain, itest, ctrain, ctest = train_test_split(imgs, caps, shuffle=True, test_size=0.2, random_state=SEED)

        if train:
            for i, img in enumerate(itrain):
                self.desc.append((img, ctrain[i]))
        else:
            for i, img in enumerate(itest):
                self.desc.append((img, ctest[i]))

    def __len__(self):
        return len(self.desc)

    def __getitem__(self, idx):
        img_path, caption = self.desc[idx]
        
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        num_caption = [self.vocab.stoi["startofseq"]] #the first token is startofseq
        num_caption += self.vocab.change_to_nums(caption) #we are adding as they r lists
        num_caption.append(self.vocab.stoi["endofseq"])

        return img, torch.tensor(num_caption, dtype=torch.long)
        #the above is a custom tokenizer

In [31]:
#this is a helper function for the dataloader
def padding(data):
    #data is what getitem returns so it is a tuple, rows are batched full data
    data.sort(key=lambda x:len(x[1]), reverse=True)
    imgs = [i[0] for i in data]
    captions = [i[1] for i in data]
    lens = [len(i) for i in captions]
    max_len = max(lens)

    padded = torch.zeros(len(captions), max_len, dtype=torch.long) #this is a 2d torch of no.of rows=captions and no.of cols=max_len
    for i, caption in enumerate(captions):
        end = lens[i]
        padded[i, :end] = caption

    imgs = torch.stack(imgs, dim=0) #to stack them as [batch, channel, row, col] where batch is the new dim at 0
    return imgs, padded, lens

In [32]:
class ResNetEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        for param in model.parameters():
            param.requires_grad=True #this is so that we can fine tune the model
        #we need to drop the last linear layer and change it to an embedding 
        modules = list(model.children())[:-1] #dropping the last layer

        self.model = nn.Sequential(*modules)
        self.fc = nn.Linear(model.fc.in_features, embed_dim) #input is same as fc but output is embedding space
        self.batchnorm = nn.BatchNorm1d(embed_dim, momentum=0.01)

    def forward(self, imgs):
        with torch.no_grad():
            features = self.model(imgs)
        #these are as [batch size, model.fc.in_features, 1, 1] therefore we need to flatten
        features = features.view(features.size(0), -1)
        features = self.fc(features) #this means features is [batch, embed dim]
        features = self.batchnorm(features)
        
        return features

In [33]:
class LSTMDecoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim) #it creates a dense vector of size vocab_size and dims of embed_dim
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, features, captions):
        captions_in = captions[:, :-1] #captions is the padded tensor
        embed = self.embeddings(captions_in)
        #lstm needs shape as [batch size, seq length, vocab size]
        features = features.unsqueeze(1) #as features are of shape [batchsize, 2048] so we make [batchsize, 1, 2048] as lstm needs this shape
        lstm_input = torch.cat((features, embed), dim=1) #we are concating these 2 along dim=1(seq length) so final is (seq length+1) there fore first token is image
        outputs, _ = self.lstm(lstm_input)
        logits = self.fc(outputs)
        
        return logits

In [34]:
class FullModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, imgs, captions):
        features = self.encoder(imgs)
        outputs = self.decoder(features, captions)

        return outputs

## Training Pipeline

In [35]:
def train_one_epoch(model, dataloader, criterion, optimizer, vocab_size, epoch):
    model.train()
    total_loss=0
    progress = tqdm(dataloader, desc=f"Epoch {epoch+1}", unit="batch")
    
    start_time = time.time()
    for imgs, captions, _lengths in progress:
        imgs = imgs.to(DEVICE)
        captions = captions.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs, captions)
        outputs = outputs[:, 1:, :].contiguous().view(-1, vocab_size) #we need to remove first token as it is the image
        targets = captions[:, 1:].contiguous().view(-1)
        #outputs = [batch * seqlength , vocab size]
        #targets = [batch * seqlength] , we are removing first as it is startofseq token

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_loss = total_loss/len(dataloader)

    end_time = time.time()
    time_taken = end_time - start_time
    print(f"Run completed. Time taken: {time_taken:.4f}")
    return avg_loss

In [36]:
def validation(model, dataloader, criterion, vocab_size):
    model.eval()
    total_loss=0

    with torch.no_grad():
        for imgs, captions, _lengths in dataloader:
            imgs = imgs.to(DEVICE)
            captions = captions.to(DEVICE)

            outputs = model(imgs, captions)
            outputs = outputs[:, 1:, :].contiguous().view((-1, vocab_size))
            targets = captions[:, 1:].contiguous().view(-1)

            loss = criterion(outputs, targets)
            total_loss += loss.item()
    avg_loss = total_loss/len(dataloader)
    return avg_loss  

In [37]:
EMBED_DIM = 256
HIDDEN_DIM = 512
LEARNING_RATE = 0.001
BATCH_SIZE = 64
EPOCHS = 25
MIN_WORD_FREQ = 1
SEED = 42
DEVICE = torch.device("cpu")
NUM_WORKERS = 4
 
IMAGES_DIR = "data/images"
TOKENS_FILE = "data/captions.txt"

BEST_CHECKPOINT_PATH = "checkpoints/best_checkpoint.pth"
FINAL_MODEL_PATH = "checkpoints/final_model.pth"
VOCAB_PATH = "vocab/vocab.pkl"
 
RESUME = True

In [38]:
if not RESUME: #i.e. if we are starting from scratch
    desc = parse_tokens(TOKENS_FILE)
    all_captions=[]
    # print(desc[list(desc.keys())[0]])
    for captions in desc.values():
        all_captions.extend(captions)
    
    # print(len(all_captions))
    vocab=Vocabulary(min_freq=1)
    vocab.build_vocab(all_captions)
    with open(VOCAB_PATH, "wb") as f:
        dump(vocab, f)
    print(f"Vocabulary saved to {VOCAB_PATH}")

    vocab_size = len(vocab)
    print(f"Vocabulary size: {vocab_size}")

else:
    with open(VOCAB_PATH, "rb") as f:
        vocab = load(f)
    print(f"Vocabulary loaded from {VOCAB_PATH}")
    vocab_size = len(vocab)
    print(f"Vocabulary size: {vocab_size}")    

Vocabulary loaded from vocab/vocab.pkl
Vocabulary size: 8492


In [39]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MyDataset(desc, vocab, transform=transform, train=True)
test_dataset = MyDataset(desc, vocab, transform=transform, train=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=NUM_WORKERS, collate_fn=padding)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=padding)

In [40]:
encoder = ResNetEncoder(EMBED_DIM)
decoder = LSTMDecoder(EMBED_DIM, HIDDEN_DIM, vocab_size)
model = FullModel(encoder, decoder)

In [44]:
total_params = sum(param.numel() for param in model.parameters())
print(f"Total parameters: {total_params}")

total_trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f"Total trainable parameters: {total_trainable_params}")

criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["pad"])
#since the model is join of encoder and decoder and we dont need to optimize params of encoder
params = list(model.decoder.parameters()) + list(model.encoder.fc.parameters()) + list(model.encoder.batchnorm.parameters())
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)

start_epoch=0 #this is cuz if we are resuming training from middle we need to know where to start
best_val_loss=float("inf")

#if we want to resume training
if RESUME and os.path.exists(BEST_CHECKPOINT_PATH):
    print(f"Resuming from checkpoint: {BEST_CHECKPOINT_PATH}")
    checkpoint = torch.load(BEST_CHECKPOINT_PATH, map_location=DEVICE)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    print(f"Resuming at epoch {start_epoch+1}, best_val_loss so far: {best_val_loss:.4f}")

elif RESUME:
    print(f"Warning: {BEST_CHECKPOINT_PATH} not found. Will have to start fresh ...")

Total parameters: 32140396
Total trainable parameters: 32140396
Resuming from checkpoint: checkpoints/best_checkpoint.pth
Resuming at epoch 3, best_val_loss so far: 1.5019


## Training Loop

In [42]:
#just checking...
imgs, captions, lengths = next(iter(train_loader))
imgs = imgs.to(DEVICE)
captions = captions.to(DEVICE)
outputs = model(imgs, captions)
outputs = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
targets = captions[:, 1:].contiguous().view(-1)
loss = criterion(outputs, targets)
loss.backward()

print("Loss:", loss.item())


Loss: 1.187027096748352


In [45]:
try:
    for epoch in range(start_epoch, EPOCHS):#this is cuz if we are resuming training from middle we need to know where to start
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, vocab_size, epoch)
        val_loss = validation(model, test_loader, criterion, vocab_size)

        print(f"[Epoch {epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}\n")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_dict = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "best_val_loss": best_val_loss
            }

            torch.save(checkpoint_dict, BEST_CHECKPOINT_PATH)
            print(f"New best model saved at {BEST_CHECKPOINT_PATH}: {val_loss:.4f}\n")

            final_checkpoint_dict = {
                "model_state_dict": model.state_dict(),
            }
            torch.save(final_checkpoint_dict, FINAL_MODEL_PATH)

except KeyboardInterrupt:
    print(f"\nTraining interrupted --> Best model saved at {BEST_CHECKPOINT_PATH}\n")

print(f"\nFinal model weights at {FINAL_MODEL_PATH}")
print(f"Best val loss: {best_val_loss:.4f}")

Epoch 3: 100%|██████████| 505/505 [1:32:55<00:00, 11.04s/batch, loss=1.4077]

Run completed. Time taken: 5575.7826





[Epoch 3/25] Train Loss: 1.2358 | Val Loss: 1.4714
New best model saved at checkpoints/best_checkpoint.pth: 1.4714


Epoch 4: 100%|██████████| 505/505 [1:36:37<00:00, 11.48s/batch, loss=1.2478]

Run completed. Time taken: 5797.1911





[Epoch 4/25] Train Loss: 1.1275 | Val Loss: 1.4488
New best model saved at checkpoints/best_checkpoint.pth: 1.4488


Epoch 5: 100%|██████████| 505/505 [1:33:47<00:00, 11.14s/batch, loss=0.9549]

Run completed. Time taken: 5627.3400





[Epoch 5/25] Train Loss: 1.0239 | Val Loss: 1.4692


Epoch 6: 100%|██████████| 505/505 [1:31:59<00:00, 10.93s/batch, loss=1.0110]

Run completed. Time taken: 5519.6562





[Epoch 6/25] Train Loss: 0.9385 | Val Loss: 1.4815


Epoch 7: 100%|██████████| 505/505 [1:34:32<00:00, 11.23s/batch, loss=0.8409]

Run completed. Time taken: 5672.4881





[Epoch 7/25] Train Loss: 0.8551 | Val Loss: 1.4995


Epoch 8: 100%|██████████| 505/505 [1:36:19<00:00, 11.44s/batch, loss=0.8715]


Run completed. Time taken: 5779.7256
[Epoch 8/25] Train Loss: 0.7793 | Val Loss: 1.4961


Epoch 9: 100%|██████████| 505/505 [1:33:05<00:00, 11.06s/batch, loss=0.5183]

Run completed. Time taken: 5585.8339





[Epoch 9/25] Train Loss: 0.7164 | Val Loss: 1.5412


Epoch 10: 100%|██████████| 505/505 [1:33:27<00:00, 11.10s/batch, loss=0.8026]

Run completed. Time taken: 5607.7127





[Epoch 10/25] Train Loss: 0.6527 | Val Loss: 1.5633


Epoch 11:  30%|██▉       | 151/505 [27:47<1:05:09, 11.04s/batch, loss=0.6807]


Training interrupted --> Best model saved at checkpoints/best_checkpoint.pth

Final model weights at checkpoints/final_model.pth
Best val loss: 1.4488



