In [1]:
import sys
sys.path.append('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/')
import pandas as pd 
import numpy as np 
import datasets
from datasets import Dataset as TFDataset 
import pickle
from bm25.rank_bm25 import BM25Okapi
import random
# import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
from transformers import PreTrainedTokenizer, AutoTokenizer
import multiprocessing

In [42]:
def cast_dtype(df : pd.DataFrame, columns=None):
    if columns is None:
        columns = df.columns
    for k in columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def tokenize_function(examples, corpus_col_name, tokenizer, max_length):
    if corpus_col_name in examples:
        return tokenizer(examples[corpus_col_name], 
            add_special_tokens=False, # don't add special tokens when preprocess
            truncation=True, 
            max_length=max_length,
            return_attention_mask=False,
            return_token_type_ids=False)

In [4]:
def construct_query_list_from_sessions(sessions_df:pd.DataFrame, product_map:dict, max_seq_len:int, product_corpus:list):
    query_list = []
    for sess in tqdm(sessions_df.itertuples(), total=sessions_df.shape[0]):
        sess_locale = sess.locale
        prev_items = eval(sess.prev_items.replace(' ', ','))[-max_seq_len : ]
        prev_items = [product_map.get(sess_locale+'_'+item, 0) for item in prev_items]
        sess_query = sum([product_corpus[item] for item in prev_items], [])
        query_list.append(sess_query)
    return query_list

# Merge test BM25 score

In [5]:
merged_candidates_feature_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_test_feature.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/products_train.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1_phase2.csv'

In [6]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature_test():
    return pd.read_parquet(merged_candidates_feature_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)


In [7]:
merged_candidates_feature_test = read_merged_candidates_feature_test()
product_data = read_product_data()
test_sessions = read_test_sessions()

In [8]:
TOKENIZER_NAME = 'xlm-roberta-base'
MAX_FEAT_LEN = 200

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
        TOKENIZER_NAME,
        use_fast=False,
)

In [10]:
feat_corpus = product_data[['brand', 'color', 'size', 'model', 'material', 'author']]
padding_df = pd.DataFrame({'brand' : [''], 'color' : [''], 'size' : [''], 'model' : [''], 'material' : [''], 'author' : ['']})
feat_corpus = pd.concat([padding_df, feat_corpus]).reset_index(drop=True) # add padding product
feat_corpus['brand'] = feat_corpus['brand'].fillna('')
feat_corpus['color'] = feat_corpus['color'].fillna('')
feat_corpus['size'] = feat_corpus['size'].fillna('')
feat_corpus['model'] = feat_corpus['model'].fillna('')
feat_corpus['material'] = feat_corpus['material'].fillna('')
feat_corpus['author'] = feat_corpus['author'].fillna('')
feat_corpus['feat'] = feat_corpus['brand'] + ' ' + feat_corpus['color'] + ' ' + feat_corpus['size'] + ' ' + feat_corpus['model'] \
        + ' ' + feat_corpus['material'] + ' ' + feat_corpus['author']
feat_corpus['feat'] = feat_corpus['feat'].apply(lambda x : x.lower())

In [11]:
feat_corpus = TFDataset.from_pandas(feat_corpus, preserve_index=False)
feat_corpus = feat_corpus.map(partial(tokenize_function, corpus_col_name='feat', tokenizer=tokenizer, max_length=MAX_FEAT_LEN), 
                                num_proc=8, remove_columns=['feat'], batched=True)
feat_corpus_list = feat_corpus['input_ids']



In [12]:
with open('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/bm25/cache/feat_bm25.pkl', 'rb') as f:
    feat_BM25 = pickle.load(f)

In [13]:
merged_candidates = merged_candidates_feature_test[['sess_id', 'sess_locale', 'product']]
merged_candidates

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,B000Q87D0Q
1,0,DE,B000QB30DW
2,0,DE,B004BIG55Q
3,0,DE,B0053FTNQY
4,0,DE,B007QWII1S
...,...,...,...
96556030,316971,UK,B0B82N3CQQ
96556031,316971,UK,B0BB9NW3F3
96556032,316971,UK,B0BDMVKTQ3
96556033,316971,UK,B0BHW1D5VP


In [14]:
product_index = product_data[['id', 'locale']]
product_index['product_index'] = product_index.index + 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  product_index['product_index'] = product_index.index + 1


In [15]:
merged_candidates = merged_candidates.merge(product_index, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates['product_index'] = merged_candidates['product_index'].fillna(0)
merged_candidates['product_index'] = merged_candidates['product_index'].astype('int64')
merged_candidates.drop(columns=['id', 'locale'], inplace=True)
assert len(merged_candidates) == len(merged_candidates_feature_test)
merged_candidates_grouped = merged_candidates.groupby(by='sess_id')['product_index'].apply(list)

In [16]:
locale_product_map = {}
for row in tqdm(product_data.itertuples(), total=product_data.shape[0]):
    locale_product_map[row.locale+'_'+row.id] = row.Index + 1

100%|██████████| 1551057/1551057 [00:07<00:00, 219005.70it/s]


In [17]:
# construct query list 
test_query_list = construct_query_list_from_sessions(test_sessions, locale_product_map, 5, product_corpus=feat_corpus_list)

100%|██████████| 316972/316972 [00:13<00:00, 22817.63it/s]


In [18]:
def get_sess_scores(sess):
    sess_id = sess['sess_id']
    scores = feat_BM25.get_batch_scores(test_query_list[sess_id], merged_candidates_grouped[sess_id])
    return {'sess_bm25_scores' : scores}

In [None]:
scores = feat_BM25.get_scores(test_query_list[150102])
top_n = np.argsort(scores)[::-1][:5]

In [None]:
top_n

array([1069719,  970734,   19560, 1244867, 1183103])

In [39]:
product_index.query("id=='B09NRWVB92' & locale=='UK'")

Unnamed: 0,id,locale,product_index
1137236,B09NRWVB92,UK,1137237


In [35]:
tokenizer.decode(test_query_list[220000])

'brand architekts vitamin c sleep and reveal night cream 200 ml (pack of 1) ba-sf25561 brand architekts white 30 ml (pack of 1) ba-sf25571 nyk1 shampoo & conditioner - salt & sulphate free nyk1 conditioner - salt and sulphate free 500 ml (pack of 1)'

In [36]:
tokenizer.decode(feat_corpus[1150019]['input_ids'])

'nyk1 shampoo - salt and sulphate free 500 ml (pack of 1)'

In [38]:
tokenizer.decode(feat_corpus[977275]['input_ids'])

'by cocochoco 150 ml (pack of 1)'

In [40]:
tokenizer.decode(feat_corpus[1137237]['input_ids'])

'nyk1 purple hair shampoo & conditioner 500 ml (pack of 2)'

In [19]:
datasets.set_progress_bar_enabled(False)
test_query_dataset = TFDataset.from_dict({'sess_id' : list(range(len(test_query_list)))})
test_query_dataset = test_query_dataset.map(get_sess_scores, num_proc=20, batched=False)
datasets.set_progress_bar_enabled(True)

In [20]:
test_scores_list = test_query_dataset['sess_bm25_scores']

In [21]:
merged_bm25_scores = []
for scores_set in tqdm(test_scores_list):
    for s in scores_set:
        merged_bm25_scores.append(s)
assert len(merged_bm25_scores) == len(merged_candidates)
assert len(merged_bm25_scores) == len(merged_candidates_feature_test)

100%|██████████| 316972/316972 [00:20<00:00, 15847.96it/s]


In [22]:
merged_candidates_feature_test['feat_BM25_scores'] = merged_bm25_scores

In [43]:
cast_dtype(merged_candidates_feature_test, ['feat_BM25_scores'])
merged_candidates_feature_test.to_parquet(merged_candidates_feature_test_path)

In [33]:
merged_candidates_feature_test.query('sess_id==220000').sort_values(by=['title_BM25_scores'], ascending=False)[['sess_id', 'sess_locale', 'product', 'feat_BM25_scores', 'sasrec_scores_2']][:50]

Unnamed: 0,sess_id,sess_locale,product,feat_BM25_scores,sasrec_scores_2
66933726,220000,UK,B01LLT790C,148.814966,17.280964
66933773,220000,UK,B07CVJKL3D,43.209051,8.495083
66933789,220000,UK,B07GSCKWQ3,39.906404,1.201975
66933724,220000,UK,B01LGLXU3U,50.352208,9.769053
66933757,220000,UK,B078W6ZC3V,48.767774,4.645908
66933983,220000,UK,B09NRWVB92,98.722858,7.050541
66933827,220000,UK,B07R3ZYQW2,37.049253,10.665802
66933829,220000,UK,B07RVJBPY1,37.049253,13.42028
66933901,220000,UK,B08MRPMLZM,27.786835,8.705627
66933899,220000,UK,B08LF4LPPM,45.269457,13.252043


In [49]:
merged_candidates_feature_test

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores,product_freq,sess_avg_price,product_price,...,roberta_scores,roberta_normalized_scores,title_BM25_scores,sasrec_scores_3,sasrec_normalized_scores_3,normalized_all_items_co_graph_count_0,all_items_co_graph_count_0,seqmlp_scores,seqmlp_normalized_scores,desc_BM25_scores
0,0,DE,4088833651,0.000000,2.975813e-09,0.000000,1.580065e-09,828,25.195269,36.761604,...,0.000000,0.000000,0.000000,0.000000,2.622550e-09,0.000000,0,0.000000,2.554478e-10,0.000000
1,0,DE,B000H6W2GW,0.000000,2.975813e-09,0.000000,1.580065e-09,875,25.195269,36.761604,...,0.000000,0.000000,0.000000,0.000000,2.622550e-09,0.000000,0,0.000000,2.554478e-10,0.000000
2,0,DE,B000JG2RAG,7.665308,6.347557e-06,8.104032,5.226502e-06,24,25.195269,23.190001,...,267.192719,0.004943,287.809601,8.885176,1.894552e-05,0.000000,0,8.786958,1.672744e-06,67.792645
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10,-2.857798,9.068785e-11,5,25.195269,6.900000,...,267.322815,0.005629,321.394653,-1.640674,5.083796e-10,0.000000,0,-3.325048,9.188664e-12,170.360588
4,0,DE,B000UGZVQM,3.977920,1.589257e-07,4.688567,1.717488e-07,4,25.195269,21.990000,...,267.242462,0.005195,285.328705,4.972019,3.784811e-07,0.000000,0,5.540127,6.506522e-08,71.169296
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04,10.629994,3.818184e-04,74,16.950001,5.800000,...,270.043762,0.014921,449.867401,10.968081,1.849500e-05,0.010237,16,11.838901,9.762144e-04,164.803133
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07,6.052083,3.923694e-06,6,16.950001,9.880000,...,269.350769,0.007462,431.585815,7.366314,5.044600e-07,0.000640,1,4.890683,9.375031e-07,303.665984
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05,7.667603,1.973815e-05,7,16.950001,22.097065,...,269.313751,0.007191,419.572662,8.286265,1.265775e-06,0.000640,1,10.187823,1.872800e-04,226.131521
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14,-6.356799,1.601719e-11,0,16.950001,5.960000,...,270.200653,0.017456,420.993561,-10.871386,6.057512e-15,0.000000,0,-4.160688,1.099036e-10,312.603594
