In [None]:
%matplotlib inline



# Similarity search on Satellite Images (DINOv2)

## Install and import

We start by installing `faiss-gpu` which is a library for efficient similarity search and clustering of dense vectors.
It contains algorithms that search in sets of vectors of any size, up to ones that possibly do not fit in RAM.


In [None]:
!pip install faiss-gpu

In [2]:
import torch
from transformers import AutoProcessor, AutoImageProcessor, AutoModel
from PIL import Image
from google.colab import drive
import os
import json
import faiss
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

## Set-up folders

Here, we mount Google Drive and set up base paths for our images and data.


In [None]:
drive.mount('/content/drive')

In [4]:
# Root folder in Google Drive
DRIVE_BASE_PATH = '/content/drive/MyDrive/'

In [5]:
# Set input folder
IMAGES = os.path.join(DRIVE_BASE_PATH, 'dino', 'input_images')

In [20]:
# Set 'home' image to search (replace with your search-file)
HOME_IMAGE = os.path.join(DRIVE_BASE_PATH, 'dino', 'home.jpg')

## Functions

Functions to use in the process:

*   'extract_features' extracts image features using the DINOv2 model;
*    'normalizeL2' normalizes the embeddings in L2 space.

In [7]:
def extract_features(image):
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt").to(device)
        outputs = model(**inputs)
        image_features = outputs.last_hidden_state
        return image_features.mean(dim=1)

In [8]:
def normalizeL2(embeddings):
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)
    return vector

## Load model and processor

Let's use the 'base' DINOv2 model. It has 86 million parameters, a size of 331MB and the features extracted from an image have a dimensionality of 768.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

## Load data

In [12]:
#Populate the images variable with all the images in the dataset folder
images = [os.path.join(IMAGES, file_name) for file_name in os.listdir(IMAGES)]

## Create FAISS-index

Here, we create and populate a FAISS index for efficient similarity search.

We choose IndexFlatL2 to measure the L2 (or Euclidean) distance between all given points between our query vector, and the vectors loaded into the index (the features extracted from the image dataset).

In [10]:
def open_and_convert_image(image_path):
    img = Image.open(image_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

In [None]:
# Use dimensionality of DINOv2 Base-model
index = faiss.IndexFlatL2(768)

#Iterate over the dataset to extract features and store features in index
for image_path in tqdm(images):
    img = open_and_convert_image(image_path)
    vector = extract_features(img)
    normalized_vector = normalizeL2(vector)
    index.add(normalized_vector)

#store the index
faiss.write_index(index,"dino.index")

## Similarity search

We perform a similarity search using a reference image. The script retrieves the top 5 similar images from our dataset.

In [21]:
#Input image
image = Image.open(HOME_IMAGE)

#Extract features from home image
image_features = extract_features(image)
image_features = normalizeL2(image_features)

#Search the top 5 images
index = faiss.read_index("dino.index")

#Get distance and indexes of images associated
d, i = index.search(image_features,5)

## Display images

In [None]:
# Use index number to get images from the image list
retrieved_images = [images[idx] for idx in i[0]]

# Set axes
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Display home image
axes[0].imshow(mpimg.imread(HOME_IMAGE))
axes[0].set_title('Home Image')
axes[0].axis('off')

# Display similar images
for ax, img_path, distance in zip(axes[1:], retrieved_images, d[0]):
    ax.imshow(mpimg.imread(img_path))
    ax.set_title(f'Similarity: {distance:.2f}')
    ax.axis('off')

plt.tight_layout()
plt.show()