# Show and Tell

Implementing the simplest model based on a [Show and Tell paper](https://arxiv.org/pdf/1411.4555.pdf)

## Prepare notebook if running in Google Colab

In [None]:
import sys

IN_COLAB = 'google.colab' in sys.modules
IN_COLAB

In [None]:
if IN_COLAB:
    from google.colab import drive

    drive.mount('/content/drive')

## Use GPU

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load dataset

In [None]:
if IN_COLAB:
    ROOT = 'drive/My Drive/test_rclone'
else:
    ROOT = 'datasets'

DATASET = 'mini_coco'
ANNOTATIONS_PATH = 'annotations/captions_{0}2014.json'
IMAGES_PATH = 'images/{0}2014'

In [None]:
import torchvision
import os

train_dataset = torchvision.datasets.CocoCaptions(
    root = os.path.join(ROOT, DATASET, IMAGES_PATH.format('train')),
    annFile = os.path.join(ROOT, DATASET, ANNOTATIONS_PATH.format('train')))

## Create dictionary

In [None]:
from nltk.tokenize import word_tokenize
import string
from collections import defaultdict

In [None]:
import nltk

nltk.download('punkt')

In [None]:
def clean_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.lower()
    return word_tokenize(text)

In [None]:
c = defaultdict(int)

for image, texts in train_dataset:
    for text in texts:
        text = clean_text(text)
        for word in text:
            c[word] += 1

In [None]:
c_filtered = [word for word in c if c[word] > 1]

In [None]:
START = '<START>'
UNK = '<UNK>'
END = '<END>'

c_filtered.append(START)
c_filtered.append(UNK)
c_filtered.append(END)

In [None]:
i2w = {}
w2i = {}

for index, word in enumerate(c_filtered):
    i2w[index] = word
    w2i[word] = index

In [None]:
def transform_text(text):
    text = clean_text(text)
    
    sequence = [w2i[START]]
    for word in text:
        if word in w2i:
            sequence.append(w2i[word])
        else:
            sequence.append(w2i[UNK])
    sequence.append(w2i[END])
    return sequence

In [None]:
print(w2i)

## Create dataloader

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
def collate_fn_train(batch):
    images_list = []
    texts_list = []
    for image, texts in batch:
        image = transform(image)
        images_list += [image] * len(texts)
        
        for text in texts:
            texts_list.append(torch.tensor(transform_text(text)))
            
    images_list, texts_list = \
        list(zip(*sorted(zip(images_list, texts_list), key=lambda x: x[1].shape[0], reverse=True)))
    
    inputs = [text[:-1] for text in texts_list]
    outputs = [text[1:] for text in texts_list]
    
    packed_inputs = torch.nn.utils.rnn.pack_sequence(inputs, enforce_sorted=True)
    packed_outputs = torch.nn.utils.rnn.pack_sequence(outputs, enforce_sorted=True)
    return torch.stack(images_list), packed_inputs, packed_outputs

In [None]:
import torch

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_train)

In [None]:
def collate_fn_validate(batch):
    images_list = []
    texts_list = []
    for image, texts in batch:
        images_list.append(transform(image))
        texts = list(map(lambda text: ' '.join(clean_text(text)), texts))
        texts_list.append(texts)
    return torch.stack(images_list), texts_list

In [None]:
valloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_validate)

## Beam search

In [None]:
import numpy as np

def simple_beam_search(model, image, w2i, i2w, max_length=15, beam_size=1):
    # Here are two problems
    # 1. Size of new_hyps on every iteration is beam_size ** 2,
    # while we can use only 2 * beam_size memory
    # 2. Here are some cycles that can be replaced with a numpy vectorized operations
    #image = transform(image)
    #image = torch.unsqueeze(image, 0)

    model = model.to(device)
    image = image.to(device)
    
    cur_hyps = [[w2i[START]]]
    cur_probs = [1.]
    cur_hiddens = model.encoder(image)
    cur_hiddens = torch.unsqueeze(cur_hiddens, 0)
    for i in range(max_length):
        packed_inputs = torch.nn.utils.rnn.pack_sequence(
            [torch.tensor([hyp[-1]]) for hyp in cur_hyps], enforce_sorted=True)
        cur_hiddens = cur_hiddens.to(device)
        packed_inputs = packed_inputs.to(device)
        probs, hiddens = model.decoder(cur_hiddens, packed_inputs)
        new_hyps = []
        new_probs = []
        new_hiddens = []
        for hyp, cur_prob, prob, hidden in zip(cur_hyps, cur_probs, probs.data, hiddens.data.tolist()[0]):
            if hyp[-1] == w2i[END]:
                new_hyps.append(hyp)
                new_probs.append(cur_prob)
                new_hiddens.append(hidden)
                continue
            max_words = torch.argsort(prob)[-beam_size:]
            for word in max_words:
                new_hyp = hyp.copy()
                new_hyp.append(word.item())
                new_hyps.append(new_hyp)
                new_probs.append(cur_prob * prob[word].item())
                new_hiddens.append(hidden)
        new_probs = np.array(new_probs)
        new_hiddens = torch.tensor(new_hiddens)
        best_hyps = np.argsort(new_probs)[-beam_size:]
        cur_probs = new_probs[best_hyps]
        cur_hiddens = new_hiddens[best_hyps]
        cur_hiddens = torch.unsqueeze(cur_hiddens, 0)
        cur_hyps = []
        for hyp_num in best_hyps:
            cur_hyps.append(new_hyps[hyp_num])
            
    assert(np.argmax(cur_probs) == len(cur_probs) - 1)
            
    return cur_hyps[-1]

## Validation

In [None]:
import evaluation
from importlib import reload

evaluation = reload(evaluation)

def validate(model, dataloader, w2i, i2w):
    gts_dict = {}
    hyps_dict = {}
    for i, (image, texts) in enumerate(dataloader):
        hyp = simple_beam_search(model, image, w2i, i2w, beam_size=3)[1:]
        if hyp[-1] == w2i['<END>']:
            hyp = hyp[:-1]
        hyp = ' '.join([i2w[word] for word in hyp])
        gts_dict[i] = texts[0]
        hyps_dict[i] = [hyp]
    return evaluation.compute_scores(gts_dict, hyps_dict)

## Setup model

In [None]:
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self, dict_size, embedding_dim, hidden_size, *args, **kwargs):
        super(SimpleModel, self).__init__(*args, **kwargs)
        self.hidden_size = hidden_size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=5, out_channels=10, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3)
        self.pooling = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(in_features=13520, out_features=hidden_size)
        self.encoder_layers = [
            self.conv1, self.pooling, self.relu,
            self.conv2, self.pooling, self.relu,
            self.conv3, self.pooling, self.relu]
        
        self.embedding = nn.Embedding(num_embeddings=dict_size, embedding_dim=embedding_dim)
        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=dict_size)
        
    def encoder(self, image):
        for layer in self.encoder_layers:
            image = layer(image)
        return self.linear1(image.view(-1, 13520)).view(-1, self.hidden_size)
    
    def decoder(self, image_vector, input_captions):
        embeddings = nn.utils.rnn.PackedSequence(
            self.embedding(input_captions.data),
            input_captions.batch_sizes)
        decoded, hiddens = self.rnn(embeddings, image_vector)
        probs = self.linear2(decoded.data)
        return nn.utils.rnn.PackedSequence(probs, decoded.batch_sizes), hiddens

    def forward(self, image, input_captions):
        image_vector = self.encoder(image)
        image_vector = image_vector.unsqueeze(0)
        return self.decoder(image_vector, input_captions)

In [None]:
from torchvision import models


class SimpleModelWithEncoder(nn.Module):
    def __init__(self, dict_size, embedding_dim, hidden_size, *args, **kwargs):
        super(SimpleModelWithEncoder, self).__init__(*args, **kwargs)
        
        resnet = torchvision.models.resnet101(pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        for param in self.resnet.parameters():
            param.requires_grad = False
        # TODO: try to use mean instead of flatten all the features
        self.linear1 = nn.Linear(in_features=100352, out_features=hidden_size)

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(num_embeddings=dict_size, embedding_dim=embedding_dim)
        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=dict_size)
        
    def encoder(self, image):
        image = self.resnet(image)
        return self.linear1(image.view(image.shape[0], -1)).view(-1, self.hidden_size)
    
    def decoder(self, image_vector, input_captions):
        embeddings = nn.utils.rnn.PackedSequence(
            self.embedding(input_captions.data),
            input_captions.batch_sizes)
        decoded, hiddens = self.rnn(embeddings, image_vector)
        probs = self.linear2(decoded.data)
        return nn.utils.rnn.PackedSequence(probs, decoded.batch_sizes), hiddens

    def forward(self, image, input_captions):
        image_vector = self.encoder(image)
        image_vector = image_vector.unsqueeze(0)
        return self.decoder(image_vector, input_captions)

In [None]:
model = SimpleModelWithEncoder(dict_size=len(w2i), embedding_dim=512, hidden_size=512)
model = model.to(device)

## Training the model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad ], lr=0.0001)

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [None]:
for epoch in range(500):
    total_loss = 0.0
    total_samples = 0
    for image, inputs, outputs in trainloader:
        image = image.to(device)
        inputs = inputs.to(device)
        outputs = outputs.to(device)

        total_samples += image.shape[0]
        optimizer.zero_grad()

        ans, _ = model(image, inputs)
        loss = criterion(ans.data, outputs.data)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * image.shape[0]
    total_loss /= total_samples
    
    torch.save(model.state_dict(), os.path.join(ROOT, 'model.pth'))
        
    with torch.no_grad():    
        scores = validate(model, valloader, w2i, i2w)

    writer.add_scalar('loss', total_loss, epoch)
    for score_name in scores:
        writer.add_scalar(score_name, scores[score_name], epoch)
    
    print('-----')
    print('Epoch: ', epoch)
    print('Loss: ', total_loss)
    print('Scores: ', scores)

In [None]:
# IN FEATURES NUM
num_ftrs = model_ft.fc.in_features

In [None]:
if IN_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir runs

## Load the model

In [None]:
model = SimpleModelWithEncoder(dict_size=len(w2i), embedding_dim=512, hidden_size=512)
model.load_state_dict(torch.load('model.pth', map_location=device))