In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
'''Import modules'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
from collections import Counter
from skimage import io, transform
from torch.nn.utils.rnn import pack_padded_sequence
from torchsummary import summary

import matplotlib.pyplot as plt # for plotting
import numpy as np
from time import time
import collections
import pickle
import os
import gensim
import nltk
from PIL import Image

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device =", device)
print("Using", torch.cuda.device_count(), "GPUs!")
platform = "local" #colab/local
restore = True #Restore Checkpoint
phase = "Test"

Device = cpu
Using 0 GPUs!


In [7]:
VOCAB = {}
WORD2IDX = {}
IDX2WORD = {}
TRAIN_CAPTIONS_DICT = {}

In [9]:
VOCAB.clear()
WORD2IDX.clear()
IDX2WORD.clear()
if platform != 'colab':
    with open('../dict/VOCAB_comp.pkl', 'rb') as handle:
        VOCAB = pickle.load(handle)
    with open('../dict/WORD2IDX_comp.pkl', 'rb') as handle:
        WORD2IDX = pickle.load(handle)
    with open('../dict/IDX2WORD_comp.pkl', 'rb') as handle:
        IDX2WORD = pickle.load(handle)
    print("Vocab Loaded Successfully")
else:
    with open('/content/drive/My Drive/A4/dict/VOCAB.pkl', 'rb') as handle:
        VOCAB = pickle.load(handle)
    with open('/content/drive/My Drive/A4/dict/WORD2IDX.pkl', 'rb') as handle:
        WORD2IDX = pickle.load(handle)
    with open('/content/drive/My Drive/A4/dict/IDX2WORD.pkl', 'rb') as handle:
        IDX2WORD = pickle.load(handle)
    print("Vocab Loaded Successfully")
print("VOCAB SIZE =", len(VOCAB))


Vocab Loaded Successfully
VOCAB SIZE = 8680


In [10]:
# Define a transform to pre-process the training images.
img_transform = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

    

In [37]:
class ImageCaptionsDataset(Dataset):

    def __init__(self, img_dir, img_transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            captions_dict: Dictionary with captions list keyed by image ids (integers)
            img_transform (callable, optional): Optional transform to be applied
                on the image sample.

            captions_transform: (callable, optional): Optional transform to be applied
                on the caption sample (list).
        """
        self.img_dir = img_dir
        self.img_transform = img_transform
        images = os.listdir(os.path.join(img_dir))
        images = [i.split("_")[1][:-4] for i in images]
        images = [int(i) for i in images]
        images.sort()
        self.image_ids = images

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        #print('IMG No.', self.image_ids[idx])
        img_name = os.path.join(self.img_dir, 'image_{}.jpg'.format(self.image_ids[idx]))
        image = Image.open(img_name).convert('RGB')
        #print("RAW IMG", image.shape)
        if self.img_transform:
            image = self.img_transform(image)
        
        sample = {'idx':self.image_ids[idx], 'image': image}

        return sample


In [38]:
class Encoder(nn.Module):
    def __init__(self, embed_dim):
        super(Encoder, self).__init__()
        resnet50 = models.resnet50(pretrained=True, progress=True)        
        
        for param in resnet50.parameters():
            param.requires_grad = False
        self.resnet50 = resnet50
        self.fc2 = nn.Linear(in_features=1000, out_features = embed_dim)
        print("EMBED DIM =", embed_dim)
        print("resnet50 Loaded Successfully..!")

    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_units, lstm_layers = 1):
        super(Decoder, self).__init__()
        vocab_size = len(VOCAB)
        print("VOCAB SIZE DECODER INIT =", vocab_size)
        
        self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_units,
                            num_layers = lstm_layers, batch_first = True)
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.linear = nn.Linear(hidden_units, vocab_size)
    
    def forward(self, image_features, image_captions):
        
        embedded_captions = self.embed(image_captions)
        input_lstm = torch.cat((image_features, embedded_captions[:,:-1]), dim = 1)
        #print("LSTM INPUT SHAPE", input_lstm.shape)
        lstm_outputs, _ = self.lstm(input_lstm)
        
        lstm_outputs = self.linear(lstm_outputs)
        
        return lstm_outputs

In [39]:
'''Save and Restore Checkpoints'''
def create_checkpoint(path,model, iteration, epoch):
    checkpoint = {'epoch': epoch,
                  'iteration': iteration,
                  'model_state_dict': model.state_dict()}

    if platform == "colab":
        directory = '/content/drive/My Drive/A4/review_cp/'
    else:
        directory = '../review_cp/'

    torch.save(checkpoint, directory + path)
    
def restore_checkpoint(path):
    new_state_dict = collections.OrderedDict()
    if platform == "colab":
        directory = '/content/drive/My Drive/A4/review_cp/'
        checkpoint = torch.load(directory + path, map_location=torch.device('cpu'))
    else:
        directory = '../review_cp/'
        checkpoint = torch.load(directory + path, map_location=torch.device('cpu'))    
    
    epoch = checkpoint['epoch']
    new_state_dict = checkpoint['model_state_dict']
    iteration = checkpoint['iteration']
    print("Iterations = {}, Epoch = {}".format(iteration, epoch))
    return new_state_dict

In [52]:
from nltk.tokenize.treebank import TreebankWordDetokenizer

def predict_captions(image_feature, max_words):
    z = image_feature.unsqueeze(0)
    results = []
    states = None

    #print(x)
    with torch.no_grad():
        for i in range(max_words):

            hiddens, states = decoder.lstm(z, states)
            #print(hiddens.shape)
            decoder_op = decoder.linear(hiddens.squeeze(1))
            predicted_word = decoder_op.argmax(1)
            
            z = decoder.embed(predicted_word)
            z = z.unsqueeze(0)

            word = predicted_word.item()
            results.append(word)

            '''if predicted_word == WORD2IDX["<end>"]:
                break'''
    

    caption = [IDX2WORD[i] for i in  results]
    #caption.remove("<start>")
    #caption.remove("<end>")
    caption = TreebankWordDetokenizer().detokenize(caption)
    caption.replace("<start>", "").replace("<end>", "").replace("<unk>","")
    #caption.remove("<end>")
    return caption

In [54]:
if platform == "colab":
    IMAGE_DIR = '/content/drive/My Drive/train_images/'
else:
    #IMAGE_DIR = 'D:/Padhai/IIT Delhi MS(R)/2019-20 Sem II/COL774 Machine Learning/Assignment/Assignment4/train_images/'
    IMAGE_DIR = 'D:/Padhai/IIT Delhi MS(R)/2019-20 Sem II/COL774 Machine Learning/Assignment/Assignment4/private_test_images/'

test_dataset = ImageCaptionsDataset(IMAGE_DIR, img_transform=img_transform)

NUM_WORKERS = 0 # Parallel threads for dataloading
EMBED_DIM = 256
HIDDEN_UNITS = 512
VOCAB_SIZE = len(VOCAB)
# Creating the DataLoader for batching purposes
train_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)


encoder = Encoder(EMBED_DIM)
decoder = Decoder(EMBED_DIM, HIDDEN_UNITS)

if restore:
    en_bkp_file = "encoder_review_final.pth"
    de_bkp_file = "decoder_review_final.pth"
    en_state_dict = collections.OrderedDict()
    en_state_dict = restore_checkpoint(en_bkp_file)
    
    de_state_dict = collections.OrderedDict()
    de_state_dict = restore_checkpoint(de_bkp_file)
    
    encoder.load_state_dict(en_state_dict)
    decoder.load_state_dict(de_state_dict)
    print("STATE DICTIONARIES LOADED SUCCESSFULLY...!!!")

decoder_params = sum(p.numel() for p in decoder.parameters())
encoder_total = sum(p.numel() for p in encoder.parameters())
encoder_trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
encoder_fc2_params = sum(p.numel() for p in encoder.fc2.parameters())


if device == "cuda":
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    
    print("ENCODER and DECODER TO CUDA...!")


if device != "cpu":
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    torch.backends.cudnn.benchmark = True
t0 = time()
MAX_WORDS = 30
prediction_dict = {}
for batch_idx, sample in enumerate(train_loader):
  
    idx, image = sample['idx'], sample['image']
    if device != "cpu":
        image = image.cuda()
    #print("\nIMAGE:",image)
    image_features = encoder(image)
    #image_features = image_features.view(-1)[torch.randperm(image_features.nelement())].view(image_features.size())
    #print("\nIMAGE FEATURES:",image_features[0].tolist()[:5],image_features[0].tolist()[-5:])
    predicted_caption = predict_captions(image_features, MAX_WORDS)
    prediction_dict[idx.item()] = predicted_caption
    print(idx.item(), predicted_caption)


EMBED DIM = 256
resnet50 Loaded Successfully..!
VOCAB SIZE DECODER INIT = 8680
Iterations = 580, Epoch = 3
Iterations = 580, Epoch = 3
STATE DICTIONARIES LOADED SUCCESSFULLY...!!!
34 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta camminando per strada . <end>. <end>. <end>. <end>. <end>.
51 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta camminando per strada . <end>. <end>. <end>. <end>. <end>.
89 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta camminando per strada . <end>. <end>. <end>. <end>. <end>.
97 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta camminando per strada . <end>. <end>. <end>. <end>. <end>.
124 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta camminando per strada . <end>. <end>. <end>. <end>. <end>.
133 <start> un uomo con un cappello da cowboy e una camicia bianca e pantaloni neri sta c

KeyboardInterrupt: 