In [1]:
%pip install datasets transformers torch

Note: you may need to restart the kernel to use updated packages.


In [4]:
from transformers import AutoFeatureExtractor, AutoModel
import pickle
model_id = 'google/vit-base-patch16-224-in21k'
extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
hidden_dim = model.config.hidden_size

with open("model.pkl", "wb") as f:
    pickle.dump(model, f)
with open("extractor.pkl", "wb") as f:
    pickle.dump(extractor, f)
with open("hidden_dim.pkl", "wb") as f:
    pickle.dump(hidden_dim, f)



In [3]:
import torch
from PIL import Image
import requests
import torchvision.transforms as transforms
from io import BytesIO
from sklearn.metrics.pairwise import cosine_similarity

device = "cuda" if torch.cuda.is_available() else "cpu"

def get_image_embeddings(image_url1, image_url2, model):
    response1 = requests.get(image_url1)
    image1 = Image.open(BytesIO(response1.content))
    image1 = image1.convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ])
    img_tensor1 = transform(image1).unsqueeze(0).to(device)

    response2 = requests.get(image_url2)
    image2 = Image.open(BytesIO(response2.content))
    image2 = image2.convert("RGB")
    img_tensor2 = transform(image2).unsqueeze(0).to(device)

    with torch.no_grad():
        embeddings1 = model(pixel_values=img_tensor1).last_hidden_state[:, 0].cpu().numpy()
        embeddings2 = model(pixel_values=img_tensor2).last_hidden_state[:, 0].cpu().numpy()

    embeddings1 = embeddings1.reshape(1, -1)
    embeddings2 = embeddings2.reshape(1, -1)

    similarity = cosine_similarity(embeddings1, embeddings2)[0][0]

    return similarity

image_url1 = "https://res.cloudinary.com/ddospzdve/image/upload/v1694427471/mhkdootcnybklimno7rj.jpg"
image_url2 = "https://res.cloudinary.com/ddospzdve/image/upload/v1694417449/lkgftzx5b8hrwkxqqvwf.jpg"
model.to(device)
similarity = get_image_embeddings(image_url1, image_url2, model)
print("Cosine Similarity:", similarity)

Cosine Similarity: 0.60953164
