In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from PIL import Image
import json

In [3]:
import pandas as pd

In [4]:
from pycocotools.coco import COCO

In [5]:
from caption import caption_image_beam_search

In [6]:
coco = COCO('./data/annotations/captions_flickr8k.json')
# coco = COCO('./data/annotations/captions_flickr30k.json')
# coco = COCO('./data/annotations/captions_val2014.json')

loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


In [7]:
img_ids = set()
for k in coco.anns.keys():
    img_ids.add(coco.anns[k]['image_id'])

In [8]:
images = []
for img_id in img_ids:
    images.append((coco.loadImgs(img_id)[0]['file_name'], coco.loadImgs(img_id)[0]['id']))

In [9]:
results = []

In [10]:

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main(args):
    
    # Load model
    checkpoint = torch.load(args['model'])
    decoder = checkpoint['decoder']
    decoder = decoder.to(device)
    decoder.eval()
    encoder = checkpoint['encoder']
    encoder = encoder.to(device)
    encoder.eval()
    
    # Load word map (word2idx)
    with open(args['word_map'], 'r') as j:
        word_map = json.load(j)
        
    # idx2word
    rev_word_map = {v: k for k, v in word_map.items()}  
    
    for i, (image_path, image_id) in enumerate(images):
        if i % 1000 == 0:
            print(i, 'th image')
        
        seq, alphas = caption_image_beam_search(encoder, decoder, 'data/Flickr8k/Flicker8k_Dataset/'+image_path, word_map, 4)
#         seq, alphas = caption_image_beam_search(encoder, decoder, 'data/Flickr30k/flickr30k-images/'+image_path, word_map, 4)
#         seq, alphas = caption_image_beam_search(encoder, decoder, 'data/val2014/'+image_path, word_map, 4)
        
        sampled_caption = []
        for ind in seq[1:]:
            word = rev_word_map[ind]
            if word == '<end>':
                break
            sampled_caption.append(word)
            
        sentence = ' '.join(sampled_caption)
        
        results.append({
            'image_id': image_id,
            'caption': sentence
        })
        
    return results

In [11]:
args = {
    'model':'BEST_checkpoint_flickr8k_5_cap_per_img_5_min_word_freq.pth.tar',
    'word_map':'data/Flickr8k/WORDMAP_flickr8k_5_cap_per_img_5_min_word_freq.json'
}

# args = {
#     'model':'BEST_checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar',
#     'word_map':'data/Flickr30k/WORDMAP_flickr30k_5_cap_per_img_5_min_word_freq.json'
# }

# args = {
#     'model':'models/BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar',
#     'word_map':'models/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json'
# }

In [12]:
results =  main(args)

0 th image


In [13]:
results[:10]

[{'image_id': 7000, 'caption': 'two brown dogs are playing in the snow'},
 {'image_id': 7001, 'caption': 'a dog is swimming in a pool'},
 {'image_id': 7002,
  'caption': 'a man in a green shirt is carrying a baby in a market'},
 {'image_id': 7003,
  'caption': 'a man is sitting on a bench in front of a red <unk>'},
 {'image_id': 7004, 'caption': 'a football player in a red uniform'},
 {'image_id': 7005, 'caption': 'a dog runs through the grass'},
 {'image_id': 7006, 'caption': 'a woman and a woman are smiling'},
 {'image_id': 7007,
  'caption': 'the brown and white dog is running through the grass'},
 {'image_id': 7008, 'caption': 'a black dog is running in the sand'},
 {'image_id': 7009, 'caption': 'a group of men play basketball'}]

In [14]:
with open('results/captions_flickr8k_results_beam4.json', 'w') as fp:
    json.dump(results, fp)
    
# with open('results/captions_flickr30k_results_beam4.json', 'w') as fp:
#     json.dump(results, fp)

# with open('results/captions_val2014_vinyals2015_results.json', 'w') as fp:
#     json.dump(results, fp)