In [1]:
import os
import pickle
import re
import sys

sys.path.append('../..')

import numpy as np
import pandas as pd
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer

from source.filters import ConsecutiveDuplicatesFilter

In [2]:
load_dotenv()

True

In [3]:
AMAZON_M2_DATA_DIR = os.environ['AMAZON_M2_DATA_DIR']
!ls {AMAZON_M2_DATA_DIR}

19dd45a8-5f0c-4c95-bf60-506398327251_kdd-2023-ground-truth.zip
ground_truth
products_train.csv
sessions_test_task2.csv
sessions_test_task2_phase1.csv
sessions_train.csv


In [4]:
def load_amazon_m2_data(filepath: str, locale: str = 'UK') -> pd.DataFrame:
    data = pd.read_csv(filepath)
    data = data[data.locale == locale]
    print(f'{data.shape = }')
    return data

In [5]:
train_sessions = load_amazon_m2_data(
    os.path.join(AMAZON_M2_DATA_DIR, 'sessions_train.csv'), locale='FR'
)
train_sessions.head()

data.shape = (117561, 3)


Unnamed: 0,prev_items,next_item,locale
3361763,['B07HB3B78C' 'B07HB3B78C' 'B07HB3B78C' 'B089L...,B0924SNP46,FR
3361764,['B00I3P3BRS' 'B00I3P3DXK' 'B086C24G9G' 'B00I3...,B00I3P3CQI,FR
3361765,['B07VFT4D6B' 'B08CKYFLHQ'],B076Q6P3TW,FR
3361766,['B07H4Q5LFH' 'B0B7S7LBMB' 'B094R3R9XH' 'B0B7S...,B0B5ND342Y,FR
3361767,['B0BD5GYKXP' 'B0B7BBS7D7'],B0B5XH75H7,FR


In [None]:
test_sessions = load_amazon_m2_data(
    os.path.join(AMAZON_M2_DATA_DIR, 'sessions_test_task2_phase1.csv'), locale='FR'
)
test_sessions.head()

data.shape = (12520, 2)


Unnamed: 0,prev_items,locale
8176,['B0BJJZG64G' 'B06W55K9N6' 'B0BJJZG64G'],FR
8177,['B09FY9DN26' 'B09RNG5QRL' 'B09RNGP5G6'],FR
8178,['B07W5JSTNH' 'B0B75HYH3Z' 'B0B75HYH3Z' 'B07W5...,FR
8179,['B08HVF83S3' 'B09KMWXFZ3'],FR
8180,['B0B2PLC1WZ' 'B07CJ9K9YN'],FR


In [7]:
ground_truth = load_amazon_m2_data(
    os.path.join(AMAZON_M2_DATA_DIR, 'ground_truth/phase1/gt_task2.csv'), locale='FR'
)
ground_truth.head()

data.shape = (12520, 2)


Unnamed: 0,next_item,locale
8176,B0713WPGLL,FR
8177,B08TTS7VP3,FR
8178,B0BFTWDD1N,FR
8179,B09CGWVRTF,FR
8180,B0B3RFCH32,FR


In [8]:
def parse_session(text: str) -> list[str]:
    return re.findall(r"'(.+?)'", text)


parse_session("['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9Z']")

['B00R9R5ND6', 'B00R9RZ9ZS', 'B00R9RZ9Z']

In [9]:
train_sessions.prev_items = train_sessions.prev_items.map(parse_session)
train_sessions.head(1)

Unnamed: 0,prev_items,next_item,locale
3361763,"[B07HB3B78C, B07HB3B78C, B07HB3B78C, B089LDLP8W]",B0924SNP46,FR


In [10]:
train_val_interactions = (
    (train_sessions.prev_items + train_sessions.next_item.map(lambda item: [item]))
    .explode()
    .reset_index()
)
train_val_interactions.columns = ('user_id', 'item_id')
train_val_interactions['timestamp'] = train_val_interactions.groupby('user_id').cumcount()
print(f'{train_val_interactions.shape = }')
train_val_interactions.head()

train_val_interactions.shape = (534358, 3)


Unnamed: 0,user_id,item_id,timestamp
0,3361763,B07HB3B78C,0
1,3361763,B07HB3B78C,1
2,3361763,B07HB3B78C,2
3,3361763,B089LDLP8W,3
4,3361763,B0924SNP46,4


In [None]:
train_val_interactions = ConsecutiveDuplicatesFilter()(train_val_interactions)
print(f'{train_val_interactions.shape = }')
train_val_interactions.head()

train_val_interactions.shape = (506928, 3)


Unnamed: 0,user_id,item_id,timestamp
0,3361763,B07HB3B78C,0
1,3361763,B089LDLP8W,3
2,3361763,B0924SNP46,4
3,3361764,B00I3P3BRS,0
4,3361764,B00I3P3DXK,1


In [12]:
train_users = (
    train_val_interactions.user_id.drop_duplicates().sort_values().sample(frac=0.8, random_state=42)
)

condition = train_val_interactions.user_id.isin(train_users)

train_interactions = train_val_interactions[condition]
val_interactions = train_val_interactions[~condition]

len(train_interactions), len(val_interactions)

(405649, 101279)

In [13]:
# Encode warm item IDs
WARM_ITEMS = train_interactions.item_id.sort_values().unique()
print(f'{len(WARM_ITEMS) = }')

# Keep zero for padding
item2index_warm = {item: index + 1 for index, item in enumerate(WARM_ITEMS)}

train_interactions.loc[:, 'item_id'] = train_interactions.item_id.map(item2index_warm)

# Filter out cold items from validation subset
val_interactions.loc[:, 'item_id'] = val_interactions.item_id.map(item2index_warm)
val_interactions = val_interactions[val_interactions.item_id.isin(WARM_ITEMS)]

len(WARM_ITEMS) = 42647


In [14]:
test_sessions.prev_items = test_sessions.prev_items.map(parse_session)
test_sessions.head(1)

Unnamed: 0,prev_items,locale
8176,"[B0BJJZG64G, B06W55K9N6, B0BJJZG64G]",FR


In [15]:
test_interactions = test_sessions.explode('prev_items').reset_index().drop('locale', axis=1)
test_interactions.columns = ('user_id', 'item_id')
test_interactions['timestamp'] = test_interactions.groupby('user_id').cumcount()
print(f'{test_interactions.shape = }')
test_interactions.head()

test_interactions.shape = (48143, 3)


Unnamed: 0,user_id,item_id,timestamp
0,8176,B0BJJZG64G,0
1,8176,B06W55K9N6,1
2,8176,B0BJJZG64G,2
3,8177,B09FY9DN26,0
4,8177,B09RNG5QRL,1


In [16]:
ground_truth = ground_truth.reset_index().drop('locale', axis=1)
ground_truth.columns = ('user_id', 'item_id')
ground_truth.head()

Unnamed: 0,user_id,item_id
0,8176,B0713WPGLL
1,8177,B08TTS7VP3
2,8178,B0BFTWDD1N
3,8179,B09CGWVRTF
4,8180,B0B3RFCH32


In [17]:
test_interactions['is_cold'] = ~test_interactions.item_id.isin(WARM_ITEMS)
test_interactions.is_cold.value_counts(normalize=True)

is_cold
False    0.923914
True     0.076086
Name: proportion, dtype: float64

In [18]:
ground_truth['is_cold'] = ~ground_truth.item_id.isin(WARM_ITEMS)
ground_truth.is_cold.value_counts(normalize=True)

is_cold
False    0.929633
True     0.070367
Name: proportion, dtype: float64

In [19]:
bias = max(item2index_warm.values()) + 1
bias

COLD_ITEMS = (
    pd.concat(
        [
            test_interactions[test_interactions.is_cold].item_id,
            ground_truth[ground_truth.is_cold].item_id,
        ]
    )
    .sort_values()
    .unique()
)
print(f'{len(COLD_ITEMS) = }')

item2index_cold = {item: index + bias for index, item in enumerate(sorted(COLD_ITEMS))}

test_interactions.loc[:, 'item_id'] = test_interactions.item_id.map(
    {**item2index_warm, **item2index_cold}
)
ground_truth.loc[:, 'item_id'] = ground_truth.item_id.map({**item2index_warm, **item2index_cold})

len(COLD_ITEMS) = 1402


In [20]:
# Save processed data
OUTPUT_DIR = '../../data/amazon_m2'

if not os.path.exists(os.path.join(OUTPUT_DIR, 'processed')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'processed'))

train_interactions.to_csv(os.path.join(OUTPUT_DIR, 'processed/train_data.csv'))
val_interactions.to_csv(os.path.join(OUTPUT_DIR, 'processed/val_data.csv'))
test_interactions.to_csv(os.path.join(OUTPUT_DIR, 'processed/test_inputs.csv'))
ground_truth.to_csv(os.path.join(OUTPUT_DIR, 'processed/test_target.csv'))

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_warm.pkl'), mode='wb') as file:
    pickle.dump(item2index_warm, file)

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_cold.pkl'), mode='wb') as file:
    pickle.dump(item2index_cold, file)

In [21]:
products = load_amazon_m2_data(os.path.join(AMAZON_M2_DATA_DIR, 'products_train.csv'), locale='FR')
products.head(1)

data.shape = (44577, 11)


Unnamed: 0,id,locale,title,price,brand,color,size,model,material,author,desc
1456019,B08B86CXL3,FR,"Coolreall Batterie Externe 10000mAh, Slim Powe...",18.99,Coolreall,Vert 10000mah,,PB125-C3,,,✔Avantage de sécurité fiable: la coque en plas...


In [22]:
COLUMNS = ['title', 'brand', 'color', 'size', 'model', 'material', 'author']

metadata = ''

for column in COLUMNS:
    metadata = metadata + (f'; {column}: ' + products[column]).fillna('')

metadata = metadata.str.lstrip('; ')
metadata.index = products['id']
metadata.index.name = 'item_id'
metadata.head()

item_id
B08B86CXL3    title: Coolreall Batterie Externe 10000mAh, Sl...
B08DQYCBGM    title: Tatouage Halloween, 6 pièces Le jour de...
B086VPVVG7    title: TOMY Ricky Zoom - Lumières Et Sons Rick...
B079NS6M8V    title: PLAYMOBIL 9318 Aventure au Camping - Fa...
B091HGL8YJ    title: Lecteur MP3 Bluetooth 5.0, MECHEN 2.4" ...
dtype: object

In [23]:
print(metadata.iloc[42])

title: 9M Double Face Extra Fort, 3 Pièces Double Face Transparent, Ruban Adhésif Double Face Puissant Nano Tape Sans Percer Réutilisable Multifonctionnel Lavable Pour Tapis Voitures(3M*3CM+3M*2CM+3M*1CM); brand: YOGINGO; color: 3 Pièces (3M*3CM+3M*2CM+3M*1CM); material: Polyuréthane


In [24]:
metadata = metadata[metadata.index.isin({**item2index_warm, **item2index_cold})]
metadata.index = metadata.index.map({**item2index_warm, **item2index_cold})
metadata = metadata.sort_index()
metadata.head()

item_id
1    title: A Song of Ice and Fire; brand: Brand: H...
2    title: A Song of Ice and Fire; brand: Harper C...
3    title: Celebrating the Graphic Design Studio B...
4    title: The Curious Incident of the Dog in the ...
5    title: Lot 2 recharges de gaz 300 ml pour briq...
dtype: object

In [25]:
assert len(metadata) == len(item2index_warm) + len(item2index_cold)

In [26]:
CHECKPOINT = 'intfloat/multilingual-e5-base'

model = SentenceTransformer(CHECKPOINT)
item_embeddings = model.encode(
    metadata.to_list(),
    batch_size=512,
    show_progress_bar=True,
    normalize_embeddings=False,
)
item_embeddings.shape

Batches:   0%|          | 0/87 [00:00<?, ?it/s]

(44049, 768)

In [27]:
warm_item_embeddings = item_embeddings[: len(WARM_ITEMS)]
cold_item_embeddings = item_embeddings[len(WARM_ITEMS) :]

In [29]:
# Save item embeddings
if not os.path.exists(os.path.join(OUTPUT_DIR, 'item_embeddings')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'item_embeddings'))

np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_warm.npy'), warm_item_embeddings)
np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_cold.npy'), cold_item_embeddings)