In [None]:
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image_paths = ['./sample_data/mix.png' ,'./sample_data/whiteshoes.png','./sample_data/whiteshoesdouble.png','./sample_data/newwhite.png', './sample_data/whiteshoesopposite.png' ,'./sample_data/image1.png', './sample_data/image2.png', './sample_data/image3.png', './sample_data/image4.png', './sample_data/image5.png']
new_image_path = './sample_data/image5.png'  

def get_features_from_image_path(image_paths):
    images = [preprocess(Image.open(image_path).convert("RGB")) for image_path in image_paths]
    image_input = torch.stack(images).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_input).float()
    return image_features
def get_feature_for_new_image(new_image_path):
    image = preprocess(Image.open(new_image_path).convert("RGB"))
    image = image.unsqueeze(0).to(device)
    with torch.no_grad():
        image_feature = model.encode_image(image).float()
    return image_feature.squeeze(0)

def cosine_similarity(v1, v2):
    return F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0)).item()

image_features = get_features_from_image_path(image_paths)
new_image_feature = get_feature_for_new_image(new_image_path)

similarities = []
for idx, feature in enumerate(image_features):
    sim = cosine_similarity(feature, new_image_feature)
    similarities.append((idx, sim))

for idx, sim in similarities:
    print(f"Similarity with image {image_paths[idx]}: {sim:.4f}")

threshold = 0.63
matching_images = [(idx, sim) for idx, sim in similarities if sim > threshold]

if matching_images:
    new_image = Image.open(new_image_path).convert("RGB")
    fig, axes = plt.subplots(1, len(matching_images) + 1, figsize=(5 * (len(matching_images) + 1), 5))
    axes[0].imshow(new_image)
    axes[0].axis('off')
    axes[0].set_title("New Image")
    for i, (idx, sim) in enumerate(matching_images):
        matched_image = Image.open(image_paths[idx]).convert("RGB")
        axes[i + 1].imshow(matched_image)
        axes[i + 1].axis('off')
        axes[i + 1].set_title(f"Sim: {sim:.2f}")

    plt.tight_layout()
    plt.show()
else:
    print("No images found with similarity greater than 0.90.")
