In [None]:
import torch
import os
from PIL import Image
from glob import glob
from matplotlib import pyplot as plt
import numpy as np
import clip

PWD = os.chdir(os.path.join(os.getcwd(), '..'))
PWD = os.getcwd()
print(f'PWD is {PWD}')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from model import VisionTransformerExtraHead
from data import transform_val
from utils import visualize_attention_patches, load_pretrained_vit, stitch_images, get_top_n_results
from _types import (
    PretrainedViTNames, 
    vit_extended_same_norm_masked_28_args_16_heads_512_width as vit_multimodal_patch_args, 
    vit_extended_28_args_16_heads_512_width as vit_no_cls_args
)

pretrained_clip_vit = load_pretrained_vit(PretrainedViTNames.vit_b_32)

In [None]:
imagenames = glob('xxx/*.jpg')
len(imagenames)

## Store embeddings and attention probabilities for the pretrained ViT

In [None]:
image_embeds_pretrained, image_attn_probs_pretrained = [],[]

for imagename in imagenames:
    image = Image.open(imagename).convert('RGB')
    with torch.no_grad():
        y_pretrained, attn_probs_pretrained = pretrained_clip_vit(transform_val(image).unsqueeze(0))
        image_embeds_pretrained.append(y_pretrained.squeeze())
        image_attn_probs_pretrained.append(attn_probs_pretrained[0,0,1:].view(7,7))

image_embeds_pretrained = torch.stack(image_embeds_pretrained)
image_embeds_pretrained_unit = image_embeds_pretrained / image_embeds_pretrained.norm(dim=-1, keepdim=True)
image_attn_probs_pretrained = torch.stack(image_attn_probs_pretrained)
image_embeds_pretrained.shape, image_attn_probs_pretrained.shape

## Store embeddings and attention probabilities for the newly trained ViT

In [None]:
vit_no_cls_model = VisionTransformerExtraHead(**vit_no_cls_args.model_dump())
vit_no_cls_model.scale = None

CKPT_DIR = os.path.join(PWD, '..', 'checkpoints')
CKPT_FILE = os.path.join(CKPT_DIR, 'checkpoint_epoch24_vit_extended_dim_2024-04-11_19-18-30.pt')
state_dict = torch.load(CKPT_FILE, map_location='cpu')
vit_no_cls_model.load_state_dict(state_dict['model_state_dict'])
_ = vit_no_cls_model.eval()

In [None]:
image_embeds_student, image_attn_probs_student = [],[]

for imagename in imagenames:
    image = Image.open(imagename).convert('RGB')
    with torch.no_grad():
        y_student, attn_probs_student = vit_no_cls_model(transform_val(image).unsqueeze(0), same_norm=False)
        image_embeds_student.append(y_student.squeeze())
        image_attn_probs_student.append(attn_probs_student[0].view(8,8))

image_embeds_student = torch.stack(image_embeds_student)
image_embeds_student_unit = image_embeds_student / image_embeds_student.norm(dim=-1, keepdim=True)
image_attn_probs_student = torch.stack(image_attn_probs_student)
image_embeds_student.shape, image_attn_probs_student.shape

## Visualize and compare attention probabilities

In [None]:
def compare_attn_probs(index:int) -> Image.Image:
    image = Image.open(imagenames[index]).convert('RGB').resize((224,224))
    student_probs = image_attn_probs_student[index]
    left_image = visualize_attention_patches(student_probs, image)
    student_probs = torch.nn.functional.interpolate(student_probs[None, None, ...], (7,7)).squeeze()
    middle_image = visualize_attention_patches(student_probs, image)
    right_image = visualize_attention_patches(image_attn_probs_pretrained[index], image)
    return stitch_images([left_image, middle_image, right_image])

In [None]:
i = 4372
compare_attn_probs(i)

In [None]:
pretrained_clip_model, _ = clip.load('ViT-B/32', device=DEVICE)

def get_text_vector(text: str) -> torch.Tensor:
    with torch.no_grad():
        tokens = clip.tokenize([text]).to(DEVICE)
        return pretrained_clip_model.encode_text(tokens)

In [None]:
def compare_search(query:str, n:int=5) -> Image.Image:
    text_vector = get_text_vector(query)
    top_n_scores, top_n_indices = get_top_n_results(text_vector, image_embeds_student_unit, n=n)
    print(f'Top n indices for student model: {top_n_indices}')
    images = [ Image.open(imagenames[i]).convert('RGB').resize((224,224)) for i in top_n_indices ]
    top_image = stitch_images(images)
    top_n_scores, top_n_indices = get_top_n_results(text_vector, image_embeds_pretrained_unit, n=n)
    print(f'Top n indices for pretrained model: {top_n_indices}')
    images = [ Image.open(imagenames[i]).convert('RGB').resize((224,224)) for i in top_n_indices ]
    bottom_image = stitch_images(images)
    return stitch_images([top_image, bottom_image], horizontal=False)

In [None]:
query = 'a dog jumping'
compare_search(query)