In [None]:
from file_utils import read_file_in_dir
from dataset_factory import get_datasets
from model_factory import get_model
import torch
from IPython.display import Image
import nltk

In [None]:
# Specify the experiment name
experiment_name = 'lstm_1'

In [None]:
# Load best model
config_data = read_file_in_dir("./", f"{experiment_name}.json")
coco_test, vocab, _, _, test_loader = get_datasets(config_data)
model = get_model(config_data, vocab)
state_dict = torch.load(f"./experiment_data/{experiment_name}/best_model.pt")
model.load_state_dict(state_dict)

# Load latest model
# config_data = read_file_in_dir("./", f"{experiment_name}.json")
# coco_test, vocab, _, _, test_loader = get_datasets(config_data)
# model = get_model(config_data, vocab)
# state_dict = torch.load(f"./experiment_data/{experiment_name}/latest_model.pt")
# model.load_state_dict(state_dict['model'])

In [None]:
imgs, captions, img_ids = next(iter(test_loader))
img = imgs[0].unsqueeze(0)
caption = captions[0].unsqueeze(0)
img_id = img_ids[0]

In [None]:
img_path = coco_test.loadImgs(img_id)[0]['file_name']
img_path = "./data/images/test/" + img_path
Image(filename=img_path) 

In [None]:
if torch.cuda.is_available():
    model.cuda()
    img = img.cuda() 
    caption = caption.cuda()

with torch.no_grad():
    output = model(img, caption)
    output = torch.argmax(output, dim=-1).squeeze()

target_caption = ""
for word_idx in caption.squeeze():
    word = vocab.idx2word[word_idx.item()]
    target_caption += word + ' '

print(target_caption)

predicted_caption = ""
for word_idx in output:
    word = vocab.idx2word[word_idx.item()]
    predicted_caption += word + ' '

print(predicted_caption)

In [None]:
with torch.no_grad():
    sampled_ids = model.sample(img, 20, 0.1, False)
    for i, sampled_id in enumerate(sampled_ids):
        # Get predicted caption
        predicted_caption = ""
        for j, word_idx in enumerate(sampled_id):
            if j == 0:
                continue
            word = vocab.idx2word[word_idx.item()]
            if word == '<end>':
                break
            predicted_caption += word + ' '
        predicted_caption = nltk.tokenize.word_tokenize(predicted_caption)
        print(predicted_caption)