In [None]:
from file_utils import read_file_in_dir
from dataset_factory import get_datasets
from model_factory import get_model
import torch
from IPython.display import Image, display

# Specify the experiment name
experiment_name = 'arch2'

# Load best model
config_data = read_file_in_dir("./", f"{experiment_name}.json")
coco_test, vocab, _, _, test_loader = get_datasets(config_data)
model = get_model(config_data, vocab)
state_dict = torch.load(f"./experiment_data/{experiment_name}/best_model.pt")
model.load_state_dict(state_dict)

Rerun the next cell to get a random image from the test set and get the reference captions:

In [None]:
imgs, captions, img_ids = next(iter(test_loader))
img_id = img_ids[0]

img_path = coco_test.loadImgs(img_id)[0]['file_name']
img_path = "./data/images/test/" + img_path
display(Image(filename=img_path))

annotations = coco_test.imgToAnns[img_id]
for annotation in annotations:
    print(annotation['caption'])

Get non-deterministic predicted captions and deterministic predicted captions:

In [None]:
# Set the temperature
temperature = 0.1

with torch.no_grad():
    model = model.cuda()
    imgs = imgs.cuda()
    sampled_ids = model.sample(imgs, 20, temperature, False)
    sampled_id = next(iter(sampled_ids))
    # Get predicted caption
    predicted_caption = ""
    for i, word_idx in enumerate(sampled_id):
        if i == 0:
            continue
        word = vocab.idx2word[word_idx.item()]
        if word == '<end>':
            break
        predicted_caption += word + ' '
    print("Non-deterministic:")
    print(predicted_caption)

with torch.no_grad():
    model = model.cuda()
    imgs = imgs.cuda()
    sampled_ids = model.sample(imgs, 20, temperature, True)
    sampled_id = next(iter(sampled_ids))
    # Get predicted caption
    predicted_caption = ""
    for i, word_idx in enumerate(sampled_id):
        if i == 0:
            continue
        word = vocab.idx2word[word_idx.item()]
        if word == '<end>':
            break
        predicted_caption += word + ' '
    print("Deterministic:")
    print(predicted_caption)