In [1]:
import os
import torch
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
from torchtext.data.metrics import bleu_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
TRAIN_RATIO = 0.7
TEST_RATIO = 0.15
VAL_RATIO = 0.15

BATCH_SIZE = 32
WORKERS = 4
LEARNING_RATE=0.01
EMBED_SIZE = 256
HIDDEN_SIZE = 512
NUM_LAYERS = 4
EPOCHS = 30


In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN

        # inception = models.inception_v3(weights=True)
        resnet101 = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        modules = list(resnet101.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.adaptive_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, embed_size)
        self.dropout = nn.Dropout(p=0.5)
        self.batchnorm = nn.BatchNorm1d(embed_size, momentum=0.01)
        self.fine_tune()

    def forward(self, images):
        out = self.resnet(images)
        out = self.adaptive_pool(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)
        return out
    
    def fine_tune(self, fine_tune=True):

        for p in self.resnet.parameters():
            p.requires_grad = False
            
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

        

In [5]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=40):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


In [6]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size).to(device)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs
    
    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []
        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0) # so that we have a dimention for batch
            states = None
            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1) # take the word with the highest probability

                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                
        return [vocabulary.itos[idx] for idx in result_caption]

In [7]:
def custom_tokenizer(text):
    patterns = [
        r"\w+",
        r"\d+", 
        r"\S+" 
    ]
    pattern = "|".join(patterns)
    
    # Use the regex pattern to tokenize the text
    tokens = re.findall(pattern, text)
    
    tokens = [token.lower() for token in tokens]
    
    return tokens


In [8]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold # Minimum frequency for a word to be included in the vocabulary

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.lower() for tok in custom_tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def caption_len(self,text):
        tokenized_text = self.tokenizer_eng(text)
        return len(tokenized_text)
    

class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=3):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)
        caplen = self.vocab.caption_len(caption)
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption), torch.LongTensor([caplen])

In [9]:
class SelfCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        caplen = [item[2] for item in batch]

        return imgs, targets, caplen

def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=BATCH_SIZE,
    num_workers=WORKERS,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    total_samples = len(dataset)

    pad_idx = dataset.vocab.stoi["<PAD>"]

    train_size = int(TRAIN_RATIO * total_samples)
    val_size = int(TEST_RATIO * total_samples)
    test_size = total_samples - (train_size + val_size)

    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=SelfCollate(pad_idx=pad_idx),
    )
    
    val_loader = DataLoader(
        dataset=val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=pin_memory,
        collate_fn=SelfCollate(pad_idx=pad_idx),
    )
    
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=8,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=pin_memory,
        collate_fn=SelfCollate(pad_idx=pad_idx),
    )

    return train_loader, val_loader, test_loader, dataset

In [10]:
# main

transform = transforms.Compose([
        transforms.Resize((356, 356)),
        transforms.RandomCrop((224, 224)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader, dataset = get_loader(
        "../Data/Images/", "../Data/captions.txt", transform=transform, num_workers=WORKERS
    )
vocab_size = len(dataset.vocab) # vocabulary size


In [11]:
model = CNNtoRNN(embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, num_layers=NUM_LAYERS).to(device)

In [12]:
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi['<PAD>']).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0.0001)


In [14]:

for images, captions, caplen in tqdm(train_loader, desc="Train\t"):
    images = images.to(device)
    captions = captions.to(device)
    print(captions[:-1])
    output = model(images, captions) 
    print(output.shape)
    # loss = criterion(output.reshape(-1, output.shape[2]), captions.reshape(-1))
    ref = captions.permute(1,0)
    output_prob = output.permute(1,0,2)
    cand = torch.argmax(output_prob, dim=2)
    bleu_scores = []
    for i in range(32):  # Assuming batch size is 32
        candidate_sentence = [dataset.vocab.itos[word.item()] for word in cand[i]]
        reference_sentence = [dataset.vocab.itos[word.item()] for word in ref[i]]
        bleu = bleu_score([candidate_sentence], [[reference_sentence]], max_n=4, weights=[0.20, 0.25, 0.35, 0.20])
        bleu_scores.append(bleu)

    # Compute the average BLEU score for the batch
    average_bleu = sum(bleu_scores) / len(bleu_scores)
    print(f'BLEU Score: {average_bleu * 100:.2f}')


    break


Train	:   0%|          | 0/885 [00:00<?, ?it/s]

tensor([[   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1],
        [   4,    4,    4,  229,  125,    4,    4, 1711,  123,   47,    4,   30,
            4,    4,    4,   15,    4,    4,   84,    4,   66,    4,    4,    4,
           24,   25,   84,    4,    4,   84,    4,  392],
        [  12,    6,   25,   63, 1541,   71,   92,  586,    7,    9,   92,  612,
           25,  244,  123,   25,  123,  302,   49,  147,   12,   25, 3775,  351,
          309,   95,   96,   47,   12,   37,   25,   63],
        [ 288,  992,   14,   95,  167,  106,    9,   18,    4,   76,  288,   25,
          113,   36,  449,   23,   14,  309,  392,   19,   23,  277,  537,    9,
           27,   18, 1155,    9, 1638,   96, 1145,   43],
        [  18,  150, 1622,    9,   45,  455,  574,    4,   57,  536,   72,   23,
           18,    9,   20,  805,    4, 

Train	:   0%|          | 0/885 [00:00<?, ?it/s]


In [15]:
def train(model, train_loader, epochs, criterion, optimizer, scheduler):
    train_loss = []
    all_bleus = []
    best_bleu = 0.0
    for epoch in range(epochs):
        total_train = 0
        running_train_loss = 0.0
        print(f'Epoch: {epoch +1}')
        model.train()
        for images, captions, caplen in tqdm(train_loader, desc="Train\t"):
            images = images.to(device)
            captions = captions.to(device)
           
            output = model(images, captions[:-1]) 
            loss = criterion(output.reshape(-1, output.shape[2]), captions.reshape(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
            total_train +=  1

        scheduler.step()

        model.eval()
        total_val = 0
        running_val_loss = 0.0
        epoch_bleu = []
        with torch.no_grad():
            for images, captions, _ in tqdm(val_loader, desc="Validate\t"):
                images = images.to(device)
                captions = captions.to(device)

                output = model(images, captions[:-1])

                ref = captions.permute(1,0)
                output_prob = output.permute(1,0,2)
                cand = torch.argmax(output_prob, dim=2)
                batch_bleu = []
                for i in range(cand.size(0)):  
                    candidate_sentence = [dataset.vocab.itos[word.item()] for word in cand[i]]
                    reference_sentence = [dataset.vocab.itos[word.item()] for word in ref[i]]
                    bleu = bleu_score([candidate_sentence], [[reference_sentence]], max_n=4, weights=[0.20, 0.25, 0.35, 0.20])
                    batch_bleu.append(bleu)

                running_val_loss += loss.item()
                total_val += 1
                epoch_bleu.append(sum(batch_bleu) / len(batch_bleu))

        if epoch_bleu[-1] > best_bleu:
            best_bleu = epoch_bleu[-1]
            torch.save(model.state_dict(), '../models/model.pth')

        print(f'Train Loss: {(running_train_loss/total_train):.4f}, Validation Loss: {(running_val_loss/total_val):.4f}, BLEU Score: {(sum(epoch_bleu) / len(epoch_bleu))* 100:.2f}\n')



In [16]:
train(model, train_loader, epochs=EPOCHS, criterion=criterion, optimizer=optimizer, scheduler=scheduler)
print("Training Complete!!")

Epoch: 1


Train	: 100%|██████████| 885/885 [01:42<00:00,  8.67it/s]
Validate	:   0%|          | 0/190 [00:00<?, ?it/s]


UnboundLocalError: local variable 'total' referenced before assignment

In [None]:
import matplotlib.pyplot as plt

torch.Size([32, 4135])


RuntimeError: a Tensor with 132320 elements cannot be converted to Scalar

In [None]:
def denormalize(image):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    denormalized_image = image * std + mean
    return denormalized_image

In [None]:
def display_test(image, target_caption):
    image = denormalize(image)


images, targets, caplen = next(iter(test_loader))
print(image.shape)

torch.Size([32, 3, 224, 224])
