In [1]:
import numpy as np
import pandas as pd
import chromadb
import openai
import json
import ast

from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from sentence_transformers import CrossEncoder, util
from IPython.core.display import display, HTML
from sentence_transformers import SentenceTransformer


  from IPython.core.display import display, HTML


## 1. Loading the dataset

In [3]:
data = pd.read_csv('Fashion Dataset v2.csv')
data.head()

Unnamed: 0,p_id,name,products,price,colour,brand,img,ratingCount,avg_rating,description,p_attributes
0,17048614,Khushal K Women Black Ethnic Motifs Printed Ku...,"Kurta, Palazzos, Dupatta",5099,Black,Khushal K,http://assets.myntassets.com/assets/images/170...,4522.0,4.418399,Black printed Kurta with Palazzos with dupatta...,"{'Add-Ons': 'NA', 'Body Shape ID': '443,333,32..."
1,16524740,InWeave Women Orange Solid Kurta with Palazzos...,"Kurta, Palazzos, Floral Print Dupatta",5899,Orange,InWeave,http://assets.myntassets.com/assets/images/165...,1081.0,4.119334,Orange solid Kurta with Palazzos with dupatta<...,"{'Add-Ons': 'NA', 'Body Shape ID': '443,333,32..."
2,16331376,Anubhutee Women Navy Blue Ethnic Motifs Embroi...,"Kurta, Trousers, Dupatta",4899,Navy Blue,Anubhutee,http://assets.myntassets.com/assets/images/163...,1752.0,4.16153,Navy blue embroidered Kurta with Trousers with...,"{'Add-Ons': 'NA', 'Body Shape ID': '333,424', ..."
3,14709966,Nayo Women Red Floral Printed Kurta With Trous...,"Kurta, Trouser, Dupatta",3699,Red,Nayo,http://assets.myntassets.com/assets/images/147...,4113.0,4.088986,Red printed kurta with trouser and dupatta<br>...,"{'Add-Ons': 'NA', 'Body Shape ID': '333,424', ..."
4,11056154,AHIKA Women Black & Green Printed Straight Kurta,Kurta,1350,Black,AHIKA,http://assets.myntassets.com/assets/images/110...,21274.0,3.978377,"Black and green printed straight kurta, has a ...","{'Body Shape ID': '424', 'Body or Garment Size..."


## 2. Performing EDA on the data

Checking the number of rows

In [6]:
len(data)

14214

Examining data types

In [8]:
data.dtypes

p_id              int64
name             object
products         object
price             int64
colour           object
brand            object
img              object
ratingCount     float64
avg_rating      float64
description      object
p_attributes     object
dtype: object

**Checking for the number of unique values**

In [10]:
data.nunique()

p_id            14214
name            13873
products          910
price            1209
colour             50
brand            1022
img             14214
ratingCount       829
avg_rating       2367
description     14181
p_attributes    13089
dtype: int64

**Checking data for null values**

In [12]:
data.isna().sum()

p_id               0
name               0
products           0
price              0
colour             0
brand              0
img                0
ratingCount     7684
avg_rating      7684
description        0
p_attributes       0
dtype: int64

Fixing the 'ratingCount' and 'avg_rating' columns.

The rating is probably null because these could be new products and no one has rated them yet.

In [14]:
data.loc[data['ratingCount'].isna(), 'ratingCount'] = 0
data.loc[data['avg_rating'].isna(), 'avg_rating'] = 0
data.isna().sum()

p_id            0
name            0
products        0
price           0
colour          0
brand           0
img             0
ratingCount     0
avg_rating      0
description     0
p_attributes    0
dtype: int64

Examining data types

In [16]:
data.dtypes

p_id              int64
name             object
products         object
price             int64
colour           object
brand            object
img              object
ratingCount     float64
avg_rating      float64
description      object
p_attributes     object
dtype: object

In [17]:
data = data.convert_dtypes()

In [18]:
data.dtypes

p_id                     Int64
name            string[python]
products        string[python]
price                    Int64
colour          string[python]
brand           string[python]
img             string[python]
ratingCount              Int64
avg_rating             Float64
description     string[python]
p_attributes    string[python]
dtype: object

EDA is now complete

## 3. Setting up ChromaDB

In [21]:
# Set the API key
filepath = "../"

with open(filepath + "OPENAI_API_Key.txt", "r") as f:
  openai.api_key = ''.join(f.readlines())

In [22]:
# Define the path where chroma collections will be stored

chroma_data_path = 'chromadb_data'

In [23]:
# Call PersistentClient()
client = chromadb.PersistentClient(path=chroma_data_path)

In [24]:
# Set up the embedding function using the OpenAI embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')

In [25]:
# Initialise a collection in chroma and pass the embedding_function to it so that it used OpenAI embeddings to embed the documents

#fashion_coll = client.get_or_create_collection(name='myntra_fashion_data', embedding_function=embedding_function)
client.delete_collection(name='myntra_fashion_data')
fashion_coll = client.get_or_create_collection(name='myntra_fashion_data')

#### Creating the collection

Combining attributes **color**, **brand**, **price** and **rating** with metadata so that a metadata search can be performed on these attributes

In [28]:
def combine_metadatas(row):
    metadata_dict = ast.literal_eval(row['p_attributes'].lower())
    metadata_dict.update({'color' : row['colour'].lower(), 
                          'brand' : row['brand'].lower(), 
                          'price' : row['price'], 
                          'ratingCount' : row['ratingCount'], 
                          'avg_rating' : row['avg_rating'], 
                          'img' : row['img'] })
    return metadata_dict


data['new_metadatas'] = data.apply(combine_metadatas, axis=1)
metadata_list = data['new_metadatas'].tolist()


Also adding metadata as text to the description to aid in semantic search

In [30]:
def construct_documents(row):
    doc = f"name {row['name'].lower()} description {row['description'].lower()} product {row['products']} description {" ".join(f"{k} = {v} \n" for k,v in row['new_metadatas'].items())}"
    return doc

documents_list = data.apply(construct_documents, axis=1).tolist()
ids = data['p_id'].apply(lambda id:str(id)).tolist()

Creating the _fashion_coll_ collection which is our main collection

In [32]:
fashion_coll.upsert(
    documents= documents_list,
    ids = ids,
    metadatas = metadata_list
)

## 4. Querying the database

#### Fetch Query Parameters from OpenAI

Here the parameters **rating**, **price**, **brand** and **color** are extracted out of user query. 

The response will be used in querying the database in the _where_ clause so that the results are accurate.

In [35]:
def parse_user_query(user_query):
    prompt = f"""
You are a helpful assistant that converts user queries into filters for a product database.
Given the following user query, extract the conditions and format them as a JSON object for use as a "where" clause in ChromaDB.

The metadata fields are:
- color: string (e.g., "red", "blue")
- price: float (e.g., {{"$lt": 15}} for "under $15", {{"$lt": 15}} for "above $15")
- brand: string (e.g., "XYZ" for "brand XYZ", "Roly Poly" for "from brand Roly Poly", "House of Pataudi" for "by House of Pataudi"  )
- avg_rating: float (e.g., {{"$gt": 4}} for "rating above 4 stars", {{"$lt": 4}} for "user rating below 4 stars"), {{"$lt": 3}} for "average rating below 3") , {{"$lt": 3}} for "average below 4 stars")

####
User Query: {user_query}
####

Strictly return the "content" as a JSON string without the special characters as prefix and suffix.
Don't add any attributes other than "color", "price", "brand", "rating".
This json will be used for "where" clause in chromadb
None of the attributes are mandatory. If you cannot find a value then don't include the attribute in the response.

If you find more than one attributes in the User Query, surround the JSON with an "$and [  ]"
Example: "$and": [  {{'color': 'black'}}, {{'price': {{'$lt': 1000}} }}, {{ 'avg_rating': {{'$gt': 4}} }}    ]

"""

    # Make the API call
    try:
        response = openai.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": "You are a helpful assistant."},
                      {"role": "user", "content": prompt}]
        )

        response = response.choices[0].message.content
        return json.loads(response.lower())
    except:
        return {}

# Example user query
user_query = 'Please show me black printed kurtas by Brand XYZ'
where_clause = parse_user_query(user_query)
print(where_clause)

user_query = 'Please show me red kurtas with rating of atleast 4 and priced under 1000 by brand Rudra Bazaar'
where_clause = parse_user_query(user_query)
print(where_clause)


{'$and': [{'color': 'black'}, {'brand': 'xyz'}]}
{'$and': [{'color': 'red'}, {'price': {'$lt': 1000}}, {'avg_rating': {'$gt': 4}}, {'brand': 'rudra bazaar'}]}


#### Testing the query

In [37]:
user_query='Black Saree under 5000 rupees with rating of atleast 4 from brand Kalini'
parsed_query = parse_user_query(user_query)
print(parsed_query)
query_embedding = model.encode(user_query.lower())
results = fashion_coll.query(query_embeddings=[query_embedding], n_results=10, where=parsed_query, include=['documents', 'distances', 'metadatas'])
print(results)


{'$and': [{'color': 'black'}, {'price': {'$lt': 5000}}, {'avg_rating': {'$gt': 4}}, {'brand': 'kalini'}]}
{'ids': [['17035744', '16748456', '17407676', '12754022', '17482890', '17241826']], 'embeddings': None, 'documents': [['name kalini black & golden woven design saree description <b> design details </b> <ul> <li> black and gold-toned saree </li> <li> geometric woven design saree with woven design border </li> </ul> <br> the saree comes with an unstitched blouse piece<br>the blouse worn by the model might be for modelling purpose only. check the image of the blouse piece to understand how the actual blouse piece looks like.<p>dryclean</p>length: 5.5 metres plus 0.8 metre blouse piece <br> width: 1.06 metres (approx.) product Saree description blouse = blouse piece \n blouse fabric = pure silk \n border = woven design \n care for me = na \n multipack set = na \n occasion = festive \n ornamentation = na \n pattern = woven design \n print or pattern type = geometric \n saree fabric = pu

#### Create a dataset out of the results from the database

In [39]:
def create_result_dataset(results):
    query_ids = []
    query_docs = []
    query_distances = []
    query_metadatas = []
    
    for key, value in results.items():
        if 'ids' in key:
            query_ids.extend(value[0])
        elif 'documents' in key:
            query_docs.extend(value[0])
        elif 'distances' in key:
            query_distances.extend(value[0])
        elif 'metadatas' in key:
            query_metadatas.extend(value[0])
    
    return pd.DataFrame({
        'IDs': query_ids,
        'Documents': query_docs,
        'Distances': query_distances,
        'Metadatas': query_metadatas
    })


results_df = create_result_dataset(results)
results_df

Unnamed: 0,IDs,Documents,Distances,Metadatas
0,17035744,name kalini black & golden woven design saree ...,1.067969,"{'avg_rating': 4.434782609, 'blouse': 'blouse ..."
1,16748456,name kalini black & off white pure cotton prin...,1.09784,"{'avg_rating': 4.468085106, 'blouse': 'blouse ..."
2,17407676,name kalini black & beige silk blend bandhani ...,1.132728,"{'avg_rating': 4.5, 'blouse': 'blouse piece', ..."
3,12754022,name kalini black & red jute silk embroidered ...,1.135782,"{'avg_rating': 4.055282555, 'blouse': 'blouse ..."
4,17482890,name kalini black & gold-toned striped boat ne...,1.181969,"{'avg_rating': 4.583333333, 'body or garment s..."
5,17241826,name kalini women black geometric checked thre...,1.379131,"{'avg_rating': 4.228205128, 'body or garment s..."


#### Implementing Caching

Create a new collection _cache_collection_ where the user queries and results will be cached for faster response

In [105]:
client.delete_collection(name='cache_collection')
cache_collection = client.get_or_create_collection(name='cache_collection')
threshold = 0.1

**This is the main function which takes in the user query and returns top 10 results in a dataframe.**
1. The user query is first searched in the cache.
2. If the user query is not found in the cache with a distance of more then threshold, then the results are fetched from main collection. The results are stored back in the cache.
3. If the user query is found in the cache with a distance of less than threshold, the results are returned from the cache collection.

In [110]:
def fetch_data(user_query):
    ids = []
    documents = []
    distances = []
    metadatas = []
    results_df = pd.DataFrame()
    
    where_clause = parse_user_query(user_query)
    query_embedding = model.encode(user_query.lower())

    cache_results = cache_collection.query(
        query_texts=user_query,
        n_results=1
    )

    if cache_results['distances'][0] == [] or cache_results['distances'][0][0] > threshold:
        results = fashion_coll.query(
            query_embeddings=[query_embedding], 
            n_results=10, 
            where=where_clause, 
            include=['documents', 'distances', 'metadatas']
        )
        results_df = create_result_dataset(results)
        
        cache_collection.add(
            documents=[user_query],
            ids = [user_query],
            metadatas = { k:str(v) for (k,v) in results_df.to_dict('list').items()}
        )
    elif cache_results['distances'][0][0] <= threshold:
        print('Found in cache!')
        cache_result_dict = cache_results['metadatas'][0][0]
        results_df = pd.DataFrame({k:ast.literal_eval(v) for (k,v) in cache_result_dict.items()})
        
    return results_df
    

Testing the cache implementation

In [45]:
print('First call')
fetch_data('Find black pullovers with user rating above 4').head(2)

First call


Unnamed: 0,IDs,Documents,Distances,Metadatas
0,4423979,name dressberry women black & grey checked pul...,1.313968,"{'avg_rating': 4.489028213, 'body or garment s..."
1,15821466,name defacto women black acrylic pullover with...,1.357816,"{'avg_rating': 5.0, 'body or garment size': 'g..."


In [46]:
print('Similar call again')
fetch_data('Find black pullovers having user rating atleast 4').head(2)

Similar call again
Found in cache!


Unnamed: 0,Distances,Documents,IDs,Metadatas
0,1.313968,name dressberry women black & grey checked pul...,4423979,"{'avg_rating': 4.489028213, 'body or garment s..."
1,1.357816,name defacto women black acrylic pullover with...,15821466,"{'avg_rating': 5.0, 'body or garment size': 'g..."


## 5. Cross Encoding

In [48]:
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [49]:
# Test the cross encoder model

scores = cross_encoder.predict([['Does the insurance cover diabetic patients?', 'The insurance policy covers some pre-existing conditions including diabetes, heart diseases, etc. The policy does not howev'],
                                ['Does the insurance cover diabetic patients?', 'The premium rates for various age groups are given as follows. Age group (<18 years): Premium rate']])

scores

array([  3.8467638, -11.252879 ], dtype=float32)

This function performs the cross validation for the user query against each of the elements returned in the result

In [51]:
def perform_cross_validation(user_query, results_df):
    cross_inputs = [[user_query, response] for response in results_df['Documents']]
    cross_rerank_scores = cross_encoder.predict(cross_inputs)
    results_df['Reranked_scores'] = cross_rerank_scores
    return results_df

Testing the cross validation.

In [53]:
perform_cross_validation(user_query, results_df)

Unnamed: 0,IDs,Documents,Distances,Metadatas,Reranked_scores
0,17035744,name kalini black & golden woven design saree ...,1.067969,"{'avg_rating': 4.434782609, 'blouse': 'blouse ...",-0.738804
1,16748456,name kalini black & off white pure cotton prin...,1.09784,"{'avg_rating': 4.468085106, 'blouse': 'blouse ...",-1.018327
2,17407676,name kalini black & beige silk blend bandhani ...,1.132728,"{'avg_rating': 4.5, 'blouse': 'blouse piece', ...",-1.950579
3,12754022,name kalini black & red jute silk embroidered ...,1.135782,"{'avg_rating': 4.055282555, 'blouse': 'blouse ...",-1.321804
4,17482890,name kalini black & gold-toned striped boat ne...,1.181969,"{'avg_rating': 4.583333333, 'body or garment s...",-4.760744
5,17241826,name kalini women black geometric checked thre...,1.379131,"{'avg_rating': 4.228205128, 'body or garment s...",-2.192584


In [54]:
top_3_semantic = results_df.sort_values(by='Distances')
top_3_semantic[:3]

Unnamed: 0,IDs,Documents,Distances,Metadatas,Reranked_scores
0,17035744,name kalini black & golden woven design saree ...,1.067969,"{'avg_rating': 4.434782609, 'blouse': 'blouse ...",-0.738804
1,16748456,name kalini black & off white pure cotton prin...,1.09784,"{'avg_rating': 4.468085106, 'blouse': 'blouse ...",-1.018327
2,17407676,name kalini black & beige silk blend bandhani ...,1.132728,"{'avg_rating': 4.5, 'blouse': 'blouse piece', ...",-1.950579


In [55]:
top_3_rerank = results_df.sort_values(by='Reranked_scores', ascending=False)
top_3_rerank[:3]

Unnamed: 0,IDs,Documents,Distances,Metadatas,Reranked_scores
0,17035744,name kalini black & golden woven design saree ...,1.067969,"{'avg_rating': 4.434782609, 'blouse': 'blouse ...",-0.738804
1,16748456,name kalini black & off white pure cotton prin...,1.09784,"{'avg_rating': 4.468085106, 'blouse': 'blouse ...",-1.018327
3,12754022,name kalini black & red jute silk embroidered ...,1.135782,"{'avg_rating': 4.055282555, 'blouse': 'blouse ...",-1.321804


In [56]:
top_3_RAG = top_3_rerank[["Documents", "Metadatas"]][:3] if len(top_3_rerank) > 3 else top_3_rerank
print(top_3_RAG["Documents"].tolist())
print(top_3_RAG["Metadatas"].tolist())

['name kalini black & golden woven design saree description <b> design details </b> <ul> <li> black and gold-toned saree </li> <li> geometric woven design saree with woven design border </li> </ul> <br> the saree comes with an unstitched blouse piece<br>the blouse worn by the model might be for modelling purpose only. check the image of the blouse piece to understand how the actual blouse piece looks like.<p>dryclean</p>length: 5.5 metres plus 0.8 metre blouse piece <br> width: 1.06 metres (approx.) product Saree description blouse = blouse piece \n blouse fabric = pure silk \n border = woven design \n care for me = na \n multipack set = na \n occasion = festive \n ornamentation = na \n pattern = woven design \n print or pattern type = geometric \n saree fabric = pure silk \n sustainable = regular \n trends = celebrity saree \n type = na \n wash care = dry clean \n wedding = bride & wedding squad \n color = black \n brand = kalini \n price = 3699 \n ratingCount = 23 \n avg_rating = 4.4

## 6. Retrieval Augmented Generation

This function takes the user query and top 3 results and displays it in HTML format so that the Jupyter notebook can display it properly.

The OpenAI chat completion is used for this purpose.

In [58]:
def display_results(user_query, top_3_RAG):
    prompt = f"""
You are a helpful assistant and a fashion expert.
The user is querying shopping data which contains various fashion apparel.
You have a question asked by the user in '{user_query}' and you have some search results from shopping data in the dataframe '{top_3_RAG["Documents"].tolist()}'. 
These search results are closest match for the shopping results that user is interested in.
Display as much information from "Metadatas" as possible. You should definitely display "Price", "Rating", "Color", "Brand". 
Also include a detailed description of each product.
The price is in Indian Rupees and has a symbol of ₹.
The "img" in the "metadata" has the product image url. Strictly use this value in the "src" attribute of <img> HTML tag. Don't use placeholder. Use actual image URL.
Any more information also will help.

Please return the data in HTML format. The results will be directly used to display on a browser.

The HTML should:
- Be visually appealing.
- Use a grid layout for the products.
- Include CSS for basic styling.
- Response should contain only HTML and CSS elements. It should not have any prefix or suffix like ```html or ```

"""

    # Make the API call
    try:
        response = openai.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": "You are a helpful assistant and a fashion expert."},
                      {"role": "user", "content": prompt}]
        )

        return response.choices[0].message.content
    except:
        return ''


Testing the RAG function

In [60]:
RAG_results_HTML = display_results(user_query=user_query, top_3_RAG=top_3_RAG)
display(HTML(RAG_results_HTML))

## 7. Putting it all together

In [62]:
def find_fashion_deals(user_query):
    results_df = fetch_data(user_query)
    if len(results_df) > 0:
        results_df = perform_cross_validation(user_query, results_df)
        results_reranked_sorted = results_df.sort_values(by='Reranked_scores', ascending=False)
        top_3_rerank = results_reranked_sorted[:3] if len(results_reranked_sorted) > 3 else results_reranked_sorted
        top_3_RAG = top_3_rerank[["Documents", "Metadatas"]]
        RAG_results_HTML = display_results(user_query, top_3_RAG)
        display(HTML(RAG_results_HTML))
    else:
        print('[FASHION SEARCH AI]: No products found for your search.')
    

In [112]:
user_query = input("[FASHION SEARCH AI]: Please enter your query to find the best of deals!\n")
find_fashion_deals(user_query)

[FASHION SEARCH AI]: Please enter your query to find the best of deals!
  Find me red kurtas with mandarin collar with user rating above 4
