In [None]:
!pip install pytorch-pretrained-bert

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets, models
import numpy as np
import matplotlib.pyplot as plt
import random
import re
from scipy import ndimage
from torch.autograd import Variable
from PIL import Image
import numpy as np

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from pytorch_pretrained_bert import BertTokenizer, BertModel
from tqdm import tqdm, trange
import pandas as pd
import io
import os
import numpy as np
import matplotlib.pyplot as plt
# % matplotlib inline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
scaler = transforms.Resize([224, 224])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

feature_extraction = torchvision.models.resnet18(pretrained=True).to(device)
feature_extraction = nn.Sequential(*list(feature_extraction.children())[:-2]).to(device)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_WE = BertModel.from_pretrained('bert-base-uncased').to(device)

In [None]:
for param in feature_extraction.parameters():
    param.requires_grad = False
feature_extraction.eval()

def img_load_feat(img_loc,img_name):
    img_loc += str(img_name) + '.jpg'
    img = Image.open(img_loc)
    t_img = normalize(to_tensor(scaler(img)))
    t_img = t_img.to(device)
    t_img = torch.unsqueeze(t_img, 0)
    # new
    feature = feature_extraction(t_img)
    return feature

In [None]:
img_loc = '../input/flickr8k/Images/'
img_features = dict()
for img_id_jpg in os.listdir(img_loc):
    img_id = img_id_jpg.split('.')[0]
    img_features[img_id] = img_load_feat(img_loc, img_id)
    img_features[img_id].requires_grad_()

In [None]:
file = open('../input/flickr8k/captions.txt', 'r')
ip_desc = file.read()
file.close()

data = list()
all_desc = dict()
max_len = 15
index = 0
my_vocab = dict()
my_rev_vocab = dict()

for line in ip_desc.split('\n')[1:-1]:
    if '"' in line:
        ip = re.split(r',(?=")',line)
    else:
        ip = line.split(',')

    img_id = ip[0].split('.')[0]
        
    # cleaning desc
    clean_desc = ''
    for ch in ip[1]:
        if ('A'<=ch and ch<='Z') or ('a'<=ch and ch<='z') or ch==' ':
            clean_desc += ch
    clean_desc = clean_desc.rstrip().lower()
    
    if img_id not in all_desc:
        all_desc[img_id] = list()
    all_desc[img_id].append(clean_desc)
    
    if len(clean_desc) > 15:
        clean_desc = clean_desc[:15]
    
    # tokenization of clean desc
    tok_desc = tokenizer.tokenize(clean_desc)    
    
    for tok in tok_desc:
        if tok not in my_vocab:
            my_vocab[tok] = index
            my_rev_vocab[index] = tok
            index += 1
    
    for i in range(0,len(tok_desc)):
        data.append([img_id,my_vocab[tok_desc[i]]])
    
    # converting tokens to IDs
    tok_desc = tokenizer.convert_tokens_to_ids(tok_desc)

In [None]:
print(index, len(img_features.keys()))

In [None]:
beg_seq = tokenizer.convert_tokens_to_ids(tokenizer.tokenize("[CLS] "))
end_seq = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" [SEP]"))

In [None]:
class Caption_Generation(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense1 = nn.Sequential(nn.Linear(7*7*512,256), nn.Dropout(0.2), nn.ReLU())
        
        self.dense2 = nn.Sequential(nn.Linear(max_len*768,1024), nn.Dropout(0.2), nn.ReLU())
        self.dense3 = nn.Sequential(nn.Linear(1024,256), nn.Dropout(0.2), nn.ReLU())
        
        self.cap_gen = nn.Sequential(
            nn.Linear(256,128), nn.ReLU(),
            nn.Linear(128,index), nn.ReLU()
        )
    
    def forward(self, x1, x2):
        desc = (beg_seq + x2 + end_seq)
        att_mask = torch.IntTensor([[1]*len(desc) + [0]*(max_len-len(desc))]).to(device)
        pad_desc = (desc + [0]*(max_len-len(desc)))
        t_pad_desc = torch.IntTensor([pad_desc]).to(device)
        bert_WE.eval()
        
        ip1 = self.flatten(x1)
        ip1 = self.dense1(ip1)
        
        ip2 = (bert_WE(t_pad_desc, attention_mask = att_mask, output_all_encoded_layers = False))[0]
        ip2 = self.flatten(ip2)
        ip2 = self.dense2(ip2)
        ip2 = self.dense3(ip2)
        
        ip = torch.add(ip1,ip2)
        out = self.cap_gen(ip)
        
        return out

In [None]:
def train(train_data, epochs, model, optim, loss_f):
    
    model = model.to(device)
    model.train()
    
    for epoch in range(epochs):
        train_loss = 0.0
        idx = 0
        for ip,op in train_data:
            optim.zero_grad()
            output = model(img_features[ip[0]],ip[1]).to(device)
            
            output = int(torch.max(output[-1].view(1,-1), 1)[1])
            output = my_rev_vocab[output]
            output = ip[1][1:]+tokenizer.convert_tokens_to_ids([output])
            
            t_output = torch.unsqueeze(torch.FloatTensor(output).to(device),0).requires_grad_()
            t_op = torch.unsqueeze(torch.FloatTensor(op).to(device),0).requires_grad_()
            # print(t_output, t_op)
            
            loss = loss_f(t_output,t_op)
            train_loss += loss
            
            for params in model.parameters():
                params.requires_grad = True
            
            loss.backward()
            optim.step()
                    
        train_loss /= len(train_data)
        print('Epoch: ', epoch, 'Avg Train loss: ', float(train_loss))
        model.train()

def test(test_data, model, loss_f):
    model.eval()
    test_loss = 0.0
    
    for ip,op in test_data:
        output = model(img_features[ip[0]],ip[1]).to(device)
        output = int(torch.max(output[-1].view(1,-1), 1)[1])
        output = my_rev_vocab[output]
        print(output, end = ' ')
        output = ip[1][1:]+tokenizer.convert_tokens_to_ids([output])
        
        t_output = torch.unsqueeze(torch.FloatTensor(output).to(device),0).requires_grad_()
        t_op = torch.unsqueeze(torch.FloatTensor(op).to(device),0).requires_grad_()
        # print(t_output, t_op)
        
        loss = loss_f(t_output,t_op)
        test_loss += loss
        print()

In [None]:
caption_generation = Caption_Generation().to(device)

optimizer = torch.optim.Adam(caption_generation.parameters(), lr = 0.001)
loss = torch.nn.CrossEntropyLoss().to(device)

train(data, 2, caption_generation, optimizer, loss)

In [None]:
test(data[:5], caption_generation, loss)