# Importing Libraries

In [1]:
import os
import math
import spacy
import torch
import numpy as np
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.models import vgg16, VGG16_Weights
from nltk.translate.bleu_score import sentence_bleu

spacy_eng = spacy.load("en_core_web_sm")

# Preparing the Flikr30k Dataset
### Preparing the Vocabulary

In [2]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k,v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
  
    def build_vocab(self, sentence_list):
        freqs = {}
        idx = 4

        for sentence in sentence_list:
            sentence = str(sentence)

            for word in self.tokenize(sentence):
                if word not in freqs:
                    freqs[word] = 1
                    
                else:
                    freqs[word] += 1

                if freqs[word] == self.freq_threshold:
                    self.itos[idx] = word
                    self.stoi[word] = idx
                    idx += 1

    def numericalize(self, sentence):
        tokens = self.tokenize(sentence)
        result = []

        for token in tokens:
            if token in self.stoi:
                result.append(self.stoi[token])
            else:
                result.append(self.stoi["<UNK>"])

        return result
    
    @staticmethod
    def tokenize(sentence):
        return [token.text.lower() for token in spacy_eng.tokenizer(sentence)]

### Defining a Custom Dataset

In [3]:
class Flickr(Dataset):
    def __init__(self, root_dir, caps, transforms=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(caps, delimiter='|')
        self.transforms = transforms

        self.img_pts = self.df['image_name']
        self.caps = self.df[' comment']
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.caps.tolist())

    def __len__(self):
        return len(self.df)
  
    def __getitem__(self, idx):
        captions = self.caps[idx]
        img_pt = self.img_pts[idx]
        img = Image.open(os.path.join(self.root_dir, img_pt)).convert('RGB')

        if self.transforms is not None:
            img = self.transforms(img)

        numberized_caps = []
        numberized_caps += [self.vocab.stoi["<SOS>"]]
        numberized_caps += self.vocab.numericalize(captions)
        numberized_caps += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(numberized_caps)

### Defining a Custom Caption Collat for Padding

In [4]:
class CapCollat:
    def __init__(self, pad_seq, batch_first=False):
        self.pad_seq = pad_seq
        self.batch_first = batch_first
  
    def __call__(self, batch):
        imgs = [itm[0].unsqueeze(0) for itm in batch]
        imgs = torch.cat(imgs, dim=0)

        target_caps = [itm[1] for itm in batch]
        target_caps = pad_sequence(target_caps, batch_first=self.batch_first,
                                   padding_value=self.pad_seq)
        return imgs, target_caps

### Loading and Testing the Dataset 

In [70]:
root_folder = "/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/"
csv_file = "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv"

transforms = T.Compose([
                        T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                       ])
batch_size = 32
num_workers = 2
batch_first = True
pin_memory = True
dataset = Flickr(root_folder, csv_file, transforms)
pad_idx = dataset.vocab.stoi["<PAD>"]

data_size = len(dataset)
train_size = int(0.9 * data_size)
val_size = data_size - train_size

train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])

dataset_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=True,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=False,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

In [None]:
val_set_start = data_size - val_size - 1
val_set_end = data_size - 1

all_labels = list()
dict_tokens = dict()

idx = 0
dataitr = iter(dataset_loader)

while idx < data_size:
    images, captions = next(dataitr)
    
    for caption in captions:
        caption_label = [dataset.vocab.itos[token] for token in caption.tolist()]

        eos_index = caption_label.index('<EOS>')

        caption_label = caption_label[1: eos_index]

        caption_label = ' '.join(caption_label)
        
        dict_tokens[idx] = [caption_label]

        all_labels.append(caption_label)
        
        idx += 1

In [None]:
dataitr = iter(val_loader)
batch = next(dataitr)

images, captions = batch

for i in range(batch_size):
    img, cap = images[i], captions[i]
    
    caption_label = [dataset.vocab.itos[token] for token in cap.tolist()]
    
    eos_index = caption_label.index('<EOS>')

    caption_label = caption_label[1:eos_index]
    
    caption_label = ' '.join(caption_label)

    img = img.permute(1,2,0)
    plt.imshow(img)
    plt.title(caption_label)

    plt.show()

# Pre-Trained CNN Encoder: VGG16 

In [7]:
class EncoderCNN(nn.Module):
    def __init__(self, enc_dim, hidden_size):
        super(EncoderCNN, self).__init__()

        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES)
        all_modules = list(vgg.children())
        modules = all_modules[:-2]

        self.vgg = nn.Sequential(*modules)

        self.V_affine = nn.Linear(enc_dim, hidden_size)

        self.disable_learning()
       
    def forward(self, image):
        encoded_image = self.vgg(image)

        batch_size = encoded_image.shape[0]
        features = encoded_image.shape[1]
        num_pixels = encoded_image.shape[2] * encoded_image.shape[3]

        enc_image = encoded_image.permute(0, 2, 3, 1)  
        enc_image = enc_image.view(batch_size, num_pixels, features)
        enc_image = F.relu(self.V_affine(enc_image))
  
        return enc_image

    def disable_learning(self):
        for param in self.vgg.parameters():
            param.requires_grad = False

# RNN Decoder: Vanilla RNN Module

In [8]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size,vocab_size, hidden_size, num_layers):
        super(DecoderRNN, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, caption):
        embeddings = self.dropout(self.embedding(caption))
        embeddings = torch.cat((features.unsqueeze(0),embeddings), dim=0)
        hiddens, _ = self.rnn(embeddings)
        outputs = self.linear(hiddens)

        return outputs

In [None]:
class Hybrid(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers):
        super(Hybrid, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, vocab_size, hidden_size, num_layers)
    
    def forward(self, images, caption):
        x = self.encoderCNN(images)
        x = self.decoderRNN(x, caption)
        return x
    
    def captionImage(self, image, vocabulary, maxlength=50):
        result_caption = []
        
        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None
            
            for _ in range(maxlength):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                print(predicted.shape)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embedding(output).unsqueeze(0)
                
                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                
        return [vocabulary.itos[i] for i in result_caption]

# Training and Evaluating the Image Captioning Model: VGG16 + Multi-layer Vanilla RNN

### Setting the Device to GPU

In [9]:
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Device:', device)

Device: cpu


### Setting Hyperparameters

In [10]:
num_epochs = 10
freq_threshold = 5
enc_dim = 2048
embed_size = 300
hidden_size = 512
num_layers = 2
learning_rate = 3e-4

### Building the Vocabulary

In [11]:
vocab = Vocabulary(freq_threshold)
vocab.build_vocab(all_labels)
vocab_size = len(vocab)

### Configuring Models

In [None]:

criterion = nn.CrossEntropyLoss(ignore_index = vocab.stoi["<PAD>"])

model = Hybrid(embed_size, hidden_size,vocab_size, num_layers).to(device=device)

optimizer = optim.Adam(model.parameters(), lr = learning_rate)

### Training Models

In [None]:
!pip install pycocoevalcap

In [14]:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

class Scorer():
    def __init__(self,ref,gt):
        self.ref = ref
        self.gt = gt

        self.word_based_scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr"),
            ]

    def compute_scores(self):
        total_scores = {
            "Bleu1":[],
            "Bleu2":[],
            "Bleu3":[],
            "Bleu4":[],
            "METEOR":[],
            "ROUGE_L":[],
            "CIDEr":[],
        }

        for scorer, method in self.word_based_scorers:
            score, scores = scorer.compute_score(self.ref, self.gt)
    
            if type(method) == list:
                total_scores["Bleu1"].append(score[0])
                total_scores["Bleu2"].append(score[1])
                total_scores["Bleu3"].append(score[2])
                total_scores["Bleu4"].append(score[3])

            else:
                total_scores[method].append(score)

        return total_scores
    
    def compute_scores_iterative(self):
        total_scores = {
            "Bleu1":[],
            "Bleu2":[],
            "Bleu3":[],
            "Bleu4":[],
            "METEOR":[],
            "ROUGE_L":[],
            "CIDEr":[],
            "SPICE":[]
        
        }

        for key in self.ref:
            curr_ref = {key:self.ref[key]}
            curr_gt = {key:self.gt[key]}

            for scorer, method in self.word_based_scorers:
                score, _ = scorer.compute_score(curr_ref, curr_gt)
                if type(method) == list:
                    total_scores["Bleu1"].append(score[0])
                    total_scores["Bleu2"].append(score[1])
                    total_scores["Bleu3"].append(score[2])
                    total_scores["Bleu4"].append(score[3])

                else:
                    total_scores[method].append(score)

        return total_scores

In [None]:
val_iter = iter(val_loader)

train_losses = list()
val_losses = list()

cumulative_bleu_scores = list()

cider_scores = list()

meteor_scores = list()

rougel_scores = list()

bleu1_scores = list()
bleu2_scores = list()
bleu3_scores = list()
bleu4_scores = list()

for epoch in range(num_epochs):
    model.train()

    for batch_idx, (imgs, captions) in enumerate(train_loader):
        imgs = imgs.to(device)
        captions = captions.to(device)
        
        train_score = model(imgs, captions[:-1])

        optimizer.zero_grad()
        train_loss = criterion(train_score.reshape(-1, train_score.shape[2]), captions.reshape(-1))
        train_losses.append(train_loss.item())
        
        train_loss.backward()
        optimizer.step()

    with torch.no_grad():
        model.eval()

        idx = val_set_start
        dataitr = iter(val_loader)
        batch = next(dataitr)

        hyp = {}
        refs = {}

        while idx <= val_set_end:
            val_imgs, val_captions = batch

            img, cap = val_imgs[idx % 32], val_captions[idx % 32]

            caption_label = [dataset.vocab.itos[token] for token in cap.tolist()]

            eos_index = caption_label.index('<EOS>')

            caption_label = caption_label[1: eos_index]

            caption_label = ' '.join(caption_label)

            img = img.to_device()
            cap = cap.to_device()

            val_score = model(img, cap[: -1])

            val_loss = criterion(val_score.reshape(-1, val_score.shape[2]), val_captions[1:].reshape(-1))

            val_losses.append(val_loss.item())

            val_pred = model.caption_image(img, vocab)

            hyp[idx] = [' '.join(val_pred)]

            refs_token = dict_tokens[idx]

            refs[idx] = refs_token

            cumulative_bleu_score = sentence_bleu(refs_token, val_pred, weights=(0.25, 0.25, 0.25, 0.25))
            cumulative_bleu_scores.append(cumulative_bleu_score)

            if idx % 32 == 0:
                batch = next(dataitr)
            
            idx += 1
            
        metrics = Scorer(refs, hyp).compute_scores()
    
        bleu1_scores.append(metrics['Bleu_1'])
        bleu2_scores.append(metrics['Bleu_2'])
        bleu3_scores.append(metrics['Bleu_3'])
        bleu4_scores.append(metrics['Bleu_4'])

        cider_scores.append(metrics['CIDEr'])

        rougel_scores.append(metrics['ROUGE_L'])

        meteor_scores.append(metrics['METEOR'])
        
        print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] | Training loss: {train_loss.item()} Validation loss: {val_loss.item()} | Cumulative BLEU Score: {cumulative_bleu_score} CIDEr Score: {metrics['CIDEr']} METEOR Score: {metrics['METEOR']} ROUGE_L: {metrics['ROUGE_L']} \n")

# if batch_idx == (len(train_loader) - 1):
#     torch.save(model.state_dict(), f'/kaggle/working/encoder_{freq_threshold}_{batch_size}_{hidden_size}_{epoch+1}.pth')


### Plotting Losses and Scores

In [None]:
plt.figure(0)
plt.plot(train_losses, label = 'Training loss')
plt.plot(val_losses, label = 'Validation loss')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_losses.png')

plt.figure(1)
plt.plot(bleu1_scores, label = 'BLEU 1')
plt.plot(bleu2_scores, label = 'BLEU 2')
plt.plot(bleu3_scores, label = 'BLEU 3')
plt.plot(bleu4_scores, label = 'BLEU 4')
plt.ylabel('BLEU Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_bleu_scores.png')
        
plt.figure(2)
plt.plot(cumulative_bleu_scores, label = 'Cumulative BLEU SCORE')
plt.plot(cider_scores, label = 'CIDEr SCORE')
plt.plot(meteor_scores, label = 'METEOR SCORE')
plt.plot(rougel_scores, label = 'ROUGE_L SCORE')
plt.ylabel('Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_scores.png')