code reference: https://huggingface.co/docs/transformers/model_doc/dinov2;
https://github.com/vra/dinov2-retrieval;

In [2]:
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import faiss
import numpy as np
import os

#load the model and processor
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(f'Current Device: {device}')

processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small').to(device)
data_folder = '../test_data'

def add_vector_to_index(embedding, index):
    vector = embedding.detach().cpu().numpy() #convert embedding to numpy
    vector = np.float32(vector)
    #Normalize vector: important to avoid wrong results when searching
    faiss.normalize_L2(vector) 
    index.add(vector)

#Populate the images variable with all the images in the dataset folder
images = []
for root, dirs, files in os.walk(data_folder):
    for file in files:
        if file.endswith(('jpg', 'png')):
            images.append(root + '/' + file)

#DINOv2_vits14 feature dim 是384维，所以建立dim=384的index, type是FlatL2
index = faiss.IndexFlatL2(384)

# t0 = time.time()
for img_path in images:
    img = Image.open(img_path).convert('RGB')
    with torch.no_grad():
        inputs = processor(images=img, return_tensors='pt').to(device)
        outputs = model(**inputs)
    features = outputs.last_hidden_state
    add_vector_to_index(features.mean(dim=1), index)

# print('Extraction done in: ', time.time() - t0)
faiss.write_index(index, 'database.index')

  from .autonotebook import tqdm as notebook_tqdm


Current Device: cpu


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


使用 dinov2 对单张图片进行检索

In [5]:
#input image
image = Image.open('../test_data/aerial/nardo-air_qu-42.png')
query_idx = faiss.IndexFlatL2(384)

#Extract the features
with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt").to(device)
    outputs = model(**inputs)
#Normalize the features before search
embeddings = outputs.last_hidden_state
embeddings = embeddings.mean(dim=1)
# vector = add_vector_to_index(embeddings, query_idx)
vector = embeddings.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)

#Read the index file and perform search of top-3 images
index = faiss.read_index("database.index")
dist, idx = index.search(vector, 3)
print('distances:', dist, 'indexes:', idx)

#images[[i][0][k]]为检索到的图片，k为0,1,2

distances: [[0.         0.46195972 0.5801141 ]] indexes: [[3 1 2]]


visualization code