In [1]:
import sys
sys.path.append('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/')
import pandas as pd 
import numpy as np 
import datasets
from datasets import Dataset as TFDataset 
import pickle
from bm25.rank_bm25 import BM25Okapi
import random
# import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
from transformers import PreTrainedTokenizer, AutoTokenizer
import multiprocessing

In [2]:
def cast_dtype(df : pd.DataFrame, columns=None):
    if columns is None:
        columns = df.columns
    for k in columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def tokenize_function(examples, corpus_col_name, tokenizer, max_length):
    if corpus_col_name in examples:
        return tokenizer(examples[corpus_col_name], 
            add_special_tokens=False, # don't add special tokens when preprocess
            truncation=True, 
            max_length=max_length,
            return_attention_mask=False,
            return_token_type_ids=False)

In [4]:
def construct_query_list_from_sessions(sessions_df:pd.DataFrame, product_map:dict, max_seq_len:int, product_corpus:list):
    query_list = []
    for sess in tqdm(sessions_df.itertuples(), total=sessions_df.shape[0]):
        sess_locale = sess.locale
        prev_items = eval(sess.prev_items.replace(' ', ','))[-max_seq_len : ]
        prev_items = [product_map.get(sess_locale+'_'+item, 0) for item in prev_items]
        sess_query = sum([product_corpus[item] for item in prev_items], [])
        query_list.append(sess_query)
    return query_list

# Merge valid BM25 score

In [5]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_feature.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/products_train.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv'

In [6]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)


In [7]:
merged_candidates_feature = read_merged_candidates_feature()
product_data = read_product_data()
valid_sessions = read_valid_sessions()

In [8]:
TOKENIZER_NAME = 'xlm-roberta-base'
MAX_FEAT_LEN = 200

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
        TOKENIZER_NAME,
        use_fast=False,
)

In [10]:
feat_corpus = product_data[['brand', 'color', 'size', 'model', 'material', 'author']]
padding_df = pd.DataFrame({'brand' : [''], 'color' : [''], 'size' : [''], 'model' : [''], 'material' : [''], 'author' : ['']})
feat_corpus = pd.concat([padding_df, feat_corpus]).reset_index(drop=True) # add padding product
feat_corpus['brand'] = feat_corpus['brand'].fillna('')
feat_corpus['color'] = feat_corpus['color'].fillna('')
feat_corpus['size'] = feat_corpus['size'].fillna('')
feat_corpus['model'] = feat_corpus['model'].fillna('')
feat_corpus['material'] = feat_corpus['material'].fillna('')
feat_corpus['author'] = feat_corpus['author'].fillna('')
feat_corpus['feat'] = feat_corpus['brand'] + ' ' + feat_corpus['color'] + ' ' + feat_corpus['size'] + ' ' + feat_corpus['model'] \
        + ' ' + feat_corpus['material'] + ' ' + feat_corpus['author']
feat_corpus['feat'] = feat_corpus['feat'].apply(lambda x : x.lower())


In [11]:
feat_corpus = TFDataset.from_pandas(feat_corpus, preserve_index=False)
feat_corpus = feat_corpus.map(partial(tokenize_function, corpus_col_name='feat', tokenizer=tokenizer, max_length=MAX_FEAT_LEN), 
                                num_proc=8, remove_columns=['feat'], batched=True)
feat_corpus_list = feat_corpus['input_ids']



In [12]:
with open('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/bm25/cache/feat_bm25.pkl', 'rb') as f:
    feat_BM25 = pickle.load(f)

In [13]:
merged_candidates = merged_candidates_feature[['sess_id', 'sess_locale', 'product']]
merged_candidates

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,355165591X
1,0,DE,3833237058
2,0,DE,B00CIXSI6U
3,0,DE,B00NVDOWUW
4,0,DE,B00NVDP3ZU
...,...,...,...
78842194,261815,UK,B0BCX524Y6
78842195,261815,UK,B0BCX6QB4L
78842196,261815,UK,B0BFPJYXQL
78842197,261815,UK,B0BH3X67S3


In [14]:
product_index = product_data[['id', 'locale']]
product_index['product_index'] = product_index.index + 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  product_index['product_index'] = product_index.index + 1


In [None]:
merged_candidates = merged_candidates.merge(product_index, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates['product_index'] = merged_candidates['product_index'].fillna(0)
merged_candidates['product_index'] = merged_candidates['product_index'].astype('int64')
merged_candidates.drop(columns=['id', 'locale'], inplace=True)
assert len(merged_candidates) == len(merged_candidates_feature)
merged_candidates_grouped = merged_candidates.groupby(by='sess_id')['product_index'].apply(list)

In [17]:
locale_product_map = {}
for row in tqdm(product_data.itertuples(), total=product_data.shape[0]):
    locale_product_map[row.locale+'_'+row.id] = row.Index + 1

100%|██████████| 1551057/1551057 [00:07<00:00, 197882.86it/s]


In [18]:
# construct query list 
valid_query_list = construct_query_list_from_sessions(valid_sessions, locale_product_map, 5, product_corpus=feat_corpus_list)

100%|██████████| 261816/261816 [00:10<00:00, 24694.72it/s]


In [19]:
def get_sess_scores(sess):
    sess_id = sess['sess_id']
    scores = feat_BM25.get_batch_scores(valid_query_list[sess_id], merged_candidates_grouped[sess_id])
    return {'sess_bm25_scores' : scores}

In [57]:
scores = feat_BM25.get_scores(valid_query_list[150102])
top_n = np.argsort(scores)[::-1][:5]

In [58]:
top_n

array([1069719,  970734,   19560, 1244867, 1183103])

In [42]:
product_index.query("id=='B004FG87G4' & locale=='UK'")

Unnamed: 0,id,locale,product_index
1329918,B004FG87G4,UK,1329919


In [36]:
tokenizer.decode(valid_query_list[250102])

'tlarder black s aluminium applaws 70g 70 g (pack of 12) mm1019net'

In [37]:
tokenizer.decode(feat_corpus[1275708]['input_ids'])

'applaws 70 g (pack of 12) 9100946'

In [39]:
tokenizer.decode(feat_corpus[1164316]['input_ids'])

'applaws 70 g (pack of 12) 1017ml-ac'

In [41]:
tokenizer.decode(feat_corpus[1033024]['input_ids'])


'applaws 70 g (pack of 12) 8001ml-a'

In [43]:
tokenizer.decode(feat_corpus[1329919]['input_ids'])


'applaws adult 70 g (pack of 24) 9104464'

In [20]:
# about 23 mins
datasets.set_progress_bar_enabled(False)
valid_query_dataset = TFDataset.from_dict({'sess_id' : list(range(len(valid_query_list)))})
valid_query_dataset = valid_query_dataset.map(get_sess_scores, num_proc=20, batched=False)
datasets.set_progress_bar_enabled(True)

In [21]:
valid_scores_list = valid_query_dataset['sess_bm25_scores']

In [22]:
merged_bm25_scores = []
for scores_set in tqdm(valid_scores_list):
    for s in scores_set:
        merged_bm25_scores.append(s)
assert len(merged_bm25_scores) == len(merged_candidates)
assert len(merged_bm25_scores) == len(merged_candidates_feature)

100%|██████████| 261816/261816 [00:24<00:00, 10877.77it/s]


In [24]:
merged_candidates_feature['feat_BM25_scores'] = merged_bm25_scores

In [67]:
cast_dtype(merged_candidates_feature, ['feat_BM25_scores'])
merged_candidates_feature.to_parquet(merged_candidates_feature_path)

In [44]:
merged_candidates_feature.query('sess_id==250102').sort_values(by=['feat_BM25_scores'], ascending=False)[['sess_id', 'sess_locale', 'product', 'feat_BM25_scores', 'normalized_sasrec_scores_2']][:25]

Unnamed: 0,sess_id,sess_locale,product,feat_BM25_scores,normalized_sasrec_scores_2
66317107,220102,UK,B078748C4T,88.072678,0.2013201
66317112,220102,UK,B079YW4SD3,77.730469,3.801024e-05
66317059,220102,UK,B01MT6DAG4,63.776562,4.546821e-06
66317040,220102,UK,B00WZRQNAC,59.516838,0.04448581
66317111,220102,UK,B079YSXCD8,58.733582,0.009240651
66317027,220102,UK,B00P1HCBI6,57.892109,1.082161e-05
66317246,220102,UK,B09BDBYFWQ,56.900654,0.02199206
66317094,220102,UK,B074YT5FL6,55.936466,0.001292779
66317180,220102,UK,B081SWBSFD,55.764687,1.39777e-09
66317212,220102,UK,B08MRMLTKQ,55.249546,0.0004006482


In [26]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,product_index
0,0,DE,355165591X,127299
1,0,DE,3833237058,248194
2,0,DE,B00CIXSI6U,115708
3,0,DE,B00NVDOWUW,55370
4,0,DE,B00NVDP3ZU,100302
...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,982599
78842195,261815,UK,B0BCX6QB4L,1046447
78842196,261815,UK,B0BFPJYXQL,1187691
78842197,261815,UK,B0BH3X67S3,983386


In [27]:
merged_candidates_feature

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,product_freq,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,normalized_sasrec_scores_2,seqmlp_scores,normalized_seqmlp_scores,narm_scores,normalized_narm_scores,gru4rec_scores_2,normalized_gru4rec_scores_2,title_BM25_scores,desc_BM25_scores
0,0,DE,355165591X,0.0,43.256542,8.990000,51.0,2.230508,7.658405e-09,0.512931,1.377575e-09,6.044256,3.054954e-09,3.628220,1.469988e-13,5.093824,1.288155e-10,0.000000,0.455206
1,0,DE,3833237058,0.0,43.256542,22.000000,84.0,9.605231,1.221631e-05,9.325538,9.255110e-06,10.732503,3.319589e-07,15.734776,2.661486e-08,13.133082,3.993682e-07,89.696823,0.904280
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,7.0,0.714114,1.681035e-09,-0.115904,7.345399e-10,4.902086,9.749146e-10,7.462822,6.802375e-12,4.295491,5.797709e-11,0.000000,0.000000
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,166.0,8.750996,5.199363e-06,8.507557,4.084482e-06,13.401752,4.789881e-06,9.774860,6.866981e-11,10.362890,2.502041e-08,180.035004,252.282623
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,502.0,8.056712,2.596729e-06,5.898870,3.007453e-07,9.162767,6.908073e-08,10.392188,1.273116e-10,12.687778,2.558471e-07,148.722244,218.617142
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,7.0,6.813615,1.076201e-03,7.203015,4.597607e-04,6.808200,4.075167e-04,7.742151,7.126370e-04,7.978776,9.559332e-04,95.027313,22.518597
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,53.0,9.030836,9.881445e-03,10.123234,8.526421e-03,10.454871,1.562696e-02,10.198983,8.314902e-03,10.432762,1.112192e-02,95.027313,22.518597
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,7.0,0.796892,2.623396e-06,1.711608,1.895152e-06,-0.960521,1.722791e-07,2.838342,5.286535e-06,2.107661,2.695469e-06,177.093521,130.335434
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,38.0,4.250781,8.296004e-05,6.447586,2.159998e-04,5.091795,7.323526e-05,6.170062,1.479513e-04,7.305784,4.876977e-04,33.781818,43.050674
