In [1]:
import sys
sys.path.append('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/')
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd 
import numpy as np 
import datasets
from datasets import Dataset as TFDataset 
import pickle
from bm25.rank_bm25 import BM25Okapi
import random
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
from transformers import PreTrainedTokenizer, AutoTokenizer
import multiprocessing

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def tokenize_function(examples, corpus_col_name, tokenizer, max_length):
    if corpus_col_name in examples:
        return tokenizer(examples[corpus_col_name], 
            add_special_tokens=False, # don't add special tokens when preprocess
            truncation=True, 
            max_length=max_length,
            return_attention_mask=False,
            return_token_type_ids=False)

In [4]:
def construct_query_list_from_sessions(sessions_df:pd.DataFrame, product_map:dict, max_seq_len:int, product_corpus:list):
    query_list = []
    for sess in tqdm(sessions_df.itertuples(), total=sessions_df.shape[0]):
        sess_locale = sess.locale
        prev_items = eval(sess.prev_items.replace(' ', ','))[-max_seq_len : ]
        prev_items = [product_map.get(sess_locale+'_'+item, 0) for item in prev_items]
        sess_query = sum([product_corpus[item] for item in prev_items], [])
        query_list.append(sess_query)
    return query_list

# Merge test BM25 score

In [28]:
merged_candidates_feature_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_test_feature.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/products_train.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'

In [29]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature_test():
    return pd.read_parquet(merged_candidates_feature_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)


In [30]:
merged_candidates_feature_test = read_merged_candidates_feature_test()
product_data = read_product_data()
test_sessions = read_test_sessions()

In [31]:
TOKENIZER_NAME = 'xlm-roberta-base'
DESC_MAX_LENGTH = 500

In [35]:
tokenizer = AutoTokenizer.from_pretrained(
        TOKENIZER_NAME,
        use_fast=False,
)

In [33]:
desc_corpus = product_data[['desc', 'brand', 'color', 'size', 'model', 'material', 'author']]
padding_df = pd.DataFrame({'desc' : [''], 'brand' : [''], 'color' : [''], 'size' : [''], 'model' : [''], 'material' : [''], 'author' : ['']})
desc_corpus = pd.concat([padding_df, desc_corpus]).reset_index(drop=True) # add padding product
desc_corpus['desc'] = desc_corpus['desc'].fillna('')
desc_corpus['brand'] = desc_corpus['brand'].fillna('')
desc_corpus['color'] = desc_corpus['color'].fillna('')
desc_corpus['size'] = desc_corpus['size'].fillna('')
desc_corpus['model'] = desc_corpus['model'].fillna('')
desc_corpus['material'] = desc_corpus['material'].fillna('')
desc_corpus['author'] = desc_corpus['author'].fillna('')

desc_corpus['desc'] = desc_corpus['desc'] + ' ' + desc_corpus['brand'] + ' ' + desc_corpus['color'] + ' ' + desc_corpus['size'] + ' ' + desc_corpus['model'] \
    + ' ' + desc_corpus['material'] + ' ' + desc_corpus['author']
desc_corpus['desc'] = desc_corpus['desc'].apply(lambda x : x.lower())

In [34]:
desc_corpus = TFDataset.from_pandas(desc_corpus, preserve_index=False)
desc_corpus = desc_corpus.map(partial(tokenize_function, corpus_col_name='desc', tokenizer=tokenizer, max_length=DESC_MAX_LENGTH), 
                                num_proc=8, remove_columns=['desc'], batched=True)
desc_corpus_list = desc_corpus['input_ids']

 #0:   9%|▉         | 17/194 [00:18<03:23,  1.15s/ba]
 #0:  10%|▉         | 19/194 [00:19<02:19,  1.25ba/s]
 #0:  10%|█         | 20/194 [00:20<02:20,  1.24ba/s]
 #1:   6%|▌         | 12/194 [00:16<03:29,  1.15s/ba]
 #0:  11%|█▏        | 22/194 [00:21<02:34,  1.11ba/s]
 #0:  12%|█▏        | 23/194 [00:23<02:40,  1.07ba/s]
 #0:  12%|█▏        | 24/194 [00:24<03:06,  1.10s/ba]
[A

[A[A
 #0:  13%|█▎        | 25/194 [00:25<02:55,  1.04s/ba]

[A[A

 #0:  13%|█▎        | 26/194 [00:26<02:39,  1.06ba/s]
[A

 #0:  14%|█▍        | 27/194 [00:26<02:21,  1.18ba/s]
[A

 #0:  14%|█▍        | 28/194 [00:27<01:57,  1.41ba/s]

[A[A
[A

 #0:  15%|█▍        | 29/194 [00:28<02:30,  1.10ba/s]
[A

 #0:  15%|█▌        | 30/194 [00:29<02:23,  1.14ba/s]

[A[A
 #0:  16%|█▌        | 31/194 [00:30<02:48,  1.03s/ba]

[A[A
 #0:  16%|█▋        | 32/194 [00:31<02:39,  1.02ba/s]

[A[A

 #0:  17%|█▋        | 33/194 [00:32<02:16,  1.18ba/s]
[A

 #0:  18%|█▊        | 34/194 [00:32<01:56,  1.37ba/s]

[

In [36]:
with open('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/bm25/cache/desc_bm25.pkl', 'rb') as f:
    desc_BM25 = pickle.load(f)

In [37]:
merged_candidates = merged_candidates_feature_test[['sess_id', 'sess_locale', 'product']]
merged_candidates

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,4088833651
1,0,DE,B000H6W2GW
2,0,DE,B000JG2RAG
3,0,DE,B000RYSOUW
4,0,DE,B000UGZVQM
...,...,...,...
69428426,316970,UK,B0BJCTH4NH
69428427,316970,UK,B0BJTQQWLG
69428428,316970,UK,B0BJV3RL4H
69428429,316970,UK,B0BK7SPC84


In [38]:
product_index = product_data[['id', 'locale']]
product_index['product_index'] = product_index.index + 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  product_index['product_index'] = product_index.index + 1


In [39]:
merged_candidates = merged_candidates.merge(product_index, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates['product_index'] = merged_candidates['product_index'].fillna(0)
merged_candidates['product_index'] = merged_candidates['product_index'].astype('int64')
merged_candidates.drop(columns=['id', 'locale'], inplace=True)
assert len(merged_candidates) == len(merged_candidates_feature_test)
merged_candidates_grouped = merged_candidates.groupby(by='sess_id')['product_index'].apply(list)

In [40]:
locale_product_map = {}
for row in tqdm(product_data.itertuples(), total=product_data.shape[0]):
    locale_product_map[row.locale+'_'+row.id] = row.Index + 1

100%|██████████| 1551057/1551057 [00:07<00:00, 218589.58it/s]


In [41]:
# construct query list 
test_query_list = construct_query_list_from_sessions(test_sessions, locale_product_map, 5, product_corpus=desc_corpus_list)

100%|██████████| 316971/316971 [00:39<00:00, 8002.21it/s] 


In [42]:
def get_sess_scores(sess):
    sess_id = sess['sess_id']
    scores = desc_BM25.get_batch_scores(test_query_list[sess_id], merged_candidates_grouped[sess_id])
    return {'sess_bm25_scores' : scores}

In [43]:
datasets.set_progress_bar_enabled(False)
test_query_dataset = TFDataset.from_dict({'sess_id' : list(range(len(test_query_list)))})
test_query_dataset = test_query_dataset.map(get_sess_scores, num_proc=20, batched=False)
datasets.set_progress_bar_enabled(True)

In [44]:
test_scores_list = test_query_dataset['sess_bm25_scores']

In [45]:
merged_bm25_scores = []
for scores_set in tqdm(test_scores_list):
    for s in scores_set:
        merged_bm25_scores.append(s)
assert len(merged_bm25_scores) == len(merged_candidates)
assert len(merged_bm25_scores) == len(merged_candidates_feature_test)

100%|██████████| 316971/316971 [00:10<00:00, 29549.10it/s]


In [46]:
merged_candidates_feature_test['desc_BM25_scores'] = merged_bm25_scores

In [None]:
cast_dtype(merged_candidates_feature_test)
merged_candidates_feature_test.to_parquet(merged_candidates_feature_test_path)

In [81]:
test_sessions.iloc[200000], test_sessions.iloc[200000]['prev_items']

(prev_items    ['B09NBQKRPC' 'B09NBQKRPC' 'B0BHQQQK2D']
 locale                                              JP
 Name: 200000, dtype: object,
 "['B09NBQKRPC' 'B09NBQKRPC' 'B0BHQQQK2D']")

In [86]:
product_data.iloc[locale_product_map['JP_B09V2MMFX4'] - 1]['title'], product_data.iloc[locale_product_map['JP_B08QVNCF14'] - 1]['title']

('RAVIAD USB C ライトニングケーブル 【1M/MFi 認証】 iPhone 充電ケーブル 急速充電 データ転送 高耐久 タイプC ライトニングケーブル PD対応 iPhone 13/13 Pro/13 Pro Max/12/12 Pro/12 Pro Max/12 mini/11 Pro Max/SE/XS/XR/X/8/8 Plus各種対応 Type C Lightningケーブル',
 'RAVIAD USB C ライトニングケーブル 【2M/MFi 認証】 iPhone 充電ケーブル 急速充電 データ転送 高耐久 タイプC ライトニングケーブル PD対応 iPhone 13/13 Pro/13 Pro Max/12/12 Pro/12 Pro Max/12 mini/11 Pro Max/SE/XS/XR/X/8/8 Plus各種対応 Type C Lightningケーブル')

In [48]:
merged_candidates_feature_test.query('sess_id==200000').sort_values(by=['title_BM25_scores'], ascending=False)[['sess_id', 'sess_locale', 'product', 'title_BM25_scores', 'desc_BM25_scores', 'sasrec_scores_2']][:50]

Unnamed: 0,sess_id,sess_locale,product,title_BM25_scores,desc_BM25_scores,sasrec_scores_2
43632311,200000,JP,B09V2MMFX4,820.670532,867.441694,14.521214
43632253,200000,JP,B08QVNCF14,820.670532,887.937117,15.688082
43632252,200000,JP,B08QVJ2BDF,820.670532,878.520506,15.84468
43632310,200000,JP,B09TVFSQ7F,796.628235,868.63111,17.949009
43632246,200000,JP,B08PS31FCM,788.036377,859.537915,15.484927
43632340,200000,JP,B0B2R8P5GX,776.291321,374.292121,12.940094
43632351,200000,JP,B0B4NSZ8FX,704.16925,231.982512,11.017905
43632330,200000,JP,B09YCMWPF5,704.16925,195.438868,9.152035
43632248,200000,JP,B08PYNP5BV,704.16925,202.674156,6.978341
43632367,200000,JP,B0B7MRLJH1,700.240784,224.814992,14.76034


In [49]:
merged_candidates_feature_test

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores,product_freq,sess_avg_price,product_price,...,roberta_scores,roberta_normalized_scores,title_BM25_scores,sasrec_scores_3,sasrec_normalized_scores_3,normalized_all_items_co_graph_count_0,all_items_co_graph_count_0,seqmlp_scores,seqmlp_normalized_scores,desc_BM25_scores
0,0,DE,4088833651,0.000000,2.975813e-09,0.000000,1.580065e-09,828,25.195269,36.761604,...,0.000000,0.000000,0.000000,0.000000,2.622550e-09,0.000000,0,0.000000,2.554478e-10,0.000000
1,0,DE,B000H6W2GW,0.000000,2.975813e-09,0.000000,1.580065e-09,875,25.195269,36.761604,...,0.000000,0.000000,0.000000,0.000000,2.622550e-09,0.000000,0,0.000000,2.554478e-10,0.000000
2,0,DE,B000JG2RAG,7.665308,6.347557e-06,8.104032,5.226502e-06,24,25.195269,23.190001,...,267.192719,0.004943,287.809601,8.885176,1.894552e-05,0.000000,0,8.786958,1.672744e-06,67.792645
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10,-2.857798,9.068785e-11,5,25.195269,6.900000,...,267.322815,0.005629,321.394653,-1.640674,5.083796e-10,0.000000,0,-3.325048,9.188664e-12,170.360588
4,0,DE,B000UGZVQM,3.977920,1.589257e-07,4.688567,1.717488e-07,4,25.195269,21.990000,...,267.242462,0.005195,285.328705,4.972019,3.784811e-07,0.000000,0,5.540127,6.506522e-08,71.169296
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04,10.629994,3.818184e-04,74,16.950001,5.800000,...,270.043762,0.014921,449.867401,10.968081,1.849500e-05,0.010237,16,11.838901,9.762144e-04,164.803133
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07,6.052083,3.923694e-06,6,16.950001,9.880000,...,269.350769,0.007462,431.585815,7.366314,5.044600e-07,0.000640,1,4.890683,9.375031e-07,303.665984
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05,7.667603,1.973815e-05,7,16.950001,22.097065,...,269.313751,0.007191,419.572662,8.286265,1.265775e-06,0.000640,1,10.187823,1.872800e-04,226.131521
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14,-6.356799,1.601719e-11,0,16.950001,5.960000,...,270.200653,0.017456,420.993561,-10.871386,6.057512e-15,0.000000,0,-4.160688,1.099036e-10,312.603594
