In [1]:
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from src.pre_vqa import PreVQA
from src.vision_transformer import interpolate_pos_embed
from src.tokenization_bert import BertTokenizer

import utils
from dataset.utils import save_result
from dataset import create_dataset, create_sampler, create_loader, vqa_collate_fn

from scheduler import create_scheduler
from optim import create_optimizer

In [2]:
config = yaml.load(open('./configs/vqa.yaml', 'r'), Loader=yaml.Loader)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = PreVQA(config=config, text_encoder='bert-base-uncased', text_decoder='bert-base-uncased', tokenizer=tokenizer)
 

In [3]:
device = 'cuda:1'
img_root=config['vqa_root']
model=model.to(device)
model.eval()

PreVQA(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate=none)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      

In [4]:
checkpoint = torch.load('output/vqa/checkpoint_00.pth', map_location='cpu') 
state_dict = checkpoint['model']
model.load_state_dict(state_dict,strict=False)

<All keys matched successfully>

In [58]:
from PIL import Image
from torchvision import transforms
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
test_transform = transforms.Compose([
        transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        normalize,
        ]) 



In [73]:
pair = { "question": "What color is his hat?", "image": "train2014/COCO_train2014_000000393227.jpg", "dataset": "vqa"}
answer_list = json.load(open(config['answer_list'],'r'))  
print(answer_list)
answer_list = [answer+config['eos'] for answer in answer_list]
print(answer_list)
answer_input = tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) 
image = os.path.join(img_root,pair['image'])
image = Image.open(image).convert('RGB')   
image = test_transform(image) 
print(image.size) 
image = torch.unsqueeze(image,0)
image = image.to(device,non_blocking=True) 
question = pair['question']
question_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
topk_ids, topk_probs = model(image, question_input, answer_input, train=False, k=config['k_test'])



answer:black[SEP]
