In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict

In [2]:
def cast_dtype(df : pd.DataFrame, columns=None):
    if columns is None:
        columns = df.columns
    for k in columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def get_min_max_price(session_df : pd.DataFrame, product2price_dict):
    min_price_list = []
    max_price_list = []
    for sess in tqdm(session_df.itertuples(), total=session_df.shape[0]):
        sess_price_list = []
        locale = sess.locale
        prev_items = eval(sess.prev_items.replace(' ', ','))
        for item in prev_items:
            sess_price_list.append(product2price_dict[item+'_'+locale])
        
        min_price_list.append(min(sess_price_list))
        max_price_list.append(max(sess_price_list))
    
    return min_price_list, max_price_list

# Merge test price

In [4]:
merged_candidates_feature_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_test_feature.parquet'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1_phase2.csv'
product_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/processed_products_train.csv'

In [5]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature_test():
    return pd.read_parquet(merged_candidates_feature_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_product_feature():
    return pd.read_csv(product_path)

In [6]:
test_sessions = read_test_sessions()
product_feature = read_product_feature()
merged_candidates_feature_test = read_merged_candidates_feature_test()
product_feature['price'] = np.exp(product_feature['price']) - 1

In [7]:
product2price = {}
for product in tqdm(product_feature.itertuples(), total=product_feature.shape[0]):
    k = product.id+'_'+product.locale
    v = product.price
    product2price[k] = v

100%|██████████| 1551057/1551057 [00:16<00:00, 94791.59it/s] 


In [8]:
min_price_list, max_price_list = get_min_max_price(test_sessions, product2price)
assert len(min_price_list) == len(test_sessions)
assert len(max_price_list) == len(test_sessions)

100%|██████████| 316972/316972 [00:14<00:00, 21367.76it/s]


In [9]:
min_price_list = np.array(min_price_list)
max_price_list = np.array(max_price_list)

In [10]:
merged_candidates_feature_test['sess_min_price'] = min_price_list[merged_candidates_feature_test['sess_id']]
merged_candidates_feature_test['sess_max_price'] = max_price_list[merged_candidates_feature_test['sess_id']]

In [12]:
cast_dtype(merged_candidates_feature_test, ['sess_min_price', 'sess_max_price'])
merged_candidates_feature_test.to_parquet(merged_candidates_feature_test_path, engine='pyarrow')

In [None]:
# del merged_candidates_feature_test_g
# del product_price_feature_g
# del merged_candidates_feature_test_price_g

In [11]:
merged_candidates_feature_test

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,normalized_sasrec_scores_2,sasrec_scores_3,normalized_sasrec_scores_3,sess_avg_price,product_price,seqmlp_scores,...,lyx_lknn_i2i_score,lyx_lknn_u2i_score,normalized_lyx_lknn_i2i_score,normalized_lyx_lknn_u2i_score,lyx_gru4rec_i2i_score,lyx_gru4rec_u2i_score,normalized_lyx_gru4rec_i2i_score,normalized_lyx_gru4rec_u2i_score,sess_min_price,sess_max_price
0,0,DE,B000Q87D0Q,0.000000,3.282997e-10,0.000000,6.689660e-10,67.527199,36.761604,0.000000,...,0.000000,0.000000,4.878070e-14,4.984622e-09,0.000000,0.000000,1.995635e-09,2.617613e-09,36.761606,107.83
1,0,DE,B000QB30DW,0.501346,5.420036e-10,-0.588501,3.713825e-10,67.527199,9.990000,7.260942,...,12.179352,2.597628,9.498896e-09,6.695266e-08,4.637702,1.632133,2.061624e-07,1.338850e-08,36.761606,107.83
2,0,DE,B004BIG55Q,6.917523,3.315223e-07,5.737720,2.076175e-07,67.527199,8.990000,2.454817,...,12.086594,6.003901,8.657427e-09,2.018801e-06,9.406904,4.884140,2.429114e-05,3.459877e-07,36.761606,107.83
3,0,DE,B0053FTNQY,-0.100895,2.967921e-10,1.507319,3.020121e-09,67.527199,36.761604,3.837643,...,7.574673,3.320492,9.503570e-11,1.379443e-07,3.897757,5.728711,9.836833e-08,8.051055e-07,36.761606,107.83
4,0,DE,B007QWII1S,3.768980,1.422714e-08,4.594047,6.615662e-08,67.527199,54.950001,4.923371,...,15.300370,7.069957,2.153348e-07,5.862403e-06,8.042400,9.173968,6.206567e-06,2.524117e-05,36.761606,107.83
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96556030,316971,UK,B0B82N3CQQ,-1.076433,6.007382e-08,-0.457645,1.105378e-07,19.459999,13.990000,6.433315,...,9.593886,5.795608,2.128749e-07,7.366873e-06,1.131881,-1.138139,3.417292e-08,5.101982e-09,13.990000,22.41
96556031,316971,UK,B0BB9NW3F3,0.000000,1.762683e-07,0.000000,1.746882e-07,19.459999,22.097065,0.000000,...,0.000000,0.000000,1.450616e-11,2.240178e-08,0.000000,0.000000,1.101824e-08,1.592305e-08,13.990000,22.41
96556032,316971,UK,B0BDMVKTQ3,-1.079334,5.989980e-08,-1.901198,2.609658e-08,19.459999,41.990002,-1.094359,...,7.214885,4.796178,1.972138e-08,2.711667e-06,1.078394,5.592590,3.239313e-08,4.274222e-06,13.990000,22.41
96556033,316971,UK,B0BHW1D5VP,6.722834,1.465088e-04,6.111193,7.876277e-05,19.459999,26.990000,8.700006,...,16.285294,9.269331,1.714612e-04,2.376302e-04,11.048322,7.961729,6.923686e-04,4.568370e-05,13.990000,22.41
