In [1]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
from nlgeval import NLGEval
import os
from glob import glob
os.environ['CUDA_VISIBLE_DEVICES']="3"
import matplotlib.pyplot as plt
import cv2

In [2]:
# Parameters
data_folder = 'final_dataset'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files
checkpoint_file = 'BEST_48checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

In [4]:
# Load model
torch.nn.Module.dump_patches = True
checkpoint = torch.load(checkpoint_file,map_location = device)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()

DecoderWithAttention(
  (attention): Attention(
    (features_att): Linear(in_features=2048, out_features=1024, bias=True)
    (decoder_att): Linear(in_features=1024, out_features=1024, bias=True)
    (full_att): Linear(in_features=1024, out_features=1, bias=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(9490, 1024)
  (dropout): Dropout(p=0.5, inplace=False)
  (top_down_attention): LSTMCell(4096, 1024)
  (language_model): LSTMCell(3072, 1024)
  (fc1): Linear(in_features=1024, out_features=9490, bias=True)
  (fc): Linear(in_features=1024, out_features=9490, bias=True)
)

In [5]:
nlgeval = NLGEval()  # loads the evaluator

In [6]:
# Load word map (word2ix)
word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

In [7]:
# get vizwiz test dataset
# test_paths = glob("../mypythia/data/vizwiz/test/*.jpg")
test_paths = glob("../mypythia/data/vizwiz/test_npy/*.npy")

In [40]:
image_features = np.load(test_paths[7])[:36]
image_features = np.expand_dims(image_features, axis=0)
image_features = torch.tensor(image_features, dtype=torch.float32)
torch.min(image_features), torch.max(image_features)

(tensor(0.), tensor(22.8031))

In [8]:
test_paths = sorted(test_paths, key=lambda x: int(os.path.split(x)[-1].split("_")[-1].split(".")[0]))

In [None]:
beam_size = 100
k = beam_size

In [27]:
for x in test_paths:
    print(f"image_file_name: {x!r}")
    image_features = np.load(x)[:36]
    image_features = np.expand_dims(image_features, axis=0)
    image_features = torch.tensor(image_features, dtype=torch.float32)

    references = list()
    hypotheses = list()

    k = beam_size

    # Move to GPU device, if available
    image_features = image_features.to(device)  # (1, 36, 2048)
    image_features_mean = image_features.mean(1)
    print(f"image_features_mean: {image_features_mean.shape!r}")

    image_features_mean = image_features_mean.expand(k,2048)

    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Lists to store completed sequences and scores
    complete_seqs = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h1, c1 = decoder.init_hidden_state(k)  # (batch_size, decoder_dim)
    h2, c2 = decoder.init_hidden_state(k)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:
        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)
        h1,c1 = decoder.top_down_attention(
            torch.cat([h2,image_features_mean,embeddings], dim=1),
            (h1,c1))  # (batch_size_t, decoder_dim)
        attention_weighted_encoding = decoder.attention(image_features,h1)
        h2,c2 = decoder.language_model(
            torch.cat([attention_weighted_encoding,h1], dim=1),(h2,c2))

        scores = decoder.fc(h2)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)

        # Add new words to sequences
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != word_map['<end>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        h1 = h1[prev_word_inds[incomplete_inds]]
        c1 = c1[prev_word_inds[incomplete_inds]]
        h2 = h2[prev_word_inds[incomplete_inds]]
        c2 = c2[prev_word_inds[incomplete_inds]]
        image_features_mean = image_features_mean[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
        
        

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1
        
    if len(complete_seqs) > 0:
        break
        
#     i = complete_seqs_scores.index(max(complete_seqs_scores))
#     seq = complete_seqs[i]

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000000.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000001.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000002.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000003.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000004.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000005.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000006.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000007.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000069.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000070.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000071.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000072.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000073.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000074.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000075.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000076.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000138.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000139.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000140.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000141.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000142.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000143.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000144.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000145.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000207.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000208.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000209.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000210.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000211.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000212.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000213.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000214.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000276.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000277.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000278.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000279.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000280.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000281.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000282.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000283.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000345.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000346.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000347.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000348.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000349.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000350.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000351.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000352.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000414.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000415.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000416.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000417.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000418.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000419.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000420.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000421.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000483.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000484.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000485.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000486.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000487.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000488.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000489.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000490.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000552.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000553.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000554.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000555.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000556.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000557.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000558.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000559.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000621.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000622.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000623.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000624.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000625.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000626.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000627.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000628.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000690.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000691.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000692.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000693.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000694.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000695.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000696.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000697.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000759.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000760.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000761.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000762.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000763.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000764.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000765.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000766.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000828.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000829.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000830.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000831.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000832.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000833.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000834.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000835.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000897.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000898.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000899.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000900.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000901.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000902.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000903.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000904.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000966.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000967.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000968.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000969.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000970.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000971.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000972.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00000973.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001035.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001036.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001037.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001038.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001039.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001040.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001041.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001042.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001104.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001105.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001106.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001107.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001108.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001109.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001110.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001111.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001173.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001174.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001175.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001176.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001177.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001178.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001179.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwiz/test_npy/VizWiz_test_00001180.npy'
image_features_mean: torch.Size([1, 2048])
image_file_name: '../mypythia/data/vizwi

KeyboardInterrupt: 

In [24]:
len(complete_inds), len(incomplete_inds)

(0, 100)

In [160]:
complete_inds

[]

In [159]:
seqs.shape

torch.Size([5, 52])

In [147]:
scores.shape

torch.Size([5, 9490])

In [153]:
rev_word_map[9487]

'<unk>'

In [150]:
np.argmax(scores.cpu().detach().numpy(), axis=1)

array([9487, 9487, 9487, 9487, 9487])

In [143]:
seqs

tensor([[9488, 7961, 7961, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487],
        [9488, 7961, 7961, 9487, 9487, 9487, 9487, 9487, 7961, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487],
        [9488, 7961, 7961, 9487, 9487, 9487, 9487, 7961, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487, 9487,
         9487, 9487, 9487, 9487, 9487, 94

In [141]:
next_word_inds

tensor([9487, 9487, 9487, 9487, 9487], device='cuda:0')

In [142]:
incomplete_inds

[0, 1, 2, 3, 4]

In [31]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
from nlgeval import NLGEval

# Parameters
data_folder = 'final_dataset'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files
checkpoint_file = 'BEST_35checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint

word_map_file = 'WORDMAP_coco_5_cap_per_img_5_min_word_freq.json'  # word map, ensure it's the same the data was encoded with and the model was trained with
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
torch.nn.Module.dump_patches = True
checkpoint = torch.load(checkpoint_file,map_location = device)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()

nlgeval = NLGEval()  # loads the evaluator

# Load word map (word2ix)
word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

def evaluate(beam_size):
    """
    Evaluation
    :param beam_size: beam size at which to generate captions for evaluation
    :return: Official MSCOCO evaluator scores - bleu4, cider, rouge, meteor
    """
    # DataLoader
    loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TEST'),
        batch_size=1, shuffle=True, num_workers=1, pin_memory=torch.cuda.is_available())

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for i, (image_features, caps, caplens, allcaps) in enumerate(
            tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):
        if i > 2:
            break

        k = beam_size

        # Move to GPU device, if available
        print(f"image_features.shape: {image_features.shape!r}")
        print(f"np.min(image_features): {torch.min(image_features)!r}, np.max(image_features): {torch.max(image_features)!r}")
        image_features = image_features.to(device)  # (1, 36, 2048)
        image_features_mean = image_features.mean(1)
        print(f"image_features_mean.shape: {image_features_mean.shape!r}")
        image_features_mean = image_features_mean.expand(k,2048)
        print(f"image_features_mean.shape: {image_features_mean.shape!r}")        

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h1, c1 = decoder.init_hidden_state(k)  # (batch_size, decoder_dim)
        h2, c2 = decoder.init_hidden_state(k)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)
            h1,c1 = decoder.top_down_attention(
                torch.cat([h2,image_features_mean,embeddings], dim=1),
                (h1,c1))  # (batch_size_t, decoder_dim)
            attention_weighted_encoding = decoder.attention(image_features,h1)
            h2,c2 = decoder.language_model(
                torch.cat([attention_weighted_encoding,h1], dim=1),(h2,c2))

            scores = decoder.fc(h2)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words / vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != word_map['<end>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h1 = h1[prev_word_inds[incomplete_inds]]
            c1 = c1[prev_word_inds[incomplete_inds]]
            h2 = h2[prev_word_inds[incomplete_inds]]
            c2 = c2[prev_word_inds[incomplete_inds]]
            image_features_mean = image_features_mean[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [rev_word_map[w] for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}],
                img_caps))  # remove <start> and pads
        img_caps = [' '.join(c) for c in img_captions]
        #print(img_caps)
        references.append(img_caps)

        # Hypotheses
        hypothesis = ([rev_word_map[w] for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])
        hypothesis = ' '.join(hypothesis)
        #print(hypothesis)
        hypotheses.append(hypothesis)
        assert len(references) == len(hypotheses)
    return complete_seqs_scores, complete_seqs

In [32]:
evaluate(5)

EVALUATING AT BEAM SIZE 5:   0%|          | 1/25000 [00:00<1:05:41,  6.34it/s]

image_features.shape: torch.Size([1, 36, 2048])
np.min(image_features): tensor(0.), np.max(image_features): tensor(32.9999)
image_features_mean.shape: torch.Size([1, 2048])
image_features_mean.shape: torch.Size([5, 2048])
image_features.shape: torch.Size([1, 36, 2048])
np.min(image_features): tensor(0.), np.max(image_features): tensor(27.5576)
image_features_mean.shape: torch.Size([1, 2048])
image_features_mean.shape: torch.Size([5, 2048])
image_features.shape: torch.Size([1, 36, 2048])
np.min(image_features): tensor(0.), np.max(image_features): tensor(26.7169)
image_features_mean.shape: torch.Size([1, 2048])
image_features_mean.shape: torch.Size([5, 2048])


EVALUATING AT BEAM SIZE 5:   0%|          | 3/25000 [00:00<42:50,  9.72it/s]  


([tensor(-6.9773, device='cuda:0', grad_fn=<SelectBackward>),
  tensor(-6.4317, device='cuda:0', grad_fn=<SelectBackward>),
  tensor(-7.2896, device='cuda:0', grad_fn=<SelectBackward>),
  tensor(-7.6408, device='cuda:0', grad_fn=<SelectBackward>),
  tensor(-7.6486, device='cuda:0', grad_fn=<SelectBackward>)],
 [[9488, 7961, 5081, 1419, 2599, 7961, 9157, 8038, 5310, 9489],
  [9488, 7961, 5081, 1419, 2599, 7961, 9157, 290, 7961, 5310, 9489],
  [9488,
   7961,
   5081,
   1419,
   2599,
   7961,
   9157,
   290,
   7961,
   9158,
   4437,
   5310,
   9489],
  [9488, 7961, 5081, 1419, 2599, 7961, 9157, 290, 7961, 5310, 702, 3632, 9489],
  [9488, 7961, 5081, 1419, 2599, 7961, 9157, 290, 7961, 72, 4437, 5310, 9489]])