## Dataset Processing Notebook


In [None]:
import os
import json
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple

project_dir = Path(os.getcwd()).parent
data_dir = project_dir / "Data" / "RAW"
processed_dir = project_dir / "Data" / "PROCESSED" 

### Links to Datasets:
* Amazon ESCI: [Github](https://github.com/amazon-science/esci-data)
* WANDS: [Github](https://github.com/wayfair/WANDS/tree/main)

### Amazon ESCI's dataset processing

In [3]:
examples_path = data_dir / "shopping_queries_dataset_examples.parquet"
products_path = data_dir / "shopping_queries_dataset_products.parquet"

examples_df = pd.read_parquet(examples_path)
products_df = pd.read_parquet(products_path)

amz_esci_df = pd.merge(
    examples_df,
    products_df,
    how="left",
    left_on="product_id",
    right_on="product_id",
    suffixes=("_example", "_product"),
)

In [21]:
print(f'Merged ESCI data shape: {amz_esci_df.shape}')
print(f'Columns: {amz_esci_df.columns.tolist()}')

Merged ESCI data shape: (613016, 11)
Columns: ['example_id', 'query', 'query_id', 'product_id', 'esci_label', 'split', 'product_title', 'product_description', 'product_bullet_point', 'product_brand', 'product_color']


In [5]:
# Filter to only small version and US locale - from dataset paper
amz_esci_df = amz_esci_df[amz_esci_df['small_version'] == 1].drop(['small_version', 'large_version'], axis=1)
amz_esci_df = amz_esci_df[amz_esci_df['product_locale_example'] == 'us'].drop(['product_locale_example', 'product_locale_product'], axis=1)

In [None]:
# Filter to only train and test splits
amz_esci_train_df_raw = amz_esci_df[amz_esci_df['split'] == 'train'].drop(['split'], axis=1)
amz_esci_test_df_raw = amz_esci_df[amz_esci_df['split'] == 'test'].drop(['split'], axis=1)

In [23]:
print(f'Amazon ESCI train shape: {amz_esci_train_df_raw.shape}')
print(f'Amazon ESCI test shape: {amz_esci_test_df_raw.shape}')

print(f'Amazon ESCI train queries count: {amz_esci_train_df_raw["query"].nunique()}')
print(f'Amazon ESCI test queries count: {amz_esci_test_df_raw["query"].nunique()}')

Amazon ESCI train shape: (427655, 10)
Amazon ESCI test shape: (185361, 10)
Amazon ESCI train queries count: 20888
Amazon ESCI test queries count: 8956


### WAND's dataset processing

In [8]:
product_df = pd.read_csv(data_dir / 'product.csv', sep='\t')
query_df = pd.read_csv(data_dir / 'query.csv', sep='\t')
label_df = pd.read_csv(data_dir / 'label.csv', sep='\t')

x = pd.merge(query_df, label_df, on='query_id', how='inner')
y = pd.merge(x, product_df, on='product_id', how='inner')
wands_df = y.drop(['id'], axis=1)

In [24]:
print(f'Merged WANDs data shape: {wands_df.shape}')
print(f'Columns: {wands_df.columns.tolist()}')

Merged WANDs data shape: (233448, 13)
Columns: ['query_id', 'query', 'query_class', 'product_id', 'label', 'product_name', 'product_class', 'category hierarchy', 'product_description', 'product_features', 'rating_count', 'average_rating', 'review_count']


In [17]:
def train_test_split(
    df: pd.DataFrame,
    test_size: float = 0.3,
    seed: int = 42
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
    
    qids = df["query_id"].astype(str).unique()
    qids_shuffled = pd.Series(qids).sample(frac=1.0, random_state=seed).tolist()
    cut = int(len(qids_shuffled) * (1 - test_size))

    train_qids = set(qids_shuffled[:cut])
    test_qids = set(qids_shuffled[cut:])
    
    train_df = df[df["query_id"].astype(str).isin(train_qids)].reset_index(drop=True)
    test_df = df[df["query_id"].astype(str).isin(test_qids)].reset_index(drop=True)
    
    assert len(set(train_df["query_id"].astype(str)).intersection(set(test_df["query_id"].astype(str)))) == 0, "Train and test sets have overlapping query_ids"
    
    return train_df, test_df

In [18]:
wands_train_df_raw, wands_test_df_raw = train_test_split(wands_df, test_size=0.3, seed=42)

In [26]:
print(f'WANDS train shape: {wands_train_df_raw.shape}')
print(f'WANDS test shape: {wands_test_df_raw.shape}')

print(f'WANDs train queries count: {wands_train_df_raw["query"].nunique()}')
print(f'WANDs test queries count: {wands_test_df_raw["query"].nunique()}')

WANDS train shape: (162789, 13)
WANDS test shape: (70659, 13)
WANDs train queries count: 336
WANDs test queries count: 144


In [31]:
def build_esci_artifacts(
    df: pd.DataFrame,
    *,
    query_id_col: str = "query_id",
    query_col: str = "query",
    product_id_col: str = "product_id",
    label_col: str = "esci_label",
    metadata_cols: Optional[List[str]] = None,
    keep_irrelevant: bool = False,
    top_k: Optional[int] = 40,
    query_prefix: str = "amz_",
    product_prefix: str = "amz_",
    grading: Optional[Dict[str, float]] = None,
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]], Dict[str, Dict[str, Any]], pd.DataFrame]:
    """
    ESCI -> (qrels_df, qrels_dict, product_store, query_table_df)

    qrels_df: query_id | ranked_product_ids (List[str])  (sorted by gain desc)
    qrels_dict: {query_id: {product_id: gain}} (graded)
    product_store: {product_id: metadata}
    query_table_df: query_id | query

    Default grading (configurable):
      Exact:3, Substitute:2, Complement:1, Irrelevant:0
    """
    if metadata_cols is None:
        metadata_cols = [
            "product_title",
            "product_description",
            "product_bullet_point",
            "product_brand",
            "product_color_name",
        ]
    if grading is None:
        grading = {
            "EXACT": 3.0, "E": 3.0,
            "SUBSTITUTE": 2.0, "S": 2.0,
            "COMPLEMENT": 1.0, "C": 1.0,
            "IRRELEVANT": 0.0, "I": 0.0,
        }

    required = {query_id_col, query_col, product_id_col, label_col}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"ESCI df missing required columns: {sorted(missing)}")

    metadata_cols = [c for c in metadata_cols if c in df.columns]
    work = df[[query_id_col, query_col, product_id_col, label_col] + metadata_cols].copy()

    # Prefix IDs
    work[query_id_col] = query_prefix + work[query_id_col].astype(str)
    work[product_id_col] = product_prefix + work[product_id_col].astype(str)

    # Label -> gain
    work[label_col] = work[label_col].astype(str).str.strip().str.upper()
    work["gain"] = work[label_col].map(grading)
    work = work.dropna(subset=["gain", query_id_col, product_id_col, query_col])

    if not keep_irrelevant:
        work = work[work["gain"] > 0.0]

    # dedupe by (qid,pid) keeping max gain deterministically
    work = (
        work.sort_values(["gain"], ascending=False, kind="mergesort")
            .drop_duplicates(subset=[query_id_col, product_id_col], keep="first")
    )

    # order within query
    work = work.sort_values(
        by=[query_id_col, "gain", product_id_col],
        ascending=[True, False, True],
        kind="mergesort",
    )

    if top_k is not None:
        work = work.groupby(query_id_col, sort=False, as_index=False).head(top_k)

    qrels_df = (
        work.groupby(query_id_col, sort=False)[product_id_col]
            .apply(list)
            .reset_index(name="ranked_product_ids")
            .rename(columns={query_id_col: "query_id"})
    )

    qrels_dict: Dict[str, Dict[str, float]] = {}
    for qid, sub in work.groupby(query_id_col, sort=False):
        qrels_dict[str(qid)] = {str(pid): float(g) for pid, g in zip(sub[product_id_col], sub["gain"])}

    # product_store (dedupe by product_id)
    prod = (
        work.sort_values([product_id_col], kind="mergesort")
            .drop_duplicates(subset=[product_id_col], keep="first")
    )
    product_store: Dict[str, Dict[str, Any]] = {}
    for _, row in prod.iterrows():
        pid = row[product_id_col]
        product_store[pid] = {c: row[c] for c in metadata_cols}

    query_table_df = (
        work[[query_id_col, query_col]]
        .drop_duplicates(subset=[query_id_col], keep="first")
        .rename(columns={query_id_col: "query_id", query_col: "query"})
        .reset_index(drop=True)
    )

    return qrels_df, qrels_dict, product_store, query_table_df

In [32]:
def build_wands_artifacts(
    df: pd.DataFrame,
    *,
    query_id_col: str = "query_id",
    query_col: str = "query",
    product_id_col: str = "product_id",
    label_col: str = "label",
    keep_irrelevant: bool = False,
    top_k: Optional[int] = 40,
    query_prefix: str = "wands_",
    product_prefix: str = "wands_",
    grading: Optional[Dict[str, float]] = None,
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]], Dict[str, Dict[str, Any]], pd.DataFrame]:
    """
    WANDS -> (qrels_df, qrels_dict, product_store, query_table_df)

    Default grading (configurable):
      Exact:2 (or 3), Partial:1, Irrelevant:0
    """
    if grading is None:
        grading = {"EXACT": 2.0, "PARTIAL": 1.0, "IRRELEVANT": 0.0}
    
    # Prefix IDs (collision-proof)
    df[product_id_col] = product_prefix + df[product_id_col].astype(str)
    df[query_id_col] = query_prefix + df[query_id_col].astype(str)

    # product_store: all columns; NaN->None for JSON friendliness
    product_store: Dict[str, Dict[str, Any]] = {}
    for _, r in df.iterrows():
        pid = str(r[product_id_col])
        product_store[pid] = {k: (None if pd.isna(v) else v) for k, v in r.items()}

    query_table_df = (
        df[[query_id_col, query_col]]
        .drop_duplicates(subset=[query_id_col], keep="first")
        .rename(columns={query_id_col: "query_id", query_col: "query"})
        .reset_index(drop=True)
    )

    work = df[[query_id_col, product_id_col, label_col]].copy()
    work[label_col] = work[label_col].astype(str).str.strip().str.upper()
    work["gain"] = work[label_col].map(grading)
    work = work.dropna(subset=["gain", query_id_col, product_id_col])

    if not keep_irrelevant:
        work = work[work["gain"] > 0.0]

    work = (
        work.sort_values(["gain"], ascending=False, kind="mergesort")
            .drop_duplicates(subset=[query_id_col, product_id_col], keep="first")
    )

    work = work.sort_values(
        by=[query_id_col, "gain", product_id_col],
        ascending=[True, False, True],
        kind="mergesort",
    )

    if top_k is not None:
        work = work.groupby(query_id_col, sort=False, as_index=False).head(top_k)

    qrels_df = (
        work.groupby(query_id_col, sort=False)[product_id_col]
            .apply(list)
            .reset_index(name="ranked_product_ids")
            .rename(columns={query_id_col: "query_id"})
    )

    qrels_dict: Dict[str, Dict[str, float]] = {}
    for qid, sub in work.groupby(query_id_col, sort=False):
        qrels_dict[str(qid)] = {str(pid): float(g) for pid, g in zip(sub[product_id_col], sub["gain"])}

    return qrels_df, qrels_dict, product_store, query_table_df

### Build & Save Data artifacts

In [38]:
def build_data_artifacts(
    amz_train_raw: pd.DataFrame,
    amz_test_raw: pd.DataFrame,
    wands_train_raw: pd.DataFrame,
    wands_test_raw: pd.DataFrame
):
    # Amazon ESCI artifacts
    amz_train_qrels_df, amz_train_qrels_dict, amz_train_product_store, amz_train_query_table_df = build_esci_artifacts(amz_train_raw)
    amz_test_qrels_df, amz_test_qrels_dict, amz_test_product_store, amz_test_query_table_df = build_esci_artifacts(amz_test_raw)

    # WANDS artifacts
    wands_train_qrels_df, wands_train_qrels_dict, wands_train_product_store, wands_train_query_table_df = build_wands_artifacts(wands_train_raw)
    wands_test_qrels_df, wands_test_qrels_dict, wands_test_product_store, wands_test_query_table_df = build_wands_artifacts(wands_test_raw)
    
    # Combine product stores for unified access
    product_store = amz_train_product_store | amz_test_product_store | wands_train_product_store | wands_test_product_store
    
    # Combine train and test query tables
    train_qrels_df = pd.concat([amz_train_qrels_df, wands_train_qrels_df], ignore_index=True)
    test_qrels_df = pd.concat([amz_test_qrels_df, wands_test_qrels_df], ignore_index=True)
    
    train_qrels_dict = {**amz_train_qrels_dict, **wands_train_qrels_dict}
    test_qrels_dict = {**amz_test_qrels_dict, **wands_test_qrels_dict}
    
    train_query_table_df = pd.concat([amz_train_query_table_df, wands_train_query_table_df], ignore_index=True)\
                             .drop_duplicates(subset=["query_id"], keep="first")\
                             .reset_index(drop=True)
    test_query_table_df = pd.concat([amz_test_query_table_df, wands_test_query_table_df], ignore_index=True)\
                            .drop_duplicates(subset=["query_id"], keep="first")\
                            .reset_index(drop=True)
    
    return {
        "train_qrels_df": train_qrels_df,
        "test_qrels_df": test_qrels_df,
        "train_qrels_dict": train_qrels_dict,
        "test_qrels_dict": test_qrels_dict,
        "train_query_table_df": train_query_table_df,
        "test_query_table_df": test_query_table_df,
        "product_store": product_store,
    } 

In [39]:
artifacts = build_data_artifacts(
    amz_esci_train_df_raw,
    amz_esci_test_df_raw,
    wands_train_df_raw,
    wands_test_df_raw
)

In [None]:
# Save product store
with open(processed_dir / "product_store.json", "w", encoding="utf-8") as f:
    json.dump(artifacts["product_store"], f, ensure_ascii=False, indent=2)

# Save train qrels dict
with open(processed_dir / "train_qrels.json", "w", encoding="utf-8") as f:
    json.dump(artifacts["train_qrels_dict"], f, ensure_ascii=False, indent=2)

with open(processed_dir / "test_qrels.json", "w", encoding="utf-8") as f:
    json.dump(artifacts["test_qrels_dict"], f, ensure_ascii=False, indent=2)

In [None]:
# Save qrels and query tables as parquet
artifacts['train_qrels_df'].to_parquet(
    processed_dir / "train_qrels.parquet", 
    index=False, compression='gzip'
    )
artifacts['test_qrels_df'].to_parquet(
    processed_dir / "test_qrels.parquet", 
    index=False, compression='gzip'
    )

artifacts['train_query_table_df'].to_parquet(
    processed_dir / "train_query_table.parquet", 
    index=False, compression='gzip'
    )
artifacts['test_query_table_df'].to_parquet(
    processed_dir / "test_query_table.parquet", 
    index=False, compression='gzip'
    )