In [None]:
# Libraries
import model as md
import data
import torch
import transformers
from pathlib import Path
from torch.utils.data import DataLoader, SequentialSampler


# Variables
HOME_DIR = Path("/home/andrewhinh/Desktop/Projects/")
CURR_DIR = Path("./")
PRETRAINED_MODEL_PATH = "large_both_knowledge/"
STAGED_MODEL_FILENAME = "model.pt"

from_model_path = HOME_DIR / PRETRAINED_MODEL_PATH
to_model_path = CURR_DIR / STAGED_MODEL_FILENAME

text_maxlength = 64
n_context = 40
per_gpu_batch_size = 1
num_workers = 16
max_length = 10


# Load best pre-trained model from https://github.com/guilk/KAT
model_class = md.FiDT5
model = model_class.from_pretrained(from_model_path)


# Loading data
model_name = 't5-large'
tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
collator = data.OKvqaCollator(text_maxlength, tokenizer)

eval_examples = data.load_okvqa_data(
    HOME_DIR / 'val2014',
    split_type='val2014',
    use_gpt=True
)

dataset = data.OkvqaDataset(eval_examples, n_context)
model.eval()

sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset,
    sampler=sampler,
    batch_size=per_gpu_batch_size,
    drop_last=False,
    num_workers=num_workers,
    collate_fn=collator
)


# Generating Answers
answers = []
model.eval()
model = model.module if hasattr(model, "module") else model

with torch.no_grad():
    for i, batch in enumerate(dataloader):
        (_, _, _, _, context_ids, context_mask) = batch

        outputs = model.generate(
            input_ids=context_ids,
            attention_mask=context_mask,
            max_length=max_length
        )

        for k, o in enumerate(outputs): answers.append(tokenizer.decode(o, skip_special_tokens=True))

print(answers)