In [2]:
import nltk
from pycocotools.coco import COCO
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence

import os
import pickle
import numpy as np
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt

import torch
import torch.nn as nn




# Load Dictionary

In [3]:
class Vocab(object):
    
    def __init__(self):
        self.w2i = {} # word to index
        self.i2w = {}
        self.index = 0
    
    # if v is a Vocab object, v(token) will return the index associated with the token
    def __call__(self,token):
        if not token in self.w2i:
            return self.w2i["<unk>"]
        return self.w2i[token]
    
    def __len__(self):
        return len(self.w2i)
    
    def add_token(self,token):
        if not token in self.w2i:
            # if the token is new, assign an index to it and update w2i, i2w, update index
            self.w2i[token] = self.index
            self.i2w[self.index] = token
            self.index += 1
    
            
        

In [4]:
with open("vocabulary.pkl","rb") as f:
    vocabulary = pickle.load(f)

# Models

In [5]:
class CNN(nn.Module):
    def __init__(self,embedding_size,weight_path):
        # pretrained ResNext-50
        super(CNN,self).__init__()
        resnet = models.resnext101_32x8d(pretrained=False)
        if weight_path:
            resnet.load_state_dict(torch.load(weight_path))
        # exclude the last layer
        module_list = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*module_list) 
        self.linear = nn.Linear(resnet.fc.in_features,embedding_size) # embed the output features
        self.batch_norm = nn.BatchNorm1d(embedding_size,momentum=0.01)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self,input_imgs):
        # extract features
        with torch.no_grad():
            features = self.resnet(input_imgs)
        
        # embed the features
        features = features.reshape(features.size(0),-1)
        features = self.dropout(self.batch_norm(self.linear(features)))
        return features


        
        

In [13]:
class LSTM(nn.Module):
    def __init__(self,embed,hidden,vocab,num_layers,max_seq_len=30):
        super(LSTM,self).__init__()
        self.embedding = nn.Embedding(vocab,embed)
        #self.lstm = nn.LSTM(embed,hidden,num_layers,batch_first=True,dropout=0.1)
        self.lstm = nn.LSTM(embed*2,hidden,num_layers,batch_first=True,dropout=0.1)
        self.linear = nn.Linear(hidden,vocab)
        self.max_seq_len = max_seq_len
    
    def forward(self,input_features,capts,lens):
        # decode image feature vectors and generate captions
        ####
        capts = capts[:,:-1] # for each caption, ignore the last token
        ###
        
        embeddings = self.embedding(capts)
        # input_features (batch,embed_size)
        # embeddings (batch,sequence_len,embed_size)
        
        #####
        repeated_features = input_features.unsqueeze(1).repeat(1,embeddings.size(dim=1),1)
        embeddings = torch.cat((repeated_features,embeddings),2)
        
        #####
        
        #embeddings = torch.cat((input_features.unsqueeze(1),embeddings), 1)
        #lstm_input = pack_padded_sequence(embeddings, np.array(lens), batch_first=True)
        
        lstm_input = pack_padded_sequence(embeddings, np.array(lens)-1, batch_first=True)
        hidden_variables,_ = self.lstm(lstm_input)
        model_outputs = self.linear(hidden_variables[0])
        return model_outputs
    
    def sample(self,input_features,lstm_states=None):
        # generate captions for given image features
        
        sampled_indices = []
        features = input_features
        
        ### 
        word = "<start>"
        embedding = torch.tensor(vocabulary(word)).unsqueeze(0)
        embedding = self.embedding(embedding.cuda())
        #print("embedding dimension before cat: ", embedding.shape)
        lstm_inputs = torch.cat((features,embedding),1).unsqueeze(1)
        #print("embedding dimension after cat: ",lstm_inputs.shape)
        
        
        ###
        
        for i in range(self.max_seq_len):
            # hidden_variables (batch, 1, hidden_size)
            print("input: ", lstm_inputs[0,0,-256:])
            hidden_variables, lstm_states = self.lstm(lstm_inputs,lstm_states)
            # output (batch,num_vocab)
            model_outputs = self.linear(hidden_variables.squeeze(1))
            # predicted outputs (batch,)
            _, predicted_outputs = model_outputs.max(1)
            sampled_indices.append(predicted_outputs)
            #print(predicted_outputs)
            lstm_inputs = self.embedding(predicted_outputs)
            lstm_inputs =torch.cat((features,embedding),1).unsqueeze(1)
        # sampled_ids (batch,max_seq_len)
        sampled_indices = torch.stack(sampled_indices,1)
        return sampled_indices

# load image & model

In [7]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.485,0.456,0.406),
                                                   (0.229,0.224,0.225))])
def load_image(path):
    img = Image.open(path).convert("RGB")
    img = img.resize([224,224],Image.LANCZOS)
    img = transform(img).unsqueeze(0)
    return img

In [14]:
encoder = CNN(256,None).eval()
decoder = LSTM(256,1024,len(vocabulary),1).eval()
encoder = encoder.to("cuda")
decoder = decoder.to("cuda")
encoder.load_state_dict(torch.load("Weights/complex_run-5-encoder.pth"))
decoder.load_state_dict(torch.load("Weights/complex_run-5-decoder.pth"))

<All keys matched successfully>

# Generate captioning

In [18]:
img = load_image("Images/4.jpg")
img = img.to("cuda")
feature = encoder(img)
sampled_indices = decoder.sample(feature)
#print(sampled_indices)
sampled_indices = sampled_indices[0].cpu().numpy()
sampled_indices

input:  tensor([-0.8467, -1.6892,  3.2016,  1.4408,  0.0467,  1.0292, -2.2968,  0.7602,
        -0.5187, -0.5010,  0.7584, -0.9053,  1.3836, -2.1331,  0.8453, -1.0967,
        -0.0484, -0.6404, -1.2127, -0.6120, -0.5505,  0.6289, -0.0200, -1.9444,
         1.7948,  1.7444,  1.0262,  1.5052, -0.5746, -0.8528, -0.1328,  0.6217,
         1.5925,  1.5349, -0.3052, -0.0918, -0.2944, -0.9253, -0.5546,  0.5540,
        -0.7967, -1.1423, -0.6112,  0.4851,  0.6963, -0.4579, -1.0087, -0.2497,
        -0.7815,  0.3194, -0.8156,  0.7245, -0.2217,  0.1612, -2.0000,  0.3959,
         0.6304,  0.3433,  1.2666, -1.7570,  0.7388, -0.0683, -1.1832,  0.2739,
         1.0795,  0.5357, -0.0813,  0.1950, -0.0322,  0.2982, -1.5255,  0.5604,
         0.1549, -0.2992,  0.0666,  0.0525, -1.1806, -0.6942,  0.3070, -0.6814,
         1.4881,  0.4291,  1.1495, -0.2090,  0.0643, -1.7271,  0.4208,  0.7165,
         2.4250,  2.2224,  0.3455, -0.2024, -0.6149,  0.6268, -0.3841, -0.3493,
         1.4415,  1.0075,  1.616

input:  tensor([-0.8467, -1.6892,  3.2016,  1.4408,  0.0467,  1.0292, -2.2968,  0.7602,
        -0.5187, -0.5010,  0.7584, -0.9053,  1.3836, -2.1331,  0.8453, -1.0967,
        -0.0484, -0.6404, -1.2127, -0.6120, -0.5505,  0.6289, -0.0200, -1.9444,
         1.7948,  1.7444,  1.0262,  1.5052, -0.5746, -0.8528, -0.1328,  0.6217,
         1.5925,  1.5349, -0.3052, -0.0918, -0.2944, -0.9253, -0.5546,  0.5540,
        -0.7967, -1.1423, -0.6112,  0.4851,  0.6963, -0.4579, -1.0087, -0.2497,
        -0.7815,  0.3194, -0.8156,  0.7245, -0.2217,  0.1612, -2.0000,  0.3959,
         0.6304,  0.3433,  1.2666, -1.7570,  0.7388, -0.0683, -1.1832,  0.2739,
         1.0795,  0.5357, -0.0813,  0.1950, -0.0322,  0.2982, -1.5255,  0.5604,
         0.1549, -0.2992,  0.0666,  0.0525, -1.1806, -0.6942,  0.3070, -0.6814,
         1.4881,  0.4291,  1.1495, -0.2090,  0.0643, -1.7271,  0.4208,  0.7165,
         2.4250,  2.2224,  0.3455, -0.2024, -0.6149,  0.6268, -0.3841, -0.3493,
         1.4415,  1.0075,  1.616

input:  tensor([-0.8467, -1.6892,  3.2016,  1.4408,  0.0467,  1.0292, -2.2968,  0.7602,
        -0.5187, -0.5010,  0.7584, -0.9053,  1.3836, -2.1331,  0.8453, -1.0967,
        -0.0484, -0.6404, -1.2127, -0.6120, -0.5505,  0.6289, -0.0200, -1.9444,
         1.7948,  1.7444,  1.0262,  1.5052, -0.5746, -0.8528, -0.1328,  0.6217,
         1.5925,  1.5349, -0.3052, -0.0918, -0.2944, -0.9253, -0.5546,  0.5540,
        -0.7967, -1.1423, -0.6112,  0.4851,  0.6963, -0.4579, -1.0087, -0.2497,
        -0.7815,  0.3194, -0.8156,  0.7245, -0.2217,  0.1612, -2.0000,  0.3959,
         0.6304,  0.3433,  1.2666, -1.7570,  0.7388, -0.0683, -1.1832,  0.2739,
         1.0795,  0.5357, -0.0813,  0.1950, -0.0322,  0.2982, -1.5255,  0.5604,
         0.1549, -0.2992,  0.0666,  0.0525, -1.1806, -0.6942,  0.3070, -0.6814,
         1.4881,  0.4291,  1.1495, -0.2090,  0.0643, -1.7271,  0.4208,  0.7165,
         2.4250,  2.2224,  0.3455, -0.2024, -0.6149,  0.6268, -0.3841, -0.3493,
         1.4415,  1.0075,  1.616

array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4], dtype=int64)

In [48]:
# convert tokens to text
predicted_caption = []
for token_index in sampled_indices:
    word = vocabulary.i2w[token_index]
    predicted_caption.append(word)
    if word == "<end>":
        break
predicted_setence = " ".join(predicted_caption) 
predicted_setence

'a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a'