# [WIP] My approach at solving the research problem of AI`nspired project with DINOV2.

**If you want to run this on your machine, be sure to download a pretrained backbone from official DINOV2 repo: https://github.com/facebookresearch/dinov2**

In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from matplotlib import cm
from matplotlib.colors import Normalize
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.decomposition import PCA
from scipy.ndimage import binary_closing, binary_opening
from typing import Tuple

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.autograd.profiler as profiler

from utilities import *
from dinov2.models.vision_transformer import vit_small, vit_base, vit_large

## Data

**Quick overview**:
 - MISSING 6A in WEB and AI
 - MISSING 12B in AI
 - MISSING 22B in WEB and AI
 - MISSING 23B in AI
 - MISSING 26 IN WEB and AI (!) - because of this we need to skip group 26 for now

In [None]:
groups = get_groups()
final_data = pd.concat([groups.iloc[:25], groups[26:]], axis=0)
final_data

Let's check if everything loaded properly.

In [None]:
for i in range(len(final_data)):
    print(f'Sample photos from group {final_data.iloc[i].group_code}')
    ai = cv2.imread(final_data.iloc[i].ai_images[1])
    web = cv2.imread(final_data.iloc[i].web_images[1])
    final = cv2.imread(final_data.iloc[i].final_submissions[1])
    ai_aggregated_similarity, web_aggregated_similarity = 0, 0
    max_similarity, picture1, picture2 = 0, "", ""

    fig = plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(ai)
    plt.title("AI")
    plt.subplot(1, 3, 2)
    plt.imshow(web)
    plt.title("WEB")
    plt.subplot(1, 3, 3)
    plt.imshow(final)
    plt.title("Submission")
    plt.show()

## DINOV2

DISCLAIMER: This notebook will not run on your computer. I had to modify DinoV2 manually on my computer to make it work. I created my own fork of the original repository and I'll adapt the code to load the model from my repository later.

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEFAULT_BACKGROUND_THRESHOLD = 0.05
DEFAULT_APPLY_OPENING = False
DEFAULT_APPLY_CLOSING = False

torch.cuda.empty_cache()

In [None]:
class DinoV2():
    
    def __init__(self, 
                checkpoint: str ='dinov2_vitb14_reg4_pretrain.pth', 
                patch_size: int = 14, 
                img_size: int = 526, 
                n_register_tokens: int = 4, 
                smaller_edge_size: int = 224, 
                device=DEVICE
                ):
        self.model = vit_base(
            patch_size=patch_size,
            img_size=img_size,
            init_values=1.0,
            num_register_tokens=n_register_tokens,
            block_chunks=0
        )
        self.patch_size = patch_size
        self.smaller_edge = smaller_edge_size
        self.n_register_tokens = n_register_tokens
        self.device = device
        self.transform = transforms.Compose([
            transforms.Resize(size=self.smaller_edge, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # imagenet defaults
        ])
        self.model.load_state_dict(torch.load(checkpoint, map_location=device, weights_only=True))
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.to(self.device)
        self.model.eval()

    def prepare_image(self, rgb_image_numpy):
        with torch.inference_mode():
            image = Image.fromarray(rgb_image_numpy)
            image_tensor = self.transform(Image.fromarray(rgb_image_numpy))
            resize_scale = image.width / image_tensor.shape[2]
            del rgb_image_numpy
            torch.cuda.empty_cache()

        # Crop image to dimensions that are a multiple of the patch size
        height, width = image_tensor.shape[1:] # C x H x W
        cropped_width, cropped_height = width - width % self.patch_size, height - height % self.patch_size # crop a bit from right and bottom parts
        image_tensor = image_tensor[:, :cropped_height, :cropped_width]
        grid_size = (cropped_height // self.patch_size, cropped_width // self.patch_size)
            
        return image_tensor, grid_size, resize_scale
    
    def prepare_mask(self, mask_image_numpy, grid_size, resize_scale):
        cropped_mask_image_numpy = mask_image_numpy[:int(grid_size[0]*self.model.patch_size*resize_scale), :int(grid_size[1]*self.model.patch_size*resize_scale)]
        image = Image.fromarray(cropped_mask_image_numpy)
        resized_mask = image.resize((grid_size[1], grid_size[0]), resample=Image.Resampling.NEAREST)
        resized_mask = np.asarray(resized_mask).flatten()
        return resized_mask

    def idx_to_source_position(self, idx, grid_size, resize_scale):
        row = (idx // grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2
        col = (idx % grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2
        return row, col
  
    def get_embedding_visualization(self, tokens, grid_size, resized_mask=None):
        pca = PCA(n_components=3)
        if resized_mask is not None:
            tokens = tokens[resized_mask]
        reduced_tokens = pca.fit_transform(tokens.astype(np.float32))
        if resized_mask is not None:
            tmp_tokens = np.zeros((*resized_mask.shape, 3), dtype=reduced_tokens.dtype)
            tmp_tokens[resized_mask] = reduced_tokens
            reduced_tokens = tmp_tokens
        reduced_tokens = reduced_tokens.reshape((*grid_size, -1))
        normalized_tokens = (reduced_tokens-np.min(reduced_tokens))/(np.max(reduced_tokens)-np.min(reduced_tokens))
        return normalized_tokens

    def extract_features(self, image_numpy, pooling: bool = True):
        with torch.inference_mode():
            image_tensor = self.prepare_image(image_numpy)[0]
            image_tensor = image_tensor.unsqueeze(0).to(self.device)

            tokens = self.model.get_intermediate_layers(image_tensor)[0].squeeze()
            del image_tensor, image_numpy
            torch.cuda.empty_cache()

            if pooling == False:
                return tokens.cpu().numpy()

            pooled_features = tokens.mean(dim=0)
            del tokens
            torch.cuda.empty_cache()

            return pooled_features

    def calculate_similarity(self, image1: str, image2: str):
        with torch.inference_mode():
            features1 = self.extract_features(cv2.cvtColor(cv2.imread(image1, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
            features2 = self.extract_features(cv2.cvtColor(cv2.imread(image2, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))

            similarity = F.cosine_similarity(features1, features2, dim=0)
            del features1, features2
            torch.cuda.empty_cache()

            return (similarity.item() + 1) / 2

    def create_attention_mask(self, image_metric, save: bool = False, show: bool = False):
        with torch.inference_mode():
            normalized_metric = Normalize(vmin=image_metric.min(), vmax=image_metric.max())(image_metric)
            del image_metric
            torch.cuda.empty_cache()

            # Apply the Reds colormap
            reds = plt.cm.Reds(normalized_metric)

            # Create the alpha channel
            alpha_max_value = 1.00  # Set your max alpha value

            # Adjust this value as needed to enhance lower values visibility
            gamma = 0.5  

            # Apply gamma transformation to enhance lower values
            enhanced_metric = np.power(normalized_metric, gamma)
            del normalized_metric, gamma
            torch.cuda.empty_cache()

            # Create the alpha channel with enhanced visibility for lower values
            alpha_channel = enhanced_metric * alpha_max_value

            # Add the alpha channel to the RGB data
            rgba_mask = np.zeros((enhanced_metric.shape[0], enhanced_metric.shape[1], 4))
            rgba_mask[..., :3] = reds[..., :3]  # RGB
            rgba_mask[..., 3] = alpha_channel  # Alpha
            del reds, alpha_max_value, enhanced_metric, alpha_channel
            torch.cuda.empty_cache()
            
            # Convert the numpy array to PIL Image
            rgba_image = Image.fromarray((rgba_mask * 255).astype(np.uint8))
            del rgba_mask
            torch.cuda.empty_cache()

            if save:
                rgba_image.save('attention_mask.png')
            if show:
                display(rgba_image)

            return rgba_image

    def create_attention_photo(self, og_image: Image, attention_mask_image, save: bool = False, show: bool = False):
        # Ensure both images are in the same mode
        if og_image.mode != 'RGBA':
            og_image = og_image.convert('RGBA')

        # Overlay the second image onto the first image
        # The second image must be the same size as the first image
        og_image.paste(attention_mask_image, (0, 0), attention_mask_image)

        if save:
            og_image.save('image_with_attention.png')
        if show:
            display(og_image)

        return og_image

    def return_attention_map(self, filepath: str, show: bool = False, mask_only: bool = False):
        with torch.inference_mode():
            # I know this is a weird way to do this but it works for now
            og_image = Image.open(filepath)
            (original_w, original_h) = og_image.size

            if show:
                display(og_image)

            img = self.prepare_image(cv2.cvtColor(cv2.imread(filepath, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))[0]
            w, h = img.shape[1] - img.shape[1] % self.patch_size, img.shape[2] - img.shape[2] % self.patch_size
            img = img[:, :w, :h]

            w_featmap = img.shape[-2] // self.patch_size
            h_featmap = img.shape[-1] // self.patch_size

            img = img.unsqueeze(0)
            img = img.to(self.device)
            attention = self.model.get_last_self_attention(img.to(self.device))
            del img, w, h
            torch.cuda.empty_cache()
            
            number_of_heads = attention.shape[1]

            # attention tokens are packed in after the first token; the spatial tokens follow
            attention = attention[0, :, 0, 1 + self.n_register_tokens:].reshape(number_of_heads, -1)

            # resolution of attention from transformer tokens
            attention = attention.reshape(number_of_heads, w_featmap, h_featmap)
            
            # upscale to higher resolution closer to original image
            attention = nn.functional.interpolate(attention.unsqueeze(0), scale_factor=self.patch_size, mode = "nearest")[0].cpu()

            # sum all attention across the 12 different heads, to get one map of attention across entire image
            attention = torch.sum(attention, dim=0)

            # interpolate attention map back into original image dimensions
            attention = nn.functional.interpolate(attention.unsqueeze(0).unsqueeze(0), size=(original_h, original_w), mode='bilinear', align_corners=False)
            del original_h, original_w, w_featmap, h_featmap, number_of_heads
            torch.cuda.empty_cache()
            
            attention = attention.squeeze()
            image_metric = attention.numpy()
            del attention
            torch.cuda.empty_cache()

            attention_mask = self.create_attention_mask(image_metric, show=show)
            del image_metric
            torch.cuda.empty_cache()

            if mask_only:
                return attention_mask
            
            photo_with_attention = self.create_attention_photo(og_image, attention_mask, show=show)
            del og_image
            torch.cuda.empty_cache()

            return attention_mask, photo_with_attention
    

### Small demo of the attention maps

In [None]:
file = 'data/ai/1B_15_4.png'
with torch.no_grad():
    dino = DinoV2()
    attention_mask, attention_map = dino.return_attention_map(file)
    display(attention_map)

In [None]:
def draw_lines(model, image1, image2, origin, similarity_threshold=0.1, max_lines=20, n_neighbors=1):
    img1 = cv2.cvtColor(cv2.imread(image1, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
    img2 = cv2.cvtColor(cv2.imread(image2, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)

    _, grid_size1, resize_scale1 = model.prepare_image(img1)
    _, grid_size2, resize_scale2 = model.prepare_image(img2)

    features1 = model.extract_features(img1, pooling=False)
    features2 = model.extract_features(img2, pooling=False)
    
    # Use multiple neighbors
    knn = NearestNeighbors(n_neighbors=n_neighbors)
    knn.fit(features1)
    distances, matches = knn.kneighbors(features2)

    # Normalize distances to [0, 1]
    distances = (distances - distances.min()) / (distances.max() - distances.min())

    # Filter and select matches based on similarity threshold
    selected_matches = []
    for idx2, (dist_row, match_row) in enumerate(zip(distances, matches)):
        for dist, idx1 in zip(dist_row, match_row):
            if dist < similarity_threshold:
                selected_matches.append((dist, idx1, idx2))

    fig = plt.figure(figsize=(20, 10))  
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.imshow(img1)
    ax1.axis("off")
    ax1.set_title("Final submission")
    ax2.imshow(img2)
    ax2.axis("off")
    ax2.set_title(f"Closest inspiration - {origin}")

    if len(selected_matches) != 0:
        # Normalize selected matches distances
        match_distances = [match[0] for match in selected_matches]
        normalized_distances = (match_distances - np.min(match_distances)) / (np.max(match_distances) - np.min(match_distances) + 1e-8)

        # Update selected_matches with normalized distances
        for i, match in enumerate(selected_matches):
            selected_matches[i] = (normalized_distances[i], match[1], match[2])

        # Sort and limit number of lines
        selected_matches = sorted(selected_matches, key=lambda x: x[0])[:max_lines]

        for dist, idx1, idx2 in selected_matches:
            enhanced_dist = np.sqrt(dist)
            row, col = model.idx_to_source_position(idx1, grid_size1, resize_scale1)
            xyA = (col, row)

            row, col = model.idx_to_source_position(idx2, grid_size2, resize_scale2)
            xyB = (col, row)

            # Map similarity to color and thickness
            color = cm.plasma(enhanced_dist)  # Colormap based on similarity
            linewidth = (1 + enhanced_dist)  # Thicker lines for higher similarity
            
            con = ConnectionPatch(xyA=xyB, xyB=xyA, coordsA="data", coordsB="data",
                                axesA=ax2, axesB=ax1, color=color, linewidth=linewidth)
            ax2.add_artist(con)

    plt.show()

def draw_attention(model, pic1, pic2):
    for picture in [pic1, pic2]:
        attn_mask, attn_photo = model.return_attention_map(picture, show=False)
        fig = plt.figure(figsize=(20, 10))
                
        ax1 = fig.add_subplot(131)
        ax2 = fig.add_subplot(132)
        ax3 = fig.add_subplot(133)
        
        ax1.imshow(cv2.cvtColor(cv2.imread(picture, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
        ax1.set_title("Photo")
        ax1.axis("off")
                
        ax2.imshow(attn_mask)
        ax2.set_title("Attention Mask")
        ax2.axis("off")
                
        ax3.imshow(attn_photo)
        ax3.set_title("Masked Photo")
        ax3.axis("off")
                
        plt.tight_layout()
        plt.show()

## Let's see if the similarity scores make any sense.

In [None]:
group = 12

with torch.inference_mode():
    torch.cuda.empty_cache()

    ai = final_data.iloc[group].ai_images
    web = final_data.iloc[group].web_images
    final = final_data.iloc[group].final_submissions
    max_ai, min_ai, max_web, min_web = -1, float('inf'), -1, float('inf')
    ai_total, web_total = 0, 0
    max_similarity, pic1, pic2 = -1, "", ""
    inspiration = "INCONCLUSIVE"

    print(f'Calculating similarity for group {final_data.iloc[group].group_code}')

    for final_photo in final:
        torch.cuda.empty_cache()
        
        print("AI PHOTOS")
        for ai_photo in ai:
            similarity = dino.calculate_similarity(final_photo, ai_photo)
            print(f"Similarity: {similarity}")

            min_ai = min(min_ai, similarity)
            max_ai = max(max_ai, similarity)

            draw_attention(dino, final_photo, ai_photo)
            draw_lines(dino, final_photo, ai_photo, "AI")

            ai_total += similarity
            if similarity > max_similarity:
                max_similarity = similarity
                pic1, pic2 = final_photo, ai_photo
                if similarity > 0.5: inspiration = "AI"
        del ai_photo
        torch.cuda.empty_cache()

        print("WEB PHOTOS")
        for web_photo in web:
            similarity = dino.calculate_similarity(final_photo, web_photo)
            print(f"Similarity: {similarity}")
            
            min_web = min(min_web, similarity)
            max_web = max(max_web, similarity)

            draw_attention(dino, final_photo, web_photo)
            draw_lines(dino, final_photo, web_photo, "WEB")

            web_total += similarity
            if similarity > max_similarity:
                max_similarity = similarity
                pic1, pic2 = final_photo, web_photo
                if similarity > 0.5: inspiration = "WEB"
        del web_photo
        torch.cuda.empty_cache()
        break

    ai_total = ai_total / (len(final) * len(ai))
    web_total = web_total / (len(final) * len(web))

    del final_photo, final, ai, web
    torch.cuda.empty_cache()

    print(f'\tSimilarity scores - AI: {ai_total:.3f}\tWEB: {web_total:.3f}')
    print(f'\tAI similarity - MAX: {max_ai} | MIN: {min_ai}')
    print(f'\tWEB similarity - MAX: {max_web} | MIN: {min_web}')
    print(f'\tAccording to DINO, this group was mostly inspired by {inspiration}.')
        
    if len(pic1) != 0 and len(pic2) != 0:
        draw_attention(dino, pic1, pic2)
        draw_lines(dino, pic1, pic2, inspiration)

## Let's see it in action!

In [None]:
similarity_results = pd.DataFrame(columns=['final_photo', 'inspiration', 'similarity'])

with torch.inference_mode():
    for i in range(len(final_data)):
        torch.cuda.empty_cache()

        ai = final_data.iloc[i].ai_images
        web = final_data.iloc[i].web_images
        final = final_data.iloc[i].final_submissions
        
        ai_total, web_total = 0, 0
        max_similarity, pic1, pic2 = -1, "", ""
        min_similarity, pic3, pic4 = float('inf'), "", ""
        inspiration = "INCONCLUSIVE"

        print(f'Calculating similarity for group {final_data.iloc[i].group_code}')

        for final_photo in final:
            
            torch.cuda.empty_cache()
            for ai_photo in ai:
                similarity = dino.calculate_similarity(final_photo, ai_photo)
                new_row = pd.DataFrame({'final_photo': [final_photo], 'inspiration': [ai_photo], 'similarity': [similarity]})
                similarity_results = pd.concat([similarity_results, new_row], ignore_index=True)

                ai_total += similarity
                
                if similarity > max_similarity:
                    max_similarity = similarity
                    pic1, pic2 = final_photo, ai_photo
                    if similarity > 0.5: inspiration = "AI"
                if similarity < min_similarity:
                    min_similarity = similarity
                    pic3, pic4 = final_photo, ai_photo
            
            del ai_photo
            torch.cuda.empty_cache()

            for web_photo in web:
                similarity = dino.calculate_similarity(final_photo, web_photo)
                new_row = pd.DataFrame({'final_photo': [final_photo], 'inspiration': [web_photo], 'similarity': [similarity]})
                similarity_results = pd.concat([similarity_results, new_row], ignore_index=True)

                web_total += similarity
                
                if similarity > max_similarity:
                    max_similarity = similarity
                    pic1, pic2 = final_photo, web_photo
                    if similarity > 0.5: inspiration = "WEB"
                if similarity < min_similarity:
                    min_similarity = similarity
                    pic3, pic4 = final_photo, web_photo

            del web_photo
            torch.cuda.empty_cache()

        ai_total = ai_total / (len(final) * len(ai))
        web_total = web_total / (len(final) * len(web))

        del final_photo, final, ai, web
        torch.cuda.empty_cache()

        print(f'\tSimilarity scores - AI: {ai_total:.3f}\tWEB: {web_total:.3f}')
        print(f'\tAccording to DINO, this group was mostly inspired by {inspiration}.')
        
        if len(pic1) != 0 and len(pic2) != 0:
            draw_attention(dino, pic1, pic2)
            draw_lines(dino, pic1, pic2, inspiration)

similarity_results.to_csv('similarity_results.csv', index=False)