In [None]:
import cv2

from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


In [None]:
device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300

model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

# a dict to store the activations
# trying to add hooks on the attention layers    
activation = {}
def getActivation(name):
    # the hook signature
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

# add hooks to the model
proj_q_1 = model.dec_blocks[0].cross_attn.register_forward_hook(getActivation('projq_1'))
proj_k_1 = model.dec_blocks[0].cross_attn.register_forward_hook(getActivation('projk_1'))

proj_q_2 = model.dec_blocks2[0].cross_attn.register_forward_hook(getActivation('projq_2'))
proj_k_2 = model.dec_blocks2[0].cross_attn.register_forward_hook(getActivation('projk_2'))

# load_images can take a list of images or a directory
images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)



In [None]:
model

In [None]:
import math
import torch

#collection of the activations
projq_1 = activation['projq_1']
projk_1 = activation['projk_1']

projq_2 = activation['projq_2']
projk_2 = activation['projk_2']

print(projq_1.shape)  # Should print [1, 768, 768]
print(projk_1.shape)  # Should print [1, 768, 768]

B, N, C = projq_1.shape  # B=1, N=768, C=768
num_heads = 8
head_dim = C // num_heads  # 96

projq_1 = projq_1.view(B, N, num_heads, head_dim).transpose(1, 2)  # [1, 8, 768, 96]
projk_1 = projk_1.view(B, N, num_heads, head_dim).transpose(1, 2)  # [1, 8, 768, 96]

projq_2 = projq_2.view(B, N, num_heads, head_dim).transpose(1, 2)  # [1, 8, 768, 96]
projk_2 = projk_2.view(B, N, num_heads, head_dim).transpose(1, 2)  # [1, 8, 768, 96]

# Compute attention scores
attn_scores_1 = torch.matmul(projq_1, projk_1.transpose(-2, -1))  # [1, 8, 768, 768]
attn_scores_2 = torch.matmul(projq_2, projk_2.transpose(-2, -1))  # [1, 8, 768, 768]

# Scale the scores
attn_scores_1 = attn_scores_1 / math.sqrt(head_dim)  # [1, 8, 768, 768]
attn_scores_2 = attn_scores_2 / math.sqrt(head_dim)  # [1, 8, 768, 768]

# Apply softmax to get attention weights
attention_weights_1 = torch.softmax(attn_scores_1, dim=-1)  # [1, 8, 768, 768]
attention_weights_2 = torch.softmax(attn_scores_2, dim=-1)  # [1, 8, 768, 768]

print(attention_weights_1.shape)  # Should print [1, 8, 768, 768]
print(attention_weights_2.shape)  # Should print [1, 8, 768, 768]

In [None]:
for head_idx in range(8):

    # Extract attention matrix for the selected head
    attention_matrix = attention_weights_1[0, head_idx].detach().cpu().numpy()  # [768, 768]

    # Verify number of patches
    num_patches = 768
    grid_height = 32
    grid_width = 24

    assert grid_height * grid_width == num_patches, "Grid dimensions do not match number of patches."

    # Reshape to [32, 24, 32, 24]
    attention_grid = attention_matrix.reshape(grid_height, grid_width, grid_height, grid_width)  # [32, 24, 32, 24]

    # Aggregate attention across query patches
    aggregated_attention = attention_grid.sum(axis=(0, 1))  # [32, 24]

    # Normalize the attention map
    aggregated_attention = aggregated_attention / aggregated_attention.max()

    # print(aggregated_attention.shape)  # Should print (32, 24)

    # Load your image
    image_path = 'croco/assets/Chateau1.png'  # Replace with your actual image path
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)

    image_path_other_view = 'croco/assets/Chateau2.png'  # Replace with your actual image path
    image_other_view = Image.open(image_path_other_view).convert('RGB')
    image_np_other_view = np.array(image_other_view)

    # Confirm image dimensions
    # print(image_np.shape)  # Should print (512, 384, 3)

    # Resize the aggregated attention map to match the image size
    attention_map = aggregated_attention  # [32, 24]

    # Use OpenCV to resize
    attention_map_resized = cv2.resize(attention_map, (image_np.shape[1], image_np.shape[0]))  # (384, 512)

    # Normalize the attention map to [0, 255]
    attention_map_resized = np.uint8(255 * attention_map_resized)

    # Apply a color map (e.g., JET)
    attention_map_colored = cv2.applyColorMap(attention_map_resized, cv2.COLORMAP_JET)

    # Convert from BGR (OpenCV default) to RGB
    attention_map_colored = cv2.cvtColor(attention_map_colored, cv2.COLOR_BGR2RGB)

    # Blend the attention map with the original image
    alpha = 0.6  # Transparency factor for the original image
    beta = 0.4   # Transparency factor for the attention map
    gamma = 0    # Scalar added to each sum

    overlayed_image = cv2.addWeighted(image_np, alpha, attention_map_colored, beta, gamma)

    # Plot the results
    plt.figure(figsize=(18, 6))

    # Original Image
    plt.subplot(1, 4, 1)
    plt.imshow(image_np_other_view)
    plt.title('Original Image')
    plt.axis('off')

    # Aggregated Attention Heatmap
    plt.subplot(1, 4, 2)
    sns.heatmap(aggregated_attention, cmap='viridis')
    plt.title(f'Aggregated Attention - Head {head_idx + 1}')
    plt.xlabel('Key Patch X (Width)')
    plt.ylabel('Key Patch Y (Height)')

    # Overlayed Image
    plt.subplot(1, 4, 3)
    plt.imshow(overlayed_image)
    plt.title('Overlayed Attention Map')
    plt.axis('off')

    # Other image Image
    plt.subplot(1, 4, 4)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.tight_layout()
    plt.show()



In [None]:


for head_idx in range(8):

    # Extract attention matrix for the selected head
    attention_matrix = attention_weights_2[0, head_idx].detach().cpu().numpy()  # [768, 768]

    # print(attention_matrix.shape)  # Should print (768, 768)

    # Verify number of patches
    num_patches = 768
    grid_height = 32
    grid_width = 24

    assert grid_height * grid_width == num_patches, "Grid dimensions do not match number of patches."

    # Reshape to [32, 24, 32, 24]
    attention_grid = attention_matrix.reshape(grid_height, grid_width, grid_height, grid_width)  # [32, 24, 32, 24]

    # Aggregate attention across query patches
    aggregated_attention = attention_grid.sum(axis=(0, 1))  # [32, 24]

    # Normalize the attention map
    aggregated_attention = aggregated_attention / aggregated_attention.max()

    # print(aggregated_attention.shape)  # Should print (32, 24)

    # Load your image
    image_path = 'croco/assets/Chateau2.png'  # Replace with your actual image path
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)

    image_path_other_view = 'croco/assets/Chateau1.png'  # Replace with your actual image path
    image_other_view = Image.open(image_path_other_view).convert('RGB')
    image_np_other_view = np.array(image_other_view)

    # Confirm image dimensions
    # print(image_np.shape)  # Should print (512, 384, 3)

    # Resize the aggregated attention map to match the image size
    attention_map = aggregated_attention  # [32, 24]

    # Use OpenCV to resize
    attention_map_resized = cv2.resize(attention_map, (image_np.shape[1], image_np.shape[0]))  # (384, 512)

    # Normalize the attention map to [0, 255]
    attention_map_resized = np.uint8(255 * attention_map_resized)

    # Apply a color map (e.g., JET)
    attention_map_colored = cv2.applyColorMap(attention_map_resized, cv2.COLORMAP_JET)

    # Convert from BGR (OpenCV default) to RGB
    attention_map_colored = cv2.cvtColor(attention_map_colored, cv2.COLOR_BGR2RGB)

    # Blend the attention map with the original image
    alpha = 0.6  # Transparency factor for the original image
    beta = 0.4   # Transparency factor for the attention map
    gamma = 0    # Scalar added to each sum

    overlayed_image = cv2.addWeighted(image_np, alpha, attention_map_colored, beta, gamma)

    # Plot the results
    plt.figure(figsize=(18, 6))

    # Original Image
    plt.subplot(1, 4, 1)
    plt.imshow(image_np_other_view)
    plt.title('Source Image')
    plt.axis('off')

    # Aggregated Attention Heatmap
    plt.subplot(1, 4, 2)
    sns.heatmap(aggregated_attention, cmap='viridis')
    plt.title(f'Aggregated Attention - Head {head_idx + 1}')
    plt.xlabel('Key Patch X (Width)')
    plt.ylabel('Key Patch Y (Height)')

    # Overlayed Image
    plt.subplot(1, 4, 3)
    plt.imshow(overlayed_image)
    plt.title('Target Overlayed Attention Map')
    plt.axis('off')

    # Other image Image
    plt.subplot(1, 4, 4)
    plt.imshow(image)
    plt.title('Target Image')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

