## Data Load

In [49]:
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import pandas as pd

dataset = load_dataset('rajuptvs/ecommerce_products_clip')

  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 100%|██████████| 477/477 [00:00<00:00, 858kB/s]
Downloading data: 100%|██████████| 48.3M/48.3M [00:04<00:00, 10.9MB/s]
Generating train split: 100%|██████████| 1913/1913 [00:00<00:00, 24350.48 examples/s]


In [187]:
df = pd.DataFrame.from_dict(dataset['train'])
print(df.shape)
df.head()

(1913, 8)


Unnamed: 0,image,Product_name,Price,colors,Pattern,Description,Other Details,Clipinfo
0,<PIL.JpegImagePlugin.JpegImageFile image mode=...,Men Regular Fit Color Block Casual Shirt,₹349,Dark Green and Black,Color Block,Blive High quality premium Full sleeves printe...,unknown,Dark Green and Black Color Block Men Regular ...
1,<PIL.JpegImagePlugin.JpegImageFile image mode=...,Men Regular Fit Printed Casual Shirt,₹959,Black,Printed,unknown,unknown,Black Printed Men Regular Fit Printed Casual S...
2,<PIL.JpegImagePlugin.JpegImageFile image mode=...,Men Regular Fit Solid Mandarin Collar Casual S...,₹339,Grey,Solid,unknown,unknown,Grey Solid Men Regular Fit Solid Mandarin Coll...
3,<PIL.JpegImagePlugin.JpegImageFile image mode=...,Men Regular Fit Solid Mandarin Collar Casual S...,₹339,Light Blue,Solid,Prime quality Full sleeves Plain Shirt direct ...,shirt for men look,Light Blue Solid Men Regular Fit Solid Mandari...
4,<PIL.JpegImagePlugin.JpegImageFile image mode=...,Men Regular Fit Solid Mandarin Collar Casual S...,₹339,Black,Solid,Prime quality Full sleeves Plain Shirt direct ...,shirt for men look,Black Solid Men Regular Fit Solid Mandarin Col...


In [189]:
# Drop duplicates for interested columns
dfi = df[['Product_name', 'colors', 'Pattern', 'Description',
       'Other Details', 'Clipinfo']].drop_duplicates().reset_index(drop=True)

# Load transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Create embeddings
dfi["clipinfo_embeddings"] = dfi["Clipinfo"].apply(lambda x: model.encode(x))

print(dfi.shape)
dfi.head()

(1121, 7)


Unnamed: 0,Product_name,colors,Pattern,Description,Other Details,Clipinfo,clipinfo_embeddings
0,Men Regular Fit Color Block Casual Shirt,Dark Green and Black,Color Block,Blive High quality premium Full sleeves printe...,unknown,Dark Green and Black Color Block Men Regular ...,"[-0.074169286, 0.07419989, -0.03574969, 0.0070..."
1,Men Regular Fit Printed Casual Shirt,Black,Printed,unknown,unknown,Black Printed Men Regular Fit Printed Casual S...,"[-0.047791626, 0.08828563, -0.08632753, 0.0710..."
2,Men Regular Fit Solid Mandarin Collar Casual S...,Grey,Solid,unknown,unknown,Grey Solid Men Regular Fit Solid Mandarin Coll...,"[-0.03997681, 0.0080082435, -0.010464183, 0.06..."
3,Men Regular Fit Solid Mandarin Collar Casual S...,Light Blue,Solid,Prime quality Full sleeves Plain Shirt direct ...,shirt for men look,Light Blue Solid Men Regular Fit Solid Mandari...,"[-0.053886674, 0.00586794, -0.0010366179, 0.07..."
4,Men Regular Fit Solid Mandarin Collar Casual S...,Black,Solid,Prime quality Full sleeves Plain Shirt direct ...,shirt for men look,Black Solid Men Regular Fit Solid Mandarin Col...,"[-0.054641847, 0.016245024, -0.0609381, 0.0851..."


### RAG POC

In [171]:
import boto3
from botocore.exceptions import ClientError
import psycopg2
import json

def get_secret(secret_name):

    region_name = "us-west-2"

    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    secret_string = get_secret_value_response['SecretString']

    secret_json_object = json.loads(secret_string)

    # print(json_object)
    
    
    return secret_json_object

In [172]:
# Get database secrets
database_secrets = get_secret("rds-coi-db-info")

dbhost = database_secrets['host']
dbport = database_secrets['port']
dbuser = database_secrets['username']
dbpass = database_secrets['password']

# Connect to database 
dbconn = psycopg2.connect(host=dbhost, user=dbuser, password=dbpass, port=dbport, connect_timeout=10)
dbconn.set_session(autocommit=True)

cur = dbconn.cursor()

# Add pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")

# Create test table
cur.execute("DROP TABLE IF  EXISTS products_test;")
cur.execute("""CREATE TABLE IF NOT EXISTS products_test(
  id bigserial primary key,
  colors text,
  pattern text,
  description text,
  clipinfo text,
  description_embeddings vector(384)
);""")

In [177]:
# Insert embeddings into table
for index, row in dfi.iterrows():
    cur.execute(
        """INSERT INTO products_test (colors, pattern, description, clipinfo, description_embeddings)
  VALUES (%s, %s, %s, %s, %s);""",
        (row["colors"], row["Pattern"],row["Description"], row["Clipinfo"], str(list(row['clipinfo_embeddings']))),
    )

# Create index for embedings    
cur.execute("""CREATE INDEX ON products_test
  USING ivfflat (description_embeddings vector_l2_ops) WITH (lists = 50);""")
cur.execute("VACUUM ANALYZE products_test;")

In [190]:
user_input = "black shirt"
user_input_embedding = model.encode(user_input)

cur.execute("""SELECT id, colors, pattern, description, clipinfo
  FROM products_test
  ORDER BY description_embeddings <-> %s
  limit 3;""", (str(list(user_input_embedding)),))

In [191]:
cur.fetchall()

[(483,
  'Black',
  'Printed',
  'Black and Red printed casual shirt, has a spread collar, long sleeves, button placket, and curved hem. This shirt from Roadster will give you the perfect amount of comfort and durability.  This black piece is the perfect fall shirt when you put it with a pair of jeans and a lightweight jacket to enjoy the crisper weather.',
  'Black Printed Men Regular Fit Printed Casual Shirt'),
 (2,
  'Black',
  'Printed',
  'unknown',
  'Black Printed Men Regular Fit Printed Casual Shirt'),
 (160,
  'Black',
  'Printed',
  'Printed Casual Shirt',
  'Black Printed Men Regular Fit Printed Button Down Collar Casual Shirt')]

In [192]:
cur.close()
dbconn.close()