In [None]:
import numpy as np
import json
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from scipy.misc import imread, imresize
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
from PIL import Image
from models import *

In [None]:
checkpoint = 'BEST_checkpoint_10.pth.tar'
checkpoint = torch.load(checkpoint)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

with open('WORDMAP.json', 'r') as j:
    word_map = json.load(j)

rev_word_map = {v: k for k, v in word_map.items()}  # idx2word

In [None]:
def predict_output(image): 
    """
    Predict output with beam size of 1 (predict the word and feed it to the next LSTM). 
    Prints out the generated sentence
    """
    max_len = 20
    begining_lstm = True
    begining_sen = True

    sampled = []
    rev_word_map = {v: k for k, v in word_map.items()}  # idx2word

    img = imread(image)
    img = imresize(img, (254, 254))
    img = img.transpose(2, 0, 1)
    img = img / 255.
    img = torch.FloatTensor(img).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([normalize])
    image = transform(img)  # (3, 254, 254)
    # Encode
    image = image.unsqueeze(0)  # (1, 3, 254, 254)
    spatial_image, global_image, enc_image = encoder(image)
    pred = torch.LongTensor([[word_map['<start>']]]).to(device)   # (1, 1)  
    hidden, cell = decoder.init_hidden_state(enc_image)       #  (1,hidden_size)
    num_pix = spatial_image.shape[1]
    alphas_stored = torch.zeros(max_len, num_pix)
    betas_stored = torch.zeros(max_len,1)

    for timestep in range(max_len):
        embeddings = decoder.embedding(pred).squeeze(1)       # (1,1,embed_dim) --> (1,embed_dim)    
        inputs = torch.cat((embeddings,global_image), dim = 1)    # (1, embed_dim * 2)
        if begining_lstm:
            h, c = decoder.LSTM(inputs, (hidden, cell))     # (1, hidden_size)
            begining_lstm = False
        else:  
            h_prev = h.clone()
            c_prev = c.clone()
            h, c = decoder.LSTM(inputs, (h, c))  # (1, hidden_size)
        # Run the sentinal model
        if begining_sen:
            st = decoder.sentinal(inputs, hidden, cell)
            begining_sen = False
        else:
            st = decoder.sentinal(inputs, h_prev, c_prev)

        alpha_t, ct, zt = decoder.spatial_attention(spatial_image,h)
        # Run the adaptive attention model
        c_hat, beta_t, alpha_hat = decoder.adaptive_attention(h, st, zt, ct)
        # Compute the probability
        pt = decoder.fc(c_hat + h)  
        _,pred = pt.max(1)
        sampled.append(pred.item())
        alphas_stored[timestep] = alpha_t
        betas_stored[timestep] = beta_t

    generated_words = [rev_word_map[sampled[i]] for i in range(len(sampled))]
    filtered_words = ' '.join([word for word in generated_words if word != '<end>'])

    print("Prediction Using Greedy Search:", filtered_words)
    print("Betas:", betas_stored)

In [None]:
# Implementation with Beam Search
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
    
    k = beam_size
    vocab_size = len(word_map)

    # Read image and process
    img = imread(image_path)
    if len(img.shape) == 2:
        img = img[:, :, np.newaxis]
        img = np.concatenate([img, img, img], axis=2)
    img = imresize(img, (254, 254))
    img = img.transpose(2, 0, 1)
    img = img / 255.
    img = torch.FloatTensor(img).to(device)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([normalize])
    image = transform(img)  # (3, 254, 254)
    # Encode
    image = image.unsqueeze(0)  # (1, 3, 254, 254)
    spatial_image, global_image, encoder_out = encoder(image) #enc_image of shape (batch_size,num_pixels,features)
    # Flatten encoding
    num_pixels = encoder_out.size(1)
    encoder_dim = encoder_out.size(2)
    enc_image_size = 8     
    # We'll treat the problem as having a batch size of k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)
    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)
    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)
    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)
    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)
    #Tensor to store the top k sequences betas
    seqs_betas = torch.ones(k,1,1).to(device) 
    # Lists to store completed sequences, their alphas, betas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()
    complete_seqs_betas = list()       
    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)
    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:
        embeddings = decoder.embedding(k_prev_words).squeeze(1)  
        inputs = torch.cat((embeddings,global_image.expand_as(embeddings)), dim = 1)
        h_prev = h.clone()
        c_prev = c.clone()
        h, c = decoder.LSTM(inputs, (h, c))  # (1, hidden_size)
        # Run the sentinal model
        st = decoder.sentinal(inputs, h_prev, c_prev)
        alpha, ct, zt = decoder.spatial_attention(spatial_image,h)
        # Run the adaptive attention model
        c_hat, beta_t, alpha_hat = decoder.adaptive_attention(h, st, zt, ct)
        alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)
        # Compute the probability
        scores = decoder.fc(c_hat + h) 
        scores = F.log_softmax(scores, dim=1)   # (s, vocab_size)
        # Add
        # (k,1) will be (k,vocab_size), then (k,vocab_size) + (s,vocab_size) --> (s, vocab_size)
        scores = top_k_scores.expand_as(scores) + scores  
        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            #Remember: torch.topk returns the top k scores in the first argument, and their respective indices in the second argument
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s) 
        next_word_inds = top_k_words % vocab_size  # (s) 
        # Add new words to sequences, alphas
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
        # (s, step+1, enc_image_size, enc_image_size)
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],dim=1)  
        seqs_betas = torch.cat([seqs_betas[prev_word_inds], beta_t[prev_word_inds].unsqueeze(1)], dim=1)  

        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if next_word != word_map['<end>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
            complete_seqs_betas.extend(seqs_betas[complete_inds])   
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]              
        seqs_alpha = seqs_alpha[incomplete_inds]   
        seqs_betas = seqs_betas[incomplete_inds]    
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]
    alphas = complete_seqs_alpha[i]
    betas = complete_seqs_betas[i]          

    return seq, alphas, betas     

In [None]:
def visualize_att(image_path, seq, alphas, betas, rev_word_map, smooth=True):
    """
    Visualizes caption with weights at every word.
    Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
    :param image_path: path to image that has been captioned
    :param seq: caption
    :param alphas: weights
    :param rev_word_map: reverse word mapping, i.e. ix2word
    :param smooth: smooth weights?
    """
    image = Image.open(image_path)
    image = image.resize([8 * 8, 8 * 8], Image.LANCZOS)
    words = [rev_word_map[ind] for ind in seq]
    print(' '.join(words[1:-1]))

    for t in range(len(words)):
        if t > 50:
            break
        plt.subplot(np.ceil(len(words) / 5.), 5, t + 1)
        plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
        plt.text(50, 1, '%.3f' % (betas[t].item()), color='red', backgroundcolor='white', fontsize=10)
        plt.imshow(image)
        current_alpha = alphas[t, :]
        if smooth:
            alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=8, sigma=8)
        else:
            alpha = skimage.transform.resize(current_alpha.numpy(), [8 * 8, 8 * 8])
        if t == 0:
            plt.imshow(alpha, alpha=0)
        else:
            plt.imshow(alpha, alpha=0.8)
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
        
    plt.show()

In [None]:
predict_output('test.jpg')

In [None]:
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 12)  # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# Encode, decode with attention and beam search k=3
seq, alphas, betas = caption_image_beam_search(encoder, decoder, 'test.jpg', word_map)
alphas = torch.FloatTensor(alphas)
# Visualize caption and attention of best sequence
visualize_att('test.jpg', seq, alphas, betas, rev_word_map, smooth=True)