<a href="https://colab.research.google.com/github/Yash-Kavaiya/AI-Accelerate-Retail-Agent-Teams/blob/main/product_image_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install git+https://github.com/facebookresearch/ImageBind.git
!pip install elasticsearch pandas pillow requests


Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting git+https://github.com/facebookresearch/ImageBind.git
  Cloning https://github.com/facebookresearch/ImageBind.git to /tmp/pip-req-build-th8ylt9o
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/ImageBind.git /tmp/pip-req-build-th8ylt9o
  Resolved https://github.com/facebookresearch/ImageBind.git to commit cfee5753e8e0b7a9df8aeb19c493e8facd81a5f2
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pytorchvideo@ git+https://github.com/facebookresearch/pytorchvideo.git@6cdc929315aab1b5674b6dcf73b16ec99147735f (from imagebind==0.1.0)
  Cloning https://github.com/facebookresearch/pytorchvideo.git (to revision 6cdc929315aab1b5674b6dcf73b16ec99147735f) to /tmp/pip-install-yjdw7e65/pytorchvideo_f81c7aeb30a348768f150bd90f89a8d8
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/pytorchvideo.git /tmp/pip-install-yjdw7e65/pytorchvideo_

In [None]:
import torch
import pandas as pd
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from tqdm import tqdm
import os

# Configuration
ES_CLOUD_ID = "my-elasticsearch-project-e0ae1f.es.us-central1.gcp.elastic.cloud:443"
ES_API_KEY = "cTBuSkVKb0JDNDR2WWhEajVsWjQ6UkhPNEtaZkJTa2VoWTEtVmxtbW8wUQ=="
INDEX_NAME = "imagebind-embeddings"
BATCH_SIZE = 16  # Reduce if you hit OOM errors
MAX_RETRIES = 3

# Check GPU availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Initialize Elasticsearch client
client = Elasticsearch(
    f"https://{ES_CLOUD_ID}",
    api_key=ES_API_KEY,
    request_timeout=120
)

# Test connection
if client.ping():
    print("Successfully connected to Elasticsearch!")
else:
    raise Exception("Could not connect to Elasticsearch")

# Load ImageBind model
print("Loading ImageBind model...")
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
print("Model loaded successfully!")

def create_index():
    mapping = {
        "mappings": {
            "properties": {
                "filename": {"type": "keyword"},
                "image_url": {"type": "text"},
                "gender": {"type": "keyword"},
                "masterCategory": {"type": "keyword"},
                "subCategory": {"type": "keyword"},
                "articleType": {"type": "keyword"},
                "baseColour": {"type": "keyword"},
                "season": {"type": "keyword"},
                "year": {"type": "integer"},
                "usage": {"type": "keyword"},
                "productDisplayName": {"type": "text"},
                "image_embedding": {
                    "type": "dense_vector",
                    "dims": 1024,
                    "index": True,
                    "similarity": "cosine"
                }
            }
        }
    }
    if client.indices.exists(index=INDEX_NAME):
        client.indices.delete(index=INDEX_NAME)
        print(f"Deleted existing index: {INDEX_NAME}")
    client.indices.create(index=INDEX_NAME, body=mapping)
    print(f"Created index: {INDEX_NAME}")

def load_images_csv(path):
    try:
        df = pd.read_csv(path, delimiter=',', on_bad_lines='skip')
        if 'filename' in df.columns and 'link' in df.columns and not df.empty:
            return df
    except Exception:
        pass
    # Recovery for pair-per-row format
    fixed_rows = []
    with open(path, 'r', encoding='utf-8') as f:
        lines = [line.strip() for line in f if line.strip()]
        # Ignore header if present
        # If lines[0] doesn't end with ".jpg" or look like an image url, skip as header
        start_idx = 0
        if not (lines[0].endswith(".jpg") or lines[0].startswith("http")):
            start_idx = 1
        for i in range(start_idx, len(lines)-1, 2):
            fname = lines[i]
            link = lines[i+1]
            fixed_rows.append([fname, link])
    df = pd.DataFrame(fixed_rows, columns=['filename', 'link'])
    return df

def load_styles_csv(path):
    return pd.read_csv(path, delimiter=',', on_bad_lines='skip')

def load_data():
    print("Loading CSV files...")
    images_df = load_images_csv('/content/images.csv')
    styles_df = load_styles_csv('/content/styles.csv')
    images_df['id'] = images_df['filename'].str.replace('.jpg', '').astype(int)
    merged_df = images_df.merge(styles_df, on='id', how='inner')
    print(f"Loaded {len(merged_df)} records")
    return merged_df

def download_image(url, retries=MAX_RETRIES):
    for attempt in range(retries):
        try:
            response = requests.get(url, timeout=15)
            if response.status_code == 200:
                img = Image.open(BytesIO(response.content))
                return img.convert('RGB')
        except Exception as e:
            if attempt == retries - 1:
                print(f"Failed after {retries} attempts: {url[:50]}...")
            continue
    return None

def generate_embeddings_batch(image_urls, batch_indices):
    valid_images = []
    valid_indices = []
    temp_paths = []
    for idx, url in enumerate(image_urls):
        img = download_image(url)
        if img is not None:
            temp_path = f"/tmp/temp_img_{batch_indices[idx]}.jpg"
            img.save(temp_path)
            valid_images.append(temp_path)
            valid_indices.append(idx)
            temp_paths.append(temp_path)
    if not valid_images:
        return None, []
    try:
        inputs = {
            ModalityType.VISION: data.load_and_transform_vision_data(valid_images, device)
        }
        with torch.no_grad():
            embeddings = model(inputs)
            vision_embeddings = embeddings[ModalityType.VISION]
        embeddings_np = vision_embeddings.cpu().numpy()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return embeddings_np, valid_indices
    except Exception as e:
        print(f"Error generating embeddings: {e}")
        return None, []
    finally:
        # Clean up temp files
        for path in temp_paths:
            try:
                if os.path.exists(path):
                    os.remove(path)
            except:
                pass

def index_documents(df):
    print("Generating embeddings and indexing to Elasticsearch...")
    actions = []
    successful = 0
    failed = 0
    total_batches = (len(df) + BATCH_SIZE - 1) // BATCH_SIZE
    for batch_idx in tqdm(range(0, len(df), BATCH_SIZE), total=total_batches):
        batch = df.iloc[batch_idx:batch_idx+BATCH_SIZE]
        batch_indices = list(range(batch_idx, min(batch_idx+BATCH_SIZE, len(df))))
        embeddings, valid_indices = generate_embeddings_batch(batch['link'].tolist(), batch_indices)
        if embeddings is None:
            failed += len(batch)
            continue
        for idx, embedding in zip(valid_indices, embeddings):
            row = batch.iloc[idx]
            doc = {
                "_index": INDEX_NAME,
                "_id": str(row['id']),
                "_source": {
                    "filename": row['filename'],
                    "image_url": row['link'],
                    "gender": row.get('gender', ''),
                    "masterCategory": row.get('masterCategory', ''),
                    "subCategory": row.get('subCategory', ''),
                    "articleType": row.get('articleType', ''),
                    "baseColour": row.get('baseColour', ''),
                    "season": row.get('season', ''),
                    "year": int(row.get('year', 0)) if pd.notna(row.get('year')) else 0,
                    "usage": row.get('usage', ''),
                    "productDisplayName": row.get('productDisplayName', ''),
                    "image_embedding": embedding.tolist()
                }
            }
            actions.append(doc)
            successful += 1
        if len(actions) >= 50:
            try:
                success_count, errors = bulk(client, actions, raise_on_error=False)
                if errors:
                    print(f"Bulk index errors: {len(errors)}")
                actions = []
            except Exception as e:
                print(f"Error during bulk indexing: {e}")
                failed += len(actions)
                actions = []
    if actions:
        try:
            bulk(client, actions, raise_on_error=False)
        except Exception as e:
            print(f"Error indexing final batch: {e}")
    print(f"\n{'='*50}")
    print(f"Indexing completed!")
    print(f"Successfully indexed: {successful}")
    print(f"Failed: {failed}")
    print(f"{'='*50}")

def search_similar_images(query_text=None, query_image_path=None, top_k=10, filters=None):
    if query_text:
        inputs = {
            ModalityType.TEXT: data.load_and_transform_text([query_text], device)
        }
    elif query_image_path:
        inputs = {
            ModalityType.VISION: data.load_and_transform_vision_data([query_image_path], device)
        }
    else:
        raise ValueError("Provide either query_text or query_image_path")
    with torch.no_grad():
        embeddings = model(inputs)
        if query_text:
            query_embedding = embeddings[ModalityType.TEXT][0].cpu().numpy()
        else:
            query_embedding = embeddings[ModalityType.VISION][0].cpu().numpy()
    search_query = {
        "knn": {
            "field": "image_embedding",
            "query_vector": query_embedding.tolist(),
            "k": top_k,
            "num_candidates": 100
        },
        "_source": ["filename", "image_url", "productDisplayName", "baseColour", "gender", "articleType"]
    }
    if filters:
        filter_clauses = []
        for field, value in filters.items():
            filter_clauses.append({"term": {field: value}})
        search_query["knn"]["filter"] = {"bool": {"must": filter_clauses}}
    results = client.search(index=INDEX_NAME, body=search_query)
    print(f"\n{'='*60}")
    print(f"Search Query: {query_text if query_text else query_image_path}")
    if filters:
        print(f"Filters: {filters}")
    print(f"{'='*60}\n")
    print(f"Top {len(results['hits']['hits'])} similar images:\n")
    for i, hit in enumerate(results['hits']['hits'], 1):
        source = hit['_source']
        print(f"{i}. Score: {hit['_score']:.4f}")
        print(f"   Product: {source.get('productDisplayName', 'N/A')}")
        print(f"   Type: {source.get('articleType', 'N/A')} | Colour: {source.get('baseColour', 'N/A')}")
        print(f"   URL: {source['image_url'][:70]}...")
        print()
    return results

def main():
    create_index()
    df = load_data()
    # df = df.head(100)  # Uncomment for quick test
    index_documents(df)
    client.indices.refresh(index=INDEX_NAME)
    print("\n" + "="*60)
    print("EXAMPLE SEARCHES")
    print("="*60)
    print("\n1. Text Search: 'blue shirt for men'")
    search_similar_images(query_text="blue shirt for men", top_k=5)
    print("\n2. Filtered Search: 'casual watch' for Women")
    search_similar_images(query_text="casual watch", top_k=5, filters={"gender": "Women"})
    stats = client.indices.stats(index=INDEX_NAME)
    doc_count = stats['indices'][INDEX_NAME]['primaries']['docs']['count']
    print(f"\n{'='*60}")
    print(f"Total documents indexed: {doc_count}")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()


Using device: cuda:0
GPU: Tesla T4
GPU Memory: 15.83 GB
Successfully connected to Elasticsearch!
Loading ImageBind model...
Model loaded successfully!
Deleted existing index: imagebind-embeddings
Created index: imagebind-embeddings
Loading CSV files...
Loaded 44424 records
Generating embeddings and indexing to Elasticsearch...


  0%|          | 8/2777 [02:04<11:53:58, 15.47s/it]

Bulk index errors: 2


  1%|          | 20/2777 [05:17<14:03:29, 18.36s/it]

Bulk index errors: 2


  1%|          | 32/2777 [08:37<13:13:58, 17.35s/it]

Bulk index errors: 2


  1%|▏         | 36/2777 [09:36<11:53:21, 15.62s/it]

Bulk index errors: 1


  2%|▏         | 44/2777 [11:36<11:19:48, 14.92s/it]

Bulk index errors: 1


  2%|▏         | 60/2777 [16:11<12:53:22, 17.08s/it]

Bulk index errors: 4


  3%|▎         | 76/2777 [20:40<11:57:45, 15.94s/it]

Bulk index errors: 2


  3%|▎         | 80/2777 [21:43<11:56:41, 15.94s/it]

Bulk index errors: 1


  3%|▎         | 84/2777 [22:46<12:05:05, 16.15s/it]

Bulk index errors: 1


  3%|▎         | 96/2777 [25:57<11:44:18, 15.76s/it]

Bulk index errors: 2


  4%|▎         | 100/2777 [27:10<13:46:36, 18.53s/it]

Bulk index errors: 1


  4%|▍         | 116/2777 [31:18<11:43:04, 15.85s/it]

Bulk index errors: 2


  5%|▍         | 132/2777 [35:37<12:02:43, 16.39s/it]

Bulk index errors: 1


  5%|▍         | 136/2777 [36:51<13:17:24, 18.12s/it]

Bulk index errors: 1


  6%|▌         | 164/2777 [44:54<13:25:30, 18.50s/it]

Bulk index errors: 3


  6%|▋         | 176/2777 [48:37<13:19:56, 18.45s/it]

Bulk index errors: 1


  7%|▋         | 204/2777 [57:49<14:38:59, 20.50s/it]

Bulk index errors: 1


  7%|▋         | 208/2777 [59:07<14:11:31, 19.89s/it]

Bulk index errors: 2


  8%|▊         | 220/2777 [1:02:33<12:05:35, 17.03s/it]

Bulk index errors: 3


  8%|▊         | 224/2777 [1:03:41<12:10:27, 17.17s/it]

Bulk index errors: 1


  8%|▊         | 228/2777 [1:04:49<12:31:54, 17.70s/it]

Bulk index errors: 1


  8%|▊         | 232/2777 [1:05:54<11:58:06, 16.93s/it]

Bulk index errors: 1


  8%|▊         | 236/2777 [1:06:56<10:52:11, 15.40s/it]

Bulk index errors: 1


  9%|▊         | 240/2777 [1:08:02<11:41:55, 16.60s/it]

Bulk index errors: 2


  9%|▉         | 244/2777 [1:09:07<11:15:39, 16.00s/it]

Bulk index errors: 1


  9%|▉         | 256/2777 [1:12:28<11:42:18, 16.71s/it]

Bulk index errors: 2


 10%|▉         | 264/2777 [1:14:46<11:45:19, 16.84s/it]

Bulk index errors: 1


 10%|▉         | 268/2777 [1:16:07<13:40:40, 19.63s/it]

Bulk index errors: 2


 11%|█         | 292/2777 [1:22:31<10:51:38, 15.73s/it]

Bulk index errors: 3


 11%|█         | 308/2777 [1:26:59<11:30:52, 16.79s/it]

Bulk index errors: 1


 12%|█▏        | 320/2777 [1:30:13<11:16:16, 16.51s/it]

Bulk index errors: 1


 12%|█▏        | 328/2777 [1:32:27<11:33:56, 17.00s/it]

Bulk index errors: 2


 12%|█▏        | 336/2777 [1:34:33<10:13:51, 15.09s/it]

Bulk index errors: 1


 13%|█▎        | 356/2777 [1:39:55<11:03:28, 16.44s/it]

Bulk index errors: 1


 13%|█▎        | 360/2777 [1:41:01<10:58:35, 16.35s/it]

Bulk index errors: 1


 13%|█▎        | 368/2777 [1:43:10<10:31:36, 15.73s/it]

Bulk index errors: 1


 14%|█▎        | 376/2777 [1:45:33<11:39:33, 17.48s/it]

Bulk index errors: 1


 14%|█▍        | 384/2777 [1:47:52<11:44:21, 17.66s/it]

Bulk index errors: 1


 14%|█▍        | 388/2777 [1:49:03<11:45:57, 17.73s/it]

Bulk index errors: 2


 14%|█▍        | 392/2777 [1:50:17<11:45:55, 17.76s/it]

Bulk index errors: 1


 15%|█▍        | 408/2777 [1:54:41<11:32:16, 17.53s/it]

Bulk index errors: 3


 15%|█▌        | 418/2777 [1:57:10<9:54:48, 15.13s/it]

Failed after 3 attempts: undefined...


 15%|█▌        | 428/2777 [2:00:01<11:31:36, 17.67s/it]

Bulk index errors: 2


 16%|█▌        | 432/2777 [2:01:05<10:41:13, 16.41s/it]

Bulk index errors: 1


 17%|█▋        | 460/2777 [2:08:39<11:55:51, 18.54s/it]

Bulk index errors: 1


 18%|█▊        | 492/2777 [2:17:59<11:27:20, 18.05s/it]

Bulk index errors: 2


 18%|█▊        | 500/2777 [2:20:23<11:21:15, 17.95s/it]

Bulk index errors: 1


 18%|█▊        | 512/2777 [2:23:57<10:54:21, 17.33s/it]

Bulk index errors: 1


 19%|█▉        | 532/2777 [2:29:35<10:56:45, 17.55s/it]

Bulk index errors: 2


 19%|█▉        | 536/2777 [2:30:37<10:01:03, 16.09s/it]

Bulk index errors: 1


 19%|█▉        | 540/2777 [2:31:49<11:04:43, 17.83s/it]

Bulk index errors: 1


 20%|██        | 556/2777 [2:36:09<9:56:03, 16.10s/it]

Bulk index errors: 1


 21%|██▏       | 596/2777 [2:47:22<10:20:40, 17.08s/it]

Bulk index errors: 1


 22%|██▏       | 604/2777 [2:49:36<10:06:18, 16.74s/it]

Bulk index errors: 1


 22%|██▏       | 616/2777 [2:53:11<9:55:46, 16.54s/it]

Bulk index errors: 1


 22%|██▏       | 624/2777 [2:55:34<11:08:13, 18.62s/it]

Bulk index errors: 1


 23%|██▎       | 636/2777 [2:58:55<9:36:45, 16.16s/it]

Bulk index errors: 1


 23%|██▎       | 648/2777 [3:01:47<8:21:44, 14.14s/it]

Bulk index errors: 1


 24%|██▍       | 664/2777 [3:05:49<8:25:52, 14.36s/it]

Bulk index errors: 1


 24%|██▍       | 672/2777 [3:07:44<8:00:24, 13.69s/it]

Bulk index errors: 1


 24%|██▍       | 676/2777 [3:08:41<7:53:40, 13.53s/it]

Bulk index errors: 1


 25%|██▌       | 696/2777 [3:13:20<8:23:43, 14.52s/it]

Bulk index errors: 1


 26%|██▌       | 712/2777 [3:17:13<7:51:52, 13.71s/it]

Bulk index errors: 1


 26%|██▌       | 720/2777 [3:19:17<9:09:26, 16.03s/it]

Bulk index errors: 1


 26%|██▌       | 728/2777 [3:21:35<10:11:04, 17.89s/it]

Bulk index errors: 1


 26%|██▋       | 732/2777 [3:22:43<9:44:35, 17.15s/it]

Bulk index errors: 1


 27%|██▋       | 744/2777 [3:26:03<9:24:21, 16.66s/it]

Bulk index errors: 2


 28%|██▊       | 768/2777 [3:32:43<9:25:51, 16.90s/it]

Bulk index errors: 1


 29%|██▊       | 792/2777 [3:39:33<9:11:01, 16.66s/it]

Bulk index errors: 2


 29%|██▉       | 800/2777 [3:41:49<9:01:37, 16.44s/it]

Bulk index errors: 1


 29%|██▉       | 812/2777 [3:45:16<9:32:29, 17.48s/it]

Bulk index errors: 1


 29%|██▉       | 816/2777 [3:46:30<10:10:19, 18.67s/it]

Bulk index errors: 1


 30%|███       | 844/2777 [3:54:32<9:05:32, 16.93s/it]

Bulk index errors: 1


 31%|███       | 848/2777 [3:55:42<9:07:58, 17.04s/it]

Bulk index errors: 1


 31%|███       | 852/2777 [3:56:45<8:29:33, 15.88s/it]

Bulk index errors: 1


 31%|███       | 856/2777 [3:57:48<8:36:37, 16.14s/it]

Bulk index errors: 1


 31%|███       | 860/2777 [3:58:57<9:00:35, 16.92s/it]

Bulk index errors: 1


 31%|███       | 861/2777 [3:59:13<8:59:57, 16.91s/it]