In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from core.config import *

from core.datasets import VOC2012SegDataset
from core.torch_utils import unprefix_state_dict
from core.viz import write_html_multi_row_image_caption, compute_cs_maps_with_captions
from core.data_utils import flatten_list_to_depth
from models.vle import VLE_REGISTRY, MapComputeMode

from itertools import product
from pathlib import Path

import torchvision.transforms.v2.functional as TF

In [3]:
config = setup_config(BASE_CONFIG, Path('/home/olivieri/exp/src/viz/viz_config.yml'))
seg_config = config['seg']
vle_name: str = config['vle_name']
vle_config = config['vles'][vle_name]

In [None]:
if vle_name == 'flair':
    vle = VLE_REGISTRY.get(
        name='flair',
        version='flair-cc3m-recap.pt',
        pretrained_weights_root_path=vle_config['pretrained_weights_root_path'],
        new_layers=vle_config['new_layers'],
        device=config['device']
    )
elif vle_name == 'fg-clip':
    vle = VLE_REGISTRY.get(
        name='fg-clip',
        version='fg-clip-base',
        pretrained_weights_root_path=vle_config['pretrained_weights_root_path'],
        device=config['device'],
        long_captions=True
    )

In [5]:
if vle_config['checkpoint_path']:
    vle_checkpoint_path = Path(vle_config['checkpoint_path'])
    # NOTE check if map_location='cpu' can improve the vRAM usage somewhere else in the codebase.
    if vle_checkpoint_path.exists():
        vle.model.load_state_dict(unprefix_state_dict(torch.load(vle_checkpoint_path, map_location='cpu')['model_state_dict'], prefix='_orig_mod.'))
    else:
        raise AttributeError(f"ERROR: VLE weights path '{vle_checkpoint_path}' not found.")

In [6]:
img_idxs = [9, 10, 11, 12, 13, 14, 15, 17]

map_compute_mode = MapComputeMode.ATTENTION
map_resize_mode = TF.InterpolationMode.NEAREST
normalize = True
mask_alpha = 0.55 # the greater, the less scene is visible
map_alpha = 1.0
viz_image_size = 350

In [7]:
viz_ds = VOC2012SegDataset(
    root_path=config['datasets']['VOC2012_root_path'],
    split='prompts_split',
    device=config['device'],
    resize_size=config['seg']['image_size'],
    center_crop=True,
    img_idxs=img_idxs,
    mask_prs_path=config['mask_prs_path']
)

jsonl_path = Path('/home/olivieri/exp/data/private/prompts_data/by_model/LRASPP_MobileNet_V3/class-splitted/answer_prs/gemma3:12b-it-qat/speed_test/SepMasks_Ovr_1fs.jsonl')

In [8]:
title = 'test'
rows = dict()

sim_datas = []

for mask_a, contr in product(mask_alphas:=[mask_alpha, 0.0], contrs:=[False, True]):

    if mask_a == 0.0 and contr == True:
        continue

    cs_pr_image_text_list = compute_cs_maps_with_captions(
        viz_ds=viz_ds,
        img_idxs=img_idxs,
        jsonl_path=jsonl_path,
        vle=vle,
        mask_color='L',
        alpha=mask_a,
        map_alpha=map_alpha,
        map_compute_mode=map_compute_mode,
        map_resize_mode=map_resize_mode,
        viz_image_size=viz_image_size,
        normalize=normalize,
        contrastive=contr
    )
    # display_prompt(flatten_list(cs_pr_image_text_list))

    if sim_datas == []:
        sim_datas = ["" for _ in range(len(flatten_list_to_depth(cs_pr_image_text_list, depth=1)))]

    new_sim_data = [f"{el[1]}<br>{el[2]}<br>" for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]
    sim_datas = [s_d+new_s_d for s_d, new_s_d in zip(sim_datas, new_sim_data)]

    images_row = [el[0] for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]

    rows |= {f"{contr=}, {mask_a=}": images_row}

In [9]:
captions = [s_d+f"<br>{el[3]}" for s_d, el in zip(sim_datas, flatten_list_to_depth(cs_pr_image_text_list, depth=1))]

In [10]:
write_html_multi_row_image_caption(title, rows, captions)

Successfully created index.html. Open it in your browser to see the result.
