# Importing Libraries

In [None]:
!pip install torchsummary

import os
import math
import spacy
import torch
import numpy as np
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
from torch import nn
from torchsummary import summary
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.models import vgg16, VGG16_Weights
from nltk.translate.bleu_score import sentence_bleu

spacy_eng = spacy.load("en_core_web_sm")

### Setting the Device

In [None]:
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Device:', device)

# Preparing the Flikr30k Dataset
### Preparing the Vocabulary

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k,v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
  
    def build_vocab(self, sentence_list):
        freqs = {}
        idx = 4

        for sentence in sentence_list:
            sentence = str(sentence)

            for word in self.tokenize(sentence):
                if word not in freqs:
                    freqs[word] = 1
                    
                else:
                    freqs[word] += 1

                if freqs[word] == self.freq_threshold:
                    self.itos[idx] = word
                    self.stoi[word] = idx
                    
                    idx += 1

    def numericalize(self, sentence):
        tokens = self.tokenize(sentence)
        result = []

        for token in tokens:
            if token in self.stoi:
                result.append(self.stoi[token])
            else:
                result.append(self.stoi["<UNK>"])

        return result
    
    @staticmethod
    def tokenize(sentence):
        return [token.text.lower() for token in spacy_eng.tokenizer(str(sentence))]

### Defining a Custom Dataset

In [4]:
class Flickr(Dataset):
    def __init__(self, root_dir, caption_path, transform, freq_threshold=5):
        self.freq_threshold = freq_threshold
        self.transform = transform
        self.root_dir = root_dir
    
        self.df = pd.read_csv(caption_path, delimiter='|')
        
        self.images = self.df['image_name']
        self.captions = self.df[' comment']
        
        self.vocab = Vocabulary(freq_threshold)
        
        self.vocab.build_vocab(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        max_seq_length = 50
        
        image = self.images[index]
        caption = self.captions[index]
        
        image = Image.open(os.path.join(self.root_dir, image)).convert("RGB")
        
        image = self.transform(image)
        
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        
        numericalized_caption += self.vocab.numericalize(caption)
        
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        if len(numericalized_caption) > max_seq_length:
            numericalized_caption = numericalized_caption[:max_seq_length]
        else:
            numericalized_caption += [self.vocab.stoi["<PAD>"]] * (max_seq_length - len(numericalized_caption))
        
        return image, torch.tensor(numericalized_caption)
    
    def get_label(self, index):
        image, caption = self[index]
    
        label = [self.vocab.itos[token] for token in caption.tolist()]

        eos_index = label.index('<EOS>')

        label = label[1: eos_index]

        return ' '.join(label)
    
    def to_list(self):
        return self.captions.tolist()
        

### Defining a Custom Caption Collat for Padding

In [5]:
class CapCollat:
    def __init__(self, pad_seq, batch_first=False):
        self.pad_seq = pad_seq
        self.batch_first = batch_first
  
    def __call__(self, batch):
        imgs = [itm[0].unsqueeze(0) for itm in batch]
        imgs = torch.cat(imgs, dim=0)

        target_caps = [itm[1] for itm in batch]
        target_caps = pad_sequence(target_caps, batch_first=self.batch_first,
                                   padding_value=self.pad_seq)
        
        return imgs, target_caps

### Loading and Testing the Dataset 

In [6]:
root_folder = "/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/"
csv_file = "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv"

transform = T.Compose([
        T.Resize((256, 256), interpolation=T.InterpolationMode.BILINEAR),
        T.CenterCrop((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

batch_size = 37
num_workers = 2
freq_threshold = 5
batch_first = True
pin_memory = True
dataset = Flickr(root_folder, csv_file, transform, freq_threshold)
pad_idx = dataset.vocab.stoi["<PAD>"]

data_size = len(dataset)
train_size = int(0.9 * data_size)
val_size = data_size - train_size

train_set, val_set = torch.utils.data.Subset(dataset, range(0, train_size)), torch.utils.data.Subset(dataset, range(train_size, data_size))

train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=True,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=False,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

In [7]:
val_set_start = data_size - val_size - 1
val_set_end = data_size - 1

all_labels = dataset.to_list()

In [None]:
for idx in range(0, 100, 10):
    image, _ = dataset[idx + val_size]
    
    label = all_labels[idx + val_size]

    image = image.permute(1,2,0)
    
    plt.imshow(image)
    plt.title(label)

    plt.show()

# Pre-Trained CNN Encoder: VGG16 

In [21]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size=256):
        super(EncoderCNN, self).__init__()

        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES)
        
        for param in vgg.parameters():
            param.requires_grad = False
            
        feature_extractor = list(vgg.children())[:-1]
            
        embedding_layer = nn.Linear(512 * 7 * 7, embed_size)

        self.encoder = nn.Sequential(*feature_extractor,
                                 nn.Flatten(),
                                 embedding_layer)
       
    def forward(self, image):
        encoded_image = self.encoder(image)
  
        return encoded_image

In [None]:
summary(EncoderCNN(256), (3, 224, 224))

# RNN Decoder: Vanilla RNN Module

In [11]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers):
        super(DecoderRNN, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

        self.features2hidden = nn.Linear(embed_size, hidden_size)

    def forward(self, features, captions):
        captions_embed = self.embedding(captions)

        initial_hidden_state = self.features2hidden(features).unsqueeze(0).repeat(self.rnn.num_layers, 1, 1)

        output, _ = self.rnn(captions_embed, initial_hidden_state)
        output = self.linear(output)

        return output


In [68]:
class ImageCap(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers):
        super(ImageCap, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, vocab_size, hidden_size, num_layers)
    
    def forward(self, images, captions):
        x = self.encoderCNN(images)
        x = self.decoderRNN(x, captions)
        
        return x
    
    def caption(self, image, vocabulary, maxlength=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(maxlength):
                hiddens, states = self.decoderRNN.rnn(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embedding(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[i] for i in result_caption]


# Training and Evaluating the Image Captioning Model: VGG16 + Multi-layer Vanilla RNN

### Setting Hyperparameters

In [47]:
num_epochs = 4
enc_dim = 2048
embed_size = 224
hidden_size = 512
num_layers = 1
learning_rate = 3e-4

### Setting the Vocabulary

In [48]:
vocab = dataset.vocab
vocab_size = len(vocab)

### Configuring Models

In [69]:
criterion = nn.CrossEntropyLoss(ignore_index = vocab.stoi["<PAD>"]).to(device)

model = ImageCap(embed_size, vocab_size, hidden_size, num_layers).to(device)

optimizer = optim.Adam(model.parameters(), lr = learning_rate)

### Training Models

In [None]:
!pip install pycocoevalcap

In [18]:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

class Scorer():
    def __init__(self,ref,gt):
        self.ref = ref
        self.gt = gt

        self.word_based_scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Meteor(),"METEOR"),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr"),
            ]

    def compute_scores(self):
        total_scores = {
            "Bleu1":[],
            "Bleu2":[],
            "Bleu3":[],
            "Bleu4":[],
            "METEOR":[],
            "ROUGE_L":[],
            "CIDEr":[],
        }

        for scorer, method in self.word_based_scorers:
            score, scores = scorer.compute_score(self.ref, self.gt)
    
            if type(method) == list:
                total_scores["Bleu1"].append(score[0])
                total_scores["Bleu2"].append(score[1])
                total_scores["Bleu3"].append(score[2])
                total_scores["Bleu4"].append(score[3])

            else:
                total_scores[method].append(score)

        return total_scores
    
    def compute_scores_iterative(self):
        total_scores = {
            "Bleu1":[],
            "Bleu2":[],
            "Bleu3":[],
            "Bleu4":[],
            "METEOR":[],
            "ROUGE_L":[],
            "CIDEr":[],
            "SPICE":[]
        
        }

        for key in self.ref:
            curr_ref = {key:self.ref[key]}
            curr_gt = {key:self.gt[key]}

            for scorer, method in self.word_based_scorers:
                score, _ = scorer.compute_score(curr_ref, curr_gt)
                if type(method) == list:
                    total_scores["Bleu1"].append(score[0])
                    total_scores["Bleu2"].append(score[1])
                    total_scores["Bleu3"].append(score[2])
                    total_scores["Bleu4"].append(score[3])

                else:
                    total_scores[method].append(score)

        return total_scores

In [None]:
train_losses = list()
val_losses = list()

cumulative_bleu_scores = list()

cider_scores = list()

meteor_scores = list()

rougel_scores = list()

bleu1_scores = list()
bleu2_scores = list()
bleu3_scores = list()
bleu4_scores = list()

val_iter = iter(val_loader)

for epoch in range(num_epochs):
    model.train()

    for batch_idx, (images, captions) in enumerate(train_loader):
        images = images.to(device)
        captions = captions.to(device)
        
        train_score = model(images, captions)

        optimizer.zero_grad()
        
        train_loss = criterion(train_score.view(-1, vocab_size), captions.view(-1))
        train_losses.append(train_loss.item())
        
        train_loss.backward()
        
        optimizer.step()

    with torch.no_grad():
        model.eval()

        idx = val_set_start
        
        refs = {}
        hyps = {}

        while idx <= val_set_end:
            if idx % batch_size == 0:
                batch = next(val_iter)

            images, captions = batch

            images = images.to_device()
            captions = captions.to_device()

            val_score = model(images, captions)

            val_loss = criterion(val_score.view(-1, vocab_size), captions.view(-1))

            val_losses.append(val_loss.item())
            
            for image in images:
                val_pred = model.caption(image.unsqueeze(0), vocab)
                
                candidate = ' '.join(val_pred)

                hyps[idx] = candidate

                label = all_labels[idx]

                refs[idx] = label

                cumulative_bleu_score = sentence_bleu(label, candidate, weights=(0.25, 0.25, 0.25, 0.25))
                cumulative_bleu_scores.append(cumulative_bleu_score)
            
                idx += 1
            
        metrics = Scorer(refs, hyps).compute_scores()
    
        bleu1_scores.append(metrics['Bleu_1'])
        bleu2_scores.append(metrics['Bleu_2'])
        bleu3_scores.append(metrics['Bleu_3'])
        bleu4_scores.append(metrics['Bleu_4'])

        cider_scores.append(metrics['CIDEr'])

        rougel_scores.append(metrics['ROUGE_L'])

        meteor_scores.append(metrics['METEOR'])
        
        print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] | Training loss: {train_loss.item()} Validation loss: {val_loss.item()} | Cumulative BLEU Score: {cumulative_bleu_score} CIDEr Score: {metrics['CIDEr']} METEOR Score: {metrics['METEOR']} ROUGE_L: {metrics['ROUGE_L']} \n")
        
        torch.save(model.state_dict(), f'/kaggle/working/ImageCap_{epoch+1}.pth')


### Plotting Losses and Scores

In [None]:
plt.figure(0)
plt.plot(train_losses, label = 'Training loss')
plt.plot(val_losses, label = 'Validation loss')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_losses.png')

plt.figure(1)
plt.plot(bleu1_scores, label = 'BLEU 1')
plt.plot(bleu2_scores, label = 'BLEU 2')
plt.plot(bleu3_scores, label = 'BLEU 3')
plt.plot(bleu4_scores, label = 'BLEU 4')
plt.ylabel('BLEU Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_bleu_scores.png')
        
plt.figure(2)
plt.plot(cumulative_bleu_scores, label = 'Cumulative BLEU SCORE')
plt.plot(cider_scores, label = 'CIDEr SCORE')
plt.plot(meteor_scores, label = 'METEOR SCORE')
plt.plot(rougel_scores, label = 'ROUGE_L SCORE')
plt.ylabel('Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_scores.png')