In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
'''Import modules'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
from collections import Counter
from skimage import io, transform
from torch.nn.utils.rnn import pack_padded_sequence
from torchsummary import summary

import matplotlib.pyplot as plt # for plotting
import numpy as np
from time import time
import collections
import pickle
import os
import gensim
import nltk
from PIL import Image

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Aman\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device =", device)
print("Using", torch.cuda.device_count(), "GPUs!")
platform = "local" #colab/local
restore = False #Restore Checkpoint
phase = "Train"

Device = cpu
Using 0 GPUs!


In [4]:
VOCAB = {}
WORD2IDX = {}
IDX2WORD = {}
TRAIN_CAPTIONS_DICT = {}

In [5]:
class BuildVocab:
    def __init__(self, captions_file_path):
        self.captions_file_path = captions_file_path
        self.raw_captions_dict = self.read_raw_captions()
        # Preprocess captions
        self.captions_dict = self.process_captions()
        
        # Create vocabulary
        self.start = "<start>"
        self.end = "<end>"
        self.oov = "<unk>"
        self.pad = "<pad>"
        self.vocab = self.generate_vocabulary()
        self.word2index = self.convert_word2index()        
        self.index2word = self.convert_index2word()
        
    def read_raw_captions(self):
        """
        Returns:
            Dictionary with raw captions list keyed by image ids (integers)
        """
        captions_dict = {}
        with open(self.captions_file_path, 'r', encoding='utf-8') as f:
            for img_caption_line in f.readlines():
                img_captions = img_caption_line.strip().split('\t')
                captions_dict[int(img_captions[0])] = img_captions[1:]
        
        TRAIN_CAPTIONS_DICT.clear()
        TRAIN_CAPTIONS_DICT.update(captions_dict)        

        return captions_dict 

    def process_captions(self):
        """
        Use this function to generate dictionary and other preprocessing on captions
        """

        raw_captions_dict = self.raw_captions_dict 
        
        # Do the preprocessing here                
        captions_dict = raw_captions_dict

        return captions_dict
    
    def generate_vocabulary(self):
        
        captions_dict = self.captions_dict
        
        all_captions = ""        
        for cap_lists in captions_dict.values():
            all_captions += " ".join(cap_lists)
        
        all_captions = nltk.tokenize.word_tokenize(all_captions.lower())
        all_captions.sort()
        
        vocab = {self.pad :1, self.oov :1, self.start :1, self.end :1}
        vocab_update = Counter(all_captions) 
        vocab_update = {k:v for k,v in vocab_update.items() if v >= freq_threshold}         
        
        vocab.update(vocab_update)

        VOCAB.clear()
        VOCAB.update(vocab)
        
        if platform == "colab":
            fname = '/content/drive/My Drive/A4/dict/VOCAB_comp.pkl'
        else:
            fname = '../dict/VOCAB_comp.pkl'

        with open(fname, 'wb') as handle:
            pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        print("VOCAB SIZE =", len(VOCAB))
        return vocab
    
    def convert_word2index(self):
        """
        word to index converter
        """
        word2index = {}
        vocab = self.vocab
        idx = 0
        words = vocab.keys()
        for w in words:
            word2index[w] = idx
            idx +=1

        WORD2IDX.clear()
        WORD2IDX.update(word2index)
        if platform == "colab":
            fname = '/content/drive/My Drive/A4/dict/WORD2IDX_comp.pkl'
        else:
            fname = '../dict/WORD2IDX_comp.pkl'
        #if not os.path.isfile(fname):
        with open(fname, 'wb') as handle:
            pickle.dump(word2index, handle, protocol=pickle.HIGHEST_PROTOCOL)
        return word2index
    
    def convert_index2word(self):
        """
        index to word converter
        """
        index2word = {}
        w2i = self.word2index
        idx = 0
        
        for k, v in w2i.items():
            index2word[v] = k
            
        IDX2WORD.clear()
        IDX2WORD.update(index2word)
        if platform == "colab":
            fname = '/content/drive/My Drive/A4/dict/IDX2WORD_comp.pkl'
        else:
            fname = '../dict/IDX2WORD_comp.pkl'
        #if not os.path.isfile(fname):
        with open(fname, 'wb') as handle:
            pickle.dump(index2word, handle, protocol=pickle.HIGHEST_PROTOCOL)
        return index2word

if platform == "colab":
    CAPTIONS_FILE_PATH = '/content/drive/My Drive/A4/train_captions.tsv'
else:
    #CAPTIONS_FILE_PATH = "D:/Padhai/IIT Delhi MS(R)/2019-20 Sem II/COL774 Machine Learning/Assignment/Assignment4/train_captions.tsv"
    CAPTIONS_FILE_PATH = "../data/train_cap64.tsv"

freq_threshold = 5

if phase == "Train":
    BuildVocab(CAPTIONS_FILE_PATH)
else:
    VOCAB.clear()
    WORD2IDX.clear()
    IDX2WORD.clear()
    if platform != 'colab':
        with open('../dict/VOCAB.pkl', 'rb') as handle:
            VOCAB = pickle.load(handle)
        with open('../dict/WORD2IDX.pkl', 'rb') as handle:
            WORD2IDX = pickle.load(handle)
        with open('../dict/IDX2WORD.pkl', 'rb') as handle:
            IDX2WORD = pickle.load(handle)
        print("Vocab Loaded Successfully")
    else:
        with open('/content/drive/My Drive/A4/dict/VOCAB.pkl', 'rb') as handle:
            VOCAB = pickle.load(handle)
        with open('/content/drive/My Drive/A4/dict/WORD2IDX.pkl', 'rb') as handle:
            WORD2IDX = pickle.load(handle)
        with open('/content/drive/My Drive/A4/dict/IDX2WORD.pkl', 'rb') as handle:
            IDX2WORD = pickle.load(handle)
        print("Vocab Loaded Successfully")
    print("VOCAB SIZE =", len(VOCAB))


VOCAB SIZE = 129


In [6]:
# Define a transform to pre-process the training images.
img_transform = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

    

In [11]:
class ImageCaptionsDataset(Dataset):

    def __init__(self, img_dir, captions_dict, img_transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            captions_dict: Dictionary with captions list keyed by image ids (integers)
            img_transform (callable, optional): Optional transform to be applied
                on the image sample.

            captions_transform: (callable, optional): Optional transform to be applied
                on the caption sample (list).
        """
        self.img_dir = img_dir
        self.captions_dict = captions_dict
        self.img_transform = img_transform
        images = os.listdir(os.path.join(img_dir))
        images = [i.split("_")[1][:-4] for i in images]
        images = [int(i) for i in images]
        images.sort()
        self.image_ids = images

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        #print('IMG No.', self.image_ids[idx])
        img_name = os.path.join(self.img_dir, 'image_{}.jpg'.format(self.image_ids[idx]))
        image = Image.open(img_name).convert('RGB')
        #print("RAW IMG", image.shape)
        if self.img_transform:
            image = self.img_transform(image)
            
        start = "<start>"
        end = "<end>"
        oov = "<unk>"
        
        captions = self.captions_dict[self.image_ids[idx]]
        processed_captions = list(map(lambda x: nltk.tokenize.word_tokenize(x.lower()) ,captions))
        
        processed_captions = list(map(lambda x: [start]+ x + [end], processed_captions))
        processed_captions = list(map(lambda x: list(map(lambda y: WORD2IDX[y] if y in VOCAB else WORD2IDX[oov],x)),
                                  processed_captions))
        #processed_captions = list(map(lambda x: torch.LongTensor(x) , processed_captions))
        #print(processed_captions)
        sample = {'idx':idx, 'image': image, 'captions': processed_captions}

        return sample
    
def custom_batch(batch):
    batch_size = len(batch)
    captions = []
    normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    
    list(map(lambda b: captions.extend(b['captions']),batch))    
    x = list(map(lambda b: b['image'].unsqueeze(0),batch)) 
    #print(x)
    captions = list(map(lambda c: torch.LongTensor(c),captions))
    lengths = list(map(lambda c: len(c),captions))
    captions = pad_sequence(captions, batch_first=True)
    images = torch.cat(x, dim=0)
    
    sample = {'image': images, 'captions': captions}    
    return sample

In [8]:
class Encoder(nn.Module):
    def __init__(self, embed_dim):
        super(Encoder, self).__init__()
        resnet50 = models.resnet50(pretrained=True, progress=True)        
        
        for param in resnet50.parameters():
            param.requires_grad = False
        self.resnet50 = resnet50
        self.fc2 = nn.Linear(in_features=1000, out_features = embed_dim)
        print("EMBED DIM =", embed_dim)
        print("resnet50 Loaded Successfully..!")

    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_units, lstm_layers = 1):
        super(Decoder, self).__init__()
        vocab_size = len(VOCAB)
        print("VOCAB SIZE DECODER INIT =", vocab_size)
        
        self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_units,
                            num_layers = lstm_layers, batch_first = True)
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        self.linear = nn.Linear(hidden_units, vocab_size)
    
    def forward(self, image_features, image_captions):
        
        embedded_captions = self.embed(image_captions)
        input_lstm = torch.cat((image_features, embedded_captions[:,:-1]), dim = 1)
        #print("LSTM INPUT SHAPE", input_lstm.shape)
        lstm_outputs, _ = self.lstm(input_lstm)
        
        lstm_outputs = self.linear(lstm_outputs)
        
        return lstm_outputs

In [9]:
'''Save and Restore Checkpoints'''
def create_checkpoint(path,model, optim_obj, loss_obj,iteration, epoch):
    checkpoint = {'epoch': epoch,
                  'iteration': iteration,
                  'model_state_dict': model.state_dict()}

    if platform == "colab":
        directory = '/content/drive/My Drive/A4/review_cp/'
    else:
        directory = '../review_cp/'

    torch.save(checkpoint, directory + path)
    
def restore_checkpoint(path):
    new_state_dict = collections.OrderedDict()
    if platform == "colab":
        directory = '/content/drive/My Drive/A4/review_cp/'
        checkpoint = torch.load(directory + path, map_location=torch.device('cpu'))
    else:
        directory = '../review_cp/'
        checkpoint = torch.load(directory + path, map_location=torch.device('cpu'))    
    
    epoch = checkpoint['epoch']
    new_state_dict = checkpoint['model_state_dict']
    iteration = checkpoint['iteration']
    print("Iterations = {}, Epoch = {}".format(iteration, epoch))
    return new_state_dict

In [10]:
if platform == "colab":
    IMAGE_DIR = '/content/drive/My Drive/train_images/'
else:
    #IMAGE_DIR = 'D:/Padhai/IIT Delhi MS(R)/2019-20 Sem II/COL774 Machine Learning/Assignment/Assignment4/train_images/'
    IMAGE_DIR = "../data/train/"
train_dataset = ImageCaptionsDataset(
    IMAGE_DIR, TRAIN_CAPTIONS_DICT, img_transform=img_transform
)
NUMBER_OF_EPOCHS = 3
LEARNING_RATE = 1e-3
BATCH_SIZE = 2
NUM_WORKERS = 0 # Parallel threads for dataloading
EMBED_DIM = 256
HIDDEN_UNITS = 512
VOCAB_SIZE = len(VOCAB)
# Creating the DataLoader for batching purposes
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
                          collate_fn=custom_batch)


encoder = Encoder(EMBED_DIM)
decoder = Decoder(EMBED_DIM, HIDDEN_UNITS)

if device == "cuda":
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_function = loss_function.cuda()
    print("ENCODER DECODER AND LOSS FUN. TO CUDA...!")


decoder_params = sum(p.numel() for p in decoder.parameters())
encoder_total = sum(p.numel() for p in encoder.parameters())
encoder_trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
encoder_fc2_params = sum(p.numel() for p in encoder.fc2.parameters())


paramaters = list(decoder.parameters()) + list(encoder.fc2.parameters())
optimizer = optim.Adam(paramaters, lr=LEARNING_RATE)
params_for_adam = sum(p.numel() for p in paramaters)

loss_function = nn.CrossEntropyLoss(ignore_index=WORD2IDX["<pad>"])

print("DECODER PARAMS ={}, ENCODER TOTAL PARAMS ={}, ENCODER TRAINABLE PARAMS ={}, ENCODER FC2 PARAMS ={}, TOTAL ADAM PARAMS: {}"
      .format(decoder_params, encoder_total, encoder_trainable_params, encoder_fc2_params, params_for_adam))
print("TOTAL EPOCHS: {}, BATCH SIZE: {}, OPTIMIZER: {}".format(NUMBER_OF_EPOCHS, BATCH_SIZE, optimizer))


if device != "cpu":
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    torch.backends.cudnn.benchmark = True
t0 = time()
for epoch in range(NUMBER_OF_EPOCHS):
    print("$$$$$----EPOCH {}----$$$$$$".format(epoch+1))
    iteration = 0
    
    for batch_idx, sample in enumerate(train_loader):
        iteration +=1
        if iteration%25 == 0:
            LEARNING_RATE *= 0.98
            for param_group in optimizer.param_groups:
                param_group['lr'] = LEARNING_RATE
            print("----LEARNING RATE =", LEARNING_RATE)
            
            
        image_batch, captions_batch = sample['image'], sample['captions']
        print("TRAIN INPUTS", image_batch.shape, captions_batch.shape)
        #print(image_batch)
        #print(captions_batch)
        optimizer.zero_grad()
        
        if device != "cpu":
            image_batch, captions_batch = image_batch.cuda(), captions_batch.cuda()
        
        image_features = encoder(image_batch)
        #print("IMAGE FEATURES:",image_features)
        image_features = torch.Tensor.repeat_interleave(image_features, repeats=5 , dim=0)
        #print("IMAGE INTERLEAVE FEATURES:",image_features)
        image_features = image_features.unsqueeze(1)
        
        decoder_op = decoder(image_features, captions_batch)
        #print("DECODER OP", decoder_op.shape)
        #print(decoder_op)
        #print("LOSS FUN INPUT", decoder_op.view(-1, VOCAB_SIZE).shape, captions_batch.view(-1).shape)
        loss = loss_function(decoder_op.view(-1, VOCAB_SIZE), captions_batch.view(-1))
        # Backward pass.
        loss.backward()
        
        # Update the parameters in the optimizer.
        optimizer.step()
        
        if iteration%50 == 0:
            create_checkpoint("encoder_review.pth", encoder, optimizer, loss, iteration, epoch+1)
            create_checkpoint("decoder_review.pth", decoder, optimizer, loss, iteration, epoch+1)
            
        print("ITERATION:[{}/{}] | LOSS: {} | EPOCH = [{}/{}] | TIME ELAPSED ={}Mins".format(iteration, round(29000/BATCH_SIZE)+1,
              round(loss.item(), 6), epoch+1, NUMBER_OF_EPOCHS, round((time()-t0)/60,2)))
    print("\n$$Loss = {},EPOCH: [{}/{}]\n\n".format(round(loss.item(), 6), epoch+1, NUMBER_OF_EPOCHS))
    create_checkpoint("encoder_review_epoch.pth", encoder, optimizer, loss, iteration, epoch+1)
    create_checkpoint("decoder_review_epoch.pth", decoder, optimizer, loss, iteration, epoch+1)

create_checkpoint("encoder_review_final.pth", encoder, optimizer, loss, iteration, epoch+1)
create_checkpoint("decoder_review_final.pth", decoder, optimizer, loss, iteration, epoch+1)


EMBED DIM = 256
resnet50 Loaded Successfully..!
VOCAB SIZE DECODER INIT = 129
DECODER PARAMS =1676161, ENCODER TOTAL PARAMS =25813288, ENCODER TRAINABLE PARAMS =256256, ENCODER FC2 PARAMS =256256, TOTAL ADAM PARAMS: 1932417
TOTAL EPOCHS: 3, BATCH SIZE: 2, OPTIMIZER: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
$$$$$----EPOCH 1----$$$$$$
IMG No. 28
IMG No. 31
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 22])
ITERATION:[1/14501] | LOSS: 4.867113 | EPOCH = [1/3] | TIME ELAPSED =0.01Mins
IMG No. 47
IMG No. 48
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 20])
ITERATION:[2/14501] | LOSS: 4.726323 | EPOCH = [1/3] | TIME ELAPSED =0.02Mins
IMG No. 37
IMG No. 2
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 23])
ITERATION:[3/14501] | LOSS: 4.506893 | EPOCH = [1/3] | TIME ELAPSED =0.04Mins
IMG No. 49
IMG No. 32
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 25])
ITERATION:[4/1

TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 17])
ITERATION:[15/14501] | LOSS: 2.146194 | EPOCH = [2/3] | TIME ELAPSED =0.86Mins
IMG No. 31
IMG No. 28
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 22])
ITERATION:[16/14501] | LOSS: 2.469762 | EPOCH = [2/3] | TIME ELAPSED =0.87Mins
IMG No. 49
IMG No. 39
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 29])
ITERATION:[17/14501] | LOSS: 2.434672 | EPOCH = [2/3] | TIME ELAPSED =0.89Mins
IMG No. 42
IMG No. 21
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 21])
ITERATION:[18/14501] | LOSS: 2.393933 | EPOCH = [2/3] | TIME ELAPSED =0.91Mins
IMG No. 30
IMG No. 38
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 18])
ITERATION:[19/14501] | LOSS: 1.909273 | EPOCH = [2/3] | TIME ELAPSED =0.93Mins
IMG No. 1
IMG No. 20
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 16])
ITERATION:[20/14501] | LOSS: 1.916919 | EPOCH = [2/3] | TIME ELAPSED =0.95Mins
IMG No. 61
IMG No. 32
TRAIN INPUTS torc

ITERATION:[30/14501] | LOSS: 2.367513 | EPOCH = [3/3] | TIME ELAPSED =1.87Mins
IMG No. 47
IMG No. 32
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 20])
ITERATION:[31/14501] | LOSS: 2.407211 | EPOCH = [3/3] | TIME ELAPSED =1.88Mins
IMG No. 40
IMG No. 61
TRAIN INPUTS torch.Size([2, 3, 224, 224]) torch.Size([10, 25])
ITERATION:[32/14501] | LOSS: 2.304646 | EPOCH = [3/3] | TIME ELAPSED =1.9Mins

$$Loss = 2.304646,EPOCH: [3/3]


