# Image Captioning

Keywords: CV, NLP, LSTM, attention (soft), MS-COCO, image processing, InceptionV3.

Implementation similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044).

In [0]:
# miscellaneous
import numpy as np
from collections import Counter
from h5py
from json
import time
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

# code
from caption import *

# PyTorch
import torch
import torch.optim
import torch.utils.data
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence

# computing environment
assert torch.cuda.is_available()
device = torch.device('cuda')
cudnn.benchmark = True

### Parameters

In [0]:
input_dir = 'processed_input' # input directory (processed data)
trained_model = "final_model.pth.tar"  # pre-trained model

### Load Word Map and Model

In [0]:
# load word map
with open('drive/My Drive/wordmap.json', 'r') as f:
    word_map = json.load(f)
# reverse the word map
reverse_word_map = {v:k for k,v in word_map.items()}
vocab_size = len(word_map)

# load model with state dict (only need enc and dec)
trained_model = torch.load(trained_model)
decoder = trained_model['decoder']
encoder = trained_model['encoder']

# set up dec and enc
encoder = encoder.to(device)
decoder = decoder.to(device)
encoder.eval()
decoder.eval()

### Transform

In [0]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

### Evaluate Function

In [0]:
def evaluate(b):
    data_loader = data.DataLoader(CapData(input_dir, 'test', transform=transforms.Compose([normalize]),
                                 batch_size=1, shuffle=True, num_workers=1, pin_memory=True))
    
    ground_truths = []
    predictions = []
    
    for i, (img, caps, len_caps, all_caps) in enumerate(tqdm(loader)):
        img = img.to(device)
        
        # get flattened encoding then expand it for beam search
        enc_out = encoder(img)
        enc_size = enc_out.size(1)
        enc_dim = enc_out.size(3)
        enc_out = enc_out.view(1, -1, enc_dim)
        num_pix = enc_out.size(1)
        enc_out = enc_out.expand(b, num_pix, enc_dim)
                                 
        # initialize top previous words, favored sequences, and scores
        prev_words = torch.LongTensor([[word_map['<start>']]] * b).to(device)
        seqs = prev_words
        top_scores = torch.zeros(b, 1).to(device)
        
        # containers of current seqs their scores
        curr_seqs = []
        curr_scores = []
        
        # init timestamp and LSTM internal states
        t = 1
        hidden, cell = decoder.init_state()
        
        while True:
            # decode with attention procedure
            emb = decoder.embedding(prev_words).squeeze(1)
            att_out, _ = decoder.attention(enc_out, hidden)
            gate = decoder.sigmoid(decoder.f_beta(hidden))
            att_out = gate * att_out
            hidden, cell = decoder.decode(torch.cat([emb, att_out], dim=1), (hidden, cell))
            
            # get scores
            scores = decoder.fc(hidden)
            scores = F.log_softmax(scores, dim=1)
            scores = top_scores.expand_as(scores) + scores
            
            # get top scores and words for this sequence step
            if t == 1:
                top_scores, top_words = scores[0].topk(b, 0, largest=True, sorted=True)
            else:
                top_scores, top_words = scores.view(-1).topk(b, 0, largest=True, sorted=True)
            
            # get indices to add new words
            prev_word_idxs = top_words / vocab_size
            next_word_idxs = top_words % vocab_size
            seqs = torch.cat([seqs[prev_word_idxs], next_word_idxs.unsqueeze(1)], dim=1)
            
            # get sequences that reached and have not reached <end>
            going_idxs = [idx for idx, next_word in enumerate(next_word_idxs) if next_word is not word_map['end']]
            ended_idxs = [set(range(len(next_word_idxs)) - set(going_idxs))]
            
            # store ended sequences
            if ended_idxs:
                curr_seqs.extend(seqs[ended_idxs].tolist())
                curr_scores.extend(top_scores[ended_idxs])
                
            # adjust beam size
            b -= len(ended_idxs)
            if b == 0:
                break
                
            # get rid of completed
            seqs = seqs[going_idxs]
            hidden = hidden[prev_word_idxs[going_idxs]]
            cell = cell[prev_word_idxs[going_idxs]]
            enc_out = enc_out[prev_word_idxs[going_idxs]]
            top_scores = top_scores[going_idxs].unsqueeze(1)
            prev_words = next_word_idxs[going_idxs].unsqueeze(1)
            
            # update step
            if t > 100:
                break
            t += 1
        
        # get seq
        i = curr_scores.index(max(curr_scores))
        seq = curr_seqs[i]
            
        # get ground truth and prediction
        img_caps = all_caps[0].tolist()
        img_caps = list(map(lambda cap: [w for w in cap if (w != word_map['<start>'] and w != word_map['<pad>'])]), 
                        img_caps)
        ground_truths.append(img_caps)
        predictions.append([w for w in seq if (w != word_map['<start>'] and w != word_map['<end>'])])
        
        assert len(ground_truths) == len(predictions)
    
    # compute bleu scores
    bleu_scores = corpus_bleu(ground_truths, predictions)
    
    return bleu_scores