In [17]:
!python --version

Python 3.12.3


In [19]:
!pip install torch torchvision transformers faiss-cpu pillow




In [21]:
import os
import torch
import faiss
import numpy as np
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt 

In [None]:
#Loading of the Contrastive Language-Image Model 

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [None]:

IMAGE_DIR = r"C:\Users\phiwe\Downloads\test_data_v2"

def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return processor(images=image, return_tensors="pt")["pixel_values"].to(device)


image_paths = [os.path.join(IMAGE_DIR, img) for img in os.listdir(IMAGE_DIR) if img.endswith(('png', 'jpg', 'jpeg'))]
image_paths = image_paths[:500] 
image_embeddings = []

with torch.no_grad():
    for img_path in image_paths:
        img_tensor = preprocess_image(img_path)
        img_embedding = model.get_image_features(img_tensor)
        image_embeddings.append(img_embedding.cpu().numpy())

image_embeddings = np.vstack(image_embeddings) 


In [None]:

dimension = image_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(image_embeddings) 


In [None]:
import pickle


with open("image_embeddings.pkl", "wb") as f:
    pickle.dump((image_embeddings, image_paths), f)


with open("image_embeddings.pkl", "rb") as f:
    image_embeddings, image_paths = pickle.load(f)

index = faiss.IndexFlatL2(image_embeddings.shape[1])
index.add(image_embeddings)

In [None]:

index = faiss.IndexFlatL2(image_embeddings.shape[1])
index.add(image_embeddings.astype(np.float32))


In [None]:
def search_images(query, top_k=5):
   
    inputs = processor(text=[query], return_tensors="pt").to(device)
    with torch.no_grad():
        query_embedding = model.get_text_features(**inputs).cpu().numpy()
    
    # Search in FAISS
    distances, indices = index.search(query_embedding, top_k)
    
    fig, axes = plt.subplots(1, top_k, figsize=(15, 5))
    for i, idx in enumerate(indices[0]):
        img = Image.open(image_paths[idx])
        axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(f"Rank {i+1}")
    
    plt.show()



In [None]:
import ipywidgets as widgets
from IPython.display import display

query_input = widgets.Text(
    placeholder="Enter an image description...",
    layout=widgets.Layout(width='50%')
)


search_button = widgets.Button(description="Search")

output = widgets.Output()

def on_search_clicked(b):
    with output:
        output.clear_output()
        search_images(query_input.value, top_k=5)


search_button.on_click(on_search_clicked)


display(query_input, search_button, output)


Text(value='', layout=Layout(width='50%'), placeholder='Enter an image description...')

Button(description='Search', style=ButtonStyle())

Output()