# Show and Tell

Implementing the simplest model based on a [Show and Tell paper](https://arxiv.org/pdf/1411.4555.pdf)

## Load dataset

In [None]:
ROOT = 'datasets'
DATASET = 'mini_coco'
ANNOTATIONS_PATH = 'annotations/captions_{0}2014.json'
IMAGES_PATH = 'images/{0}2014'

In [None]:
import torchvision
import os

train_dataset = torchvision.datasets.CocoCaptions(
    root = os.path.join(ROOT, DATASET, IMAGES_PATH.format('train')),
    annFile = os.path.join(ROOT, DATASET, ANNOTATIONS_PATH.format('train')))

## Create dictionary

In [None]:
from nltk.tokenize import word_tokenize
import string
from collections import defaultdict

In [None]:
def clean_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.lower()
    return word_tokenize(text)

In [None]:
c = defaultdict(int)

for image, texts in train_dataset:
    for text in texts:
        text = clean_text(text)
        for word in text:
            c[word] += 1

In [None]:
c_filtered = [word for word in c if c[word] > 0]

In [None]:
START = '<START>'
UNK = '<UNK>'
END = '<END>'

c_filtered.append(START)
c_filtered.append(UNK)
c_filtered.append(END)

In [None]:
i2w = {}
w2i = {}

for index, word in enumerate(c_filtered):
    i2w[index] = word
    w2i[word] = index

In [None]:
def transform_text(text):
    text = clean_text(text)
    
    sequence = [w2i[START]]
    for word in text:
        if word in w2i:
            sequence.append(w2i[word])
        else:
            sequence.append(w2i[UNK])
    sequence.append(w2i[END])
    return sequence

In [None]:
print(w2i)

## Create dataloader

In [None]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.Resize((200, 200)),
     torchvision.transforms.ToTensor()])

In [None]:
def collate_fn_train(batch):
    images_list = []
    texts_list = []
    for image, texts in batch:
        image = transform(image)
        images_list += [image] * len(texts)
        
        for text in texts:
            texts_list.append(torch.tensor(transform_text(text)))
            
    images_list, texts_list = \
        list(zip(*sorted(zip(images_list, texts_list), key=lambda x: x[1].shape[0], reverse=True)))
    
    inputs = [text[:-1] for text in texts_list]
    outputs = [text[1:] for text in texts_list]
    
    packed_inputs = torch.nn.utils.rnn.pack_sequence(inputs, enforce_sorted=True)
    packed_outputs = torch.nn.utils.rnn.pack_sequence(outputs, enforce_sorted=True)
    return torch.stack(images_list), packed_inputs, packed_outputs

In [None]:
import torch

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_train)

In [None]:
def collate_fn_validate(batch):
    images_list = []
    texts_list = []
    for image, texts in batch:
        images_list.append(transform(image))
        texts = list(map(lambda text: ' '.join(clean_text(text)), texts))
        texts_list.append(texts)
    return torch.stack(images_list), texts_list

In [None]:
valloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_validate)

## Beam search

In [None]:
import numpy as np

def simple_beam_search(model, image, w2i, i2w, max_length=15, beam_size=1):
    # Here are two problems
    # 1. Size of new_hyps on every iteration is beam_size ** 2,
    # while we can use only 2 * beam_size memory
    # 2. Here are some cycles that can be replaced with a numpy vectorized operations
    #image = transform(image)
    #image = torch.unsqueeze(image, 0)
    
    cur_hyps = [[w2i[START]]]
    cur_probs = [1.]
    cur_hiddens = model.encoder(image)
    cur_hiddens = torch.unsqueeze(cur_hiddens, 0)
    for i in range(max_length):
        print('Ith beam: ', i)
        packed_inputs = torch.nn.utils.rnn.pack_sequence(
            [torch.tensor([hyp[-1]]) for hyp in cur_hyps], enforce_sorted=True)
        probs, hiddens = model.decoder(cur_hiddens, packed_inputs)
        print(probs)
        new_hyps = []
        new_probs = []
        new_hiddens = []
        for hyp, cur_prob, prob, hidden in zip(cur_hyps, cur_probs, probs.data, hiddens.data.tolist()[0]):
            if hyp[-1] == w2i[END]:
                new_hyps.append(hyp)
                new_probs.append(cur_prob)
                new_hiddens.append(hidden)
                continue
            max_words = torch.argsort(prob)[-beam_size:]
            for word in max_words:
                new_hyp = hyp.copy()
                new_hyp.append(word.item())
                new_hyps.append(new_hyp)
                new_probs.append(cur_prob * prob[word].item())
                new_hiddens.append(hidden)
        new_probs = np.array(new_probs)
        new_hiddens = torch.tensor(new_hiddens)
        best_hyps = np.argsort(new_probs)[-beam_size:]
        cur_probs = new_probs[best_hyps]
        cur_hiddens = new_hiddens[best_hyps]
        cur_hiddens = torch.unsqueeze(cur_hiddens, 0)
        cur_hyps = []
        for hyp_num in best_hyps:
            cur_hyps.append(new_hyps[hyp_num])
            
    assert(np.argmax(cur_probs) == len(cur_probs) - 1)
            
    return cur_hyps[-1]

## Model evaluation

### BLEU

In [None]:
#!/usr/bin/env python

# bleu_scorer.py
# David Chiang <chiang@isi.edu>

# Copyright (c) 2004-2006 University of Maryland. All rights
# reserved. Do not redistribute without permission from the
# author. Not for commercial use.

# Modified by:
# Hao Fang <hfang@uw.edu>
# Tsung-Yi Lin <tl483@cornell.edu>

''' Provides:
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
'''

import copy
import sys, math, re
from collections import defaultdict


def precook(s, n=4, out=False):
    """Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well."""
    words = s.split()
    counts = defaultdict(int)
    for k in range(1, n + 1):
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i:i + k])
            counts[ngram] += 1
    return (len(words), counts)


def cook_refs(refs, eff=None, n=4):  ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.'''

    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram, count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen)) / len(reflen)

    ## lhuang: N.B.: leave reflen computaiton to the very end!!

    ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)

    return (reflen, maxcounts)


def cook_test(test, ref_tuple, eff=None, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.'''

    testlen, counts = precook(test, n, True)
    reflen, refmaxcounts = ref_tuple

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
    else:  ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]

    result['correct'] = [0] * n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)

    return result


class BleuScorer(object):
    """Bleu scorer.
    """

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"

    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        ''' copy the refs.'''
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        ''' singular instance '''

        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test)  ## N.B.: -1
            else:
                self.ctest.append(None)  # lens of crefs and ctest have to match

        self._score = None  ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        '''
        return (bleu, len_ratio) pair
        '''

        return self.fscore(option=option), self.ratio(option=option)

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        ''' replace test(s) with new test(s), and returns the new score.'''

        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None  ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens)) / len(reflens)
        elif option == "closest":
            reflen = min((abs(l - testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15  ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None:  ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess', 'correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        / (float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1. / (k + 1)))
            ratio = (testlen + tiny) / (reflen + small)  ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1 / ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1. / (k + 1)))
        ratio = (self._testlen + tiny) / (self._reflen + small)  ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1 / ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list

In [None]:
#!/usr/bin/env python
#
# File Name : bleu.py
#
# Description : Wrapper for BLEU scorer.
#
# Creation Date : 06-01-2015
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>

class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        #score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
        #score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"

## Validation

In [None]:
def validate(model, dataloader, w2i, i2w):
    bleu = Bleu(n=4)
    
    gts_dict = {}
    hyps_dict = {}
    for i, (image, texts) in enumerate(dataloader):
        hyp = simple_beam_search(model, image, w2i, i2w, beam_size=3)[1:]
        if hyp[-1] == w2i['<END>']:
            hyp = hyp[:-1]
        hyp = ' '.join([i2w[word] for word in hyp])
        gts_dict[i] = texts[0]
        hyps_dict[i] = [hyp]
    (bleu1, bleu2, bleu3, bleu4), _ = bleu.compute_score(gts_dict, hyps_dict)
    return bleu1, bleu2, bleu3, bleu4

## Setup model

In [None]:
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self, dict_size, embedding_dim, hidden_size, *args, **kwargs):
        super(SimpleModel, self).__init__(*args, **kwargs)
        self.hidden_size = hidden_size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=5, out_channels=10, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=3)
        self.pooling = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(in_features=10580, out_features=hidden_size)
        self.encoder_layers = [
            self.conv1, self.pooling, self.relu,
            self.conv2, self.pooling, self.relu,
            self.conv3, self.pooling, self.relu]
        
        self.embedding = nn.Embedding(num_embeddings=dict_size, embedding_dim=embedding_dim)
        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=dict_size)
        self.softmax = nn.Softmax(dim=1)
        
    def encoder(self, image):
        for layer in self.encoder_layers:
            image = layer(image)
        return self.linear1(image.view(-1, 10580)).view(-1, self.hidden_size)
    
    def decoder(self, image_vector, input_captions):
        embeddings = nn.utils.rnn.PackedSequence(
            self.embedding(input_captions.data),
            input_captions.batch_sizes)
        decoded, hiddens = self.rnn(embeddings, image_vector)
        probs = self.softmax(self.linear2(decoded.data))
        return nn.utils.rnn.PackedSequence(probs, decoded.batch_sizes), hiddens

    def forward(self, image, input_captions):
        image_vector = self.encoder(image)
        image_vector = image_vector.unsqueeze(0)
        return self.decoder(image_vector, input_captions)

In [None]:
model = SimpleModel(dict_size=len(w2i), embedding_dim=32, hidden_size=32)

## Training the model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [None]:
for epoch in range(200):
    total_loss = 0.0
    for image, inputs, outputs in trainloader:
        optimizer.zero_grad()

        ans, _ = model(image, inputs)
        loss = criterion(ans.data, outputs.data)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    bleu1, bleu2, bleu3, bleu4 = validate(model, valloader, w2i, i2w)

    writer.add_scalar('loss', total_loss, epoch)
    writer.add_scalar('bleu1', bleu1, epoch)
    writer.add_scalar('bleu2', bleu2, epoch)
    writer.add_scalar('bleu3', bleu3, epoch)
    writer.add_scalar('bleu4', bleu4, epoch)