# Importing Libraries

In [1]:
import os
import math
import spacy
import torch
import numpy as np
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchvision.models import vgg16, VGG16_Weights
from nltk.translate.bleu_score import sentence_bleu

spacy_eng = spacy.load("en_core_web_sm")



# Preparing the Flikr30k Dataset
### Preparing the Vocabulary

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k,v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
  
    def build_vocab(self, sentence_list):
        freqs = {}
        idx = 4

        for sentence in sentence_list:
            sentence = str(sentence)

            for word in self.tokenize(sentence):
                if word not in freqs:
                    freqs[word] = 1
                    
                else:
                    freqs[word] += 1

                if freqs[word] == self.freq_threshold:
                    self.itos[idx] = word
                    self.stoi[word] = idx
                    idx += 1

    def numericalize(self, sentence):
        tokens = self.tokenize(sentence)
        result = []

        for token in tokens:
            if token in self.stoi:
                result.append(self.stoi[token])
            else:
                result.append(self.stoi["<UNK>"])

        return result
    
    @staticmethod
    def tokenize(sentence):
        return [token.text.lower() for token in spacy_eng.tokenizer(sentence)]

### Defining a Custom Dataset

In [4]:
class Flickr(Dataset):
    def __init__(self, root_dir, caps, transforms=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(caps, delimiter='|')
        self.transforms = transforms

        self.img_pts = self.df['image_name']
        self.caps = self.df[' comment']
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.caps.tolist())

    def __len__(self):
        return len(self.df)
  
    def __getitem__(self, idx):
        captions = self.caps[idx]
        img_pt = self.img_pts[idx]
        img = Image.open(os.path.join(self.root_dir, img_pt)).convert('RGB')

        if self.transforms is not None:
            img = self.transforms(img)

        numberized_caps = []
        numberized_caps += [self.vocab.stoi["<SOS>"]]
        numberized_caps += self.vocab.numericalize(captions)
        numberized_caps += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(numberized_caps)

### Defining a Custom Caption Collat for Padding

In [5]:
class CapCollat:
    def __init__(self, pad_seq, batch_first=False):
        self.pad_seq = pad_seq
        self.batch_first = batch_first
  
    def __call__(self, batch):
        imgs = [itm[0].unsqueeze(0) for itm in batch]
        imgs = torch.cat(imgs, dim=0)

        target_caps = [itm[1] for itm in batch]
        target_caps = pad_sequence(target_caps, batch_first=self.batch_first,
                                   padding_value=self.pad_seq)
        return imgs, target_caps

### Loading and Testing the Dataset 

In [7]:
batch_size = 32
root_folder = "/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/"
csv_file = "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv"

transforms = T.Compose([
                        T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                       ])

num_workers = 2
batch_first = True
pin_memory = True
dataset = Flickr(root_folder, csv_file, transforms)
pad_idx = dataset.vocab.stoi["<PAD>"]

train_size = int(0.9 * len(dataset))
val_size = int(0.05 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=True,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=False,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

test_loader = DataLoader(test_set,
                            batch_size=batch_size,
                            pin_memory=pin_memory,
                            num_workers=num_workers,
                            shuffle=False,
                            collate_fn=CapCollat(pad_seq=pad_idx, batch_first=batch_first))

In [None]:
all_labels = list()

for img in dataset['images']:   
    for sentence in img['sentences']:
        all_labels.append(sentence['tokens'])

dict_tokens = dict()
imgs_idx = dict()
idx_imgs = dict()

for idx, img in enumerate(dataset['images']):
    dict_tokens[idx] = [sentence['tokens'] for sentence in img['sentences']]   
    imgs_idx[img['filename']] = idx
    idx_imgs[idx] = img['filename']

In [None]:
dataitr = iter(test_loader)
batch = next(dataitr)

images, captions = batch

for i in range(batch_size):
    img, cap = images[i], captions[i]
    
    caption_label = [dataset.vocab.itos[token] for token in cap.tolist()]
    eos_index = caption_label.index('<EOS>')

    caption_label = caption_label[1:eos_index]
    caption_label = ' '.join(caption_label)

    img = img.permute(1,2,0)
    plt.imshow(img)
    plt.title(caption_label)

    plt.show()

# Pre-Trained CNN Encoder: VGG16 

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, enc_dim, embed_size, hidden_size):
        super(EncoderCNN, self).__init__()

        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES)
        all_modules = list(vgg.children())
        modules = all_modules[:-2]

        self.vgg = nn.Sequential(*modules) 
        self.avgpool = nn.AvgPool2d(7)
        self.V_affine = nn.Linear(enc_dim, hidden_size)
        self.vg_affine = nn.Linear(enc_dim, embed_size)

        self.disable_learning()
       
    def forward(self, images):
        encoded_image = self.vgg(images)

        batch_size = encoded_image.shape[0]
        features = encoded_image.shape[1]
        num_pixels = encoded_image.shape[2] * encoded_image.shape[3]

        global_features = self.avgpool(encoded_image).view(batch_size, -1)
        global_features = F.relu(self.vg_affine(global_features))

        enc_image = encoded_image.permute(0, 2, 3, 1)  
        enc_image = enc_image.view(batch_size,num_pixels,features)
        enc_image = F.relu(self.V_affine(enc_image))
  
        return enc_image, global_features

    def disable_learning(self):
        for param in self.vgg.parameters():
            param.requires_grad = False

# RNN Decoder: Vanilla RNN Module

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size,vocab_size, hidden_size, num_layers):
        super(DecoderRNN, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, caption):
        embeddings = self.dropout(self.embedding(caption))
        embeddings = torch.cat((features.unsqueeze(0),embeddings), dim=0)
        hiddens, _ = self.rnn(embeddings)
        outputs = self.linear(hiddens)

        return outputs

# Training and Evaluating the Image Captioning Model: VGG16 + Multi-layer Vanilla RNN

### Setting the Device to GPU

In [None]:
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Device:', device)

### Setting Hyperparameters

In [None]:
num_epochs = 10
freq_threshold = 5
enc_dim = 2048
embed_size = 300
hidden_size = 512
att_dim = 49
train_embed = False
lr_enc = 1e-5
lr_dec = 5e-4

### Building the Vocabulary

In [None]:
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(all_labels)
vocab_size = len(vocab)

### Configuring Models

In [None]:

criterion = nn.CrossEntropyLoss(ignore_index = vocab.stoi["<PAD>"])

enc = EncoderCNN(enc_dim, embed_size, hidden_size).to(device)
dec = DecoderRNN(embed_size, hidden_size, att_dim, vocab_size, train_embed).to(device)

optim_dec = optim.Adam(dec.parameters(), lr = lr_dec, betas = (0.8, 0.999))

### Defining Metrics

In [None]:
import copy
from collections import defaultdict
import pdb

def precook(s, n=4, out=False):
    words = s.split()
    counts = defaultdict(int)

    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1

    return counts

def cook_refs(refs, n=4):
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    return precook(test, n, True)

class CiderScorer(object):
    def copy(self):
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)

        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def cook_append(self, test, refs):
        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                self.ctest.append(cook_test(test))
            else:
                self.ctest.append(None)

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        if type(other) is tuple:
            self.cook_append(other[0], other[1])

        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    
    def compute_doc_freq(self):
        for refs in self.crefs:
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
                self.document_frequency[ngram] += 1

    def compute_cider(self):
        def counts2vec(cnts):
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram,term_freq) in cnts.iteritems():
                df = np.log(max(1.0, self.document_frequency[ngram]))
                n = len(ngram)-1
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq

            norm = [np.sqrt(n) for n in norm]

            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            delta = float(length_hyp - length_ref)
            val = np.array([0.0 for _ in range(self.n)])

            for n in range(self.n):
                for (ngram,count) in vec_hyp[n].iteritems():
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))

            return val

        self.ref_len = np.log(float(len(self.crefs)))

        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            vec, norm, length = counts2vec(test)
            score = np.array([0.0 for _ in range(self.n)])

            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)

            score_avg = np.mean(score)
            score_avg /= len(refs)
            score_avg *= 10.0
            scores.append(score_avg)

        return scores

    def compute_score(self, option=None, verbose=0):
        self.compute_doc_freq()
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        score = self.compute_cider()

        return np.mean(np.array(score)), np.array(score)
    
    class Cider:
        def __init__(self, test=None, refs=None, n=4, sigma=6.0):
            self._n = n
            self._sigma = sigma

        def compute_score(self, gts, res):
            assert(gts.keys() == res.keys())
            imgIds = gts.keys()

            cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)

            for id in imgIds:
                hypo = res[id]
                ref = gts[id]

                assert(type(hypo) is list)
                assert(len(hypo) == 1)
                assert(type(ref) is list)
                assert(len(ref) > 0)

                cider_scorer += (hypo[0], ref)

            (score, scores) = cider_scorer.compute_score()

            return score, scores

        def method(self):
            return "CIDEr"

### Training Models

In [None]:
!pip install pycocoevalcap

from pycocoevalcap import Scorer

In [None]:
def pred(inp, V, vg, states, decoder):
    inp_c = torch.cat((vg, inp), dim = 1)
    h_t, m_t = states
    s_t, h_t, m_t = decoder.adaptive_lstm(inp_c, states)
            
    states = (h_t, m_t)
            
    out_l = decoder.adaptive_att(V, h_t, s_t)
    output = decoder.p_affine(decoder.dropout(out_l))
    output = F.softmax(output, dim = 1)
    
    return output.view(output.size(1)).detach().cpu().numpy(), states

def caption_image_beam(image, vocabulary, encoder, decoder, device = 'cpu', k = 10, max_length=50):
    result_caption = []

    with torch.no_grad():
        V, vg = encoder(image)
        states = (torch.zeros((1, V.size(2))).to(device), torch.zeros((1, V.size(2))).to(device))
        sequences = [[list(), 0.0, states]]
        inp = vocabulary.stoi['<SOS>']

        for _ in range(max_length):
            
            all_candidates = list()
            
            for i in range(len(sequences)):
                seq, score, states = sequences[i]
                
                if len(seq) != 0:
                    inp = seq[-1]
                    
                    if vocabulary.itos[inp] == "<EOS>":
                        all_candidates.append(sequences[i])
                        continue
                        
                inp = decoder.embed(torch.tensor([inp]).to(device))
                    
                predictions, states = pred(inp, V, vg, states, decoder)
                
                word_preds = np.argsort(predictions)[-k:]
                
                for j in word_preds:
                    candidate = (seq + [j], score - math.log(predictions[j]), states)
                    all_candidates.append(candidate)
                    
            ordered = sorted(all_candidates, key=lambda tup:tup[1]/(len(tup[0])))
            sequences = ordered[:k]     
            
    output_arr = sequences[0][0]
            
    if vocabulary.itos[sequences[0][0][-1]] == '<EOS>':
        output_arr = sequences[0][0][:-1]
        
    if vocabulary.itos[sequences[0][0][0]] == '<SOS>':
        output_arr = output_arr[1:]    
        
    return [vocabulary.itos[idx] for idx in output_arr]

def compute_individual_metrics(reference, candidate):
    metrics = dict()

    metrics['Bleu_1'] = sentence_bleu(reference, candidate, weights=(1, 0, 0, 0))
    metrics['Bleu_2'] = sentence_bleu(reference, candidate, weights=(0, 1, 0, 0))
    metrics['Bleu_3'] = sentence_bleu(reference, candidate, weights=(0, 0, 1, 0))
    metrics['Bleu_4'] = sentence_bleu(reference, candidate, weights=(0, 0, 0, 1))

In [None]:
val_iter = iter(val_loader)

train_losses = list()
val_losses = list()

cumulative_bleu_scores = list()

cider_scores = list()

meteor_scores = list()

rougel_scores = list()

bleu1_scores = list()
bleu2_scores = list()
bleu3_scores = list()
bleu4_scores = list()

for epoch in range(num_epochs):
    for batch_idx, (imgs, captions, img_ids) in enumerate(train_loader):

        dec.train()
                
        imgs = imgs.to(device)
        captions = captions.to(device)

        enc_imgs, global_features = enc(imgs)
        outputs = dec(enc_imgs, global_features, captions[:-1], captions[:-1].size(0), device)

        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[1:].reshape(-1))
                
        optim_dec.zero_grad()

        loss.backward()
                
        optim_dec.step()

        if batch_idx % 600 == 0:
            
            with torch.no_grad():
                
                enc.eval()
                dec.eval()
                
                try:
                    val_imgs, val_captions, val_img_ids = next(val_iter)
                
                except StopIteration:
                    val_iter = iter(val_loader)
                    val_imgs, val_captions, val_img_ids = next(val_iter)
                
                val_imgs = val_imgs.to(device)
                val_captions = val_captions.to(device)

                enc_val_imgs, global_val_features = enc(val_imgs)
                val_outputs = dec(enc_val_imgs, global_val_features, val_captions[:-1], val_captions[:-1].size(0), device)

                val_loss = criterion(val_outputs.reshape(-1, val_outputs.shape[2]), val_captions[1:].reshape(-1))
                
                val_losses.append(val_loss.item())
                train_losses.append(loss.item())
                
                val_img_ids = val_img_ids.squeeze(1).numpy()
                
                r_set = np.arange(batch_size)
                
                np.random.shuffle(r_set)
                
                index = r_set[0]
                
                val_img = val_imgs[index]
                candidate = caption_image_beam(val_img.unsqueeze(0), vocab, enc, dec, device)
                ref_tokens = dict_tokens[val_img_ids[index]]
                
                cumulative_bleu_score = sentence_bleu(ref_tokens, candidate, weights=(0.25, 0.25, 0.25, 0.25))
                    
                hyp = ' '.join(candidate)
                refs = list()
                    
                for sentence in ref_tokens:
                    ref = ' '.join(sentence)
                    refs.append(ref)
                    
                metrics = compute_individual_metrics(refs, hyp)
                
                np.random.shuffle(refs)
                
                bleu1_scores.append(metrics['Bleu_1'])
                bleu2_scores.append(metrics['Bleu_2'])
                bleu3_scores.append(metrics['Bleu_3'])
                bleu4_scores.append(metrics['Bleu_4'])

                cider_scores.append(metrics['CIDEr'])

                cumulative_bleu_scores.append(cumulative_bleu_score)

                rougel_scores.append(metrics['ROUGE_L'])

                meteor_scores.append(metrics['METEOR'])
                
                pil_img = Image.open(os.path.join('flickr30k_images', idx_imgs[val_img_ids[index]])).convert("RGB")
                plt.imshow(np.asarray(pil_img))
                
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] 
                      Training loss: {loss.item()} 
                      Validation loss: {val_loss.item()} 
                      Cumulative BLEU Score: {cumulative_bleu_score} 
                      CIDEr Score: {metrics['CIDEr']} 
                      METEOR Score: {metrics['METEOR']} 
                      ROUGE_L: {metrics['ROUGE_L']} \n")
                
                print(f"Ground Truth: {refs[0]}\n")
                
                print(f"Predicted Caption: {hyp}\n\n")
    
        if batch_idx == (len(train_loader) - 1):
            torch.save(enc.state_dict(), f'/kaggle/working/encoder_{freq_threshold}_{batch_size}_{hidden_size}_{epoch+1}.pth')
            torch.save(dec.state_dict(), f'/kaggle/working/decoder_{freq_threshold}_{batch_size}_{hidden_size}_{epoch+1}.pth')


### Plotting Losses and Scores

In [None]:
plt.figure(0)
plt.plot(train_losses, label = 'Training loss')
plt.plot(val_losses, label = 'Validation loss')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_losses.png')

plt.figure(1)
plt.plot(bleu1_scores, label = 'BLEU 1')
plt.plot(bleu2_scores, label = 'BLEU 2')
plt.plot(bleu3_scores, label = 'BLEU 3')
plt.plot(bleu4_scores, label = 'BLEU 4')
plt.ylabel('BLEU Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_bleu_scores.png')
        
plt.figure(2)
plt.plot(cumulative_bleu_scores, label = 'Cumulative BLEU SCORE')
plt.plot(cider_scores, label = 'CIDEr SCORE')
plt.plot(meteor_scores, label = 'METEOR SCORE')
plt.plot(rougel_scores, label = 'ROUGE_L SCORE')
plt.ylabel('Scores')
plt.legend()
plt.savefig(f'/kaggle/working/{freq_threshold}_{batch_size}_{hidden_size}_{num_epochs}_scores.png')