# Поиск похожих изображений по картинке

In [1]:
import cv2

Реализовано несколько способов поиска на основе:
<ol>
    <li>гистограммы цветов;</li>
    <li>гистограммы градиентов;</li>
    <li>SIFT</li>
</ol>
</div>
<div>
    Данные - Content Based Image Retrieval (CBIR) <a href="https://www.kaggle.com/datasets/theaayushbajaj/cbir-dataset">dataset</a> 
</div>

In [2]:
import os
import numpy as np
import pandas as pd


def calculate_vectors(calculate_vector_method, output_file_name, images_path="images"):
    df = pd.DataFrame(data=None, columns=["img_path", "vector"])
    df_index = 0
    
    for img in os.listdir(images_path):
        img_path = os.path.join(images_path, img)
        img = cv2.imread(img_path)
        img_vector = calculate_vector_method(img)
        df.loc[df_index] = [img_path, img_vector]
        df_index += 1
        
    df.to_pickle(output_file_name)

        
def sift_descriptors(img):
    gray_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
    
    sift = cv2.SIFT_create()
    _, img_descriptors = sift.detectAndCompute(gray_img, None)
    return img_descriptors


def color_histogram(img):
    img = np.array(img)
    result = []
    colors = ("red", "green", "blue")
    for channel_id, color in enumerate(colors):
        histogram, _ = np.histogram(img[:, :, channel_id], bins=256)
        histogram = histogram / np.linalg.norm(histogram)
        result.extend(histogram)
    return np.array(result)


def hog(img):
    img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
    
    sobel_x = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=5)
    sobel_y = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=5)
    g, theta = cv2.cartToPolar(sobel_x, sobel_y)
    hist, _ = np.histogram(theta.flatten(), bins=256, range=(0, 2*np.pi), weights=g.flatten())
    hist = hist / np.linalg.norm(hist)
    
    return hist


# calculate_vectors(sift_descriptors, "vectors/sift_vectors.pkl")
# calculate_vectors(color_histogram, "vectors/color_histogram_vectors.pkl")
# calculate_vectors(hog, "vectors/hog_vectors.pkl")

In [3]:
def find_similar_images(image, method="sift"):
    if method == "sift":
        images_vectors = pd.read_pickle("vectors/sift_vectors.pkl")
        query_vector = sift_descriptors(image)
    elif method == "color_histogram":
        images_vectors = pd.read_pickle("vectors/color_histogram_vectors.pkl")
        query_vector = color_histogram(image)
    elif method == "hog":
        images_vectors = pd.read_pickle("vectors/hog_vectors.pkl")
        query_vector = hog(image)
    else:
        raise ValueError("Unknown method")
        
    all_img_paths = images_vectors["img_path"]
    all_vectors = images_vectors["vector"]
    top1_image_path = ""
    top2_image_path = ""
    top3_image_path = ""
    
    if method == "sift":
        bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
        top1_matches = 0
        top2_matches = 0
        top3_matches = 0
        for i, vector in enumerate(all_vectors):
            matches = bf.match(query_vector, vector)
            num_matches = len(matches)
            if num_matches > top1_matches:
                top3_matches = top2_matches
                top2_matches = top1_matches
                top1_matches = num_matches
                top3_image_path = top2_image_path
                top2_image_path = top1_image_path
                top1_image_path = all_img_paths[i]
            elif num_matches > top2_matches:
                top3_matches = top2_matches
                top2_matches = num_matches
                top3_image_path = top2_image_path
                top2_image_path = all_img_paths[i]
            elif num_matches > top3_matches:
                top3_matches = num_matches
                top3_image_path = all_img_paths[i]
                
        top1_image = cv2.imread(top1_image_path)
        top1_image = cv2.cvtColor(top1_image, cv2.COLOR_BGR2RGB)
        top2_image = cv2.imread(top2_image_path)
        top2_image = cv2.cvtColor(top2_image, cv2.COLOR_BGR2RGB)
        top3_image = cv2.imread(top3_image_path)
        top3_image = cv2.cvtColor(top3_image, cv2.COLOR_BGR2RGB)
        return [top1_image, top2_image, top3_image]

    elif method == "color_histogram" or method == "hog":
        top1_distance = float("inf")
        top2_distance = float("inf")
        top3_distance = float("inf")
        for i, vector in enumerate(all_vectors):
            distance = np.linalg.norm(query_vector - vector)
            if distance < top1_distance:
                top3_distance = top2_distance
                top2_distance = top1_distance
                top1_distance = distance
                top3_image_path = top2_image_path
                top2_image_path = top1_image_path
                top1_image_path = all_img_paths[i]
            elif distance < top2_distance:
                top3_distance = top2_distance
                top2_distance = distance
                top3_image_path = top2_image_path
                top2_image_path = all_img_paths[i]
            elif distance < top3_distance:
                top3_distance = distance
                top3_image_path = all_img_paths[i]
                
        top1_image = cv2.imread(top1_image_path)
        top1_image = cv2.cvtColor(top1_image, cv2.COLOR_BGR2RGB)
        top2_image = cv2.imread(top2_image_path)
        top2_image = cv2.cvtColor(top2_image, cv2.COLOR_BGR2RGB)
        top3_image = cv2.imread(top3_image_path)
        top3_image = cv2.cvtColor(top3_image, cv2.COLOR_BGR2RGB)
        return [top1_image, top2_image, top3_image]
    
    else:
        raise ValueError("Unknown method")

# chosen_img = "images/0.jpg"
# img = cv2.imread(chosen_img)
# chosen_method = "sift"
# find_similar_images(img, "sift")

In [6]:
import gradio as gr

def get_similar_images(image, method):
    image = cv2.resize(image, (512, 512))
    similar_images = find_similar_images(image, method)
    return [gr.Image(similar_images[i]) for i in range(3)]

gr.Interface(get_similar_images, 
             inputs=["image", gr.Dropdown(choices=["sift", "color_histogram", "hog"])], 
             outputs=["image", "image", "image"]).launch()

* Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


