In [None]:
import os
import sys

import evaluate
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from modeling import *

In [None]:
bleu = evaluate.load("bleu")
meteor = evaluate.load('meteor')
rouge = evaluate.load('rouge')


def score(out, ann):
    references = [[x] for x in ann]
    predictions = [out]

    print(bleu.compute(predictions=predictions, references=[references]))
    print(meteor.compute(predictions=predictions, references=[references]))
    print(rouge.compute(predictions=predictions, references=[references]))


def eval(model, ds, index, samples, device):
    if index == -1:
        index = torch.randint(0, len(ds), (1,)).item()

    img, ann = ds[index]

    display(img)
    print("Target annotations:")
    print("\n".join(ann))
    print("\nOutput annotations:")

    img = transforms.PILToTensor()(img).to(device)
    img_in = model.vit_processor(img.unsqueeze(0), return_tensors="pt").pixel_values
    img_batch = img_in.unsqueeze(0).to(device)

    for _ in range(samples):
        out = model.generate(img_batch, max_new_tokens=64)

        print(out)
        score(out, ann)
        print()

In [None]:
device = torch.device("cuda:0")

ds = dset.CocoCaptions(
    root = '/mnt/gryf/home/xbuban1/coco/images/val2017',
    annFile = '/mnt/gryf/home/xbuban1/coco/annotations/captions_val2017.json'
)

model = LlamaGameDescription.from_pretrained(
    "/mnt/gryf/home/xbuban1/runs/0_Llama_Game_Desc/models/model_1",
    task="caption",
    device=device
)

In [None]:
# index -1 means random index
index = -1
samples = 5

eval(model, ds, index, samples, device)