In [17]:
import torch
import clip
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import requests
import json
import os
import requests
import time
import torch
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    db, 
    model
)

In [14]:
from pymilvus import MilvusClient

client = MilvusClient("milvus_demo.db")

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, preprocess = clip.load("ViT-B/16", device=device)

100%|███████████████████████████████████████| 335M/335M [00:26<00:00, 13.2MiB/s]


In [4]:
print(torch.cuda.is_available())

True


In [5]:
DATASET_DIRECTORY = "home/Semantic-Search-using-Vector-Database/Amazon"
sample_data_path = 'home/Semantic-Search-using-Vector-Database/Amazon/meta_Gift_Cards.jsonl'

In [6]:
def image_url_to_img(image_url, retries=5, delay=1):
    for _ in range(retries):
        try:
            response = requests.get(image_url, timeout=10)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content))
            return img
        except requests.exceptions.RequestException as e:
            print(f"Error fetching image from {image_url}: {e}")
            time.sleep(delay)  # Wait before retrying
    print(f"Failed to fetch image from {image_url} after {retries} retries.")
    return None

In [7]:
def generate_text_embeddings(text):
    text_features = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_embeddings = model.encode_text(text_features)
        text_embeddings /= text_embeddings.norm(dim = -1, keepdim= True)
        return text_embeddings.cpu().numpy()[0]

def generate_image_embeddings(img):
    img_preprocessed = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        image_embeddings = model.encode_image(img_preprocessed).float()
        image_embeddings /= image_embeddings.norm(dim = -1, keepdim = True)
    return image_embeddings[0].cpu().numpy()


In [None]:
img = image_url_to_img('https://m.media-amazon.com/images/I/612JNfob9nL._AC_UY218_.jpg')
embed = generate_image_embeddings(img)
print(embed)

In [9]:
def extract_img_urls(image_array):
    urls = []
    for item in image_array:
        # print(item)
        if 'hi_res' in item and item['hi_res']:
            urls.append(item['hi_res'])
        elif 'large' in item and item['large']:
            urls.append(item['large'])
        else:
            print(f"Key 'hi_res' and 'large' not found in item: {item}")
    return urls

In [None]:
# client = connections.connect("default", host="localhost", port="19530")

In [18]:
fields = [
    FieldSchema(name="product_id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title_vector", dtype=DataType.FLOAT_VECTOR, dim=512),  
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length = 500),
    FieldSchema(name="average_rating", dtype=DataType.FLOAT),
    FieldSchema(name="features", dtype=DataType.ARRAY, max_capacity = 100, element_type = DataType.VARCHAR, max_length = 1000),
    FieldSchema(name="description", dtype=DataType.ARRAY, max_capacity = 50, element_type = DataType.VARCHAR, max_length = 5000),
    FieldSchema(name="categories", dtype=DataType.ARRAY, max_capacity = 50, element_type = DataType.VARCHAR, max_length = 100),
    FieldSchema(name="price", dtype=DataType.FLOAT),
    FieldSchema(name="store", dtype=DataType.VARCHAR, max_length = 100),
    FieldSchema(name="main_category", dtype=DataType.VARCHAR, max_length = 50)
]

fields_images = [
    FieldSchema(name="image_id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="p_id", dtype=DataType.INT64),  # Foreign key to Products
    FieldSchema(name="image_vector", dtype=DataType.FLOAT_VECTOR, dim=512),
    FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=5000)
]

In [27]:
product_schema = CollectionSchema(fields, description="Products collection")
# products_collection = Collection(name = 'products', schema=product_schema)
products_collection = client.create_collection(collection_name = 'products', schema = product_schema)

In [28]:
image_schema = CollectionSchema(fields_images, description='Images Collection')
# images_collection = Collection(name = 'images', schema = image_schema)
images_collection = client.create_collection(collection_name = 'images', schema = image_schema)

In [29]:
index_params = {
    "metric_type": "COSINE",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

products_collection.create_index(field_name="title_vector", index_params = index_params)
products_collection.load()

images_collection.create_index(field_name="image_vector", index_params = index_params)
images_collection.load()

AttributeError: 'NoneType' object has no attribute 'create_index'

In [None]:
res = client.insert(collection_name="demo_collection", data=data)


In [None]:
i = 0
with open(sample_data_path) as file:
    for line in file:
        data = json.loads(line.strip())
        i += 1
        if i == 1100:
            break
        elif i < 50:
            continue
        # Generate title embedding
        if '"' in data['title']:
            continue
        title_embedding = generate_text_embeddings(data['title'])
        
        # Extract product information
        title = data['title']
        average_rating = data['average_rating']
        features = data['features']
        description = data['description']
        categories = data['categories']
        price = data['price']
        main_category = data['main_category']
        store = data['store']
        if not price:
            price = 0.0
        if not store:
            store = ''
        if not main_category:
            main_category = ''
        if not average_rating:
            average_rating = 0.0
        
        
        # Prepare product data
        product_data = {
            'title_vector': title_embedding.tolist(),
            'title': title,
            'average_rating': average_rating,
            'features': features,
            'description': description,
            'categories': categories,
            'price': price,
            'store': store,
            'main_category': main_category,
        }
        
        # Insert product data into the products collection
        products_collection.insert([product_data])
        # Flush to ensure data is written
        products_collection.flush()
        
        # Retrieve product ID
        product_ids = products_collection.query(expr=f'title == "{title}"', output_fields=["product_id"])
        product_id = product_ids[0]["product_id"]
        
        # Extract and process image URLs
        image_urls = extract_img_urls(data['images'])
        image_embeddings = []
        for url in image_urls:
            img = image_url_to_img(url)
            if img:
                image_embedding = generate_image_embeddings(img)
                image_embeddings.append((image_embedding, url))
        
        # Insert each image embedding with the associated product ID
        for image_embedding, image_url in image_embeddings:
            image_data = {
                'p_id': product_id,
                'image_vector': image_embedding.tolist(),
                'image_url': image_url,
            }
            images_collection.insert([image_data])
        
        # Flush to ensure data is written
        images_collection.flush()
        
        print(f"Inserted product ID: {product_id} with {len(image_embeddings)} images")