In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import tqdm
import random
import torch
import matplotlib.pyplot as plt

from utils.visualization import visualize_sample
from model.utils import create_vlm
from model.utils import VisualTextualTokenization
from data import get_dataset
from visual_tokenizer import get_visual_tokenizer
import json

In [None]:
# eval_dataset = get_dataset('coco', '/share/datasets/coco2017', split='val')
eval_dataset = get_dataset('clevr_caption', '/home/dchenbs/workspace/datasets/CLEVR_v1.0', split='val')

In [None]:
checkpoint = "runs/0921-1420-clevr_caption-directsam_tiny(32)-convnext_in22k_stage2-SmolLM-360M-Instruct/checkpoint-200"
model, textual_tokenizer = create_vlm(checkpoint)

model = model.cuda().half().eval()

In [None]:
image_resolution = 512
max_tokens = 64

# config = json.load(open('configs/visual_tokenizer/patch_8_per_side_random.json'))
# config = json.load(open('configs/visual_tokenizer/patch_8_per_side_raster.json'))
config = json.load(open('configs/visual_tokenizer/directsam_0424.json'))
# config = json.load(open('configs/visual_tokenizer/directsam_tiny.json'))

visual_tokenizer = get_visual_tokenizer(**config, image_resolution=image_resolution, max_tokens=max_tokens)

vl_tokenizer = VisualTextualTokenization(textual_tokenizer, visual_tokenizer)

In [None]:
n_samples = 10
loss = 0
for _ in tqdm.tqdm(range(n_samples)):
    sample = eval_dataset[random.randint(0, len(eval_dataset))]
    inputs = vl_tokenizer([sample], eval=True)

    with torch.no_grad():
        outputs = model(**inputs)
        loss += outputs['loss'].item()

print(f"Loss: {loss / n_samples}")

In [None]:
sample = eval_dataset[random.randint(0, len(eval_dataset))]

label = sample['text'].split('<|assistant|>')[1].strip().replace(textual_tokenizer.eos_token, '')
sample['text'] = sample['text'].split('<|assistant|>')[0] + '<|assistant|>'

inputs = vl_tokenizer([sample], eval=True)


inputs_embeds, labels = model.prepare_inputs_embeds(
    inputs['text'], inputs['image'], inputs['masks']
)

outputs = model.generate(
    inputs_embeds=inputs_embeds,
    do_sample=False,
    max_new_tokens=eval_dataset.max_text_tokens,
    eos_token_id = textual_tokenizer.eos_token_id,
    pad_token_id = textual_tokenizer.pad_token_id,
)
prediction = textual_tokenizer.decode(outputs[0], skip_special_tokens=True)

visualize_sample(sample, inputs)
print(label)
print(prediction)