In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from core.config import *

from core.datasets import VOC2012SegDataset
from core.viz import write_html_multi_row_image_caption, compute_cs_gt_pr_maps_with_captions
from core.data_utils import flatten_list_to_depth
from models.vle import VLE_REGISTRY, MapComputeMode, NewLayer

from pathlib import Path

import torchvision.transforms.v2.functional as TF

In [3]:
config = setup_config(BASE_CONFIG, Path('/home/olivieri/exp/src/train/vle/2-GT_PR_vs_dT/viz_config.yml'))
seg_config = config['seg']
vle_config = config['vle']

new_layers = [new_layer for new_layer in NewLayer if new_layer.value in vle_config['new_layers']]

In [4]:
vle = VLE_REGISTRY.get(
    name='flair',
    version='flair-cc3m-recap.pt',
    pretrained_weights_root_path=vle_config['pretrained_weights_root_path'],
    new_layers=new_layers,
    device=config['device']
)

In [5]:
# img_idxs = [9, 10, 11, 12, 13, 14, 15, 17]
img_idxs = np.random.choice(range(0, 80), size=10, replace=False).tolist() # reproducible
assert 16 not in img_idxs # one-shot example
img_idxs, len(img_idxs)

([30, 0, 22, 31, 18, 28, 10, 70, 4, 12], 10)

`img_idxs` should be array([30,  0, 22, 31, 18, 28, 10, 70,  4, 12]).

In [6]:
map_compute_mode = MapComputeMode.AVG_TEXT_TOKEN_ATTN
map_resize_mode = TF.InterpolationMode.NEAREST
mask_color = 'L'
mask_alpha = 0.55 # the greater, the less scene is visible
map_alpha = 1.0
viz_image_size = 224
normalize = True

In [7]:
viz_ds = VOC2012SegDataset(
    root_path=config['datasets']['VOC2012_root_path'],
    split='prompts_split',
    device=config['device'],
    resize_size=config['seg']['image_size'],
    center_crop=True,
    mask_prs_path=config['mask_prs_path']
)

jsonl_path = Path('/home/olivieri/exp/data/private/prompts_data/by_model/LRASPP_MobileNet_V3/class-splitted/answer_prs/gemma3:12b-it-qat/speed_test/SepMasks_Ovr_1fs.jsonl')

In [8]:
ckpt_paths: dict[str, Path] = {
    'concat_adapter': Path('/home/olivieri/exp/data/private/exps/vle/2-GT_PR_vs_dT/b128-251021_0740/weights/flair-cc3m-recap.pt-2-GT_PR_vs_dT-b128-251021_0740.pth'),
    'text_post+concat_adapter': Path('/home/olivieri/exp/data/private/exps/vle/2-GT_PR_vs_dT/250930_1351/weights/flair-flair-cc3m-recap.pt-diff_concat_adapter_gt-pr__t_L_b128_250930_1351.pth')
}

In [9]:
title = 'avg_text_token_attn'
rows = dict()

sim_datas = []

for ckpt_name, ckpt_path in ckpt_paths.items():

    vle.load_model_state_dict(ckpt_path)

    cs_pr_image_text_list = compute_cs_gt_pr_maps_with_captions(
        viz_ds=viz_ds,
        img_idxs=img_idxs,
        jsonl_path=jsonl_path,
        vle=vle,
        mask_color=mask_color,
        mask_alpha=mask_alpha,
        map_alpha=map_alpha,
        map_compute_mode=map_compute_mode,
        map_resize_mode=map_resize_mode,
        viz_image_size=viz_image_size,
        normalize=normalize,
        with_patch_text=True,
        font_size=10
    )

    if sim_datas == []:
        sim_datas = ["" for _ in range(len(flatten_list_to_depth(cs_pr_image_text_list, depth=1)))]

    new_sim_data = [f"{el[1]}<br>{el[2]}<br>" for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]
    sim_datas = [s_d+new_s_d for s_d, new_s_d in zip(sim_datas, new_sim_data)]

    images_row = [el[0] for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]

    rows[ckpt_name] = images_row

In [10]:
captions = [s_d+f"<br>{el[3]}" for s_d, el in zip(sim_datas, flatten_list_to_depth(cs_pr_image_text_list, depth=1))]

In [11]:
write_html_multi_row_image_caption(title, rows, captions)

Successfully created avg_text_token_attn.html. Open it in your browser to see the result.
