# Baseline Description Generation Model

In this notebook is the implementation of the baseline model for a description generation model. This model is based on an encoder-decoder model, where the encoder is a CNN and the decoder is a LSTM-RNN language model. 

The encoder will be pretrained on a dataset with a relative small number of labels, for the classification task. After the pretraining, both the encoder and decoder are jointly trained for the task of generating descriptions. 

For the current baseline a simple implementation will be used without any form of attention. 

## import packages

In [2]:
# loadbars to track the run/speed
from tqdm import tqdm_notebook, tnrange

# numpy for arrays/matrices/mathematical stuff
import numpy as np
np.set_printoptions(threshold=np.nan) #will print entire matrix without dots...

# nltk for tokenizer
from nltk.tokenize import wordpunct_tokenize   

# torch for the NN stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD

# torch tools for data processing
from torch.utils.data import DataLoader
import pycocotools #cocoAPI

# torchvision for the image dataset and image processing
from torchvision.datasets import CocoCaptions
from torchvision import transforms
from torchvision import models

# packages for plotting
import matplotlib.pyplot as plt
import seaborn

# additional stuff
import dill
import pickle
from collections import Counter
from collections import defaultdict
import os
from datetime import datetime

In [14]:
resnet = models.resnet152(pretrained=True)
list(resnet.children())[:-1]

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
 Sequential(
   (0): Bottleneck(
     (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (downsample): Sequential(
       (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affi

#### test if device has GPU

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
device

device(type='cpu')

## Hyper Parameters

In the code many hyper parameters will be used. For instance the file locations, dimensions for the networks layers, etc.

In [4]:
learning_rate = 1e-1
max_epochs = 30
batch_size = 16

vocab_size = 30000
embedding_size = 2048

save_step = 100

PAD = '<PAD>'
START = '<START>'
END = '<END>'
UNK = '<UNK>'

crop_size = 224
transform = transforms.Compose([ 
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(), 
            transforms.Normalize((0.485, 0.456, 0.406), 
                                 (0.229, 0.224, 0.225))])

## Data Processing

First we load the data from the COCO captions Dataset

In [5]:
temp_data = CocoCaptions(root = '/home/victor/coco/images/train2014/',annFile = '/home/victor/coco/annotations/captions_train2014.json', transform=transforms.ToTensor())
train_data = CocoCaptions(root = '/home/victor/coco/images/train2014/',annFile = '/home/victor/coco/annotations/captions_train2014.json', transform=transform)
val_data = CocoCaptions(root = '/home/victor/coco/images/val2014/',annFile = '/home/victor/coco/annotations/captions_val2014.json', transform=transform)

loading annotations into memory...
Done (t=1.53s)
creating index...
index created!
loading annotations into memory...
Done (t=0.83s)
creating index...
index created!
loading annotations into memory...
Done (t=0.72s)
creating index...
index created!


A vocabulary class is created to keep track of words in the dataset.

In [6]:
class DataProcessor():
    def __init__(self, data, vocab_size, filename=None):
        self.vocab_size = vocab_size
        if filename == None:
            filename = 'vocab_'+str(self.vocab_size)+'.pkl'
        self.filename = filename
        if os.path.isfile(self.filename):
            self.vocab, self.vocab_size, self.vocab_weight = self.load(data)
        else: 
            self.vocab, self.vocab_size, self.vocab_weight = self.build_vocab(data)            
        self.w2i, self.i2w = self.build_dicts()
    
    def build_dicts(self):
        """
        creates lookup tables to find the index given the word 
        and the otherway around 
        """
        w2i = defaultdict(lambda: w2i[UNK])
        i2w = dict()
        for i,w in enumerate(self.vocab):
            i2w[i] = w
            w2i[w] = i
        return w2i, i2w
    
    def build_vocab(self, data): 
        """
        builds a vocabulary with the most occuring words, in addition to
        the UNK token at index 1 and PAD token at index 0. 
        START and END tokens are added to the vocabulary through the
        preprocessed sentences.
        with vocab size none, all existing words in the data are used
        """
        vocab = Counter()
        for item in tqdm_notebook(data):
            for sent in item[1]:
                s = wordpunct_tokenize(sent[0].lower())
                for w in s:
                    vocab[w] += 1

        vocab = [k for k,_ in vocab.most_common(self.vocab_size - 4)] #minus 4 because of the default tokens
        vocab_weights = list(range(len(vocab)))
        vocab = [PAD,UNK,START,END] + vocab # padding needs to be first, because of the math
        vocab_weights = [0.,1.,1.,1.] + vocab_weights
        return vocab,len(vocab), vocab_weights 
    
    def save(self):
        pickle.dump(self.vocab, open(self.filename, 'wb'))
        
    def load(self):
        pickle.load(vocab, open(self.filename, 'rb'))
        vocab_size = len(vocab)
        vocab_weights = [0.,1.,1.,1.] + list(range(len(vocab)))
        return vocab, vocab_size, vocab_weights
        

### function for preparing the batch in correct format

In [7]:
def transform_batch(batch, processor):
    """
    input batch: a list of tuples of sentences. 
    the lenght of the list is the number of sentences for an image. 
    the length of the tuple is the batch size.
    
    output batch: a tensor with for each image one of the sentences randomly chosen. 
    the first dim is the batchsize. second dim is the sentence length. 
    the sentences are padded with zeros and prefixed and post fixed with the 
    START and END token. The words are transformed to indices. 
    """
    chosen_sents = []
    sent_lengths = []
    longest = -1
    for sample in range(len(batch[0])):
        sentnum = np.random.choice(len(batch))
        s = [START] + wordpunct_tokenize(batch[sentnum][sample].lower()) + [END]
        l = len(s)
        chosen_sents.append(s)
        sent_lengths.append(l)
        if longest < l:
            longest = l

    trans_batch = np.zeros((len(chosen_sents), longest))
    for i,s in enumerate(chosen_sents):
        trans_batch[i,:len(s)] = np.array([processor.w2i[w] for w in s])
    batch = torch.from_numpy(trans_batch).type(torch.LongTensor).to(device)
    sent_lengths = torch.FloatTensor(sent_lengths).to(device)
    return batch, sent_lengths

## Encoder

The encoder is a CNN which first is pretrained on the image classification task. Once pretrained, it will be used for encoding in an vector representation.

This can be extended to deviding the image into a grid, where each gridcell is encoded into a vector. During decoding, an attention can then be used over the grid vectors. 

In [8]:
class EncoderCNN(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embedding_size)
        self.batchnorm = nn.BatchNorm1d(embedding_size)
        
    def forward(self, x):
        # the resnet is pretrained, so turn of the gradient
        with torch.no_grad():
            out = self.resnet(x)
        out = out.reshape(out.size(0), -1)
        out = self.linear(out)
        out = self.batchnorm(out)
        return out

## Decoder

The decoder is a LSTM-RNN which for each timestep generates a single word. In the first step, the hidden layer is initialized with the encoded vector. 

In [9]:
class Decoder(nn.Module):
    def __init__(self, target_vocab_size, embedding_size):
        super().__init__()
        
        self.embedding_size = embedding_size
    
        self.target_embeddings = nn.Embedding(target_vocab_size, embedding_size)
        self.LSTM = nn.LSTM(embedding_size, embedding_size)
        self.logit_lin = nn.Linear(embedding_size, target_vocab_size) # out
        
    def forward(self, input_words, hidden_input):  
        # find the embedding of the correct word to be predicted
        emb = self.target_embeddings(input_words)
        # reshape to the correct order for the LSTM
        emb = emb.view(1,emb.size(0),self.embedding_size)
        # Put through the next LSTM step
        lstm_output, hidden = self.LSTM(emb, hidden_input)
        output = self.logit_lin(lstm_output)

        return output, hidden

## Encoder-Decoder

A single model is created to tie both the networks together

In [10]:
class CaptionModel(nn.Module):
    def __init__(self, 
                 embedding_size,
                 target_vocab_size,
                 device,
                 num_topics = 100):
        
        super().__init__()
        self.target_vocab_size = target_vocab_size
        
        self.encoder = EncoderCNN(embedding_size).to(device)
        self.decoder = Decoder(target_vocab_size,embedding_size).to(device)
        
        self.topic_embeddings = nn.Embedding(num_topics, embedding_size)
        self.topic_lin1 = nn.Linear(embedding_size, embedding_size)
        self.topic_lin2 = nn.Linear(embedding_size, embedding_size)
        self.topic_relu = nn.ReLU()

        self.loss = nn.CrossEntropyLoss(ignore_index=0, reduce=False).to(device)

    def forward(self,images, captions, caption_lengths):     
        #topic modelling
        topic_r1 = self.topic_relu(self.topic_lin1(images))
        topic_r2 = self.topic_relu(self.topic_lin2(r1))
        
        # Encode
        h0 = self.encoder(images)
        
        #prepare decoder initial hidden state
        h0 = h0.unsqueeze(0)
        c0 = torch.zeros(h0.shape)
        hidden_state = (h0,c0)
        
        # Decode
        batch_size, max_sent_len = captions.shape
        out = torch.zeros((batch_size))  
        for w_idx in range(max_sent_len-1):
            # binary switch
            torch.cat(images, hidden_state)
            # language model
            prediction, hidden_state = self.decoder(captions[:,w_idx].view(-1,1), hidden_state)
            out += self.loss(prediction.squeeze(0), captions[:,w_idx+1])
        
        #normalize loss
        out = torch.mean(torch.div(out,caption_lengths))  # the loss is the average of losses, so divide over number of words in each sentence
        
        return out

## Setup Network

the model is initialised and the optimizer for the model is set. 

In [11]:
caption_model = CaptionModel(embedding_size, vocab_size, device)
caption_model.train(True) #probably not needed. better to be safe
opt = SGD(caption_model.parameters(), lr=learning_rate)

KeyboardInterrupt: 

An dataprocessor is created. If a pickle with the given vocabsize already exists, it is loaded, otherwise a new one is created. 

In [13]:
# setup dataloaders with train and val data
temploader = DataLoader(dataset=temp_data, batch_size=1, shuffle=False, drop_last=False, num_workers=1)
processor = DataProcessor(data=temploader, vocab_size=vocab_size)
processor.save()
del(temploader)
del(temp_data)


HBox(children=(IntProgress(value=0, max=82783), HTML(value='')))




Dataloader for processing the data for both the training and validation data are loaded. 

In [55]:
trainloader = DataLoader(dataset=train_data, batch_size=2, shuffle=True, drop_last=True, num_workers=4)
valloader = DataLoader(dataset=val_data, batch_size=1, shuffle=True, drop_last=True, num_workers=4)

## Train

In [107]:
losses = []

opt.zero_grad()

#loop over number of epochs
for it in tnrange(1):
    batch_losses = []
    #loop over all the training batches
    for i_batch, (image, caption) in tqdm_notebook(enumerate(trainloader), total=len(trainloader),leave=False):
        image = image.to(device)
        caption, caption_lengths = transform_batch(caption, processor)
        loss = caption_model(image, caption, caption_lengths)
        loss.backward()
        batch_losses.append(float(loss))
        opt.step()
        if i_batch == 100:
            break
    losses += batch_losses

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=41391), HTML(value='')))

#### save the trained model

In [108]:
losses

[5.677347183227539,
 6.263830184936523,
 6.010639667510986,
 6.627231597900391,
 5.240501403808594,
 6.097297668457031,
 6.538796424865723,
 5.61591911315918,
 5.088732719421387,
 6.336756706237793,
 5.212302207946777,
 5.183477878570557,
 6.108344078063965,
 5.698338985443115,
 6.346169471740723,
 5.851977348327637,
 5.411420822143555,
 5.807036399841309,
 5.518843650817871,
 4.911383628845215,
 5.431375503540039,
 6.984222412109375,
 6.465095520019531,
 7.481247425079346,
 6.067309379577637,
 6.476211071014404,
 6.563053607940674,
 6.030815124511719,
 6.7606964111328125,
 6.8211259841918945,
 5.965514183044434,
 6.815525054931641,
 5.64166259765625,
 6.392036437988281,
 6.32445764541626,
 8.40359115600586,
 7.272234916687012,
 7.894472122192383,
 7.5314788818359375,
 5.755794525146484,
 7.293753623962402,
 6.357259750366211,
 5.608423233032227,
 8.047916412353516,
 6.749805450439453,
 6.04095458984375,
 7.458600044250488,
 7.317066669464111,
 7.891091346740723,
 7.4335126876831055,
 