In [1]:
import time
import pandas as pd
from datasets import load_dataset
from datetime import date
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, losses, InputExample

In [2]:
esci = load_dataset("tasksource/esci")

In [3]:
esci

DatasetDict({
    train: Dataset({
        features: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_text'],
        num_rows: 2027874
    })
    test: Dataset({
        features: ['example_id', 'query', 'query_id', 'product_id', 'product_locale', 'esci_label', 'small_version', 'large_version', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color', 'product_text'],
        num_rows: 652490
    })
})

In [4]:
esci_train_df = esci["train"].to_pandas()

In [5]:
esci_train_df.shape

(2027874, 14)

In [6]:
esci_train_df.head()

Unnamed: 0,example_id,query,query_id,product_id,product_locale,esci_label,small_version,large_version,product_title,product_description,product_bullet_point,product_brand,product_color,product_text
0,0,revent 80 cfm,0,B000MOO21W,us,Irrelevant,0,1,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...,,WhisperCeiling fans feature a totally enclosed...,Panasonic,White,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...
1,291891,bathroom fan without light,13723,B000MOO21W,us,Exact,1,1,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...,,WhisperCeiling fans feature a totally enclosed...,Panasonic,White,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...
2,1,revent 80 cfm,0,B07X3Y6B1V,us,Exact,0,1,Homewerks 7141-80 Bathroom Fan Integrated LED ...,,OUTSTANDING PERFORMANCE: This Homewerk's bath ...,Homewerks,80 CFM,Homewerks 7141-80 Bathroom Fan Integrated LED ...
3,2,revent 80 cfm,0,B07WDM7MQQ,us,Exact,0,1,Homewerks 7140-80 Bathroom Fan Ceiling Mount E...,,OUTSTANDING PERFORMANCE: This Homewerk's bath ...,Homewerks,White,Homewerks 7140-80 Bathroom Fan Ceiling Mount E...
4,3,revent 80 cfm,0,B07RH6Z8KW,us,Exact,0,1,Delta Electronics RAD80L BreezRadiance 80 CFM ...,This pre-owned or refurbished product has been...,Quiet operation at 1.5 sones\nBuilt-in thermos...,DELTA ELECTRONICS (AMERICAS) LTD.,White,Delta Electronics RAD80L BreezRadiance 80 CFM ...


## Define target labels

### Our objective is to create a dataset such that 
### For each query there are **positive products** and **negative products**

**(query, positive products, negative products)**

In [9]:
esci_train_df_sample = esci_train_df#.sample(100_000)

def filter_negs(rows):
    rows = rows[(rows["esci_label"] == 'Irrelevant')]
    return list(rows.product_id.head(20))

def filter_pos(rows):
    rows = rows[(rows["esci_label"] == 'Exact')]
    return list(rows.product_id.head(20))
    

    
def create_mnli_data(esci_train_df_sample, file_id):
    query_negs = esci_train_df_sample.groupby('query_id').apply(lambda x: filter_negs(x))
    query_negs = query_negs.reset_index()
    query_negs.columns = ["query_id", "neg_product_ids"]
    query_negs = query_negs[query_negs.apply(lambda x: len(x["neg_product_ids"]) > 2, axis=1)]

    print("query_negs.shape", query_negs.shape)

    query_pos = esci_train_df_sample.groupby('query_id').apply(lambda x: filter_pos(x))
    query_pos = query_pos.reset_index()
    query_pos.columns = ["query_id", "pos_product_ids"]
    query_pos = query_pos[query_pos.apply(lambda x: len(x["pos_product_ids"]) > 2, axis=1)]
    print("query_pos.shape", query_pos.shape)

    merged_negs_pos = pd.merge(query_negs, query_pos, on="query_id")
    print("merged_negs_pos.shape", merged_negs_pos.shape)
    
    merged_negs_pos.to_csv(f"{file_id}_esci_pos_neg_top3_{date.today()}.csv", index=None)
    

In [10]:
create_mnli_data(esci_train_df, "train")

esci_test_df = esci["test"].to_pandas()
create_mnli_data(esci_test_df, "test")

query_negs.shape (20550, 2)
query_pos.shape (92004, 2)
merged_negs_pos.shape (17060, 3)
query_negs.shape (7507, 2)
query_pos.shape (28339, 2)
merged_negs_pos.shape (6304, 3)


In [11]:
pd.read_csv("train_esci_pos_neg_top3_2024-06-17.csv")

Unnamed: 0,query_id,neg_product_ids,pos_product_ids
0,1,"['B075SCHMPY', 'B082K7V2GZ', 'B00N16T5D8', 'B0...","['B08L3B9B9P', 'B07C1WZG12', 'B077QMNXTS', 'B0..."
1,3,"['B08TQF34WG', 'B00KHTM8L8', 'B07HG7ZST9', 'B0...","['B08TQH4RBB', 'B08H2H18DN', 'B08HJQJT2T', 'B0..."
2,5,"['B07NZ4T2SL', 'B08Y8VX63W', 'B08Y97TG79', 'B0...","['B07RP3F2LW', 'B07FK9PCZB', 'B07GJCSF8Z', 'B0..."
3,6,"['B07ZGD1SL3', 'B07JZJLHCF', 'B07G7F6JZ6', 'B0...","['B082CLHSPD', 'B07NH2D2CP', 'B084Q4R7SG', 'B0..."
4,7,"['B07D3L3FCW', 'B07KKP1VVP', 'B07DJ846QY', 'B0...","['B096LZW7J1', 'B07MCMMS92', 'B096LMZMYT', 'B0..."
...,...,...,...
17055,130630,"['B07CXZMXNV', 'B00B6XM1IO', 'B00VAAIN5I', 'B0...","['B07VJH4Q5X', 'B073J224FL', 'B00LWV5VHI', 'B0..."
17056,130631,"['4119160343', '4401162331', '4636956699', '46...","['B082PQ73YH', 'B099C3GLMD', 'B00PA5HKZ8', 'B0..."
17057,130635,"['B07GC4DX9D', 'B00GMA6KWO', 'B07VNDZZRG', 'B0...","['B07X5V55TX', 'B07X5V55TX', 'B081821VZW', 'B0..."
17058,130636,"['B072J57CTP', 'B07CH9R2VH', 'B093L4W9QV', 'B0...","['B001G61KM4', 'B01N0UYSFP', 'B08Q3DH2ZM', 'B0..."
