In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import pandas as pd
import csv
import numpy as np
import cv2
import random

from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchsummary import summary
from pycocotools.coco import COCO

np.random.seed(0)

We use COCO to test our image captioning model.

In [3]:
class CocoDataset(data.Dataset):

    '''
    Implement the dataloader for COCO. This is used to train the image captioning model.
    '''

    def __init__(self):

        self.captions = COCO('./data/Coco/annotations/captions_val2017.json')
        self.num_tokens = 15000

        #Keep things very simple for now - just have these 4 special tokens
        #Change later to prevent initialization from messing up these tokens - save to text file
        #somewhere instead of initializing inside dataset

        self.eos = np.random.randn(50) #index 0
        self.bos = np.random.randn(50) #index 1
        self.unk = np.random.randn(50) #index 2
        self.pad = np.random.randn(50) #index 3
        glove_data_filepath = './data/glove.6B.50d.txt'

        #I wanted to try something a bit different from the paper - that is, using pretrained word embeddings
        #instead. This fits in with the spirit of the paper in learning with less information
        #In practice, we probably want something like nltk's tokenize, or implement a scheme similar to
        #BERT's wordpiece tokenizer. Here as the data is already fairly clean (and to prevent me from
        #spending a ton of time tinkering with the tokenizer), we simply split on spaces and convert
        #to lowercase. OOV words are replaced by the UNK token. If you want to get really fancy, encode
        #words using something like BERT's encoding layer.

        df = pd.read_csv(glove_data_filepath, sep=" ", quoting=3, header=None, index_col=0).head(self.num_tokens)

        words = list(df.T.items())
        self.word_vector_dict = {key: (i+4, val.values) for i, (key, val) in enumerate(words)}

        self.max_length = 16

    def __getitem__(self, idx):

        caption = self.captions.dataset['annotations'][idx]['caption']
        words = caption[:-1].split(' ') + [caption[-1]]
        word_indices = [1]
        caption_embeddings = [self.bos]
        for word in words[:self.max_length]:

            try:

                index, word_vector = self.word_vector_dict[word.lower()]
                caption_embeddings.append(word_vector)
                word_indices.append(index)
            except KeyError:
                caption_embeddings.append(self.unk)
                word_indices.append(2)

        len_caption = len(caption_embeddings)
        len_padding = self.max_length + 1 - len_caption

        caption_embeddings = caption_embeddings + [self.eos] + [self.pad]*len_padding
        word_indices = word_indices + [0] + [3]*len_padding

        picture_name = str(self.captions.dataset['annotations'][idx]['image_id']).zfill(12)
        picture_filepath = f'./data/Coco/images/{picture_name}.jpg'

        #In a real life scenario, I would probably normalize these images using global mean/variance
        #statistics as well as add in augmentation. Also the images are very small, but here I just wanted
        #to test that the model works on my crappy GPU. This is very proof of concept.
        picture = cv2.imread(picture_filepath)
        picture = cv2.resize(picture, (28, 28))
        picture = cv2.cvtColor(picture, cv2.COLOR_BGR2GRAY)

        return picture/255, np.array(caption_embeddings), np.array(word_indices)

    def __len__(self):

        return len(self.captions.dataset['annotations'])

In [4]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=128):
        return input.view(input.size(0), size, 1, 1)

In [None]:
class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=128, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels,4, kernel_size=5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            Flatten()
        )

        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)

        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 16, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(8, image_channels, kernel_size=6, stride=2, padding=1),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar, num_samples):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)

        if num_samples == 1:
            esp = torch.randn(*mu.size())
        else:
            esp = torch.randn(num_samples, *mu.size())

        z = mu + std * esp.to(device)
        return z

    def bottleneck(self, h, num_samples):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar, num_samples)
        return z, mu, logvar

    def encode(self, x, num_samples=1):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h, num_samples=num_samples)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

In [None]:
class RNN(nn.Module):

    def __init__(self, hidden_size=32):

        super(RNN, self).__init__()
        num_tokens = 15000
        glove_embedding_dim = 50

        self.hidden_size = hidden_size
        self.rnn = nn.GRU(glove_embedding_dim, hidden_size, num_layers=1, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, num_tokens)
        self.num_tokens = num_tokens
        self.max_length = 16 + 2

    def get_last_token(self, x, hidden):

        output, hidden = self.rnn(x, hidden)

        prediction = self.fc_out(output.squeeze(0))

        return prediction, hidden

    def forward(self, x, hidden):

        outputs = torch.zeros(self.max_length, x.shape[0], self.num_tokens)
        input = x[:,0,:].unsqueeze(1)

        for i in range(1, self.max_length):

            output, hidden = self.get_last_token(input, hidden)
            outputs[i,...] = output.squeeze(1)

            input = x[:,i,:].unsqueeze(1)

        return outputs

In [None]:
def loss_fn(pred_x, x, mu, logvar, pred_caption=None, caption=None):

    alpha = 1
    beta = 1

    # BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

    BCE = F.binary_cross_entropy(pred_x, x, size_average=False)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    if caption != None:
        sentence_loss = F.cross_entropy(pred_caption, word_indices.long(), ignore_index=2)
        return alpha*(BCE + KLD) + beta*sentence_loss

    return BCE + KLD

In [None]:
def train():
    
    #Let's use coco

    coco = CocoDataset()
    batch_size = 64

    dataloader = torch.utils.data.DataLoader(dataset=coco,
                                                batch_size=batch_size,
                                                shuffle=True)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    vae = VAE(image_channels=1).to(device)
    rnn = RNN().to(device)

    num_epochs = 10
    num_samples = 10
    num_tokens = 15000

    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

    for epoch in range(num_epochs):
        for idx, (images, caption, word_indices) in enumerate(dataloader):

            images = images.float().unsqueeze(1).to(device)
            caption = caption.float().to(device)
            
            
            #Get ten samples from each distribution
            #Variance reduction from reparametrization trick

            monte_carlo_embeddings, mu, logvar = vae.encode(images, num_samples=num_samples)
            words_per_sentence = torch.count_nonzero(word_indices - 3, axis=1)

            loss = 0

            for embedding in monte_carlo_embeddings:

                #Using teacher forcing, get predictions for each word in the caption
                pred_caption = rnn(caption, embedding.view(1, batch_size, -1))
                pred_caption = pred_caption.permute(1,2,0)

                pred_images = vae.decode(embedding)
                loss += loss_fn(pred_images, images, mu, logvar, pred_caption, caption)

            loss /= num_samples

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()