In [None]:
# Libraries
import model as md
import data
import torch
import transformers
import json
from pathlib import Path
from torch.utils.data import DataLoader, SequentialSampler
#conda install -c conda-forge transformers
#conda install -c anaconda mkl-service


# Variables
PRETRAINED_MODEL_PATH = "Desktop/Projects/large_both_knowledge/"
STAGED_MODEL_FILENAME = "model.pt"
TO_PROJECT = "admirer-training"
LOG_DIR = Path("training") / "logs"
STAGED_MODEL_TYPE = "prod-ready"

from_model_path = Path("/home/andrewhinh/") / PRETRAINED_MODEL_PATH
to_model_path = Path("./") / STAGED_MODEL_FILENAME


# Load best pre-trained model from https://github.com/guilk/KAT
model_class = md.FiDT5
model = model_class.from_pretrained(from_model_path)


# Loading data
model_name = 't5-large'
tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
collator = data.OKvqaCollator(64, tokenizer)

eval_examples = data.load_okvqa_data(
    './val2014',
    split_type='val2014',
    global_rank=1,
    world_size=1,
    use_gpt=True
)

dataset = data.OkvqaDataset(eval_examples, 40)
model.eval()

sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset,
    sampler=sampler,
    batch_size=1, #8,
    drop_last=False,
    num_workers=16,
    collate_fn=collator
)

answers = []
model.eval()
model = model.module if hasattr(model, "module") else model
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        (img_id, idx, _, _, context_ids, context_mask) = batch

        outputs = model.generate(
            input_ids=context_ids.cuda(),
            attention_mask=context_mask.cuda(),
            max_length=10
        )
        for k, o in enumerate(outputs): answers.append(tokenizer.decode(o, skip_special_tokens=True))

print(answers)