In [None]:
import sys

INSTALL = False # Switch this to install dependencies
if INSTALL: # Try installing package with extras
    REPO_URL = "https://github.com/facebookresearch/dinov2"
    !{sys.executable} -m pip install -e {REPO_URL}'[extras]' --extra-index-url https://download.pytorch.org/whl/cu117  --extra-index-url https://pypi.nvidia.com
else:
    REPO_PATH = "/home/osero/Desktop/CMPE/dinov2" # Specify a local path to the repository (or use installed package instead)
    sys.path.append(REPO_PATH)

# import torch
# torch.cuda.empty_cache()

In [None]:
# A simple notebook demonstrating how to extract an attention map from DinoV2 inference (with registers) 

# Most of the core code was originally published here:
#  https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/

# November 11th, 2023 by Lance Legel (lance@3co.ai) from 3co, Inc. (https://3co.ai)

%load_ext autoreload
%autoreload 2

In [None]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
from dinov2.models.vision_transformer import vit_small, vit_base, vit_large
from matplotlib.colors import Normalize
from io import BytesIO
import requests
os.environ["XFORMERS_DISABLED"] = "1" # Switch to enable xFormers
print(os.environ.get("XFORMERS_DISABLED") )


In [None]:
# These are settings for ensuring input images to DinoV2 are properly sized

class ResizeAndPad:
    def __init__(self, target_size, multiple):
        self.target_size = target_size
        self.multiple = multiple

    def __call__(self, img):
        # Resize the image
        img = transforms.Resize(self.target_size)(img)

        # Calculate padding
        pad_width = (self.multiple - img.width % self.multiple) % self.multiple
        pad_height = (self.multiple - img.height % self.multiple) % self.multiple

        # Apply padding
        img = transforms.Pad((pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2))(img)
        
        return img

image_dimension = 952
    
# This is what DinoV2 sees
target_size = (image_dimension, image_dimension)

# During inference / testing / deployment, we want to remove data augmentations from the input transform:
data_transforms = transforms.Compose([ ResizeAndPad(target_size, 14),
                                       transforms.CenterCrop(image_dimension),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                                     ]
                                     )

In [None]:
image_size = (image_dimension, image_dimension)
output_dir = '.'
patch_size = 14
n_register_tokens = 4

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#model.load_state_dict(torch.load('/home/osero/Desktop/CMPE/dinov2/my_test/dinov2_vitg14_pretrain.pth'))
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").to(device)

# for p in model.parameters():
#     p.requires_grad = False
# model.to(device)
# model.eval()

In [None]:
# # URL of the image
# image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/c/cd/STS-124_launch_from_a_distance.jpg/800px-STS-124_launch_from_a_distance.jpg"  # Replace with your image URL

# # Download the image
# response = requests.get(image_url)

# # Check if the request was successful
# if response.status_code == 200:
#     # Open the image
#     original_image = Image.open(BytesIO(response.content))

#     # Display the image
#     display(original_image)
# else:
#     print(f"Failed to download the image. Status code: {response.status_code}")

original_image = Image.open('dino_test_imgs/STS-124_launch_from_a_distance.jpg')

(original_w, original_h) = original_image.size

img = data_transforms(original_image)

# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
img = img[:, :w, :h]

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size

img = img.unsqueeze(0)
img = img.to(device)

In [None]:
attention = model.get_last_self_attention(img.to(device))

In [None]:
print("Attention {}: {}".format(attention.shape, attention))

In [None]:
number_of_heads = attention.shape[1]

# attention tokens are packed in after the first token; the spatial tokens follow
attention = attention[0, :, 0, 1 + n_register_tokens:].reshape(number_of_heads, -1)

In [None]:
print(attention.shape)

In [None]:
# resolution of attention from transformer tokens
attention = attention.reshape(number_of_heads, w_featmap, h_featmap)
print(attention.shape)

In [None]:
# upscale to higher resolution closer to original image
attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=patch_size, mode = "nearest")[0].cpu()
print(attention.shape)

In [None]:
# sum all attention across the 12 different heads, to get one map of attention across entire image
attention = torch.sum(attention, dim=0)
print(attention.shape)

In [None]:
# interpolate attention map back into original image dimensions
attention_of_image = nn.functional.interpolate(attention.unsqueeze(0).unsqueeze(0), size=(original_h, original_w), mode='bilinear', align_corners=False)
attention_of_image = attention_of_image.squeeze()
print(attention_of_image.shape)

In [None]:
# Normalize image_metric to the range [0, 1]
image_metric = attention_of_image.numpy()
normalized_metric = Normalize(vmin=image_metric.min(), vmax=image_metric.max())(image_metric)

# Apply the Reds colormap
reds = plt.cm.Reds(normalized_metric)

# Create the alpha channel
alpha_max_value = 1.00  # Set your max alpha value

# Adjust this value as needed to enhance lower values visibility
gamma = 0.5  

# Apply gamma transformation to enhance lower values
enhanced_metric = np.power(normalized_metric, gamma)

# Create the alpha channel with enhanced visibility for lower values
alpha_channel = enhanced_metric * alpha_max_value

# Add the alpha channel to the RGB data
rgba_mask = np.zeros((image_metric.shape[0], image_metric.shape[1], 4))
rgba_mask[..., :3] = reds[..., :3]  # RGB
rgba_mask[..., 3] = alpha_channel  # Alpha

# Convert the numpy array to PIL Image
rgba_image = Image.fromarray((rgba_mask * 255).astype(np.uint8))

# Save the image
rgba_image.save('attention_mask.png')

In [None]:
# Load the attention mask with PIL
attention_mask_image = Image.open("{}/attention_mask.png".format(output_dir))

# Ensure both images are in the same mode
if original_image.mode != 'RGBA':
    original_image = original_image.convert('RGBA')

# Overlay the second image onto the first image
# The second image must be the same size as the first image
original_image.paste(attention_mask_image, (0, 0), attention_mask_image)

# Save or show the combined image
original_image.save('image_with_attention.png')

# Or display it
display(original_image)