# Comparing Contextual Mappings of Tokens

In [1]:
!pip install torch torchvision transformers timm ipywidgets -q

[0m

In [4]:
import torch
from transformers import ViTModel, ViTConfig
from torchvision import transforms
from PIL import Image, ImageDraw
import ipywidgets as widgets
import numpy as np
import math 

def display_image_with_patches(image, patch_size, title):
    image_width, image_height = image.size
    draw = ImageDraw.Draw(image)
    patch_id = 0

    for y in range(0, image_height, patch_size):
        for x in range(0, image_width, patch_size):
            draw.rectangle([x, y, x + patch_size, y + patch_size], outline="red", width=2)
            draw.text((x + 5, y + 5), str(patch_id), fill="red")
            patch_id += 1

    image.show(title=title)

    
# Load a pre-trained Vision Transformer model
config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
model = ViTModel.from_pretrained("google/vit-base-patch16-224")

# Prepare the image inputs
image_path_A = "heart.png"
image_path_B = "heart2.png"
image_A = Image.open(image_path_A).convert("RGB")
image_B = Image.open(image_path_B).convert("RGB")

num_patches_sqrt = int(math.sqrt(config.num_hidden_layers))
patch_size = 224 // num_patches_sqrt

# Display images with labeled patches
display_image_with_patches(image_A.copy(), patch_size, title="Image A with Patches")
display_image_with_patches(image_B.copy(), patch_size, title="Image B with Patches")


# Transform the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_input_A = transform(image_A).unsqueeze(0)  # Add batch dimension
image_input_B = transform(image_B).unsqueeze(0)  # Add batch dimension

# Pass the images through the model to get features
with torch.no_grad():
    outputs_A = model(image_input_A)
    outputs_B = model(image_input_B)
    hidden_states_A = outputs_A.last_hidden_state
    hidden_states_B = outputs_B.last_hidden_state

# Function to compute cosine similarity between image patches
def compute_cosine_similarity(patch_index_A, patch_index_B):
    patch_vector_A = hidden_states_A[0, patch_index_A, :]
    patch_vector_B = hidden_states_B[0, patch_index_B, :]
    
    cosine_similarity = torch.nn.functional.cosine_similarity(patch_vector_A.unsqueeze(0), patch_vector_B.unsqueeze(0))
    return cosine_similarity.item()

# Create input boxes for patch indices
patch_index_box_A = widgets.IntText(
    value=1,
    description='Patch index A:',
    disabled=False
)

patch_index_box_B = widgets.IntText(
    value=1,
    description='Patch index B:',
    disabled=False
)

# Create a button to trigger the computation of cosine similarity
button = widgets.Button(description="Compute cosine similarity")

def on_button_click(b):
    cosine_similarity = compute_cosine_similarity(patch_index_box_A.value, patch_index_box_B.value)
    print(f"Cosine similarity for patch indices {patch_index_box_A.value} and {patch_index_box_B.value}: {cosine_similarity}")

button.on_click(on_button_click)

# Display the UI
display(patch_index_box_A)
display(patch_index_box_B)
display(button)

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


IntText(value=1, description='Patch index A:')

IntText(value=1, description='Patch index B:')

Button(description='Compute cosine similarity', style=ButtonStyle())

Cosine similarity for patch indices 178 and 76: 0.6201204061508179
