In [1]:
import os
import pickle
import sys

sys.path.append('../..')

import numpy as np
import polars as pl
from dotenv import load_dotenv
from replay.preprocessing.filters import MinCountFilter
from sentence_transformers import SentenceTransformer

In [2]:
load_dotenv()

True

In [3]:
BEAUTY_DATA_DIR = os.getenv('BEAUTY_DATA_DIR')
!ls {BEAUTY_DATA_DIR}

meta_Beauty.parquet  reviews_Beauty_5.parquet


In [4]:
interactions = pl.read_parquet(os.path.join(BEAUTY_DATA_DIR, 'reviews_Beauty_5.parquet'))
print(f'{interactions.shape = }')
interactions.head()

interactions.shape = (198502, 9)


reviewerID,asin,reviewerName,helpful,reviewText,overall,summary,unixReviewTime,reviewTime
str,str,str,list[i64],str,i64,str,i64,str
"""A1YJEY40YUW4SE""","""7806397051""","""Andrea""","[3, 4]","""Very oily and creamy. Not at a…",1,"""Don't waste your money""",1391040000,"""01 30, 2014"""
"""A60XNB876KYML""","""7806397051""","""Jessica H.""","[1, 1]","""This palette was a decent pric…",3,"""OK Palette!""",1397779200,"""04 18, 2014"""
"""A3G6XNM240RMWA""","""7806397051""","""Karen""","[0, 1]","""The texture of this concealer …",4,"""great quality""",1378425600,"""09 6, 2013"""
"""A1PQFP6SAJ6D80""","""7806397051""","""Norah""","[2, 2]","""I really can't tell what exact…",2,"""Do not work on my face""",1386460800,"""12 8, 2013"""
"""A38FVHZTNQ271F""","""7806397051""","""Nova Amor""","[0, 0]","""It was a little smaller than I…",3,"""It's okay.""",1382140800,"""10 19, 2013"""


In [5]:
# Select columns
interactions = interactions.select(
    pl.col('reviewerID').alias('user_id'),
    pl.col('asin').alias('item_id'),
    pl.col('overall').alias('rating'),
    pl.col('unixReviewTime').alias('timestamp'),
)
interactions.head(1)

user_id,item_id,rating,timestamp
str,str,i64,i64
"""A1YJEY40YUW4SE""","""7806397051""",1,1391040000


In [6]:
# Keep positive interactions only
MIN_RATING = 3

interactions = interactions.filter(pl.col('rating') > MIN_RATING).drop('rating')
len(interactions)

154272

In [7]:
interactions.select(pl.col('user_id', 'item_id').n_unique())

user_id,item_id
u32,u32
22269,12086


In [8]:
# Split data into train/val and test subsets
TIME_THRESHOLD_QUANTILE = 0.9
TIME_THRESHOLD = interactions.get_column('timestamp').quantile(TIME_THRESHOLD_QUANTILE)

condition = pl.col('timestamp') > TIME_THRESHOLD

train_val_interactions = interactions.filter(~condition)
test_interactions = interactions.filter(condition)

len(train_val_interactions), len(test_interactions)

(138899, 15373)

In [9]:
# Apply N-core filtering to train/val data
N = 3
shape = None

while shape != train_val_interactions.shape:
    shape = train_val_interactions.shape

    train_val_interactions = MinCountFilter(N, groupby_column='user_id').transform(
        train_val_interactions
    )
    train_val_interactions = MinCountFilter(N, groupby_column='item_id').transform(
        train_val_interactions
    )

len(train_val_interactions)

133780

In [10]:
# Assign tail interactions back to test users
test_users = test_interactions.get_column('user_id').unique()
test_interactions = pl.concat(
    [test_interactions, train_val_interactions.filter(pl.col('user_id').is_in(test_users))]
)
len(test_interactions)

46727

In [11]:
# Split data into train and validation subsets
TRAIN_USERS_FRACTION = 0.9
TRAIN_USERS = (
    train_val_interactions.get_column('user_id')
    .unique()
    .sort()
    .sample(fraction=TRAIN_USERS_FRACTION, seed=42)
)

condition = pl.col('user_id').is_in(TRAIN_USERS)

train_interactions = train_val_interactions.filter(condition)
val_interactions = train_val_interactions.filter(~condition)

len(train_interactions), len(val_interactions)

(120122, 13658)

In [12]:
# Encode warm item IDs
WARM_ITEMS = train_interactions.get_column('item_id').unique().sort()
print(f'{len(WARM_ITEMS) = }')

# Keep zero for padding
item2index_warm = {item: index + 1 for index, item in enumerate(WARM_ITEMS)}

train_interactions = train_interactions.with_columns(
    pl.col('item_id').replace_strict(item2index_warm)
)

# Filter out cold items from validation subset
val_interactions = val_interactions.filter(pl.col('item_id').is_in(WARM_ITEMS))
val_interactions = val_interactions.with_columns(pl.col('item_id').replace_strict(item2index_warm))

len(WARM_ITEMS) = 11165


In [13]:
# Encode cold item IDs
test_interactions = test_interactions.with_columns(
    ~pl.col('item_id').is_in(WARM_ITEMS).alias('is_cold')
)

COLD_ITEMS = test_interactions.filter(pl.col('is_cold')).get_column('item_id').unique().sort()
print(f'{len(COLD_ITEMS) = }')

bias = max(item2index_warm.values()) + 1
item2index_cold = {item: index + bias for index, item in enumerate(COLD_ITEMS)}

test_interactions = test_interactions.with_columns(
    pl.col('item_id').replace_strict({**item2index_warm, **item2index_cold})
)

len(COLD_ITEMS) = 568


In [14]:
# Separate ground truth from test interactions
test_interactions = test_interactions.with_columns(
    pl.col('user_id')
    .cum_count(reverse=True)
    .over('user_id', order_by='timestamp')
    .alias('position')
)

condition = pl.col('position') == 1

ground_truth = test_interactions.filter(condition)
test_interactions = test_interactions.filter(~condition)

In [15]:
len(train_interactions), len(val_interactions), len(test_interactions), len(ground_truth)

(120122, 13649, 41566, 5161)

In [16]:
# Apply some basic checks
assert (train_interactions.get_column('timestamp') <= TIME_THRESHOLD).all()
assert (val_interactions.get_column('timestamp') <= TIME_THRESHOLD).all()
assert (ground_truth.get_column('timestamp') > TIME_THRESHOLD).all()

In [17]:
# Save processed data
OUTPUT_DIR = '../../data/beauty'

if not os.path.exists(os.path.join(OUTPUT_DIR, 'processed')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'processed'))

train_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/train_interactions.parquet'))
val_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/val_interactions.parquet'))
test_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/test_interactions.parquet'))
ground_truth.write_parquet(os.path.join(OUTPUT_DIR, 'processed/ground_truth.parquet'))

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_warm.pkl'), mode='wb') as file:
    pickle.dump(item2index_warm, file)

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_cold.pkl'), mode='wb') as file:
    pickle.dump(item2index_cold, file)

In [18]:
metadata = pl.read_parquet(os.path.join(BEAUTY_DATA_DIR, 'meta_Beauty.parquet'))
print(f'{metadata.shape = }')
metadata.head(1)

metadata.shape = (259204, 9)


asin,description,title,imUrl,salesRank,categories,price,related,brand
str,str,str,str,struct[29],list[list[str]],f64,struct[4],str
"""0205616461""","""As we age, our once youthful, …","""Bio-Active Anti-Aging Serum (F…","""http://ecx.images-amazon.com/i…","{null,null,null,null,null,null,null,null,null,null,null,461765,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null}","[[""Beauty"", ""Skin Care"", … ""Creams & Moisturizers""]]",,"{null,null,null,null}",


In [19]:
# Process metadata
metadata = (
    metadata.rename({'asin': 'item_id'})
    .filter(pl.col('item_id').is_in(pl.concat((WARM_ITEMS, COLD_ITEMS))))
    .with_columns(pl.col('item_id').replace_strict({**item2index_warm, **item2index_cold}))
    .sort('item_id')
)
print(f'{metadata.shape = }')
metadata.head(1)

metadata.shape = (11733, 9)


item_id,description,title,imUrl,salesRank,categories,price,related,brand
i64,str,str,str,struct[29],list[list[str]],f64,struct[4],str
1,"""Prada Candy By Prada Eau De Pa…","""Prada Candy By Prada Eau De Pa…","""http://ecx.images-amazon.com/i…","{null,null,null,78916,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null}","[[""Beauty"", ""Fragrance"", … ""Eau de Parfum""]]",65.86,"{[""B006C5OHSI"", ""B006P14842"", … ""9788072208""],[""B0072CSVB4"", ""B005YWBOHW"", … ""B000C1Z3LS""],[""B006C5OHSI"", ""B006P14842""],null}","""Prada"""


In [20]:
assert len(metadata) == len(item2index_warm) + len(item2index_cold)

In [21]:
CHECKPOINT = 'intfloat/e5-base-v2'

model = SentenceTransformer(CHECKPOINT)

item_embeddings = model.encode(
    metadata.get_column('title').to_list(),
    batch_size=512,
    show_progress_bar=True,
    normalize_embeddings=False,
)
item_embeddings.shape

Batches:   0%|          | 0/23 [00:00<?, ?it/s]

(11733, 768)

In [22]:
warm_item_embeddings = item_embeddings[: len(WARM_ITEMS)]
cold_item_embeddings = item_embeddings[len(WARM_ITEMS) :]

In [23]:
# Save item embeddings
if not os.path.exists(os.path.join(OUTPUT_DIR, 'item_embeddings')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'item_embeddings'))

np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_warm.npy'), warm_item_embeddings)
np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_cold.npy'), cold_item_embeddings)