# Поиск похожих изображений по картинке

In [234]:
import cv2
import torch
import torchvision.models as models
from torchvision.models import VGG16_Weights, Inception_V3_Weights
from transformers import AutoImageProcessor, AutoModel


device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

IMAGES_FOLDER = "flickr30k_images"

CLIP_OPTION = "CLIP Model Embeddings"
INCEPTION_OPTION = "Inception Model Embeddings"
DINO_OPTION = "DINO Model Embeddings (Preferred for Images)"
VGG_OPTION = "VGG-16 Model Embeddings"
HOG_OPTION = "Histogram of Gradients Embeddings (HOG)"
COLOR_HIST_OPTION = "Color Histogram Embeddings"
SIFT_OPTION = "SIFT Embeddings"

In [235]:
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
dino_model = AutoModel.from_pretrained('facebook/dinov2-small').to(device)

In [236]:
vgg_model = models.vgg16(weights=VGG16_Weights.DEFAULT)
vgg_model.classifier = vgg_model.classifier[0]
vgg_model = vgg_model.to(device)

In [237]:
inception_model = models.inception_v3(weights=Inception_V3_Weights.DEFAULT)
inception_model.fc = torch.nn.Identity()
inception_model = inception_model.to(device)

In [238]:
import clip.clip as clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

In [239]:
from PIL import Image
import os
import numpy as np
import pandas as pd
import tqdm as tqdm
from torchvision import transforms


def calculate_embeddings(calculate_embedding_method, output_file_name, images_path=IMAGES_FOLDER):
    df = pd.DataFrame(data=None, columns=["img_path", "vector"])
    df_index = 0
    
    bar = tqdm.tqdm(total=len(os.listdir(images_path)))
    for img in os.listdir(images_path):
        img_path = os.path.join(images_path, img)
        img = cv2.imread(img_path)
        
        img_vector = calculate_embedding_method(img)
        
        df.loc[df_index] = [img_path, img_vector]
        df_index += 1
        bar.update(1)
        
    df.to_pickle(output_file_name)

        
def sift_descriptors(img):
    img = cv2.resize(img, (512, 512))
    gray_img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
    
    sift = cv2.SIFT_create()
    _, img_descriptors = sift.detectAndCompute(gray_img, None)
    return img_descriptors


def color_histogram(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    
    result = []
    colors = ("red", "green", "blue")
    for channel_id, color in enumerate(colors):
        histogram, _ = np.histogram(img[:, :, channel_id], bins=256)
        histogram = histogram / np.linalg.norm(histogram)
        result.extend(histogram)
    return np.array(result)


def hog(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
    
    sobel_x = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=5)
    sobel_y = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=5)
    g, theta = cv2.cartToPolar(sobel_x, sobel_y)
    hist, _ = np.histogram(theta.flatten(), bins=256, range=(0, 2*np.pi), weights=g.flatten())
    hist = hist / np.linalg.norm(hist)
    
    return hist


def dino_embedding(img):
    img = cv2.resize(img, (512, 512))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    with torch.no_grad():
        inputs = dino_processor(images=img, return_tensors="pt").to(device)
        outputs = dino_model(**inputs)
        
    features = outputs.last_hidden_state
    embedding = features.mean(dim=1).squeeze().cpu().detach().numpy()
    embedding = np.float32(embedding) / np.linalg.norm(embedding)
    
    return embedding


def vgg_16_embedding(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)

    preprocess = transforms.Compose([
        transforms.Resize(214),
        transforms.CenterCrop(214),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(img)
    input_tensor = input_tensor.unsqueeze(0)
    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        vgg_model.eval()
        embedding = vgg_model(input_tensor).squeeze().cpu().detach().numpy()

    embedding = np.float32(embedding) / np.linalg.norm(embedding)
    return embedding


def inception_embedding(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    
    preprocess = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(img)
    input_tensor = input_tensor.unsqueeze(0)
    input_tensor = input_tensor.to(device)
    
    with torch.no_grad():
        inception_model.eval()
        embedding = inception_model(input_tensor).squeeze().cpu().detach().numpy()

    embedding = np.float32(embedding) / np.linalg.norm(embedding)
    return embedding


def clip_embedding_img(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    
    img = clip_preprocess(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        image_features = clip_model.encode_image(img)
    
    return image_features.cpu().numpy()[0]


def clip_embedding_text(text):
    with torch.no_grad():
        text_features = clip_model.encode_text(clip.tokenize([text]).to(device))
    return text_features.cpu().numpy()[0]

In [240]:
# RUN ONLY 1 TIME:
if not os.path.isdir(f"{IMAGES_FOLDER}_vectors"):
    os.makedirs(f"{IMAGES_FOLDER}_vectors")

# Classic computer vision methods:
# calculate_embeddings(sift_descriptors, f"{IMAGES_FOLDER}_vectors/sift_vectors.pkl")
# calculate_embeddings(color_histogram, f"{IMAGES_FOLDER}_vectors/color_histogram_vectors.pkl")
# calculate_embeddings(hog, f"{IMAGES_FOLDER}_vectors/hog_vectors.pkl")

# Deep learning methods:
calculate_embeddings(dino_embedding, f"{IMAGES_FOLDER}_vectors/dino_vectors.pkl") # Preferred for images
calculate_embeddings(clip_embedding_img, f"{IMAGES_FOLDER}_vectors/clip_vectors.pkl") # Preferred for text
calculate_embeddings(vgg_16_embedding, f"{IMAGES_FOLDER}_vectors/vgg_16_vectors.pkl")
calculate_embeddings(inception_embedding, f"{IMAGES_FOLDER}_vectors/inception_vectors.pkl")






  0%|          | 0/31783 [00:00<?, ?it/s][A[A[A[A[A




  0%|          | 2/31783 [00:00<36:22, 14.56it/s][A[A[A[A[A




  0%|          | 5/31783 [00:00<27:49, 19.03it/s][A[A[A[A[A




  0%|          | 7/31783 [00:00<27:46, 19.06it/s][A[A[A[A[A




  0%|          | 10/31783 [00:00<25:53, 20.45it/s][A[A[A[A[A




  0%|          | 13/31783 [00:00<26:30, 19.97it/s][A[A[A[A[A




  0%|          | 15/31783 [00:00<26:51, 19.71it/s][A[A[A[A[A




  0%|          | 17/31783 [00:00<27:14, 19.43it/s][A[A[A[A[A




  0%|          | 19/31783 [00:00<27:28, 19.27it/s][A[A[A[A[A




  0%|          | 22/31783 [00:01<26:11, 20.21it/s][A[A[A[A[A




  0%|          | 25/31783 [00:01<26:06, 20.28it/s][A[A[A[A[A




  0%|          | 28/31783 [00:01<25:25, 20.81it/s][A[A[A[A[A




  0%|          | 31/31783 [00:01<25:21, 20.87it/s][A[A[A[A[A




  0%|          | 34/31783 [00:01<25:11, 21.01it/s][A[A[A[A[A




  0%|          | 37/31783 [

In [241]:
def find_similar_images(query, method = SIFT_OPTION, top_k = 30):
    if method == SIFT_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/sift_vectors.pkl")
        query_embedding = sift_descriptors(query)
    elif method == COLOR_HIST_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/color_histogram_vectors.pkl")
        query_embedding = color_histogram(query)
    elif method == HOG_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/hog_vectors.pkl")
        query_embedding = hog(query)
    elif method == DINO_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/dino_vectors.pkl")
        query_embedding = dino_embedding(query)
    elif method == VGG_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/vgg_16_vectors.pkl")
        query_embedding = vgg_16_embedding(query)
    elif method == INCEPTION_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/inception_vectors.pkl")
        query_embedding = inception_embedding(query)
    elif method == CLIP_OPTION:
        images_embeddings = pd.read_pickle(f"{IMAGES_FOLDER}_vectors/clip_vectors.pkl")
        query_embedding = clip_embedding_text(query)
    else:
        raise ValueError("Unknown method")
        
    all_img_paths = images_embeddings["img_path"]
    all_embeddings = images_embeddings["vector"]
    
    if method == SIFT_OPTION:
        all_num_matches = []
        
        bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
        for i, vector in enumerate(all_embeddings):
            matches = bf.match(query_embedding, vector)
            num_matches = len(matches)
            all_num_matches.append(num_matches)
            
        sorted_results = sorted(zip(all_num_matches, all_img_paths), key=lambda x: x[0], reverse=True)
        
    else:
        all_distances = []
        
        for i, vector in enumerate(all_embeddings):
            if method in [DINO_OPTION, VGG_OPTION, INCEPTION_OPTION, CLIP_OPTION]:
                cosine_similarity = (np.dot(query_embedding, vector) /
                                     (np.linalg.norm(query_embedding) * np.linalg.norm(vector)))
                
                distance = 1 - cosine_similarity
                if (method == DINO_OPTION and distance < 0.03) or (method != DINO_OPTION and distance < 0.1):
                    distance = 1
            else:
                distance = np.linalg.norm(query_embedding - vector)
            
            all_distances.append(distance)
        
        sorted_results = sorted(zip(all_distances, all_img_paths), key=lambda x: x[0])

    result_images = []
    for i in range(top_k):
        img_path = sorted_results[i][1]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        result_images.append(img)
    return result_images

In [242]:
import gradio as gr


def get_similar_images(image, slider_value, method = CLIP_OPTION):
    similar_images = find_similar_images(image, method, slider_value)
    images = []
    
    for image in similar_images:
        images.append(cv2.cvtColor(image, cv2.COLOR_BGR2RGB),)
    
    return similar_images


with gr.Blocks() as gradio_interface:
    choice = gr.Radio(["Image Input", "Text Input"], label="Choose Input Type", value="Image Input")

    with gr.Group(visible=True) as image_block:
        with gr.Row(variant="compact"):
            with gr.Column(variant="compact"):
                image_input = gr.Image(label="Upload Image", height=300)
                top_k_slider = gr.Slider(1, 300, 30, step=1, label="Similar Images Number")
                option_dropdown = gr.Dropdown(
                    choices=[DINO_OPTION, CLIP_OPTION, VGG_OPTION, INCEPTION_OPTION, HOG_OPTION, COLOR_HIST_OPTION, SIFT_OPTION],
                    label="Choose Embedding Method"
                )
                process_image_button = gr.Button("Find Similar Images")

            with gr.Column(variant="compact"):
                image_output = gr.Gallery(label="Similar Images", columns=3)

            process_image_button.click(
                get_similar_images,
                inputs=[image_input, top_k_slider, option_dropdown],
                outputs=[image_output]
            )

    with gr.Group(visible=False) as text_block:
        with gr.Row(variant="compact"):
            with gr.Column(variant="compact"):
                text_input = gr.Textbox(label="Enter Text")
                top_k_slider = gr.Slider(1, 300, 30, step=1, label="Similar Images Number")
                process_text_button = gr.Button("Submit Text")
            
            with gr.Column(variant="compact"):
                image_output = gr.Gallery(label="Similar Images", columns=3)

        process_text_button.click(
            get_similar_images,
            inputs=[text_input, top_k_slider],
            outputs=[image_output]
        )

    def update_interface(choice):
        if choice == "Text Input":
            return gr.update(visible=False), gr.update(visible=True)
        elif choice == "Image Input":
            return gr.update(visible=True), gr.update(visible=False)

    choice.change(
        update_interface,
        inputs=[choice],
        outputs=[image_block, text_block]
    )

gradio_interface.launch()

* Running on local URL:  http://127.0.0.1:7902

To create a public link, set `share=True` in `launch()`.


