In [12]:
# Import required libraries
import json
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import faiss
from io import BytesIO
import requests
import logging

In [3]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [5]:
#  Upload Product Catalog and remove NaN rows
products = pd.read_csv('/content/drive/MyDrive/WizCommerce_Assignment/product-catalog.csv')
products.dropna(subset=['description', 'primary_image'], inplace = True)
print("Number of Products in the Catalog: {}".format(len(products)))
products.head()

Number of Products in the Catalog: 499


Unnamed: 0,id,sku_id,description,primary_image
0,100187e3-9203-4bea-8753-91172429f571,00585-01,8X7X3 RSN TRI GEO ELEPHANT 2/A BLK/WHT,https://fileserver-g-p.sourcerer.tech/files/f0...
2,74a475fc-96a2-4e54-a739-466939b9bb4c,10483,"GARDEN STOOL, CERAMIC, PEARL SNAKE",https://fileserver-g-p.sourcerer.tech/files/7b...
3,f9b3eee8-8232-4709-a63c-4b8fe6a434ef,10483D,"CERAMIC GARDEN STOOL, PEARL SNAKE",https://fileserver-g-p.sourcerer.tech/files/c2...
4,e6294cfe-22c5-4827-b7b2-00d62302631b,10484,COVERED JAR CERAMIC - PEARL SNAKE,https://fileserver-g-p.sourcerer.tech/files/21...
5,d42216cb-7b46-43d8-b6a6-467c1e4a4a9c,10484D,"COVERED JAR CERAMIC, PEARL SNAKE",https://fileserver-g-p.sourcerer.tech/files/47...


In [6]:
# Initialize FAISS index
dimension = 512  # CLIP embedding dimension
index = faiss.IndexFlatIP(dimension)  # Inner product (cosine similarity)

In [7]:
# Generate product embeddings
product_image_embeddings = []
for url in products['primary_image']:
  response = requests.get(url)
  product_image = Image.open(BytesIO(response.content))

  inputs = processor(images = product_image, return_tensors = "pt")
  with torch.no_grad():
    product_img_embedding = model.get_image_features(**inputs).numpy()
    product_image_embeddings.append(product_img_embedding)

    
# Process descriptions and generate embeddings
descriptions = products['description'].tolist()
inputs = processor(text = descriptions, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
  text_embeddings = model.get_text_features(**inputs).numpy()

# Combine embeddings (average of image and text)
combined_embeddings = [(img + txt) / 2 for img, txt in zip(product_image_embeddings, text_embeddings)]
combined_embeddings = np.array(combined_embeddings).astype('float32')
combined_embeddings_reshape = combined_embeddings.reshape(combined_embeddings.shape[0], -1)

# Add to FAISS index
faiss.normalize_L2(combined_embeddings_reshape)   #Normalize the embedding vector to get scores between 0-1
index.add(combined_embeddings_reshape)
logger.info(f"Processed {len(products)} products into FAISS index")

In [10]:
# Generate scene embedding and match products to a scene image to find how aptly they fit the scene
def match_scene(scene_image_path, top_k):
  scene_image = Image.open(scene_image_path)
  inputs = processor(images=scene_image, return_tensors="pt")
  with torch.no_grad():
    scene_embedding = model.get_image_features(**inputs).numpy().astype('float32')

  faiss.normalize_L2(scene_embedding)   ##Normalize the embedding vector to get scores between 0-1

  # Search FAISS index
  scores, indices = index.search(scene_embedding, top_k)

  results = []
  for score, idx in zip(scores[0], indices[0]):
    results.append({
        "id": products['id'][idx],
        "score": float(score),
        "image_url": products['primary_image'][idx],
        "description": products['description'][idx]
        })

  return {
      "scene": scene_image_path,
      "matches": results
      }

In [13]:
if __name__ == "__main__":
  scene_image_path = '/content/drive/MyDrive/WizCommerce_Assignment/scene-1.jpg'
  result = match_scene(scene_image_path, top_k=3)
  print(json.dumps(result, indent=2))

{
  "scene": "/content/drive/MyDrive/WizCommerce_Assignment/scene-1.jpg",
  "matches": [
    {
      "id": "9d6223a2-85b5-457a-8d1c-88be0eb6b4d2",
      "score": 0.6211128830909729,
      "image_url": "https://fileserver-g-p.sourcerer.tech/files/33666e04-0ebe-4398-87c3-9e4a12009af3",
      "description": "METAL CIRCLE SCULPTURE MIRROR,GOLD"
    },
    {
      "id": "c4175cae-af96-4f7d-a653-29f89c9efff5",
      "score": 0.6160077452659607,
      "image_url": "https://fileserver-g-p.sourcerer.tech/files/9bfa2706-8995-4f3d-bf40-e7ecaf7d1ef7",
      "description": "METAL GEO FRAME MIRROR, GOLD,WINDOW BOX"
    },
    {
      "id": "4dcb8874-8692-4082-987a-17c5f771d4af",
      "score": 0.6049721837043762,
      "image_url": "https://fileserver-g-p.sourcerer.tech/files/498d765f-c0a8-4c74-896a-2373aae5ba9a",
      "description": "RUSTIC WOOD CHEST, NATURAL"
    }
  ]
}
