In [1]:
import sys
sys.path.append('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/')
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd 
import numpy as np 
import datasets
from datasets import Dataset as TFDataset 
import pickle
from bm25.rank_bm25 import BM25Okapi
import random
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
from transformers import PreTrainedTokenizer, AutoTokenizer
import multiprocessing

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def tokenize_function(examples, corpus_col_name, tokenizer, max_length):
    if corpus_col_name in examples:
        return tokenizer(examples[corpus_col_name], 
            add_special_tokens=False, # don't add special tokens when preprocess
            truncation=True, 
            max_length=max_length,
            return_attention_mask=False,
            return_token_type_ids=False)

In [4]:
def construct_query_list_from_sessions(sessions_df:pd.DataFrame, product_map:dict, max_seq_len:int, product_corpus:list):
    query_list = []
    for sess in tqdm(sessions_df.itertuples(), total=sessions_df.shape[0]):
        sess_locale = sess.locale
        prev_items = eval(sess.prev_items.replace(' ', ','))[-max_seq_len : ]
        prev_items = [product_map.get(sess_locale+'_'+item, 0) for item in prev_items]
        sess_query = sum([product_corpus[item] for item in prev_items], [])
        query_list.append(sess_query)
    return query_list

# Merge valid BM25 score

In [5]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_2_feature.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/products_train.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'

In [6]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)


In [7]:
merged_candidates_feature = read_merged_candidates_feature()
product_data = read_product_data()
valid_sessions = read_valid_sessions()

In [8]:
TOKENIZER_NAME = 'xlm-roberta-base'
DESC_MAX_LENGTH = 500

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
        TOKENIZER_NAME,
        use_fast=False,
)

'HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /xlm-roberta-base/resolve/main/sentencepiece.bpe.model (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7f54a9736370>, 'Connection to huggingface.co timed out. (connect timeout=10)'))' thrown while requesting HEAD https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model


In [10]:
desc_corpus = product_data[['desc', 'brand', 'color', 'size', 'model', 'material', 'author']]
padding_df = pd.DataFrame({'desc' : [''], 'brand' : [''], 'color' : [''], 'size' : [''], 'model' : [''], 'material' : [''], 'author' : ['']})
desc_corpus = pd.concat([padding_df, desc_corpus]).reset_index(drop=True) # add padding product
desc_corpus['desc'] = desc_corpus['desc'].fillna('')
desc_corpus['brand'] = desc_corpus['brand'].fillna('')
desc_corpus['color'] = desc_corpus['color'].fillna('')
desc_corpus['size'] = desc_corpus['size'].fillna('')
desc_corpus['model'] = desc_corpus['model'].fillna('')
desc_corpus['material'] = desc_corpus['material'].fillna('')
desc_corpus['author'] = desc_corpus['author'].fillna('')

desc_corpus['desc'] = desc_corpus['desc'] + ' ' + desc_corpus['brand'] + ' ' + desc_corpus['color'] + ' ' + desc_corpus['size'] + ' ' + desc_corpus['model'] \
    + ' ' + desc_corpus['material'] + ' ' + desc_corpus['author']
desc_corpus['desc'] = desc_corpus['desc'].apply(lambda x : x.lower())

In [11]:
desc_corpus = TFDataset.from_pandas(desc_corpus, preserve_index=False)
desc_corpus = desc_corpus.map(partial(tokenize_function, corpus_col_name='desc', tokenizer=tokenizer, max_length=DESC_MAX_LENGTH), 
                                num_proc=8, remove_columns=['desc'], batched=True)
desc_corpus_list = desc_corpus['input_ids']

 #0:  11%|█▏        | 22/194 [00:14<01:37,  1.76ba/s]
 #0:  12%|█▏        | 23/194 [00:15<01:48,  1.58ba/s]
 #0:  12%|█▏        | 24/194 [00:16<02:27,  1.15ba/s]
 #0:  13%|█▎        | 25/194 [00:17<02:51,  1.01s/ba]
 #0:  14%|█▍        | 27/194 [00:20<03:18,  1.19s/ba]
 #0:  14%|█▍        | 28/194 [00:22<03:30,  1.27s/ba]
 #0:  15%|█▌        | 30/194 [00:23<02:51,  1.05s/ba]
 #0:  16%|█▌        | 31/194 [00:24<02:35,  1.05ba/s]
 #0:  16%|█▋        | 32/194 [00:25<02:18,  1.17ba/s]
 #0:  17%|█▋        | 33/194 [00:25<02:08,  1.26ba/s]
 #0:  18%|█▊        | 34/194 [00:27<02:31,  1.05ba/s]
[A

[A[A

 #0:  18%|█▊        | 35/194 [00:29<03:29,  1.32s/ba]
[A

 #0:  19%|█▊        | 36/194 [00:30<02:54,  1.10s/ba]
[A

[A[A

 #0:  19%|█▉        | 37/194 [00:31<03:36,  1.38s/ba]
[A
 #0:  20%|█▉        | 38/194 [00:32<02:59,  1.15s/ba]

 #2:   7%|▋         | 14/194 [00:18<03:37,  1.21s/ba]

 #0:  20%|██        | 39/194 [00:34<03:28,  1.34s/ba]
[A

 #0:  21%|██        | 40/194 [00:35<03:2

In [12]:
with open('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/bm25/cache/desc_bm25.pkl', 'rb') as f:
    desc_BM25 = pickle.load(f)

In [13]:
merged_candidates = merged_candidates_feature[['sess_id', 'sess_locale', 'product']]
merged_candidates

Unnamed: 0,sess_id,sess_locale,product
0,0,UK,B000OPPVCS
1,0,UK,B000V599Y2
2,0,UK,B0018HH444
3,0,UK,B0079JI4DU
4,0,UK,B0079JI4EY
...,...,...,...
84407334,361580,DE,B0BB7XV97M
84407335,361580,DE,B0BB7YSRBX
84407336,361580,DE,B0BB7ZMGY8
84407337,361580,DE,B0BD4CP7N3


In [14]:
product_index = product_data[['id', 'locale']]
product_index['product_index'] = product_index.index + 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  product_index['product_index'] = product_index.index + 1


In [15]:
product_index.query("id=='B0079JI4DU'")

Unnamed: 0,id,locale,product_index
101952,B0079JI4DU,DE,101953
1516940,B0079JI4DU,IT,1516941


In [16]:
merged_candidates = merged_candidates.merge(product_index, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates['product_index'] = merged_candidates['product_index'].fillna(0)
merged_candidates['product_index'] = merged_candidates['product_index'].astype('int64')
merged_candidates.drop(columns=['id', 'locale'], inplace=True)
assert len(merged_candidates) == len(merged_candidates_feature)
merged_candidates_grouped = merged_candidates.groupby(by='sess_id')['product_index'].apply(list)

In [17]:
locale_product_map = {}
for row in tqdm(product_data.itertuples(), total=product_data.shape[0]):
    locale_product_map[row.locale+'_'+row.id] = row.Index + 1

100%|██████████| 1551057/1551057 [00:21<00:00, 71725.33it/s] 


In [18]:
# construct query list 
valid_query_list = construct_query_list_from_sessions(valid_sessions, locale_product_map, 5, product_corpus=desc_corpus_list)

100%|██████████| 361581/361581 [00:25<00:00, 14316.32it/s]


In [19]:
def get_sess_scores(sess):
    sess_id = sess['sess_id']
    scores = desc_BM25.get_batch_scores(valid_query_list[sess_id], merged_candidates_grouped[sess_id])
    return {'sess_bm25_scores' : scores}

In [20]:
# about 23 mins
datasets.set_progress_bar_enabled(False)
valid_query_dataset = TFDataset.from_dict({'sess_id' : list(range(len(valid_query_list)))})
valid_query_dataset = valid_query_dataset.map(get_sess_scores, num_proc=20, batched=False)
datasets.set_progress_bar_enabled(True)

In [21]:
valid_scores_list = valid_query_dataset['sess_bm25_scores']

In [22]:
merged_bm25_scores = []
for scores_set in tqdm(valid_scores_list):
    for s in scores_set:
        merged_bm25_scores.append(s)
assert len(merged_bm25_scores) == len(merged_candidates)
assert len(merged_bm25_scores) == len(merged_candidates_feature)

100%|██████████| 361581/361581 [00:14<00:00, 24848.21it/s]


In [23]:
merged_candidates_feature['desc_BM25_scores'] = merged_bm25_scores

In [24]:
cast_dtype(merged_candidates_feature)
merged_candidates_feature.to_parquet(merged_candidates_feature_path)

In [None]:
valid_sessions.iloc[300001], valid_sessions.iloc[300001]['prev_items']

(prev_items    ['B07ZZ5JH12' 'B09KTRFTJJ']
 next_item                      B091BFSKKM
 locale                                 UK
 Name: 300001, dtype: object,
 "['B07ZZ5JH12' 'B09KTRFTJJ']")

In [None]:
product_data.iloc[locale_product_map['UK_B09KTRFTJJ'] - 1]['id'], \
product_data.iloc[locale_product_map['UK_B09KTRFTJJ'] - 1]['title'], \
product_data.iloc[locale_product_map['UK_B07ZZ5JH12'] - 1]['id'], \
product_data.iloc[locale_product_map['UK_B07ZZ5JH12'] - 1]['title']

('B09KTRFTJJ',
 'Bird Feeders for Outside, Bird feeder, Wild Bird seed for Outside Feeders, Squirrel Proof Birds Feeder, Garden Decoration Black',
 'B07ZZ5JH12',
 'Oakdale Wild Bird Feeder Pre-Filled with Premium Seeds, Large Hanging Metal Frame with Dual Perches, Refillable Lawn and Garden Outdoor Use, Enjoy Birdwatching or Birding')

In [None]:
product_data.iloc[locale_product_map['UK_B09PR5X9LY'] - 1]['id'], \
product_data.iloc[locale_product_map['UK_B09PR5X9LY'] - 1]['title'], \
product_data.iloc[locale_product_map['UK_B093GF9T5N'] - 1]['id'], \
product_data.iloc[locale_product_map['UK_B093GF9T5N'] - 1]['title']

('B09PR5X9LY',
 'Bird Feeders Hanging - Wild Bird Seed Feeder Garden Metal Bird Feeders for Garden Squirrel Proof Unusual Bird Feeders Frog',
 'B093GF9T5N',
 'Bird Feeders Hanging - Wild Bird Seed Feeder Garden Metal Bird Feeders for Garden Squirrel Proof Unusual Bird Feeders Sunflower')

In [27]:
merged_candidates_feature.query('sess_id==200002').sort_values(by=['title_BM25_scores'], ascending=False)[['sess_id', 'sess_locale', 'product', 'desc_BM25_scores', 'title_BM25_scores', 'normalized_sasrec_scores_2']][:50]

Unnamed: 0,sess_id,sess_locale,product,desc_BM25_scores,title_BM25_scores,roberta_normalized_scores
46695487,200002,UK,B0B25LQQPC,101.166428,191.407257,0.014758
46695378,200002,UK,B086BKGSC1,180.199966,182.573959,0.017476
46695491,200002,UK,B0B25NTRGD,103.387924,178.059723,0.010392
46695389,200002,UK,B089DNM8LR,162.421295,172.520035,0.015799
46695407,200002,UK,B08GVDNTGJ,162.421295,172.520035,0.015601
46695445,200002,UK,B098W1NDV2,174.515518,169.153549,0.006852
46695320,200002,UK,B0786QNS9B,86.68528,163.970383,0.011266
46695488,200002,UK,B0B25LZGGW,73.187576,160.92189,0.016346
46695489,200002,UK,B0B25MJ1YT,73.187576,160.92189,0.016275
46695493,200002,UK,B0B25P44CL,73.187576,160.92189,0.013167


In [18]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,product_index
0,0,UK,B000OPPVCS,1375599
1,0,UK,B000V599Y2,1324417
2,0,UK,B0018HH444,1413111
3,0,UK,B0079JI4DU,0
4,0,UK,B0079JI4EY,0
...,...,...,...,...
84407334,361580,DE,B0BB7XV97M,446969
84407335,361580,DE,B0BB7YSRBX,275922
84407336,361580,DE,B0BB7ZMGY8,429872
84407337,361580,DE,B0BD4CP7N3,276547


In [17]:
merged_candidates_grouped

sess_id
0         [1375599, 1324417, 1413111, 0, 0, 970646, 1132...
1         [826127, 673569, 751275, 889131, 654649, 77556...
2         [1149066, 1253359, 1343812, 1310769, 960407, 9...
3         [1186226, 1165726, 1126038, 1410888, 1153463, ...
4         [766710, 592913, 695302, 904349, 882975, 76287...
                                ...                        
361576    [1134842, 1094904, 1212843, 1260859, 1094344, ...
361577    [843495, 679363, 887052, 666868, 521356, 67026...
361578    [111083, 428215, 21233, 56473, 116992, 87684, ...
361579    [140800, 457181, 477726, 329896, 479625, 33046...
361580    [476029, 0, 287329, 372558, 232874, 8876, 8945...
Name: product_index, Length: 361581, dtype: object

In [21]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,product_index
0,0,UK,B000OPPVCS,1375599
1,0,UK,B000V599Y2,1324417
2,0,UK,B0018HH444,1413111
3,0,UK,B0079JI4DU,0
4,0,UK,B0079JI4EY,0
...,...,...,...,...
84407334,361580,DE,B0BB7XV97M,446969
84407335,361580,DE,B0BB7YSRBX,275922
84407336,361580,DE,B0BB7ZMGY8,429872
84407337,361580,DE,B0BD4CP7N3,276547


In [None]:
merged_candidates_feature

Unnamed: 0,sess_id,sess_locale,product,target,sasrec_scores_2,sasrec_normalized_scores_2,product_freq,gru4rec_scores,gru4rec_normalized_scores,sess_avg_price,...,roberta_scores,roberta_normalized_scores,title_BM25_scores,sasrec_scores_3,sasrec_normalized_scores_3,normalized_all_items_co_graph_count_0,all_items_co_graph_count_0,seqmlp_scores,seqmlp_normalized_scores,desc_BM25_scores
0,0,UK,B000OPPVCS,0.0,11.972421,2.286162e-04,104,6.484859,3.816029e-05,7.388571,...,265.826630,1.087245e-03,298.915375,10.891474,2.517129e-04,0.002635,2,8.378721,1.007163e-05,623.559754
1,0,UK,B000V599Y2,0.0,13.152878,7.443427e-04,37,4.342063,4.477209e-06,7.388571,...,259.157867,1.380768e-06,111.069756,10.677187,2.031618e-04,0.003953,3,7.534612,4.330198e-06,404.532660
2,0,UK,B0018HH444,0.0,5.606023,3.928400e-07,7,3.220763,1.458925e-06,7.388571,...,257.331421,2.222824e-07,0.000000,6.074605,2.036883e-06,0.001318,1,11.199949,1.691779e-04,405.707216
3,0,UK,B0079JI4DU,0.0,0.000000,1.443945e-09,67,0.000000,5.824698e-08,7.388571,...,0.000000,0.000000e+00,0.000000,0.000000,4.685961e-09,0.002635,2,0.000000,2.313489e-09,388.388606
4,0,UK,B0079JI4EY,0.0,0.000000,1.443945e-09,77,0.000000,5.824698e-08,7.388571,...,0.000000,0.000000e+00,0.000000,0.000000,4.685961e-09,0.002635,2,0.000000,2.313489e-09,388.388606
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
84407334,361580,DE,B0BB7XV97M,0.0,9.117821,6.077226e-05,56,9.268379,1.396883e-05,32.424000,...,263.574158,1.378417e-03,118.126396,9.635838,3.403967e-05,0.003356,2,13.965775,2.930834e-03,200.418798
84407335,361580,DE,B0BB7YSRBX,0.0,9.163816,6.363281e-05,58,7.047796,1.516259e-06,32.424000,...,263.523743,1.310646e-03,124.881615,9.159988,2.115080e-05,0.001678,1,10.212118,6.867505e-05,209.420113
84407336,361580,DE,B0BB7ZMGY8,0.0,11.256460,5.158278e-04,452,9.359167,1.529639e-05,32.424000,...,263.567017,1.368608e-03,124.881615,10.119755,5.522656e-05,0.038591,23,12.275789,5.408041e-04,209.420113
84407337,361580,DE,B0BD4CP7N3,0.0,-3.778687,1.523433e-10,1,-0.593306,7.282568e-10,32.424000,...,265.401611,8.571040e-03,192.540955,-1.612869,4.433373e-10,0.000000,0,-2.456360,2.162463e-10,254.826076
