# Поиск похожих изображений по картинке

In [69]:
import cv2
import torch
import torchvision.models as models
from transformers import AutoImageProcessor, AutoModel

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

dino_option = "DINO MODEL EMBEDDINGS"
vgg_option = "VGG-16 MODEL EMBEDDINGS"
hog_option = "HISTOGRAM OF GRADIENTS EMBEDDINGS"
color_hist_option = "COLOR HISTOGRAM EMBEDDINGS"
sift_option = "SIFT EMBEDDINGS"

In [70]:
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
dino_model = AutoModel.from_pretrained('facebook/dinov2-small').to(device)

In [71]:
vgg_model = models.vgg16(pretrained=True)
vgg_model.classifier = vgg_model.classifier[0]

In [75]:
import os
import numpy as np
import pandas as pd
import tqdm as tqdm


def calculate_embeddings(calculate_embedding_method, output_file_name, images_path="images"):
    df = pd.DataFrame(data=None, columns=["img_path", "vector"])
    df_index = 0
    
    bar = tqdm.tqdm(total=len(os.listdir(images_path)))
    for img in os.listdir(images_path):
        img_path = os.path.join(images_path, img)
        img = cv2.imread(img_path)
        
        img_vector = calculate_embedding_method(img)
        
        df.loc[df_index] = [img_path, img_vector]
        df_index += 1
        bar.update(1)
        
    df.to_pickle(output_file_name)

        
def sift_descriptors(img):
    img = cv2.resize(img, (512, 512))
    gray_img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
    
    sift = cv2.SIFT_create()
    _, img_descriptors = sift.detectAndCompute(gray_img, None)
    return img_descriptors


def color_histogram(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    
    result = []
    colors = ("red", "green", "blue")
    for channel_id, color in enumerate(colors):
        histogram, _ = np.histogram(img[:, :, channel_id], bins=256)
        histogram = histogram / np.linalg.norm(histogram)
        result.extend(histogram)
    return np.array(result)


def hog(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
    
    sobel_x = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=5)
    sobel_y = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=5)
    g, theta = cv2.cartToPolar(sobel_x, sobel_y)
    hist, _ = np.histogram(theta.flatten(), bins=256, range=(0, 2*np.pi), weights=g.flatten())
    hist = hist / np.linalg.norm(hist)
    
    return hist


def dino_embedding(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    with torch.no_grad():
        inputs = dino_processor(images=img, return_tensors="pt").to(device)
        outputs = dino_model(**inputs)
    features = outputs.last_hidden_state
    embedding = features.mean(dim=1).squeeze().cpu().detach().numpy()
    embedding = np.float32(embedding) / np.linalg.norm(embedding)
    
    return embedding


def vgg_16_embedding(img):
    img = cv2.resize(img, (214, 214))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    img = torch.tensor(img, dtype=torch.float)
    img = img.unsqueeze(0)
    
    embedding = vgg_model(img.permute(0, 3, 1, 2)).squeeze().cpu().detach().numpy()
    embedding = np.float32(embedding) / np.linalg.norm(embedding)
    
    return embedding

In [76]:
# RUN ONLY 1 TIME:
calculate_embeddings(sift_descriptors, "vectors/sift_vectors.pkl")
calculate_embeddings(color_histogram, "vectors/color_histogram_vectors.pkl")
calculate_embeddings(hog, "vectors/hog_vectors.pkl")
calculate_embeddings(dino_embedding, "vectors/dino_vectors.pkl")
calculate_embeddings(vgg_16_embedding, "vectors/vgg_16_vectors.pkl")



  0%|          | 0/4738 [00:00<?, ?it/s][A[A

  0%|          | 1/4738 [00:00<14:04,  5.61it/s][A[A

  0%|          | 2/4738 [00:00<13:27,  5.87it/s][A[A

  0%|          | 3/4738 [00:00<13:11,  5.99it/s][A[A

  0%|          | 4/4738 [00:00<13:03,  6.04it/s][A[A

  0%|          | 5/4738 [00:00<13:06,  6.02it/s][A[A

  0%|          | 6/4738 [00:01<13:10,  5.99it/s][A[A

  0%|          | 7/4738 [00:01<13:08,  6.00it/s][A[A

  0%|          | 8/4738 [00:01<13:07,  6.01it/s][A[A

  0%|          | 9/4738 [00:01<13:05,  6.02it/s][A[A

  0%|          | 10/4738 [00:01<13:08,  5.99it/s][A[A

  0%|          | 11/4738 [00:01<13:02,  6.04it/s][A[A

  0%|          | 12/4738 [00:02<13:04,  6.02it/s][A[A

  0%|          | 13/4738 [00:02<13:07,  6.00it/s][A[A

  0%|          | 14/4738 [00:02<13:13,  5.95it/s][A[A

  0%|          | 15/4738 [00:02<13:13,  5.95it/s][A[A

  0%|          | 16/4738 [00:02<13:14,  5.94it/s][A[A

  0%|          | 17/4738 [00:02<13:14,  5.94it/

In [77]:
def find_similar_images(image, method = sift_option):
    if method == sift_option:
        image = cv2.resize(image, (512, 512))
        images_embeddings = pd.read_pickle("vectors/sift_vectors.pkl")
        query_embedding = sift_descriptors(image)
    elif method == color_hist_option:
        image = cv2.resize(image, (512, 512))
        images_embeddings = pd.read_pickle("vectors/color_histogram_vectors.pkl")
        query_embedding = color_histogram(image)
    elif method == hog_option:
        image = cv2.resize(image, (512, 512))
        images_embeddings = pd.read_pickle("vectors/hog_vectors.pkl")
        query_embedding = hog(image)
    elif method == dino_option:
        image = cv2.resize(image, (512, 512))
        images_embeddings = pd.read_pickle("vectors/dino_vectors.pkl")
        query_embedding = dino_embedding(image)
    elif method == vgg_option:
        image = cv2.resize(image, (214, 214))
        images_embeddings = pd.read_pickle("vectors/vgg_16_vectors.pkl")
        query_embedding = vgg_16_embedding(image)
    else:
        raise ValueError("Unknown method")
        
    all_img_paths = images_embeddings["img_path"]
    all_embeddings = images_embeddings["vector"]
    top1_image_path = ""
    top2_image_path = ""
    top3_image_path = ""
    
    if method == sift_option:
        bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
        top1_matches = 0
        top2_matches = 0
        top3_matches = 0
        for i, vector in enumerate(all_embeddings):
            matches = bf.match(query_embedding, vector)
            num_matches = len(matches)
            if num_matches > top1_matches:
                top3_matches = top2_matches
                top2_matches = top1_matches
                top1_matches = num_matches
                top3_image_path = top2_image_path
                top2_image_path = top1_image_path
                top1_image_path = all_img_paths[i]
            elif num_matches > top2_matches:
                top3_matches = top2_matches
                top2_matches = num_matches
                top3_image_path = top2_image_path
                top2_image_path = all_img_paths[i]
            elif num_matches > top3_matches:
                top3_matches = num_matches
                top3_image_path = all_img_paths[i]

    else:
        top1_distance = float("inf")
        top2_distance = float("inf")
        top3_distance = float("inf")
        for i, vector in enumerate(all_embeddings):
            if method == dino_option or method == vgg_option:
                cosine_similarity = np.dot(query_embedding, vector)
                distance = 1 - cosine_similarity
            else:
                distance = np.linalg.norm(query_embedding - vector)
            if distance < top1_distance:
                top3_distance = top2_distance
                top2_distance = top1_distance
                top1_distance = distance
                top3_image_path = top2_image_path
                top2_image_path = top1_image_path
                top1_image_path = all_img_paths[i]
            elif distance < top2_distance:
                top3_distance = top2_distance
                top2_distance = distance
                top3_image_path = top2_image_path
                top2_image_path = all_img_paths[i]
            elif distance < top3_distance:
                top3_distance = distance
                top3_image_path = all_img_paths[i]

    top1_image = cv2.imread(top1_image_path)
    top1_image = cv2.cvtColor(top1_image, cv2.COLOR_BGR2RGB)
    top2_image = cv2.imread(top2_image_path)
    top2_image = cv2.cvtColor(top2_image, cv2.COLOR_BGR2RGB)
    top3_image = cv2.imread(top3_image_path)
    top3_image = cv2.cvtColor(top3_image, cv2.COLOR_BGR2RGB)
    
    return [top1_image, top2_image, top3_image]

In [83]:
import gradio as gr


def get_similar_images(image, method):
    similar_images = find_similar_images(image, method)
    images = []
    
    for image in similar_images:
        images.append(cv2.cvtColor(image, cv2.COLOR_BGR2RGB),)
    
    return similar_images

(gr.Interface(get_similar_images,
             inputs=["image",
                     gr.Dropdown(choices=
                                 [dino_option, vgg_option, hog_option, color_hist_option, sift_option])],
             outputs=gr.Gallery(label="Similar Images", columns=3))
            .launch())

* Running on local URL:  http://127.0.0.1:7865

To create a public link, set `share=True` in `launch()`.


