In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import tqdm
import random
import torch
import matplotlib.pyplot as plt

from utils.visualization import visualize_sample
from model.utils import create_vlm
from model.utils import VisualTextualTokenization
from data import get_dataset
from visual_tokenizer import get_visual_tokenizer
import json

In [None]:

from data import (
    ShareGPT4V,
    ImageNet,
    Cambrian
)

dataset = ShareGPT4V(
    root='/private/home/delong/workspace/data/ShareGPT4V',
    split='sharegpt4v_instruct_gpt4-vision_cap100k.json')


In [None]:
checkpoint = "/private/home/delong/workspace/subobjects-VLM/runs/sharegpt4v/1106-1359-patch_6_per_side_raster(36)-clip_vit_l_14_336-SmolLM2-1_7B-Instruct/checkpoint-2000"
model, textual_tokenizer = create_vlm(checkpoint)

model = model.cuda().half().eval()

In [None]:
image_resolution = 384
max_tokens = 36

config = json.load(open('configs/visual_tokenizer/directsam/directsam_tiny_dsa_100ep@0.1.json'))


visual_tokenizer = get_visual_tokenizer(**config, image_resolution=image_resolution, max_tokens=max_tokens)

vl_tokenizer = VisualTextualTokenization(textual_tokenizer, visual_tokenizer)

In [None]:
n_samples = 10
loss = 0
for _ in tqdm.tqdm(range(n_samples)):
    sample = dataset[random.randint(0, len(dataset))]
    inputs = vl_tokenizer([sample], eval=True)

    # print(inputs.keys(), inputs)

    with torch.no_grad():
        outputs = model(**inputs)
        loss += outputs['loss'].item()

print(f"Loss: {loss / n_samples}")

In [None]:
sample = dataset[random.randint(0, len(dataset))]

label = sample['text'].split('<|assistant|>')[1].strip().replace(textual_tokenizer.eos_token, '')
sample['text'] = sample['text'].split('<|assistant|>')[0] + '<|assistant|>'

inputs = vl_tokenizer([sample], eval=True)


inputs_embeds, labels = model.prepare_inputs_embeds(
    inputs['text'], inputs['image'], inputs['masks']
)

outputs = model.generate(
    inputs_embeds=inputs_embeds,
    do_sample=False,
    max_new_tokens=dataset.max_text_tokens,
    eos_token_id = textual_tokenizer.eos_token_id,
    pad_token_id = textual_tokenizer.pad_token_id,
)
prediction = textual_tokenizer.decode(outputs[0], skip_special_tokens=True)

visualize_sample(sample, inputs)
print(label)
print('-' * 80)
print(prediction)