In [1]:
import torch
import clip
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import requests
import json

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
sample_data_path = 'D:\Datasets\Amazon\\toys.json'

In [128]:
def generate_text_embeddings(text):
    text_features = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_embeddings = model.encode_text(text_features).cpu().numpy()[0]
        return text_embeddings

def generate_image_embeddings(image_url):
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    img_preprocessed = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        image_embeddings = model.encode_image(img_preprocessed).float()
    # print(image_embeddings.shape)
    return image_embeddings[0].cpu().numpy()
    
    # plt.imshow(img)
    # plt.show()


In [129]:
e = generate_image_embeddings('https://m.media-amazon.com/images/I/612JNfob9nL._AC_UY218_.jpg')
print(e)

[-2.08610132e-01  4.08630490e-01  1.46234974e-01 -4.42701690e-02
 -3.56850028e-01 -6.14461899e-02 -5.60767472e-01  3.61060292e-01
 -2.53259391e-02  4.98774171e-01 -3.23381603e-01 -2.67722845e-01
  8.79097581e-01  1.68361962e-01  2.90638149e-01  2.56888747e-01
 -3.88589621e-01  2.00827450e-01 -2.49647021e-01 -5.81927657e-01
  4.18553442e-01 -2.81286955e-01 -3.60215187e-01 -5.06011784e-01
 -4.50602770e-02  1.44954160e-01  1.28729761e-01 -7.32398629e-01
 -4.11395282e-02  3.99572514e-02  5.24300635e-01  9.42827091e-02
  4.17828828e-01 -4.06405181e-02 -5.97974539e-01 -1.40750557e-02
 -2.04334721e-01 -5.71300387e-02 -3.09228972e-02 -4.96517181e-01
  8.62757936e-02  3.49788308e-01  1.35576487e-01  4.89378840e-01
  6.75252751e-02 -1.02247715e+00 -1.23937964e-01  7.05363750e-01
  5.56168735e-01 -2.24395603e-01  8.05873156e-01  4.19138968e-01
 -2.93181121e-01  6.12093210e-01 -2.85308540e-01  7.82279596e-02
 -2.31565610e-01  6.74251616e-02  1.43902034e-01  5.97470582e-01
 -9.89833772e-02 -3.37480

In [20]:
text_embeddings = []
with open(sample_data_path) as file:
    products = json.load(file)
    for product in products:
        # print(product['Reviews'])
        image_embedding = generate_image_embeddings(product['Image'])
        print(image_embedding)
        # title = product['Title']
        # title_embeddings = generate_text_embeddings(title)
        # text_embeddings.append(title_embeddings)

tensor([ 1.4349e-01, -1.4771e-01,  2.0845e-01, -3.1182e-02,  4.2192e-01,
        -1.3866e-01,  2.3576e-01,  1.4752e-01, -3.6201e-01,  1.6100e-02,
         4.2032e-02,  1.8565e-01,  4.5107e-01, -2.1406e-01,  6.5625e-02,
        -2.1122e-02,  5.2586e-01,  8.3030e-01, -1.9966e-01, -4.0490e-01,
        -5.2840e-01,  2.7315e-01,  3.0629e-01,  1.0364e-01, -6.8649e-01,
         7.2678e-02, -1.2261e-02,  5.2864e-02, -2.4551e-01, -3.6119e-03,
         2.9616e-01,  1.0736e-01, -8.0298e-03, -2.0013e-02, -1.0055e-01,
        -5.5729e-01,  7.1272e-02,  8.0928e-01, -1.0143e-01,  1.3242e+00,
        -1.8550e-01,  2.4193e-01,  2.5273e-01,  1.8669e-01, -2.0155e-01,
         1.6852e-01, -6.8699e-02,  1.2017e-01, -1.2170e-01,  3.9902e-02,
        -2.5192e-01, -5.7535e-02, -9.8298e-02, -3.6264e-02,  5.0633e-01,
        -2.6528e-03, -2.1268e-01, -8.3906e-02,  7.4107e-01,  2.6166e-01,
         5.9199e-01, -6.6694e-01, -3.6999e-01, -4.9910e-03,  7.6485e-01,
        -6.7196e-02,  1.4083e-01,  5.3839e-01,  2.7

In [6]:
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    db
)

In [51]:
client = connections.connect("default", host="localhost", port="19530")

In [29]:
db.list_database()

['default', 'Products']

In [47]:
# Defining schema
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title_vector", dtype=DataType.FLOAT_VECTOR, dim=512),  
    FieldSchema(name="image_vector", dtype=DataType.FLOAT_VECTOR, dim=512),  
    FieldSchema(name="position", dtype=DataType.INT64),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length = 500),
    FieldSchema(name="product_link", dtype=DataType.VARCHAR, max_length = 5000),
    FieldSchema(name="price", dtype=DataType.VARCHAR, max_length = 15),
    FieldSchema(name="reviews", dtype=DataType.INT64),
    FieldSchema(name="rating", dtype=DataType.VARCHAR, max_length = 30),
    FieldSchema(name="sponsored", dtype=DataType.VARCHAR, max_length = 50)
]

In [48]:
utility.drop_collection('product_collection')

In [49]:
collection_name = 'product_collection'
schema = CollectionSchema(fields, description="Product search")
collection = Collection(name = collection_name, schema=schema)

In [50]:
for product in products:
    title_embedding = generate_text_embeddings(product["Title"])
    image_embedding = generate_image_embeddings(product["Image"])
    
    # Remove commas from the reviews field and convert to integer
    reviews_cleaned = int(product["Reviews"].replace(",", ""))
    positions_cleaned = int(product['Position'].replace(',', ''))

    entity = {
        "title_vector": title_embedding.tolist(),
        "image_vector": image_embedding.tolist(),
        "position": positions_cleaned,
        "title": product["Title"],
        "product_link": product["Product link"],
        "price": product["Price"],
        "reviews": reviews_cleaned,
        "rating": product["Rating"],
        "sponsored": product["Sponsored"]
    }
    


In [None]:
status, ids = collection.insert([entity])


In [68]:
# collection.drop_index(index_name = 'title_vector')

In [69]:
index_params = {
    "metric_type": "COSINE",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

In [70]:
collection.create_index(field_name="title_vector", index_params= index_params)
collection.create_index(field_name = "image_vector", index_params = index_params)

Status(code=0, message=)

In [71]:
collection.load()

In [72]:
search_params = {
    "metric_type": "COSINE", 
    "offset": 0, 
    "ignore_growing": False, 
    "params": {"nprobe": 10}
}

In [130]:
# sample_query = products[1]['Title']
sample_query = products[1]['Image']
# sample_query_embedding = generate_text_embeddings(sample_query)
sample_query_embedding = generate_image_embeddings(sample_query)
sample_query_embedding

array([-7.11805940e-01, -4.88294363e-02,  3.30361724e-01,  2.18768850e-01,
        1.27947137e-01,  1.02498323e-01,  1.86483592e-01,  6.57398775e-02,
        1.82011008e-01,  1.76132232e-01,  3.03569198e-01, -5.66980481e-01,
       -1.45702362e-01, -6.45787418e-02,  2.16639698e-01,  1.31733939e-01,
       -6.89856231e-01, -1.97257280e-01, -1.79133296e-01,  8.63223076e-02,
       -6.48642004e-01, -2.63700396e-01,  2.93147087e-01,  4.38241363e-02,
        8.10087621e-02,  2.22875297e-01, -3.14478755e-01,  3.99146676e-01,
        4.84298170e-03,  3.14859897e-01,  1.40053779e-02, -3.39155197e-01,
        1.23245724e-01, -3.54683131e-01, -3.18638176e-01, -6.34527922e-01,
       -2.28771135e-01,  2.39075616e-01, -6.80975020e-02,  8.76162291e-01,
        1.57782972e-01,  1.91779166e-01,  2.00409859e-01,  5.16777217e-01,
        2.45569929e-01, -5.45612693e-01, -8.59006941e-02,  2.32251883e-02,
        4.35072601e-01,  4.87945974e-04,  1.95686996e-01, -2.96398550e-01,
        4.53765631e-01,  

In [131]:
sample_query

'https://m.media-amazon.com/images/I/81sv3I05wCL._AC_UL320_.jpg'

In [132]:
results = collection.search(
    data=[sample_query_embedding], 
    anns_field="image_vector", 
    # the sum of `offset` in `param` and `limit` 
    # should be less than 16384.
    param=search_params,
    limit=10,
    expr=None,
    # set the names of the fields you want to 
    # retrieve from the search result.
    output_fields=['title','price'],
    consistency_level="Strong"
)


In [84]:
results[0].ids

[450687289225645638,
 450687289225645662,
 450687289225645648,
 450687289225645646,
 450687289225645666,
 450687289225645664,
 450687289225645674,
 450687289225645644,
 450687289225645654,
 450687289225645676]

In [85]:
results[0].distances

[1.0000001192092896,
 0.5990191698074341,
 0.5963578224182129,
 0.5753413438796997,
 0.5619478225708008,
 0.5553554892539978,
 0.5465940237045288,
 0.538995087146759,
 0.5369709730148315,
 0.520717442035675]

In [133]:
hit = results[0][0]
hit.entity.get('title')

'Play Purse for Little Girls, 35PCS Toddler Purse with Pretend Makeup for Toddlers, Princess Toys Includes Handbag, Phone, Wallet, Camera, Keys, Kids Purse Birthday Gift for Girls Age 3 4 5 6+'

In [134]:
for result in results[0]:
    # print(result[0].entity.get('title'),'\n')
    print(result.entity)

id: 450687289225645640, distance: 0.9999998807907104, entity: {'title': 'Play Purse for Little Girls, 35PCS Toddler Purse with Pretend Makeup for Toddlers, Princess Toys Includes Handbag, Phone, Wallet, Camera, Keys, Kids Purse Birthday Gift for Girls Age 3 4 5 6+', 'price': '$17.99'}
id: 450687289225645650, distance: 0.6877689957618713, entity: {'title': 'Kids Smart Phone for Girls, Christmas Birthday Gifts for Girls Age 3-10 Kids Toys Cell Phone, 2.8" Touchscreen Toddler Learning Play Toy Phone with Dual Camera, Game, Music Player, 8G SD Card (Purple)', 'price': '$35.90'}
id: 450687289225645664, distance: 0.6292790174484253, entity: {'title': 'Fidget Toys, 120 Pack Fidgets Set Stocking Stuffers for Kids Party Favors Autism Sensory Toy Bulk Adults Kids Boys Girls Teens Stress Autistic ADHD Anxiety Carnival Treasure Classroom Prizes', 'price': '$14.99'}
id: 450687289225645658, distance: 0.6275117993354797, entity: {'title': 'Sloosh Bubble Lawn Mower Toddler Toys - Kids Toys Bubble Mach