In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import tqdm
import random
import torch
import matplotlib.pyplot as plt

from utils.visualization import visualize_sample
from visual_tokenizer.directsam import DirectSAMTokenizer
from model.utils import create_vlm
from model.utils import VisualTextualTokenization
from data import get_dataset


In [None]:
eval_dataset = get_dataset('coco', '/share/datasets/coco2017', split='train')
# eval_dataset = get_dataset('clevr_caption', '/home/dchenbs/workspace/datasets/CLEVR_v1.0', split='val')

In [None]:
checkpoint = "runs/0919-1632-coco-vae-no-query-SmolLM-360M-Instruct/checkpoint-2000"
model, textual_tokenizer = create_vlm(checkpoint)

model = model.cuda().half().eval()

In [None]:
visual_tokenizer = DirectSAMTokenizer(
    checkpoint="chendelong/DirectSAM-tiny-distilled-30ep-plus-50ep-1024px-0910",
    threshold=0.1,
    image_resolution=model.config.vlm_config.image_resolution,
    max_tokens=128,
    device="cuda"
)
vl_tokenizer = VisualTextualTokenization(textual_tokenizer, visual_tokenizer)

In [None]:
n_samples = 10
loss = 0
for _ in tqdm.tqdm(range(n_samples)):
    sample = eval_dataset[random.randint(0, len(eval_dataset))]
    inputs = vl_tokenizer([sample], eval=True)

    with torch.no_grad():
        outputs = model(**inputs)
        loss += outputs['loss'].item()

print(f"Loss: {loss / n_samples}")

In [None]:
sample = eval_dataset[random.randint(0, len(eval_dataset))]

label = sample['text'].split('<|assistant|>')[1].strip().replace(textual_tokenizer.eos_token, '')
sample['text'] = sample['text'].split('<|assistant|>')[0] + '<|assistant|>'

inputs = vl_tokenizer([sample], eval=True)


inputs_embeds, labels = model.prepare_inputs_embeds(
    inputs['text'], inputs['image'], inputs['masks']
)

outputs = model.generate(
    inputs_embeds=inputs_embeds,
    do_sample=False,
    max_new_tokens=eval_dataset.max_text_tokens,
    eos_token_id = textual_tokenizer.eos_token_id,
    pad_token_id = textual_tokenizer.pad_token_id,
)
prediction = textual_tokenizer.decode(outputs[0], skip_special_tokens=True)

visualize_sample(sample, inputs)
print(label)
print(prediction)