In [None]:
Steps to train the model:
* Take a pre-trained inception v3 to vectorize images
* Stack an LSTM on top of it
* Train on MSCOCO

In [None]:
import numpy as np
import json
from random import choice
from collections import defaultdict, Counter
from textwrap import wrap
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from IPython.display import clear_output
from models.beheaded_inception3 import beheaded_inception_v3
import warnings
warnings.filterwarnings("ignore")

In [None]:
# load dataset (vectorized images and captions)
img_codes = np.load('data/image_codes.npy')
captions = json.load(open('data/captions_tokenized.json'))

In [None]:
# split descriptions into tokens
for img_i in range(len(captions)):
    for caption_i in range(len(captions[img_i])):
        sentence = captions[img_i][caption_i] 
        captions[img_i][caption_i] = ["#START#"] + sentence.split(' ') + ["#END#"]

In [None]:
from collections import Counter

# compute word frequencies for each word in captions
word_counts = Counter()
for img in captions:
    for caption in img:
        word_counts.update(caption)

# build a vocabulary
vocab  = ['#UNK#', '#START#', '#END#', '#PAD#']
vocab += [k for k, v in word_counts.items() if v >= 5 if k not in vocab]
n_tokens = len(vocab)

word_to_index = {w: i for i, w in enumerate(vocab)}

In [None]:
# function for converting list of tokens into matrix

eos_ix = word_to_index['#END#']
unk_ix = word_to_index['#UNK#']
pad_ix = word_to_index['#PAD#']

def as_matrix(sequences, max_len=None):
    """ Convert a list of tokens into a matrix with padding """
    max_len = max_len or max(map(len,sequences))
    
    matrix = np.zeros((len(sequences), max_len), dtype='int32') + pad_ix
    for i,seq in enumerate(sequences):
        row_ix = [word_to_index.get(word, unk_ix) for word in seq[:max_len]]
        matrix[i, :len(row_ix)] = row_ix
    
    return matrix

# Model 

<img src="https://github.com/yunjey/pytorch-tutorial/raw/master/tutorials/03-advanced/image_captioning/png/model.png" style="width:70%">

In [None]:
class CaptionNet(nn.Module):
    
    def __init__(self, n_tokens, emb_size=128, lstm_units=256, cnn_feature_size=2048):
        """ A recurrent 'head' network for image captioning. See scheme above. """
        super().__init__()
        
        # a layer that converts conv features to initial_h (h_0) and initial_c (c_0)
        self.cnn_to_h0 = nn.Linear(cnn_feature_size, lstm_units)
        self.cnn_to_c0 = nn.Linear(cnn_feature_size, lstm_units)

        # create embedding for input words. Use the parameters (e.g. emb_size).
        self.embedding = nn.Embedding(n_tokens, emb_size)
            
        # lstm: create a recurrent core of your network. Use either LSTMCell or just LSTM. 
        # In the latter case (nn.LSTM), make sure batch_first=True
        self.lstm = nn.LSTM(emb_size, lstm_units, batch_first=True)
            
        # create logits: linear layer that takes lstm hidden state as input and computes one number per token
        self.rnn_to_logits = nn.Linear(lstm_units, n_tokens)
        
    def forward(self, image_vectors, captions_ix):
        """ 
        Apply the network in training mode. 
        :param image_vectors: torch tensor containing inception vectors. shape: [batch, cnn_feature_size]
        :param captions_ix: torch tensor containing captions as matrix. shape: [batch, word_i]. 
            padded with pad_ix
        :returns: logits for next token at each tick, shape: [batch, word_i, n_tokens]
        """

        self.lstm.flatten_parameters()

        initial_cell = self.cnn_to_c0(image_vectors)
        initial_hid = self.cnn_to_h0(image_vectors)
        
        # compute embeddings for captions_ix
        emb_ix = self.embedding(captions_ix)
        
        # lstm_out should be lstm hidden state sequence of shape [batch, caption_length, lstm_units]
        lstm_out, _ = self.lstm(emb_ix, (initial_cell[None], initial_hid[None]))
        
        # compute logits from lstm_out
        logits = self.rnn_to_logits(lstm_out)

        return logits

In [None]:
def compute_loss(network, image_vectors, captions_ix):
    """
    :param image_vectors: torch tensor containing inception vectors. shape: [batch, cnn_feature_size]
    :param captions_ix: torch tensor containing captions as matrix. shape: [batch, word_i]. 
        padded with pad_ix
    :returns: crossentropy (neg llh) loss for next captions_ix given previous ones. Scalar float tensor
    """
    
    # captions for input - all except last because we don't know next token for last one.
    captions_ix_inp = captions_ix[:, :-1].contiguous()
    captions_ix_next = captions_ix[:, 1:].contiguous()
    
    # apply the network, get predictions for captions_ix_next
    logits_for_next = network.forward(image_vectors, captions_ix_inp)
    
    # compute the loss function between logits_for_next and captions_ix_next
    loss = F.cross_entropy(
        logits_for_next.permute((0,2,1)), 
        captions_ix_next, 
        ignore_index=pad_ix
    )

    return loss

In [None]:
def generate_batch(img_codes, captions, batch_size, max_caption_len=None):
    
    # sample random numbers for image/caption indicies
    random_image_ix = np.random.randint(0, len(img_codes), size=batch_size)
    
    # get images
    batch_images = img_codes[random_image_ix]
    
    # captions for each image
    captions_for_batch_images = captions[random_image_ix]
    
    # pick one from a set of captions for each image
    batch_captions = list(map(choice,captions_for_batch_images))
    
    # convert to matrix
    batch_captions_ix = as_matrix(batch_captions,max_len=max_caption_len)
    
    return torch.tensor(batch_images, dtype=torch.float32), \
        torch.tensor(batch_captions_ix, dtype=torch.int64)

In [None]:
def generate_caption(network, vectorizer, image, caption_prefix = ('#START#',), t=1, sample=True, max_len=100):
    network = network.cpu().eval()

    assert isinstance(image, np.ndarray) and np.max(image) <= 1\
           and np.min(image) >= 0 and image.shape[-1] == 3
    
    image = torch.tensor(image.transpose([2, 0, 1]), dtype=torch.float32)
    
    vectors_8x8, vectors_neck, logits = vectorizer(image[None])
    caption_prefix = list(caption_prefix)
    
    for _ in range(max_len):
        
        prefix_ix = as_matrix([caption_prefix])
        prefix_ix = torch.tensor(prefix_ix, dtype=torch.int64)
        next_word_logits = network.forward(vectors_neck, prefix_ix)[0, -1]
        next_word_probs = F.softmax(next_word_logits, -1).detach().numpy()
        
        assert len(next_word_probs.shape) == 1, 'probs must be one-dimensional'
        next_word_probs = next_word_probs ** t / np.sum(next_word_probs ** t) # apply temperature

        if sample:
            next_word = np.random.choice(vocab, p=next_word_probs) 
        else:
            next_word = vocab[np.argmax(next_word_probs)]

        caption_prefix.append(next_word)

        if next_word == '#END#':
            break

    return ' '.join(caption_prefix[1:-1])

In [None]:
# split the dataset
captions = np.array(captions)
train_img_codes, val_img_codes, train_captions, val_captions = train_test_split(
    img_codes, captions, test_size=0.1, random_state=42
)

In [None]:
def train(
    network,
    optimizer,
    checkpoint_in, 
    train_img_codes,
    train_captions,
    val_img_codes,
    val_captions,
    batch_size=128,
    n_epochs_to_train=100,
    n_batches_per_epoch=50,
    n_validation_batches=5,   
    max_epochs_to_improve=5
    device: torch.device = torch.device('cpu'),
):

    '''
    function to performe:
    - training of the network (network)
        using optimizer (optimizer)
        on the training set of images (train_img_codes) and 
        captions (train_captions)
    - validation of the results 
        on the validating set of images (val_img_codes) and 
        captions (val_captions)
    function keeps track of the training and validation losses
    function starts training from the specified checkpoint (checkpoint_in)
    and saves and returns the checkpoint after specified number of training epochs (n_epochs_to_train)
    is perfomed
    function stops training if the following early stopping criterion is satisfied:
    - if after reaching local minimum validation loss exceeds the local minimum 
    in more than (max_epochs_to_improve) consecutive epochs, then training stops   
    '''

    # initialize dictionary to track 
    # training and validation losses
    if not checkpoint_in:
        checkpoint_in = {
            'epoch': 1,
            'state_dict': network_n.state_dict(),
            'optimizer': optimizer_n.state_dict(),
            'train_loss': [],
            'valid_loss': []
        }    
    
    HISTORY = collections.defaultdict(list)
    HISTORY['train_loss'] = checkpoint_in['train_loss'].copy()
    HISTORY['valid_loss'] = checkpoint_in['valid_loss'].copy()

    network.to(device)

    # load initial state of the network  
    network.load_state_dict(checkpoint_in['state_dict'])        
    # load initial state of the optimizer
    optimizer.load_state_dict(checkpoint_in['optimizer'])
    # get the number of the initial epoch
    start_epoch = checkpoint_in['epoch']

    # set the initial local minimum of the validation loss
    if len(checkpoint_in['valid_loss']) == 0:
        min_valid_loss = np.inf
    else:    
        min_valid_loss = min(checkpoint_in['valid_loss'])

    # set the initial number of validation loss values
    # exceeding local minimum 
    epochs_to_improve = 0

    # for each training epoch 
    for epoch in range(1, n_epochs_to_train + 1):    
        # train the network
        # and compute the training loss of the epoch
        train_loss = 0
        network.train()
        for _ in tqdm(range(n_batches_per_epoch)):
            images, captions = generate_batch(train_img_codes, train_captions, batch_size)
            images = images.to(DEVICE)
            captions = captions.to(DEVICE)
            loss_t = compute_loss(network, images, captions)
            # clear old gradients
            optimizer.zero_grad()
            # do a backward pass
            loss_t.backward()
            # next step
            optimizer.step()
            train_loss += loss_t.detach().cpu().numpy()

        # calculate and store the training loss of the epoch    
        train_loss /= n_batches_per_epoch
        HISTORY['train_loss'].append(train_loss[0])

        # and compute the validating loss of the epoch
        val_loss = 0
        network.eval()
        for _ in range(n_validation_batches):
            images, captions = generate_batch(val_img_codes, val_captions, batch_size)
            images = images.to(DEVICE)
            captions = captions.to(DEVICE)
            with torch.no_grad():
                loss_t = compute_loss(network, images, captions)
            val_loss += loss_t.detach().cpu().numpy()

        # calculate and store the validating loss of the epoch
        val_loss /= n_validation_batches
        HISTORY['valid_loss'].append(val_loss[0])

        # visualize train and validation losses
        display.clear_output()       
        print(f'\nEpoch: {start_epoch + epoch}, train loss: {train_loss[0]:.4f}, validation loss: {val_loss[0]:.4f}')

        # plot training and validation losses     
        fig, axes = plt.subplots(1, 1, figsize=(7, 7))    
        axes.set_title('Loss (Cross Entropy)')
        axes.plot(range(1, start_epoch + epoch + 1), HISTORY['train_loss'], label='Train Loss')
        axes.plot(range(1, start_epoch + epoch + 1), HISTORY['valid_loss'], label='Validation Loss')
        axes.set_xticks(range(1, start_epoch + epoch + 1))
        axes.grid()
        axes.legend(fontsize=20)    
        plt.show()

        # increase counter of validation loss values
        # exceeding local minimum   
        if val_loss[0] > min_valid_loss:
            epochs_to_improve += 1
        # update local minimum for validation loss
        # and zero the counter  
        elif val_loss[0] < min_valid_loss:
            min_valid_loss = val_loss[0]
            epochs_to_improve = 0

        # apply early stopping criterion
        if epochs_to_improve > max_epochs_to_improve:
            print('\nEarly stopping criteria satisfied!')
            break     

    # save checkpoint after training
    checkpoint_out = {
    'epoch': start_epoch + epoch,
    'state_dict': network.state_dict(),
    'optimizer': optimizer.state_dict(),
    'train_loss': HISTORY['train_loss'],
    'valid_loss': HISTORY['valid_loss']
    }

    print(f"\nTraining completed after {start_epoch + epoch} epochs!")

    return checkpoint_out

In [None]:
def print_metrics(checkpoint):
  '''
  function to print last observed and average loss metrics
  on train and validation datasets for the checkpoint provided
  '''

  print(f"\nAfter training the network for {checkpoint['epoch']-1} epoches we achieved: \
        \n last train loss: {checkpoint['train_loss'][-1]:0.4f} \
        \n average train loss: {np.mean(checkpoint['train_loss']):0.4f} \
        \n last validation loss: {checkpoint['valid_loss'][-1]:0.4f} \
        \n average validation loss: {np.mean(checkpoint['valid_loss']):0.4f} \
        ")

In [None]:
# set path to the checkpoint file
MODEL_CHECKPOINT_PATH = 'checkpoint.pt'
# set number of training epochs
NUM_OF_TRAINING_EPOCHS = 30
# set the batch size
BATCH_SIZE = 128

In [None]:
# set up the device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# create the model
model = CaptionNet(n_tokens)
# create an optimizer
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
# try to load checkpoint
checkpoint_in = None
try:
    checkpoint_in = torch.load(MODEL_CHECKPOINT_PATH, map_location=device)
except FileNotFoundError:
    print('unable to load model checkpoint\n')

# train the network for 6 epoch
checkpoint_out = train(
    model,
    optimizer,
    checkpoint_in, 
    train_img_codes, 
    train_captions,
    val_img_codes,
    val_captions,
    batch_size=BATCH_SIZE, 
    n_epochs_to_train=n_epochs_to_train,    
    n_batches_per_epoch=n_batches_per_epoch, 
    n_validation_batches=n_validation_batches
    )

# print metrics after training
print_metrics(checkpoint_out)