# Show and Tell

Implementing the simplest model based on a [Show and Tell paper](https://arxiv.org/pdf/1411.4555.pdf)

## Load dataset

In [None]:
ROOT = 'datasets'
DATASET = 'mini_coco'
ANNOTATIONS_PATH = 'annotations/captions_{0}2014.json'
IMAGES_PATH = 'images/{0}2014'

In [None]:
import torchvision
import os

train_dataset = torchvision.datasets.CocoCaptions(
    root = os.path.join(ROOT, 'mini_coco', IMAGES_PATH.format('train')),
    annFile = os.path.join(ROOT, 'mini_coco', ANNOTATIONS_PATH.format('train')))

## Create dataloader

In [None]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize((200, 200)),
     torchvision.transforms.ToTensor()])

def collate_fn(batch):
    # WORKS ONLY FOR BATCH SIZE = 0!!!!!!
    image = transform(batch[0][0])
    image = torch.unsqueeze(image, 0)

    transformed_texts = []
    for text in batch[0][1]:
        transformed_texts.append(torch.tensor(transform_text(text)))
    transformed_texts.sort(key=lambda x: x.shape[0], reverse=True)
    
    inputs = [text[:-1] for text in transformed_texts]
    outputs = [text[1:] for text in transformed_texts]
    
    packed_inputs = torch.nn.utils.rnn.pack_sequence(inputs, enforce_sorted=True)
    packed_outputs = torch.nn.utils.rnn.pack_sequence(outputs, enforce_sorted=True)
    return image, packed_inputs, packed_outputs

In [None]:
import torch

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

## Create dictionary

In [None]:
def clean_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.lower()
    return word_tokenize(text)

In [None]:
from nltk.tokenize import word_tokenize
import string
from collections import defaultdict

c = defaultdict(int)

for image, texts in cap:
    for text in texts:
        text = clean_text(text)
        for word in text:
            c[word] += 1

In [None]:
c_filtered = [word for word in c if c[word] > 3]

In [None]:
START = '<START>'
UNK = '<UNK>'
END = '<END>'

c_filtered.append(START)
c_filtered.append(UNK)
c_filtered.append(END)

In [None]:
i2w = {}
w2i = {}

for index, word in enumerate(c_filtered):
    i2w[index] = word
    w2i[word] = index

In [None]:
def transform_text(text):
    text = clean_text(text)
    
    sequence = [w2i[START]]
    for word in text:
        if word in w2i:
            sequence.append(w2i[word])
        else:
            sequence.append(w2i[UNK])
    sequence.append(w2i[END])
    return sequence

## Setup model

In [None]:
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self, dict_size, embedding_dim, hidden_size, *args, **kwargs):
        super(SimpleModel, self).__init__(*args, **kwargs)
        self.hidden_size = hidden_size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=5, out_channels=10, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3)
        self.pooling = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(in_features=10580, out_features=hidden_size)
        self.encoder_layers = [
            self.conv1, self.pooling, self.relu,
            self.conv2, self.pooling, self.relu,
            self.conv3, self.pooling, self.relu]
        
        self.embedding = nn.Embedding(num_embeddings=dict_size, embedding_dim=embedding_dim)
        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=dict_size)
        self.softmax = nn.Softmax(dim=1)
        
    def encoder(self, image):
        for layer in self.encoder_layers:
            image = layer(image)
        return self.linear1(image.view(-1, 10580))
    
    def decoder(self, image_vector, input_captions):
        embeddings = nn.utils.rnn.PackedSequence(
            self.embedding(input_captions.data),
            input_captions.batch_sizes)
        decoded, _ = self.rnn(embeddings, image_vector)
        probs = self.softmax(self.linear2(decoded.data))
        return nn.utils.rnn.PackedSequence(probs, decoded.batch_sizes)

    def forward(self, image, input_captions):
        image_vector = self.encoder(image).view(-1, self.hidden_size)
        image_vector = image_vector.repeat(1, 5, 1)
        return self.decoder(image_vector, input_captions)

In [None]:
model = SimpleModel(dict_size=len(w2i), embedding_dim=32, hidden_size=32)

## Training the model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
for epoch in range(10000):
    running_loss = 0.0
    for image, inputs, outputs in trainloader:
        optimizer.zero_grad()

        ans = model(image, inputs)
        loss = criterion(ans.data, outputs.data)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    if epoch % 100 == 0:
        print(running_loss)