In [None]:
!pip install pytorch-pretrained-bert

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets, models
import numpy as np
import matplotlib.pyplot as plt
import random
import re
from scipy import ndimage
from torch.autograd import Variable
from PIL import Image
import numpy as np

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertModel
from tqdm import tqdm, trange
import pandas as pd
import io
import os
import numpy as np
import matplotlib.pyplot as plt
# % matplotlib inline

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
scaler = transforms.Resize([224, 224])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

feature_extraction = torchvision.models.resnet18(pretrained=True).to(device)
feature_extraction = nn.Sequential(*list(feature_extraction.children())[:-2]).to(device)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_WE = BertModel.from_pretrained('bert-base-uncased').to(device)

In [None]:
for param in feature_extraction.parameters():
    param.requires_grad = False
feature_extraction.eval()

def img_load_feat(img_loc,img_name):
    img_loc += str(img_name) + '.jpg'
    img = Image.open(img_loc)
    t_img = normalize(to_tensor(scaler(img)))
    t_img = t_img.to(device)
    t_img = torch.unsqueeze(t_img, 0)
    feature = feature_extraction(t_img)
    return feature

In [None]:
img_loc = '../input/flickr8k/Images/'
img_features = dict()
for img_id_jpg in os.listdir(img_loc):
    img_id = img_id_jpg.split('.')[0]
    img_features[img_id] = img_load_feat(img_loc, img_id)
    img_features[img_id].requires_grad_()

In [None]:
file = open('../input/flickr8k/captions.txt', 'r')
ip_desc = file.read()
file.close()

img_cap = dict()
img_cap_indexed = dict()
max_len = 15
index = 1
my_vocab = dict()
my_rev_vocab = dict()

for line in ip_desc.split('\n')[1:-1]:
    if '"' in line:
        ip = re.split(r',(?=")',line)
    else:
        ip = line.split(',')
    
    # image name
    img_id = ip[0].split('.')[0]
    
    # cleaning desc
    clean_desc = ''
    for ch in ip[1]:
        if ('A'<=ch and ch<='Z') or ('a'<=ch and ch<='z') or ch==' ':
            clean_desc += ch
    clean_desc = clean_desc.rstrip().lower()
    
    # truncating sentences with len > max_len
    if len(clean_desc) > max_len:
        clean_desc = clean_desc[:max_len]
    
    # tokenization of clean desc
    tok_desc = tokenizer.tokenize(clean_desc)    
    
    # mapping each img_id to a list of 5 captions
    if img_id not in img_cap:
        img_cap[img_id] = list()
        img_cap_indexed[img_id] = list()
    img_cap[img_id].append(clean_desc)
    
    # building vocab
    tok_desc_ind = list()
    for tok in tok_desc:
        if tok not in my_vocab:
            my_vocab[tok] = index
            my_rev_vocab[index] = tok
            index += 1
        tok_desc_ind.append(my_vocab[tok])
    tok_desc_ind += [0]*(max_len-len(tok_desc))
    img_cap_indexed[img_id].append(tok_desc_ind)
    
    # converting tokens to IDs
    tok_desc = tokenizer.convert_tokens_to_ids(tok_desc)
    

In [None]:
img_ids = list(img_cap.keys())
train_data = img_ids[:(80*len(img_ids))//100]
test_data = img_ids[(80*len(img_ids))//100:]
# train_data, test_data = torch.utils.data.random_split(img_ids, [int(0.8*len(img_ids)),len(img_ids)-int(0.8*len(img_ids))])

print(index, len(img_features.keys()))

In [None]:
def get_cider(captions):

    cider = dict()
    tfidfVectorizers=[]
    for i in range(4):
        tf_idf_vect = TfidfVectorizer(ngram_range = (i+1,i+1))
        tfidfVectorizers.append(tf_idf_vect)
    for key in captions:
        corpus = captions[key]
        cider[key] = []
        
        tfidf_matrix = tfidfVectorizers[0].fit_transform(corpus)
        # terms = tf_idf_vect.get_feature_names()
        # print(tfidf_matrix)
        
        tfidf_mat = tfidfVectorizers[0].fit_transform(corpus)
        mean_cider=0
        for i in range(len(corpus)):
            mean_cider=0
            for n in range(4):
                tfidf_mat = tfidfVectorizers[0].fit_transform(corpus)
                # print(tfidf_mat.todense().tolist())
                tfidf_mat = tfidf_mat.todense().tolist()
                r = tfidf_mat[i] 
                r_norm = np.linalg.norm(np.array(r))
                sum=0
                for j in range(len(corpus)):
                    if i != j:
                        vec_norm = np.linalg.norm(np.array(tfidf_mat[j]))
                        sum += (np.dot(r,tfidf_mat[j])/(np.sqrt(r_norm*vec_norm)))
                mean_cider += (sum/(len(corpus)-1))
            mean_cider /= 4.0   
            cider[key].append(mean_cider)
            
    return cider

In [None]:
class ConsensusLoss(torch.nn.Module):
    
    def _init_(self):
        super(ConsensusLoss,self)._init_()
    
    def softmax(self, x):
        exp_x = torch.exp(x)
        sum_x = torch.sum(exp_x, dim=1, keepdim=True)
        return exp_x/sum_x
    
    def log_softmax(self, x):
        return torch.exp(x) - torch.sum(torch.exp(x), dim=1, keepdim=True)
        
    def forward(self, outputs, targets, ciders):
    
        num_words = targets.shape[1]
        sentence_size = outputs.shape[0]
        outputs = self.log_softmax(outputs)
        # print(outputs)
        # print(len(outputs))
        cl = 0
        
        # print(len(ciders))
        # print(targets)
        # print(type(targets[0]))

        for i in range(len(ciders)):
            output = outputs[range(sentence_size), targets[i]]
            # print(ciders[i])            
            cl += (-((torch.sum(output)/num_words)*ciders[i]))
            # print(cl)

        return cl/len(ciders)


        # # reshape labels to give a flat vector of length batch_size*seq_len
        # labels = labels.view(-1)  

        # # mask out 'PAD' tokens
        # mask = (labels >= 0).float()

        # # the number of tokens is the sum of elements in mask
        # num_tokens = int(torch.sum(mask).data[0])

        # # pick the values corresponding to labels and multiply by mask
        # outputs = outputs[range(outputs.shape[0]), labels]*mask

        # # cross entropy loss for all non 'PAD' tokens
        # return -torch.sum(outputs)/num_tokens

In [None]:
all_ciders = get_cider(img_cap)
print(type(all_ciders))

In [None]:
beg_seq = tokenizer.convert_tokens_to_ids(tokenizer.tokenize("[CLS] "))
end_seq = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" [SEP]"))
pad_seq = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" [PAD] "))
print(beg_seq, end_seq, pad_seq)

In [None]:
class Caption_Generation(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense1 = nn.Sequential(nn.Linear(7*7*512,8192), nn.Dropout(0.2), nn.ReLU(),
                                    nn.Linear(8192,1024), nn.Dropout(0.2), nn.ReLU(),
                                    nn.Linear(1024,256), nn.Dropout(0.2), nn.ReLU()
                                   )
        
        self.dense2 = nn.Sequential(nn.Linear((max_len+2)*768,4096), nn.Dropout(0.2), nn.ReLU(),
                                    nn.Linear(4096,1024), nn.Dropout(0.2), nn.ReLU(),
                                    nn.Linear(1024,256), nn.Dropout(0.2), nn.ReLU()
                                   )
        
        self.cap_gen = nn.Sequential(
            nn.Linear(512,256), nn.ReLU(),
            nn.Linear(256,index), nn.ReLU()
        )
    
    def forward(self, feat, idx):
        output = list()
        desc = list()
        
        for i in range(max_len):
            ip1 = self.flatten(feat)
            ip1 = self.dense1(ip1)
            # print(desc)
            
            pad_desc = list()
            pad_desc += beg_seq
            pad_desc += tokenizer.convert_tokens_to_ids(desc) + end_seq
            pad_desc += ([0]*(max_len-len(desc)))
            
            att_mask = torch.tensor([[int(i>0) for i in pad_desc]]).to(device)
            t_pad_desc = torch.LongTensor([pad_desc]).to(device)
            # print("att_mask: ", att_mask)
            # print("pad_desc: ", pad_desc)
            
            bert_WE.eval()
            ip2 = (bert_WE(t_pad_desc, attention_mask = att_mask, output_all_encoded_layers = False))[0]
            ip2 = self.flatten(ip2)
            ip2 = self.dense2(ip2)
            
            # ip = torch.add(ip1,ip2)
            ip = torch.unsqueeze(torch.cat((ip1[0],ip2[0]),0),0)
            out = self.cap_gen(ip)[0]
            output.append(out)
            
            word = int(torch.max(out.view(1,-1), 1)[1])
            word = my_rev_vocab[word]
            desc.append(word)
        
        if idx%100 == 0:
            print(idx, desc)
        return output

In [None]:
def train(train_data, all_ciders, epochs, model, optim, loss_f):
    
    model = model.to(device)
    model.train()
    loss_vals = list()
    
    for epoch in range(epochs):
        train_loss = 0.0
        idx = 0
        for img_id in train_data:
            if img_id not in img_features:
                continue
            optim.zero_grad()
            
            output = model(img_features[img_id], idx)
            out = list()
            for o in output:
                out.append(o.tolist())
            output = out
            t_output = (torch.FloatTensor(output).to(device)).requires_grad_()
            op = img_cap_indexed[img_id]
            t_op = (torch.LongTensor(op).to(device))
            # print(t_output, t_op)
            
            loss = loss_f(t_output, t_op, all_ciders[img_id])
            if idx%100 == 0:
                print(idx, img_id, loss.item())
                img = Image.open(img_loc+img_id+'.jpg')
                imgplot = plt.imshow(img)
                plt.show()
            idx += 1
                
            train_loss += loss
            loss.backward()
            optim.step()
                    
        train_loss /= len(train_data)
        print('Epoch: ', epoch, 'Avg Train loss: ', float(train_loss))
        loss_vals.append(train_loss)
        
    plt.plot(np.linspace(1, epochs, epochs).astype(int), loss_vals)
        

def test(img_loc, test_data, model, loss_f):
    model.eval()
    test_loss = 0.0
    
    for img_id in test_data:
        
        img = Image.open(img_loc+img_id+'.jpg')
        imgplot = plt.imshow(img)
        plt.show()
        
        output = model(img_features[img_id]).to(device)        
        t_output = (torch.FloatTensor(output).to(device),0)
        op = img_cap_indexed[img_id]
        t_op = (torch.LongTensor(op).to(device),0)
        # print(t_output, t_op)
        
        for word_prob in output:
            word = int(torch.max(word_prob[-1].view(1,-1), 1)[1])
            word = my_rev_vocab[word]
            print(word, end=' ')
        print()

        loss = loss_f(t_output,t_op)
        test_loss += loss
        
    test_loss /= len(train_data)
    print('Epoch: ', epoch, 'Avg Test loss: ', float(test_loss))

In [None]:
caption_generation = Caption_Generation().to(device)

optimizer = torch.optim.Adam(caption_generation.parameters(), lr = 0.001)
loss = ConsensusLoss().to(device)

train(train_data, all_ciders, 2, caption_generation, optimizer, loss)

In [None]:
test(img_loc, test_data, caption_generation, loss)

In [None]:
embeding dimension 512
positional embedding 2048
attention layer 6