# Prepare Example Dataset

We use [UCI Online Retail](https://archive.ics.uci.edu/dataset/352/online+retail) dataset and [preprocess](https://cstorm125.github.io/posts/sales_prediction/) it to a customer-level dataset with `TargetSales` (sales during target period), `TargetDescriptions` (products purchased during target period), and `TargetCategories` (categories purchased during target period). Features are calculated using transactions 2011-01 to 2011-09 and outcome using those between 2011-10 to 2011-12. We use `global.anthropic.claude-sonnet-4-5-20250929-v1:0` to augment product categories from product descriptions.

In [1]:
import pandas as pd
import numpy as np
import random
import dspy
from ucimlrepo import fetch_ucirepo 
from tqdm.auto import tqdm

def string_to_yearmon(date):
    date = date.split()
    date = date[0].split('/') + date[1].split(':')
    date = date[2] + '-' + date[0].zfill(2) #+ '-' + date[1].zfill(2) + ' ' + date[3].zfill(2) + ':' + date[4].zfill(2)
    return date

#categorize product to categories
lm = dspy.LM(
    model='bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0',
    region_name='us-west-2',
    max_tokens=10_000,
    temperature=0.,
)
dspy.settings.configure(lm=lm)
def categorize_product(product_descriptions: str) -> str:
    categories = [
        'Home Decor',
        'Kitchen and Dining',
        'Fashion Accessories',
        'Stationary and Gifts',
        'Toys and Games',
        'Seasonal and Holiday',
        'Personal Care and Wellness',
        'Outdoor and Garden',   
    ]
    output_prompt = f"""Product descriptions with their classified product categories namely: {", ".join(categories)} or Others, if they do not fall into any. 
Use the following format:
"product description of product #1"|"product category classified into"
"product description of product #2"|"product category classified into"
...
"product description of product #n"|"product category classified into"
Make sure all product descriptions have a classified product category.                       
"""
    class ClaudeQuery(dspy.Signature):
        """You are a product categorizer at a retail website"""
        product_description = dspy.InputField(desc="Product descriptions")
        output = dspy.OutputField(desc=output_prompt)
    
    # Use ChainOfThought for better reasoning
    claude_module = dspy.ChainOfThought(ClaudeQuery)
    
    # Generate response
    result = claude_module(
        product_description=product_descriptions
    )
    
    return result.output

  from .autonotebook import tqdm as notebook_tqdm


In [50]:
online_retail = fetch_ucirepo(id=352) 
transaction_df = online_retail['data']['original']
original_nb = transaction_df.shape[0]

#create yearmon for train-valid split
transaction_df['yearmon'] = transaction_df.InvoiceDate.map(string_to_yearmon)

#get rid of transactions without cid
transaction_df = transaction_df[~transaction_df.CustomerID.isna()].reset_index(drop=True)
has_cid_nb = transaction_df.shape[0]

#fill in unknown descriptions
transaction_df['Description'] = transaction_df.Description.fillna('UNKNOWN')

#convert customer id to string
transaction_df['CustomerID'] = transaction_df['CustomerID'].map(lambda x: str(int(x)))

#filter out non-product stock code
transaction_df = transaction_df[transaction_df.StockCode.map(lambda x: x not in ['BANK CHARGES','C2','DOT','M','PADS','POST'])]

#simplify by filtering unit price and quantity to be non-zero (get rid of discounts, cancellations, etc)
transaction_df = transaction_df[(transaction_df.UnitPrice>0)&\
                                (transaction_df.Quantity>0)].reset_index(drop=True)
has_sales_nb = transaction_df.shape[0]

#add sales
transaction_df['Sales'] = transaction_df.UnitPrice * transaction_df.Quantity

In [51]:
# #get stockcode to description mapping
# stock_code_description = transaction_df.groupby('StockCode').Description.max().reset_index()
# stock_code_description.tail()

# # generate product category loop through descriptions in batches of batch_size
# product_descriptions = stock_code_description.Description.tolist()
# res_texts = []
# original_descriptions = []
# batch_size = 100
# for i in tqdm(range(0, len(product_descriptions), batch_size)):
#     original_description = product_descriptions[i:i+batch_size]
#     res = categorize_product('\n'.join(original_description))
#     res_text = res.strip().split('\n')
#     if len(res_text)!=batch_size:
#         print(f'{batch_size} descriptions but {len(res_text)} categorized')
#     res_texts.extend(res_text)
#     original_descriptions.extend(original_description)
# stock_code_description['category'] = [i.split('|')[1].replace('"','') for i in res_texts]
# stock_code_description.to_csv('../data/product_description_category.csv',
#                               index=False)

In [52]:
product_description_category = pd.read_csv('../data/product_description_category.csv')
transaction_df = transaction_df.merge(product_description_category[['StockCode','category']],
                                      how='left',
                                      on='StockCode')

# product category distribution 
product_description_category.category.value_counts(normalize=True)

category
Home Decor                    0.317300
Kitchen and Dining            0.172451
Fashion Accessories           0.154687
Stationary and Gifts          0.132550
Seasonal and Holiday          0.100847
Toys and Games                0.050287
Personal Care and Wellness    0.036075
Outdoor and Garden            0.031703
Others                        0.004099
Name: proportion, dtype: float64

In [55]:
feature_period = {'start': '2011-01', 'end': '2011-09'}
outcome_period = {'start': '2011-10', 'end': '2011-12'}

feature_transaction = transaction_df[(transaction_df.yearmon>=feature_period['start'])&\
                                      (transaction_df.yearmon<=feature_period['end'])]
outcome_transaction = transaction_df[(transaction_df.yearmon>=outcome_period['start'])&\
                                      (transaction_df.yearmon<=outcome_period['end'])]

#aggregate sales during outcome period
outcome_sales = outcome_transaction.groupby('CustomerID').Sales.sum().reset_index()

#aggregate sales during feature period
feature_sales = feature_transaction.groupby('CustomerID').Sales.sum().reset_index()

#aggregate items during outcome period
outcome_items = outcome_transaction.groupby('CustomerID').Description.apply(lambda x: '|'.join(x.unique()))

#aggregate categories during feature period
feature_categories = feature_transaction.groupby('CustomerID').category.apply(lambda x: '|'.join(x.unique()))

#aggregate categories during outcome period
outcome_categories = outcome_transaction.groupby('CustomerID').category.apply(lambda x: '|'.join(x.unique()))

#aggregate items during feature period
feature_items = feature_transaction.groupby('CustomerID').Description.apply(lambda x: '|'.join(x.unique()))

#merge to get TargetSales including those who spent during feature period but not during outcome (zeroes)
outcome_df = feature_sales[['CustomerID']]\
    .merge(outcome_sales, on='CustomerID', how='left')\
    .merge(outcome_items, on='CustomerID', how='left')\
    .merge(outcome_categories, on='CustomerID', how='left')\
    .merge(feature_items, on='CustomerID', how='left')\
    .merge(feature_categories, on='CustomerID', how='left')

outcome_df.columns = ['CustomerID',
                      'TargetSales','TargetDescriptions','TargetCategories',
                      'bought_descriptions','bought_categories',
                      ]
outcome_df['TargetSales'] = outcome_df['TargetSales'].fillna(0)
outcome_df['TargetDescriptions'] = outcome_df['TargetDescriptions'].fillna('')
outcome_df['TargetCategories'] = outcome_df['TargetCategories'].fillna('')
outcome_df.tail()

Unnamed: 0,CustomerID,TargetSales,TargetDescriptions,TargetCategories,bought_descriptions,bought_categories
3429,18280,0.0,,,WOOD BLACK BOARD ANT WHITE FINISH|RETROSPOT LA...,Home Decor|Kitchen and Dining|Seasonal and Hol...
3430,18281,0.0,,,ROBOT BIRTHDAY CARD|CARD CIRCUS PARADE|PENNY F...,Stationary and Gifts|Toys and Games|Fashion Ac...
3431,18282,77.84,REGENCY CAKESTAND 3 TIER|ROSES REGENCY TEACUP ...,Kitchen and Dining|Stationary and Gifts,ANTIQUE CREAM CUTLERY CUPBOARD|FRENCH STYLE ST...,Home Decor|Kitchen and Dining|Seasonal and Hol...
3432,18283,974.21,16 PIECE CUTLERY SET PANTRY DESIGN|BISCUIT TIN...,Kitchen and Dining|Seasonal and Holiday|Statio...,CHARLOTTE BAG PINK POLKADOT|LUNCH BAG WOODLAND...,Fashion Accessories|Toys and Games|Home Decor|...
3433,18287,1072.0,HAND WARMER OWL DESIGN|SET OF 3 WOODEN SLEIGH ...,Personal Care and Wellness|Seasonal and Holida...,SMALL PURPLE BABUSHKA NOTEBOOK |SMALL RED BABU...,Stationary and Gifts|Seasonal and Holiday|Toys...


In [58]:
#convert invoice date to datetime
feature_transaction['InvoiceDate'] = pd.to_datetime(feature_transaction['InvoiceDate'])

# last date in feature set
current_date = feature_transaction['InvoiceDate'].max()

#rfm
customer_features = feature_transaction.groupby('CustomerID').agg({
    'InvoiceDate': [
        ('recency', lambda x: (current_date - x.max()).days),
        ('first_purchase_date', 'min'),
        ('purchase_day', 'nunique'),
    ],
    'InvoiceNo': [('nb_invoice', 'nunique')],
    'Sales': [
        ('total_sales', 'sum')
    ],
    'StockCode': [('nb_product', 'nunique')],
    'category': [('nb_category', 'nunique')]
}).reset_index()

# Flatten column names
customer_features.columns = [
    'CustomerID',
    'recency',
    'first_purchase_date',
    'purchase_day',
    'nb_invoice',
    'total_sales',
    'nb_product',
    'nb_category'
]

customer_features['customer_lifetime'] = (current_date - customer_features['first_purchase_date']).dt.days
customer_features['avg_purchase_frequency'] = customer_features['customer_lifetime'] / customer_features['purchase_day']
customer_features['avg_purchase_value'] = customer_features['total_sales'] / customer_features['purchase_day']

#category preference
category_sales = feature_transaction.pivot_table(
    values='Sales', 
    index='CustomerID', 
    columns='category', 
    aggfunc='sum', 
    fill_value=0
)
category_sales.columns = [i.lower().replace(' ','_') for i in category_sales.columns]
customer_features = customer_features.merge(category_sales, on='CustomerID', how='left')

total_sales = customer_features['total_sales']
for col in category_sales.columns:
    percentage_col = f'per_{col}'
    customer_features[percentage_col] = customer_features[col] / total_sales

selected_features = [
 'recency',
 'purchase_day',
 'total_sales',
 'nb_product',
 'nb_category',
 'customer_lifetime',
 'avg_purchase_frequency',
 'avg_purchase_value',
 'per_fashion_accessories',
 'per_home_decor',
 'per_kitchen_and_dining',
 'per_others',
 'per_outdoor_and_garden',
 'per_personal_care_and_wellness',
 'per_seasonal_and_holiday',
 'per_stationary_and_gifts',
 'per_toys_and_games'
 ]

customer_features = customer_features[['CustomerID']+selected_features]
df = outcome_df.merge(customer_features, on='CustomerID')
df.to_csv('../data/uci_online_retail.csv',index=False)
df.tail()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  feature_transaction['InvoiceDate'] = pd.to_datetime(feature_transaction['InvoiceDate'])


Unnamed: 0,CustomerID,TargetSales,TargetDescriptions,TargetCategories,bought_descriptions,bought_categories,recency,purchase_day,total_sales,nb_product,...,avg_purchase_value,per_fashion_accessories,per_home_decor,per_kitchen_and_dining,per_others,per_outdoor_and_garden,per_personal_care_and_wellness,per_seasonal_and_holiday,per_stationary_and_gifts,per_toys_and_games
3429,18280,0.0,,,WOOD BLACK BOARD ANT WHITE FINISH|RETROSPOT LA...,Home Decor|Kitchen and Dining|Seasonal and Hol...,207,1,180.6,10,...,180.6,0.082226,0.590255,0.098007,0.0,0.0,0.0,0.229513,0.0,0.0
3430,18281,0.0,,,ROBOT BIRTHDAY CARD|CARD CIRCUS PARADE|PENNY F...,Stationary and Gifts|Toys and Games|Fashion Ac...,110,1,80.82,7,...,80.82,0.204157,0.18931,0.0,0.0,0.0,0.0,0.0,0.187082,0.419451
3431,18282,77.84,REGENCY CAKESTAND 3 TIER|ROSES REGENCY TEACUP ...,Kitchen and Dining|Stationary and Gifts,ANTIQUE CREAM CUTLERY CUPBOARD|FRENCH STYLE ST...,Home Decor|Kitchen and Dining|Seasonal and Hol...,56,1,100.21,7,...,100.21,0.0,0.127233,0.332402,0.0,0.0,0.363736,0.176629,0.0,0.0
3432,18283,974.21,16 PIECE CUTLERY SET PANTRY DESIGN|BISCUIT TIN...,Kitchen and Dining|Seasonal and Holiday|Statio...,CHARLOTTE BAG PINK POLKADOT|LUNCH BAG WOODLAND...,Fashion Accessories|Toys and Games|Home Decor|...,25,10,1114.72,191,...,111.472,0.404101,0.182216,0.195699,0.0,0.008926,0.03382,0.02598,0.096015,0.053242
3433,18287,1072.0,HAND WARMER OWL DESIGN|SET OF 3 WOODEN SLEIGH ...,Personal Care and Wellness|Seasonal and Holida...,SMALL PURPLE BABUSHKA NOTEBOOK |SMALL RED BABU...,Stationary and Gifts|Seasonal and Holiday|Toys...,131,1,765.28,27,...,765.28,0.0,0.488684,0.0,0.0,0.0,0.176406,0.109372,0.165952,0.059586


## Product Description Vector Database

In [1]:
import polars as pl
from sentence_transformers import SentenceTransformer
from src.counterfactual.similarity_searcher import SimilaritySearcher

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#load product description 
product_df = pl.read_csv('../data/product_description_category.csv')
product_df.shape

(3659, 3)

In [None]:
# Load embedding model
model = SentenceTransformer('NovaSearch/stella_en_1.5B_v5')

# Generate embeddings
embeddings = model.encode(product_df['Description'])

embedding_cols = {f'emb_{i}': embeddings[:, i] for i in range(embeddings.shape[1])}
product_df = product_df.hstack(pl.DataFrame(embedding_cols))
product_df.write_csv('../data/product_description_category_emb.csv')
product_df

StockCode,Description,category,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,emb_10,emb_11,emb_12,emb_13,emb_14,emb_15,emb_16,emb_17,emb_18,emb_19,emb_20,emb_21,emb_22,emb_23,emb_24,emb_25,emb_26,emb_27,emb_28,emb_29,emb_30,emb_31,emb_32,emb_33,…,emb_987,emb_988,emb_989,emb_990,emb_991,emb_992,emb_993,emb_994,emb_995,emb_996,emb_997,emb_998,emb_999,emb_1000,emb_1001,emb_1002,emb_1003,emb_1004,emb_1005,emb_1006,emb_1007,emb_1008,emb_1009,emb_1010,emb_1011,emb_1012,emb_1013,emb_1014,emb_1015,emb_1016,emb_1017,emb_1018,emb_1019,emb_1020,emb_1021,emb_1022,emb_1023
str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""10002""","""INFLATABLE POLITICAL GLOBE ""","""Toys and Games""",0.09797,0.52877,0.688597,-0.049607,-0.39251,-0.034466,-0.839582,1.249184,0.455385,-0.04335,-0.48993,0.755253,0.180672,0.040401,0.334391,0.618307,-0.550197,-1.144616,-0.490638,-0.595389,0.567807,0.116539,-1.176983,-0.309072,-0.281908,0.226463,-0.099196,-1.264173,0.107136,0.496841,-0.215387,0.579674,-0.751316,-1.101377,…,0.093357,-0.342301,0.694351,0.357893,-0.069469,-0.104702,-0.481543,-0.816441,-0.075111,0.226359,-0.043889,0.071427,-0.203997,-0.398894,-0.766069,0.540775,0.55563,0.840128,0.32328,0.469963,-0.284414,0.141342,-0.040315,-0.470405,1.022281,0.69458,-0.308926,0.044527,0.589936,1.310948,0.66225,-0.087186,0.68555,-1.205114,-1.148063,-0.534157,1.484566
"""10080""","""GROOVY CACTUS INFLATABLE""","""Toys and Games""",-0.347405,0.381222,0.607923,0.21378,0.044668,-0.000772,-0.517559,0.517194,0.434281,0.047029,0.027418,0.393879,0.35171,-0.392793,0.176278,0.112154,-0.294087,-0.411064,0.031848,-0.182201,0.082042,0.487947,-0.890516,-0.849967,-0.150745,0.09369,-0.300535,-0.619114,0.158257,0.378622,-0.310621,0.381253,-0.642744,-1.094794,…,-0.343972,0.027678,0.486431,0.45054,-0.510519,-0.259736,0.111654,-0.314042,-0.0516,0.223027,-0.246322,0.056138,-0.204005,-0.154736,-0.549224,0.42955,0.728388,0.525432,0.405687,0.464178,-0.114606,-0.245039,-0.01754,-0.479815,0.263839,-0.076843,-0.354421,-0.164753,0.180709,0.291167,0.587385,-0.176118,0.182536,-0.453956,-0.526059,0.119049,0.988272
"""10120""","""DOGGY RUBBER""","""Stationary and Gifts""",-0.631071,0.908719,1.054359,0.45567,-0.297082,0.232016,-0.826012,1.369144,0.714186,0.075906,-0.072901,0.277224,0.984123,-0.356684,0.244168,0.275512,-0.355399,-0.640439,0.012325,-0.495924,0.359716,0.63398,-1.3134,-1.412417,-0.314442,-0.175914,-0.364045,-1.303088,0.158304,0.728896,-0.35522,1.128381,-1.282468,-1.660495,…,-0.378258,-0.234497,0.91852,0.667035,-1.073139,-0.411292,0.200203,-0.608779,-0.078403,0.345564,-0.241761,0.030522,-0.321446,-0.643157,-0.750831,0.761127,1.276431,1.110931,0.756835,1.212485,-0.402981,-0.300275,0.2482,-0.607273,0.606658,-0.130685,-0.613512,0.247889,0.481746,0.729057,0.697257,-0.359537,0.538085,-1.039176,-1.107872,-0.302364,1.371418
"""10123C""","""HEARTS WRAPPING TAPE ""","""Stationary and Gifts""",-0.782022,0.8347,1.046383,0.874038,0.017138,0.021939,-0.799635,1.109006,1.015718,0.773115,-0.357521,0.611825,1.255638,-0.407927,0.151851,0.066019,-0.761465,-0.538999,-0.416364,-0.298912,0.374055,0.093368,-1.10751,-1.277474,-0.650284,0.181221,-0.563834,-1.094759,-0.245934,0.993042,-0.112146,0.740669,-1.133127,-1.816272,…,-0.302476,0.172798,0.657739,1.097979,-0.72816,-0.781545,0.485683,-0.475137,0.263176,0.418247,-0.289595,-0.164604,0.000234,-0.336874,-1.084661,0.569308,1.245114,0.903228,0.52282,0.890602,-0.515013,-0.375527,-0.044275,-0.712855,0.725209,-0.762814,-0.821496,0.068349,0.249773,1.067853,0.736138,-0.65862,0.553964,-0.984261,-1.240549,0.497538,1.366852
"""10124A""","""SPOTS ON RED BOOKCOVER TAPE""","""Stationary and Gifts""",-0.391419,0.183194,0.62852,0.391882,-0.341371,0.21175,-0.616613,0.948919,0.450837,0.439172,0.043348,0.684145,0.521359,-0.262855,0.062017,0.188432,-0.097526,-1.090487,-0.326444,-0.259293,0.127875,0.190115,-1.2821,-0.568068,0.148642,0.636491,0.113123,-1.252529,-0.057968,0.306752,-0.024477,0.677334,-0.906317,-1.24937,…,-0.331599,-0.135946,0.415264,0.757922,-0.935403,-0.707713,0.2233,-0.626341,-0.09955,0.208032,0.163378,0.050371,-0.227078,-0.203597,-1.200686,0.932282,1.335649,0.626397,0.243387,0.470265,-0.353656,0.264847,0.044937,-0.234007,0.326365,-0.86187,-0.552408,0.582964,0.291259,1.024892,0.543527,-0.366075,0.43287,-1.019098,-0.914329,0.310193,1.37177
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""90214U""","""LETTER ""U"" BLING KEY RING""","""Fashion Accessories""",-0.374687,0.539812,-0.180003,0.423528,-0.70176,0.339897,-0.091606,0.936668,0.807156,-0.735187,-0.664057,0.329441,0.598741,-0.195631,0.592881,0.567391,-0.531321,-1.227798,-0.387537,-0.145475,0.277294,0.128882,-1.070593,-0.416432,-0.009479,0.235597,-0.25603,-1.089052,0.079904,0.270456,0.676973,0.314132,-0.594751,-0.694704,…,-0.323553,0.722178,0.084239,0.247127,-0.97412,0.413182,0.170306,-0.326332,0.583068,-0.225572,-0.55657,-0.497389,0.06489,-0.719749,-0.707106,0.370221,0.724837,0.979838,0.165588,0.749781,0.369076,-0.146627,-0.594984,-0.443635,0.736152,-0.476199,-1.061437,0.733691,0.417372,0.248823,0.639008,0.380204,0.890001,-0.714833,-0.470144,-0.0036,0.941995
"""90214V""","""LETTER ""V"" BLING KEY RING""","""Fashion Accessories""",-0.372562,0.261567,0.261163,0.566093,-0.627734,0.327683,-0.114117,0.95283,0.827494,-0.629188,-0.635776,0.506057,0.366829,-0.414967,0.55959,0.421473,-0.411623,-1.226873,-0.352772,-0.228755,0.445966,0.236731,-1.028975,-0.494957,0.083574,0.144084,-0.339356,-1.126376,0.193937,0.151737,0.691,0.380005,-0.572798,-0.740825,…,-0.322177,0.554235,-0.000484,-0.146259,-0.776708,0.587896,0.051565,-0.270714,0.32635,-0.39139,-0.89556,-0.555574,0.010102,-0.792509,-0.841667,0.167425,0.519104,1.010214,0.027302,0.752631,0.316805,0.072763,-0.343574,-0.60554,0.725009,-0.385385,-0.827287,0.374201,0.478875,0.328974,0.45831,0.229797,0.898561,-0.577782,-0.325038,0.173634,1.038139
"""90214W""","""LETTER ""W"" BLING KEY RING""","""Fashion Accessories""",-0.306514,0.538191,0.015089,0.38017,-0.687554,0.203737,-0.206501,0.906383,0.930412,-0.674272,-0.688661,0.298578,0.39312,-0.152018,0.627574,0.430264,-0.206023,-1.245329,-0.348162,-0.056932,0.420931,-0.014924,-1.02258,-0.433313,-0.27875,0.289973,-0.231243,-1.091326,0.043085,0.354163,0.505423,0.330737,-0.265795,-0.688141,…,-0.260986,0.380649,0.063702,-0.22769,-0.852271,0.36748,0.160362,-0.212683,0.469409,-0.354669,-0.831466,-0.576493,0.212333,-0.446714,-1.02335,0.334727,0.365602,1.038367,0.130972,0.749433,0.142639,-0.190556,-0.606746,-0.418639,0.680122,-0.300327,-0.874665,0.643537,0.454198,0.462616,0.55718,0.235191,0.95126,-0.39838,-0.465476,-0.085751,0.930102
"""90214Y""","""LETTER ""Y"" BLING KEY RING""","""Fashion Accessories""",-0.237547,0.59991,0.144916,0.599599,-0.814979,0.113586,-0.151361,1.06674,0.712871,-0.815609,-0.662914,0.209356,0.341349,-0.405578,0.657158,0.293047,-0.266851,-1.310045,-0.35564,-0.083757,0.316441,-0.05675,-0.914381,-0.172349,-0.111866,0.20313,-0.286485,-1.0317,0.081053,0.264566,0.278164,0.481326,-0.14208,-1.00365,…,-0.207773,0.36166,0.100924,0.046373,-0.894472,0.470858,0.289756,-0.381021,0.408375,-0.19729,-0.540592,-0.551812,-0.237419,-0.255261,-0.4659,0.280877,0.394007,0.783666,0.234748,0.917534,0.250819,-0.075693,-0.39814,-0.430862,0.816863,-0.371522,-0.700414,0.727824,0.422215,0.427966,0.585446,0.000907,0.852546,-0.679986,-0.518456,-0.154034,1.1433


In [3]:
product_df = pl.read_csv('../data/product_description_category_emb.csv')

In [5]:
# initialize SimilaritySearcher
similarity_features = [f'emb_{i}' for i in range(1024)]
searcher = SimilaritySearcher(product_df, similarity_features=similarity_features)

In [6]:
#example query
product_df[20].to_dicts()[0]['Description']

'EDWARDIAN PARASOL PINK'

In [8]:
[i['Description'] for i in searcher.search(product_df[20].to_dicts()[0])]

['EDWARDIAN PARASOL PINK',
 'MINI LADLE LOVE HEART PINK',
 'ACRYLIC HANGING JEWEL,PINK',
 'ACRYLIC JEWEL ICICLE, PINK',
 'PIN CUSHION BABUSHKA PINK']

In [7]:
[i['_similarity_score'] for i in searcher.search(product_df[20].to_dicts()[0])]

[1.0,
 4.5481869055452e-11,
 2.490316581516502e-13,
 1.731577282831863e-14,
 5.680553022877409e-16]