In [None]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
# from models import TransformerLM, TransformerConditionedLM, TransformerSentenceLM
from models_modified import TransformerSentenceLM_FixedImg, TransformerSentenceLM_FixedImg_gated
from datasets import *
from utils_sentence import *       #changed
from nltk.translate.bleu_score import corpus_bleu
# from torch.optim.lr_scheduler import LambdaLR
import shutil

from torch.utils.tensorboard import SummaryWriter

import yaml

In [4]:
with open('../../config.yml') as yml:
    config = yaml.safe_load(yml)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
word_map_path="/net/papilio/storage2/yhaoyuan/transformer_I2S/data/processed/SpokenCOCO_LibriSpeech/WORDMAP_coco_1_cap_per_img_1_min_word_freq.json"
# Load word map (word2ix)
with open(word_map_path) as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word
special_words = {"<unk>", "<start>", "<end>", "<pad>"}

# I2U
# checkpoint_path = "../../saved_model/I2U/VC_5_captions/Trial_1/"
checkpoint_path = "../../saved_model/I2U/VC_5_captions_224/fixed_img_1024_sentence_8/"
# checkpoint_path = "../../saved_model/I2U/VC_5_captions_224/"

with open(checkpoint_path + 'config_sentence.yml', 'r') as yml:
    model_config = yaml.safe_load(yml)

checkpoint = checkpoint_path + f'bleu-4_BEST_checkpoint_coco_{str(model_config["i2u"]["captions_per_image"])}_cap_per_img_{str(model_config["i2u"]["min_word_freq"])}_min_word_freq_gpu.pth.tar'
dir_name = model_config["i2u"]["dir_name"]
model_params = model_config["i2u"]["model_params"]
train_params = model_config["i2u"]["train_params"]
img_refine_params = model_config["i2u"]["refine_encoder_params"]

data_folder = f'../../data/processed/{dir_name}/'  # folder with data files saved by create_input_files.py
# data_name = 'coco_4_cap_per_img_5_min_word_freq'  # base name shared by data files
#data_name = f'coco_{str(config["i2u"]["captions_per_image"])}_cap_per_img_{str(config["i2u"]["min_word_freq"])}_min_word_freq'  # base name shared by data files
data_name = f'coco_{str(model_config["i2u"]["captions_per_image"])}_cap_per_img_{str(model_config["i2u"]["min_word_freq"])}_min_word_freq'  # base name shared by data files
model_params['vocab_size'] = len(word_map)
model_params['refine_encoder_params'] = img_refine_params
model = TransformerSentenceLM_FixedImg(**model_params)
model.load_state_dict(torch.load(checkpoint)["model_state_dict"])
model.eval()
model.to(device)

TransformerSentenceLM_FixedImg(
  (embed): Embedding(1017, 1024)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (LM_decoder): None
  (classifier): Linear(in_features=1024, out_features=1017, bias=True)
  (image_encoder): DinoResEncoder_NoPool(
    (resnet): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [17]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
CaptionDataset_transformer(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
batch_size=1, shuffle=True, num_workers=10, pin_memory=True)

In [22]:
refs = list() # GT captions
hypos = list() # pred captions
start_unit = word_map["<start>"]
end_unit = word_map["<end>"]

for i, (imgs, caps, caplens, padding_mask, all_caps, all_padding_mask) in enumerate(iter(val_loader)):
    imgs = imgs.to(device)
    caps = caps.to(device)
    caplens = caplens.to(device)
    caplens = caplens.squeeze()
    padding_mask = padding_mask.to(device)
    all_caps = all_caps.to(device)
    all_padding_mask = all_padding_mask.to(device)

    pred_seq = model.decode(
        start_unit = start_unit,
        end_unit = end_unit,
        image = imgs,
        max_len = 130,
        beam_size = 5,
    )
    if len(pred_seq) == 0:
        continue

    pred_seq = [w for w in pred_seq if w not in {word_map['<start>'], word_map['<pad>']}]
    pred_seq = []
    hypos.append(pred_seq)
    # hypos

    for j in range(all_caps.shape[0]):
        img_caps = all_caps[j].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                img_caps))  # remove <start> and pads
        refs.append(img_captions)
    assert len(hypos) == len(refs)

    print(i)
    if i>=10:
        break

bleu_4 = corpus_bleu(refs, hypos)

0
1
2
3
4
5
6
7
8
9
10


In [23]:
bleu_4

0