In [15]:
import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import json
from tqdm import tqdm

In [27]:
class CLIPModelHandler:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)

    def get_image_embedding(self, image):
        # Process image
        inputs = self.processor(images=image, return_tensors="pt")
        with torch.no_grad():
            embeddings = self.model.get_image_features(**inputs)
        return embeddings.cpu().numpy()

    def get_text_embedding(self, text):
        # Process text
        inputs = self.processor(text=[text], return_tensors="pt", padding=True)
        with torch.no_grad():
            embeddings = self.model.get_text_features(**inputs)
        return embeddings.cpu().numpy()

class ImageDatabase:
    def __init__(self, db_file="image_db.json"):
        self.db_file = db_file
        self.image_data = self.load_db()

    def load_db(self):
        if os.path.exists(self.db_file):
            with open(self.db_file, "r") as file:
                return json.load(file)
        else:
            return {}

    def save_db(self):
        with open(self.db_file, "w") as file:
            json.dump(self.image_data, file)

    def add_image(self, image_path, embedding):
        self.image_data[image_path] = {
            "embedding": embedding.tolist(),
            "image_path": image_path
        }

    def get_all_embeddings(self):
        return np.array([data["embedding"] for data in self.image_data.values()])

    def get_image_paths(self):
        return list(self.image_data.keys())

class RAGPipeline:
    def __init__(self, model_handler, image_db):
        self.model_handler = model_handler
        self.image_db = image_db

    def store_image_embeddings(self, image_dir):
        # Process all images in the directory
        for image_name in tqdm(os.listdir(image_dir)):
            if image_name.endswith(".webp"):
                if image_name in self.image_db.image_data:
                    continue
                image_path = os.path.join(image_dir, image_name)
                image = Image.open(image_path)
                embedding = self.model_handler.get_image_embedding(image)
                self.image_db.add_image(image_path, embedding)
        self.image_db.save_db()

    def retrieve_top_k(self, text_description,top_k=5):
        # Get text embedding
        text_embedding = np.array(self.model_handler.get_text_embedding(text_description))[0]

        # Get all image embeddings
        image_embeddings = np.array(self.image_db.get_all_embeddings())[:,0]

        # Compute cosine similarity
        similarities = cosine_similarity([text_embedding], image_embeddings).flatten()

        # Get top-k indices
        top_k_indices = similarities.argsort()[-top_k:][::-1]

        # Retrieve top-k image paths
        image_paths = np.array(self.image_db.get_image_paths())[top_k_indices]

        return image_paths, similarities[top_k_indices]
    
    def retrieve_top_k_with_image(self, image_path,top_k=5):
        # Get image embedding
        image = Image.open(image_path)
        image_embedding = self.model_handler.get_image_embedding(image)

        # Get all image embeddings
        image_embeddings = np.array(self.image_db.get_all_embeddings())[:,0]

        # Compute cosine similarity
        similarities = cosine_similarity(image_embedding, image_embeddings).flatten()

        # Get top-k indices
        top_k_indices = similarities.argsort()[-top_k:][::-1]

        # Retrieve top-k image paths
        image_paths = np.array(self.image_db.get_image_paths())[top_k_indices]

        return image_paths, similarities[top_k_indices]

In [None]:
model_handler = CLIPModelHandler()

In [33]:
cosine_similarity(
    model_handler.get_text_embedding("Karim KARIM KaRiM"),
    model_handler.get_text_embedding("Karim is ")
)

array([[0.9290306]], dtype=float32)

In [None]:
def teach_with_image_caption(image , caption):
    

In [18]:
image_db = ImageDatabase("image_db.json")
rag_pipeline = RAGPipeline(model_handler, image_db, top_k=20)
# rag_pipeline.store_image_embeddings("./WhatsApp-Stickers")

In [26]:
# Retrieve top-k images for a given text description
text_description = "Mark Zuckerberg good job team"
top_k_images, top_k_similarities = rag_pipeline.retrieve_top_k(text_description)

for img_path, similarity in zip(top_k_images, top_k_similarities):
    print(f"Image: {img_path}, Similarity: {similarity}")

(512,)
(8126, 512)
Image: ./WhatsApp-Stickers/STK-20210408-WA0050.webp, Similarity: 0.31752158451610424
Image: ./WhatsApp-Stickers/STK-20240709-WA0014.webp, Similarity: 0.30731167439716567
Image: ./WhatsApp-Stickers/STK-20240419-WA0013.webp, Similarity: 0.27728643630470473
Image: ./WhatsApp-Stickers/STK-20240629-WA0004.webp, Similarity: 0.27726663677550034
Image: ./WhatsApp-Stickers/STK-20230911-WA0002.webp, Similarity: 0.2692796310062471
Image: ./WhatsApp-Stickers/STK-20221005-WA0002.webp, Similarity: 0.26872216333119603
Image: ./WhatsApp-Stickers/STK-20201123-WA0062.webp, Similarity: 0.2684034479995195
Image: ./WhatsApp-Stickers/STK-20210702-WA0004.webp, Similarity: 0.2674686101165349
Image: ./WhatsApp-Stickers/STK-20210113-WA0012.webp, Similarity: 0.2663539763018652
Image: ./WhatsApp-Stickers/STK-20220425-WA0022.webp, Similarity: 0.26552057978475896
Image: ./WhatsApp-Stickers/STK-20210522-WA0001.webp, Similarity: 0.26482932589941727
Image: ./WhatsApp-Stickers/STK-20211214-WA0027.web

In [50]:
# Target image path
image_path = "./WhatsApp-Stickers/STK-20240922-WA0058.webp"
top_k_images, top_k_similarities = rag_pipeline.retrieve_top_k_with_image(image_path)

for img_path, similarity in zip(top_k_images, top_k_similarities):
    print(f"Image: {img_path}, Similarity: {similarity}")

Image: ./WhatsApp-Stickers/STK-20240922-WA0058.webp, Similarity: 1.0000000000000013
Image: ./WhatsApp-Stickers/STK-20210408-WA0002.webp, Similarity: 0.991802594534197
Image: ./WhatsApp-Stickers/STK-20210411-WA0007.webp, Similarity: 0.7937019988029289
Image: ./WhatsApp-Stickers/STK-20240912-WA0037.webp, Similarity: 0.7806971058890797
Image: ./WhatsApp-Stickers/STK-20210424-WA0054.webp, Similarity: 0.7701733408645601
Image: ./WhatsApp-Stickers/STK-20211118-WA0009.webp, Similarity: 0.7595438039673797
Image: ./WhatsApp-Stickers/STK-20210415-WA0013.webp, Similarity: 0.7544395148228715
Image: ./WhatsApp-Stickers/STK-20210408-WA0028.webp, Similarity: 0.7524808492695164
Image: ./WhatsApp-Stickers/STK-20210411-WA0005.webp, Similarity: 0.7489389823805476
Image: ./WhatsApp-Stickers/STK-20210411-WA0006.webp, Similarity: 0.7458556883054313
Image: ./WhatsApp-Stickers/STK-20210112-WA0035.webp, Similarity: 0.7355798537040582
Image: ./WhatsApp-Stickers/STK-20210415-WA0017.webp, Similarity: 0.7350884189