In [4]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import os
import gzip
import csv
from datetime import date
import json
from torch.utils.data import DataLoader
from datasets import load_dataset
import pandas as pd

In [5]:
train_batch_size = 384
num_epochs = 2
model_save_path = "output/training_esci-crossencoder-exact_vs_nonexact" + str(datetime.today())


In [6]:
# We use cross-encoder/ms-marco-MiniLM-L-6-v2 as base model and set num_labels=1, 
# which predicts a continuous score between 0 and 1
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", num_labels=1)

  return self.fget.__get__(instance, owner)()


In [7]:
train_samples = []
dev_samples = []
test_samples = []

In [8]:
esci = load_dataset("tasksource/esci")

In [9]:
esci_train_df = esci["train"].to_pandas()
esci_train_df.sample(4)

Unnamed: 0,example_id,query,query_id,product_id,product_locale,esci_label,small_version,large_version,product_title,product_description,product_bullet_point,product_brand,product_color,product_text
1651080,1970610,tattoo aftercare,100839,B000HP1IDG,us,Exact,0,1,"Tattoo Goo The Original After Care Salve, 0.75...",,Natural Choice for healing your tattoo fast\nF...,Tattoo Goo,As Picture,"Tattoo Goo The Original After Care Salve, 0.75..."
150666,83849,2年保証付き 工事なしled蛍光灯 グロー式工事不要・ラビット式工事不要・インバーター式工事不要,3290,B07X9LM6HV,jp,Exact,1,1,LED蛍光灯 40W形 T8直管 18W消費電力 120cm G13口金 昼光色（6500K...,◆商品名：T8直管蛍光灯型LEDランプ◆口金：G13 ◆電圧：AC85V〜265V ◆消費電...,商品名：T8直管蛍光灯型LEDランプ(※T5非対応) ルーメン度：2600LM　口金：G1...,TARUNA,40W,LED蛍光灯 40W形 T8直管 18W消費電力 120cm G13口金 昼光色（6500K...
1403170,2138386,wall mounted bathroom towel storage rack,109536,B07QXT7WDL,us,Exact,0,1,"2-Tier Metal Industrial 23.6"" Bathroom Shelves...","2-Tier Industrial 23.6"" Bathroom Shelves Wall ...","The bathroom shelves are versatile, such as ba...",MBQQ,C,"2-Tier Metal Industrial 23.6"" Bathroom Shelves..."
853546,714798,drive belt troybuilt riding mower 46 inch lawn...,35572,B07SM5Z8FJ,us,Substitute,1,1,"QIJIA Lawn Mower Tranmission Belt 5/8"" x 70.9""...","<p>Replaces OEM:754-04249, 754-04249A, 954-042...","Belt For Drive,Belt measures(inch):5/8"" x 70.9...",QIJIA,,"QIJIA Lawn Mower Tranmission Belt 5/8"" x 70.9""..."


In [10]:
esci_train_df = esci_train_df[esci_train_df.product_locale == 'us']

In [11]:
esci_label_map = {"Irrelevant": 0.0, "Exact": 1.0, "Substitute": 0.2, "Complement": 0.0}
esci_train_df['label'] = esci_train_df.esci_label.map(esci_label_map)

In [12]:
def create_item_text(row):
    title = row["product_title"]
    brand = row["product_brand"]
    color = row["product_color"]
    return f"Title: {title}. Brand: {brand}. Color: {color}"

In [13]:
esci_train_df['item_text'] = esci_train_df.apply(lambda row: create_item_text(row), axis=1)
esci_train_df = esci_train_df.dropna(subset=["query", "item_text", "label"])
esci_train_df.shape

(1420372, 16)

In [14]:
for q, t, s in zip(esci_train_df["query"].tolist(), esci_train_df["item_text"].tolist(), esci_train_df["label"].tolist()):
    train_samples.append(InputExample(texts=[q, t], label=s))

In [15]:
esci_test_df = esci["test"].to_pandas()
esci_test_df = esci_test_df[esci_test_df.product_locale == 'us'].copy().sample(frac=1.0, random_state=19)


esci_test_df['label'] = esci_test_df.esci_label.map(esci_label_map)

esci_test_df['item_text'] = esci_test_df.apply(lambda row: create_item_text(row), axis=1)
esci_test_df = esci_test_df.dropna(subset=["query", "item_text", "label"])
esci_test_df.shape

(434234, 16)

In [16]:
val_size = 5_000
val_queries_ids = esci_test_df["query_id"].unique().tolist()[:val_size]
val_queries_ids = set(val_queries_ids)
len(val_queries_ids)


5000

In [17]:
test_size = 5_000
test_query_ids = set()
for qid in esci_test_df.query_id.unique():
    if qid not in val_queries_ids:
        test_query_ids.add(qid)
    if len(test_query_ids) > test_size:
        break
        

In [18]:
val_ids_df = pd.DataFrame({"query_id": list(val_queries_ids)})
test_ids_df = pd.DataFrame({"query_id": list(test_query_ids)})

In [19]:
val_df = esci_test_df[esci_test_df.query_id.isin(val_queries_ids)]
test_df = esci_test_df[esci_test_df.query_id.isin(test_query_ids)]

In [20]:
for q, t, s in zip(val_df["query"].tolist(), val_df["item_text"].tolist(), val_df["label"].tolist()):
    dev_samples.append(InputExample(texts=[q, t], label=s))

In [21]:
for q, t, s in zip(test_df["query"].tolist(), test_df["item_text"].tolist(), test_df["label"].tolist()):
    test_samples.append(InputExample(texts=[q, t], label=s))

In [22]:
len(train_samples), len(dev_samples), len(test_samples)

(1420372, 111223, 103012)

In [23]:
#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)

In [24]:
# We wrap train_samples (which is a List[InputExample]) into a pytorch DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# We add an evaluator, which evaluates the performance during training
evaluator = CECorrelationEvaluator.from_input_examples(dev_samples)


In [25]:
# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))

2024-06-17 07:18:04 - Warmup-steps: 740


In [None]:
# Train the model
model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
)


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3699 [00:00<?, ?it/s]

2024-06-17 07:50:12 - CECorrelationEvaluator: Evaluating the model on  dataset after epoch 0:
2024-06-17 07:50:53 - Correlation:	Pearson: 0.5577	Spearman: 0.5772
2024-06-17 07:50:53 - Save model to output/training_esci-crossencoder-exact_vs_nonexact2024-06-17 07:16:25.325538


Iteration:   0%|          | 0/3699 [00:00<?, ?it/s]

NameError: name 'x' is not defined

In [28]:
##### Load model and eval on test set
model = CrossEncoder(model_save_path)

evaluator = CECorrelationEvaluator.from_input_examples(test_samples, name="esci-test")
evaluator(model)

2024-06-17 08:43:12 - Use pytorch device: cuda
2024-06-17 08:43:12 - CECorrelationEvaluator: Evaluating the model on esci-test dataset:
2024-06-17 08:43:47 - Correlation:	Pearson: 0.5705	Spearman: 0.5843


0.5843368280819196