In [None]:
import os
import sys

import evaluate
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2Tokenizer

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from modeling import *

In [None]:
bleu = evaluate.load("bleu")
meteor = evaluate.load('meteor')
rouge = evaluate.load('rouge')


def score(out, ann):
    references = [[x] for x in ann]
    predictions = [out]

    print(bleu.compute(predictions=predictions, references=[references]))
    print(meteor.compute(predictions=predictions, references=[references]))
    print(rouge.compute(predictions=predictions, references=[references]))


def eval(llama_model, ved_model, ds, index, device):
    if index == -1:
        index = torch.randint(0, len(ds), (1,)).item()

    img, ann = ds[index]

    display(img)
    print(f'Index: {index}\n')
    print("Target annotations:")
    print("\n".join(ann))
    print("\nOutput annotations:")

    img = transforms.PILToTensor()(img).to(device)
    img_in = llama_model.vit_processor(img.unsqueeze(0), return_tensors="pt").pixel_values
    img_batch = img_in.unsqueeze(0).to(device)
    out = llama_model.generate(img_batch, max_new_tokens=64, do_sample=False, top_p=None, temperature=None)

    print("Llama:")
    print(out)
    score(out, ann)

    out = ved_model.generate(img, max_new_tokens=64)

    print("\nVED:")
    print(out)
    score(out, ann)

In [None]:
device = torch.device("cuda:2")

ds = dset.CocoCaptions(
    root = '/home/xbuban1/coco/images/val2017',
    # root = '/mnt/gryf/home/xbuban1/coco/images/train2017',
    annFile = '/home/xbuban1/coco/annotations/captions_val2017.json'
    # annFile = '/mnt/gryf/home/xbuban1/coco/annotations/captions_train2017.json'
)

llama_model = LlamaGameDescription.from_pretrained(
    # "/mnt/gryf/home/xbuban1/runs/0_Llama_Game_Desc/models/model_1",
    "/home/xbuban1/LlamaGames/runs/2_Llama_Captions_small/models/model_10",
    # "/home/xbuban1/LlamaGames/runs/3_Llama_Captions_448_full/models/model_10",
    task="caption",
    device=device
)

ved_model = VEDModel(
    "/home/xbuban1/LlamaGames/models/ved_model",
    "nlpconnect/vit-gpt2-image-captioning",
    device=device
)

In [None]:
# index -1 means random index
index = -1

eval(llama_model, ved_model, ds, index, device)