# Zero-Shot CLIP Inference with OpenVINO

This notebook performs image and text-based search using zero-shot CLIP model optimized with OpenVINO.

In [None]:
import os
import numpy as np
import pickle
from pathlib import Path
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from openvino.runtime import Core
import faiss
import matplotlib.pyplot as plt

## Configuration

In [None]:
MODEL_DIR = Path(r"e:\Projects\AI Based\RecTrio\V2\models")
VECTOR_DB_DIR = Path(r"e:\Projects\AI Based\RecTrio\V2\vector_db")

CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
VISION_MODEL_PATH = MODEL_DIR / "clip_vision_model.xml"
TEXT_MODEL_PATH = MODEL_DIR / "clip_text_model.xml"
FAISS_INDEX_PATH = VECTOR_DB_DIR / "faiss_index.bin"
METADATA_PATH = VECTOR_DB_DIR / "metadata.pkl"

TOP_K = 10

## Load Processor and Model Components

In [None]:
print("Loading CLIP processor...")
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)

print("Loading CLIP model for projections...")
model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
model.eval()

visual_projection = model.visual_projection.weight.detach().numpy().T
text_projection = model.text_projection.weight.detach().numpy().T

print("Processor and projections loaded successfully")

## Load OpenVINO Models

In [None]:
core = Core()

print("Loading vision model...")
vision_compiled_model = core.compile_model(str(VISION_MODEL_PATH), "CPU")
vision_input_layer = vision_compiled_model.input(0)
vision_output_layer = vision_compiled_model.output(0)

print("Loading text model...")
text_compiled_model = core.compile_model(str(TEXT_MODEL_PATH), "CPU")
text_output_layer = text_compiled_model.output(0)

print("OpenVINO models loaded on CPU")

## Load FAISS Index and Metadata

In [None]:
print("Loading FAISS index...")
index = faiss.read_index(str(FAISS_INDEX_PATH))
print(f"Loaded index with {index.ntotal} vectors")

print("Loading metadata...")
with open(METADATA_PATH, 'rb') as f:
    metadata = pickle.load(f)

image_paths = metadata['image_paths']
print(f"Loaded {len(image_paths)} image paths")

## Embedding Functions

In [None]:
def get_image_embedding(image_input):
    if isinstance(image_input, str):
        image = Image.open(image_input).convert('RGB')
    else:
        image = image_input.convert('RGB')
    
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].numpy()
    
    result = vision_compiled_model([pixel_values])[vision_output_layer]
    pooled_output = result[0]
    
    embedding = np.dot(pooled_output, visual_projection)
    embedding = embedding / np.linalg.norm(embedding)
    
    return embedding.astype('float32')

def get_text_embedding(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    input_ids = inputs['input_ids'].numpy()
    attention_mask = inputs['attention_mask'].numpy()
    
    result = text_compiled_model({'input_ids': input_ids, 'attention_mask': attention_mask})[text_output_layer]
    pooled_output = result[0]
    
    embedding = np.dot(pooled_output, text_projection)
    embedding = embedding / np.linalg.norm(embedding)
    
    return embedding.astype('float32')

## Search Function

In [None]:
def search_similar_images(query_embedding, top_k=TOP_K):
    query_embedding = query_embedding.reshape(1, -1)
    
    distances, indices = index.search(query_embedding, top_k)
    
    results = []
    for idx, dist in zip(indices[0], distances[0]):
        results.append({
            'path': image_paths[idx],
            'similarity': float(dist)
        })
    
    return results

## Visualization Function

In [None]:
def display_results(results, query_info=None):
    n_results = len(results)
    cols = 5
    rows = (n_results + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    axes = axes.flatten() if n_results > 1 else [axes]
    
    for idx, result in enumerate(results):
        img = Image.open(result['path'])
        axes[idx].imshow(img)
        axes[idx].axis('off')
        axes[idx].set_title(f"Similarity: {result['similarity']:.4f}", fontsize=10)
    
    for idx in range(n_results, len(axes)):
        axes[idx].axis('off')
    
    if query_info:
        fig.suptitle(f"Query: {query_info}", fontsize=14, y=1.00)
    
    plt.tight_layout()
    plt.show()

## Image-Based Search

In [None]:
query_image_path = r"e:\Projects\AI Based\RecTrio\V1\datasets\animals\raw-img\cat\1.jpeg"

print(f"Searching for images similar to: {query_image_path}")
query_embedding = get_image_embedding(query_image_path)
results = search_similar_images(query_embedding)

print(f"\nTop {len(results)} similar images:")
for i, result in enumerate(results, 1):
    print(f"{i}. {result['path']} (Similarity: {result['similarity']:.4f})")

display_results(results, query_info=f"Image: {Path(query_image_path).name}")

## Text-Based Search

In [None]:
query_text = "a photo of a cat"

print(f"Searching for: '{query_text}'")
query_embedding = get_text_embedding(query_text)
results = search_similar_images(query_embedding)

print(f"\nTop {len(results)} similar images:")
for i, result in enumerate(results, 1):
    print(f"{i}. {result['path']} (Similarity: {result['similarity']:.4f})")

display_results(results, query_info=f"Text: '{query_text}'")

## Interactive Search

In [None]:
def interactive_search():
    print("Choose search type:")
    print("1. Image search")
    print("2. Text search")
    
    choice = input("Enter choice (1 or 2): ").strip()
    
    if choice == "1":
        image_path = input("Enter image path: ").strip()
        if not os.path.exists(image_path):
            print("Image not found!")
            return
        
        print(f"\nSearching for images similar to: {image_path}")
        query_embedding = get_image_embedding(image_path)
        results = search_similar_images(query_embedding)
        
        print(f"\nTop {len(results)} similar images:")
        for i, result in enumerate(results, 1):
            print(f"{i}. {result['path']} (Similarity: {result['similarity']:.4f})")
        
        display_results(results, query_info=f"Image: {Path(image_path).name}")
        
    elif choice == "2":
        text_query = input("Enter search text: ").strip()
        
        print(f"\nSearching for: '{text_query}'")
        query_embedding = get_text_embedding(text_query)
        results = search_similar_images(query_embedding)
        
        print(f"\nTop {len(results)} similar images:")
        for i, result in enumerate(results, 1):
            print(f"{i}. {result['path']} (Similarity: {result['similarity']:.4f})")
        
        display_results(results, query_info=f"Text: '{text_query}'")
    else:
        print("Invalid choice!")

interactive_search()