In [46]:
from transformers import CLIPProcessor, CLIPModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import os
import numpy as np

# Initialize CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# Function to extract text features from the query text
def extract_text_features(query_text):
    # Preprocess the text input
    inputs = processor(text=query_text, return_tensors="pt", padding=True)
    
    # Get the text features from CLIP model
    outputs = model.get_text_features(**inputs)  # Use get_text_features to get the text features
    text_features = outputs  # Text features
    
    return text_features

# Function to extract image features from meme images
def extract_image_features(meme_image, query_text):
    # Process the meme image along with the query text to get image features
    inputs = processor(text=query_text, images=meme_image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    meme_image_features = outputs.image_embeds  # Image features
    return meme_image_features

# Function to query memes based on text similarity
def query_memes_by_text(query_text, meme_folder, top_n=5, similarity_weight=0.5):
    # Extract features for the query text
    query_text_features = extract_text_features(query_text)
    
    # Initialize a list to store meme image features
    image_features = []
    meme_files = []

    # Loop through all meme images in the folder and extract their features using the query text
    for filename in os.listdir(meme_folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
            meme_path = os.path.join(meme_folder, filename)
            meme_image = Image.open(meme_path)  # Open the image file using PIL
            
            # Extract image features for the meme
            meme_image_features = extract_image_features(meme_image, query_text)
            image_features.append(meme_image_features)
            meme_files.append(filename)

    # Convert list of image features to a tensor
    image_features = torch.stack(image_features).squeeze(1)

    # Step 1: Calculate cosine similarity between the query text and meme image features
    text_similarities = cosine_similarity(query_text_features.detach().numpy(), image_features.detach().numpy()).flatten()

    # Step 2: Calculate cosine similarity between the query image and meme image features
    image_similarities = cosine_similarity(query_text_features.detach().numpy(), image_features.detach().numpy()).flatten()

    # Combine text and image similarities using weighted sum
    combined_similarities = (similarity_weight * text_similarities + (1 - similarity_weight) * image_similarities)

    # Step 3: Get the top N memes based on combined similarity
    top_indices = combined_similarities.argsort()[-top_n:][::-1]  # Indices of top N matches
    top_memes = [(meme_files[i], combined_similarities[i]) for i in top_indices]

    return top_memes

# Example usage
query_text = "Cat"  # Example query text
meme_folder = "C:/Users/mahed/OneDrive/Desktop/The Meme Files-20241214T145346Z-001/The Meme Files"  # Path to the folder containing meme images

# Get top matching memes based on text similarity
image_results = query_memes_by_text(query_text, meme_folder)

# Print the top memes based on the query text
print("Top memes based on text query:")
for image, score in image_results:
    print(f"Image: {image}, Similarity: {score:.4f}")


Top memes based on text query:
Image: 14 - 1(2).png, Similarity: 0.2910
Image: 912488_170x100.png, Similarity: 0.2811
Image: miVWSLVYCZoeUs6hoaSGXgbq_tEqpOiHMi7WnOQSEcdJ3npMk-djfRe4qKv4etX2aikXt9U64xzDrUdEr9IqXH_0JiPCuITESiBo2GMf=w346-h490.jpg, Similarity: 0.2786
Image: MvOXQLkhcWM2ytqq27txg_U9JBUeQsDlGJX8MnarmIDB9y4XHuIzDkdtf9tOY52_zOehyXZm3Wjt5pSjy8eG74nOOlgM-phxFULh8zWlVaUp_H99=w346-h468.jpg, Similarity: 0.2744
Image: images-17.jpeg, Similarity: 0.2737
