In [1]:
%load_ext autoreload
%autoreload 2

In [25]:
from core.config import *

from core.datasets import get_answer_objects, JsonlIO, VOC2012SegDataset
from core.torch_utils import unprefix_state_dict
from core.viz import overlay_map, create_diff_mask, write_html_multi_row_image_caption
from core.prompter import get_significant_classes
from core.torch_utils import blend_tensors
from core.color_map import apply_colormap
from core.data_utils import flatten_list_to_depth
from models.vle import VLE_REGISTRY, VLEncoder, NewLayer, MapComputeMode

from itertools import product

import torchvision.transforms.v2.functional as TF

from core._types import Optional, Literal

In [3]:
config = setup_config(BASE_CONFIG, Path('/home/olivieri/exp/src/viz/config.yml'))
seg_config = config['seg']
vle_config = config['vle']

In [4]:
# Vision-Language Encoder
vle = VLE_REGISTRY.get(
    name="flair",
    version='flair-cc3m-recap.pt',
    pretrained_weights_root_path=vle_config['pretrained_weights_root_path'],
    new_layers=vle_config['new_layers'],
    device=config['device']
)

In [5]:
vle_weights_path = Path('/home/olivieri/exp/data/private/exps/vle/1-dM_vs_dT/250902_2047/weights/flair-flair-cc3m-recap.pt-text_proj_L_b256_250902_2047.pth')
if vle_weights_path.exists():
    vle.model.load_state_dict(unprefix_state_dict(torch.load(vle_weights_path, map_location='cuda')['model_state_dict'], prefix='_orig_mod.'))
else:
    raise AttributeError(f"ERROR: VLE weights path '{vle_weights_path}' not found.")

In [6]:
img_idxs = [9, 10, 11, 12, 13, 14, 15, 17]

map_compute_mode = MapComputeMode.ATTENTION
map_resize_mode = TF.InterpolationMode.NEAREST
normalize = True
mask_alpha = 0.55 # the greater, the less scene is visible 
map_alpha = 1.0
viz_image_size = 350

In [15]:
def compute_maps_with_captions(
        model: VLEncoder,
        images: torch.Tensor,
        captions: list[str],
        map_compute_mode: MapComputeMode,
        viz_image_size: Optional[int | list[int]] = None,
        map_resize_mode: TF.InterpolationMode = TF.InterpolationMode.NEAREST,
        normalize: bool = True,
) -> list[tuple[Image.Image, str]]:
    image_text_list = []

    for img, text in zip(images, captions):
        img_tensor = model.preprocess_images([img])
        text_tensor = model.preprocess_texts([text])
        sim = model.get_similarity(img_tensor, text_tensor, broadcast=False)
        map, min_value, max_value = model.get_maps(
            img_tensor, text_tensor,
            map_compute_mode=map_compute_mode,
            upsample_size=viz_image_size, upsample_mode=map_resize_mode,
            attn_heads_idx=[0, 3, 5, 7] # as done by the authors
        ) # [1, 1, H, W], m, M
        map = map.squeeze(0) # [H, W]
        if viz_image_size:
            img = TF.resize(img, size=viz_image_size, interpolation=TF.InterpolationMode.BILINEAR)
        ovr_img = overlay_map(img, map, normalize=normalize, alpha=map_alpha) # (H_viz, W_viz)

        image_text_list.append((ovr_img, f"SIM = {sim.item():.2f}", f"MIN VALUE = {min_value.item():.2f}, MAX VALUE = {max_value.item():.2f}", text, "---"))

    return image_text_list

In [None]:
def compute_cs_pr_maps_with_captions(
        img_idxs: list[int],
        vle: VLEncoder,
        mask_color: Literal['L', 'RB'],
        alpha: float,
        map_compute_mode: MapComputeMode,
        map_resize_mode: TF.InterpolationMode,
        viz_image_size: int | tuple[int, int],
        normalize: bool,
        contrastive: bool = False
) -> list[str | Image.Image]:
    prs_path = Path('/home/olivieri/exp/data/private/prompts_data/by_model/LRASPP_MobileNet_V3/class-splitted/answer_prs/gemma3:12b-it-qat/speed_test/SepMasks_Ovr_1fs.jsonl')
    cs_answers_pr = get_answer_objects(prs_path, idxs=None, jsonlio=JsonlIO(), return_state=False, format_to_dict=True)
    cs_answers_pr_text = [list(cs_answers_pr[i].values()) for i in img_idxs]

    viz_ds = VOC2012SegDataset(
        root_path=config['datasets']['VOC2012_root_path'],
        split='prompts_split',
        device=config['device'],
        resize_size=config['seg']['image_size'],
        center_crop=True,
        img_idxs=img_idxs,
        mask_prs_path=config['mask_prs_path']
    )
    
    scs, gts, prs = viz_ds[:]

    cs_ovr_mask_prs = []

    for sc, gt, pr in zip(scs, gts, prs):
        gt_sign_classes = get_significant_classes(gt)
        pr_sign_classes = get_significant_classes(pr)
        sign_classes = list(set(gt_sign_classes + pr_sign_classes))
        if 0 in sign_classes and sign_classes != [0]:
            sign_classes.remove(0)
        sign_classes = sorted(sign_classes)

        ovr_mask_prs = []

        for pos_c in sign_classes:
            pos_class_gt = (gt == pos_c)
            pos_class_pr = (pr == pos_c)

            diff_mask = create_diff_mask(pos_class_gt, pos_class_pr)

            # L overlay image
            ovr_diff_mask_L = blend_tensors(sc, diff_mask*255, alpha)

            # RB overlay image
            diff_mask += (diff_mask*pos_class_gt) # sets to 2 the false negatives
            diff_mask_col_RB = apply_colormap([diff_mask], {0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 0, 255)}).squeeze()
            ovr_diff_mask_RB = blend_tensors(sc, diff_mask_col_RB, alpha)

            if mask_color == 'L':
                ovr_mask_prs.append(ovr_diff_mask_L)
            elif mask_color == 'RB':
                ovr_mask_prs.append(ovr_diff_mask_RB)

        if contrastive:
            ovr_mask_prs = [ovr_mask_prs[0]]*int(len(ovr_mask_prs)) # instead of show consider positive images, show only the first one.

        cs_ovr_mask_prs.append(ovr_mask_prs)
    
    cs_pr_image_text_list = [compute_maps_with_captions(vle, ovr_mask_prs, answers_pr_text, map_compute_mode=map_compute_mode, viz_image_size=viz_image_size, map_resize_mode=map_resize_mode, normalize=normalize) for ovr_mask_prs, answers_pr_text in zip(cs_ovr_mask_prs, cs_answers_pr_text)]
    
    return cs_pr_image_text_list

In [28]:
title = 'visual_proj_text_proj_l_full_b256'
rows = dict()

sim_datas = []

for mask_a, contr in product(mask_alphas:=[mask_alpha, 0.0], contrs:=[False, True]):

    if mask_a == 0.0 and contr == True:
        continue

    cs_pr_image_text_list = compute_cs_pr_maps_with_captions(
        img_idxs=img_idxs,
        vle=vle,
        mask_color='L',
        alpha=mask_a,
        map_compute_mode=map_compute_mode,
        map_resize_mode=map_resize_mode,
        viz_image_size=viz_image_size,
        normalize=normalize,
        contrastive=contr
    )
    # display_prompt(flatten_list(cs_pr_image_text_list))

    if sim_datas == []:
        sim_datas = ["" for _ in range(len(flatten_list_to_depth(cs_pr_image_text_list, depth=1)))]

    new_sim_data = [f"{el[1]}<br>{el[2]}<br>" for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]
    sim_datas = [s_d+new_s_d for s_d, new_s_d in zip(sim_datas, new_sim_data)]

    images_row = [el[0] for el in flatten_list_to_depth(cs_pr_image_text_list, depth=1)]

    rows |= {f"{contr=}, {mask_a=}": images_row}

In [29]:
captions = [s_d+f"<br>{el[3]}" for s_d, el in zip(sim_datas, flatten_list_to_depth(cs_pr_image_text_list, depth=1))]

In [31]:
write_html_multi_row_image_caption(title, rows, captions)

Successfully created index.html. Open it in your browser to see the result.
