In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from PIL import Image
import json

In [None]:
from pycocotools.coco import COCO

In [None]:
from caption import caption_image_beam_search, visualize_att

In [None]:
coco = COCO('./data/annotations/captions_flickr30k.json')

In [None]:
img_ids = set()
for k in coco.anns.keys():
    img_ids.add(coco.anns[k]['image_id'])

In [None]:
images = []
for img_id in img_ids:
    images.append('data/Flickr30k/flickr30k-images/'+coco.loadImgs(img_id)[0]['file_name'])

In [None]:

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main(args, test_images, visualize=False):
    
    # Load model
    checkpoint = torch.load(args['model'])
    decoder = checkpoint['decoder']
    decoder = decoder.to(device)
    decoder.eval()
    encoder = checkpoint['encoder']
    encoder = encoder.to(device)
    encoder.eval()
    
    # Load word map (word2idx)
    with open(args['word_map'], 'r') as j:
        word_map = json.load(j)
        
    # idx2word
    rev_word_map = {v: k for k, v in word_map.items()}  
    
    for i, image_path in enumerate(test_images):
        seq, alphas = caption_image_beam_search(encoder, decoder, image_path, word_map, 4)
        
        sampled_caption = [rev_word_map[ind] for ind in seq]
        sampled_caption = []
        for ind in seq[1:]:
            word = rev_word_map[ind]
            if word == '<end>':
                break
            sampled_caption.append(word)
            
        sentence = ' '.join(sampled_caption)
        
        # Print out the image and the generated caption
        print (sentence)
        image = Image.open(image_path)
        plt.imshow(np.asarray(image))
        plt.title(sentence)
        plt.xticks([])
        plt.yticks([])
#         plt.savefig('outputs/'+image_path.split('/')[-1])
        plt.show()

        if visualize:
            alphas = torch.FloatTensor(alphas)
            visualize_att(image_path, seq, alphas, rev_word_map)

In [None]:
args = {
    'model':'./models/BEST_checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar',
    'word_map':'./models/WORDMAP_flickr30k_5_cap_per_img_5_min_word_freq.json'
}
test_images = [f'png/test{i}.png' for i in [1,2,3,4,5,6]]
main(args, test_images, True)

In [None]:
args = {
    'model':'./models/BEST_checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar',
    'word_map':'./models/WORDMAP_flickr30k_5_cap_per_img_5_min_word_freq.json'
}
test_images = images[10:20]
main(args, test_images, True)