In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict

In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def get_avg_price(session_df, product2price_dict):
    avg_price_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess = session_df.iloc[i]
        locale = sess['locale']
        avg_price = 0.0
        prev_items = eval(sess['prev_items'].replace(' ', ','))
        for item in prev_items:
            avg_price += product2price_dict[item+'_'+locale]
        avg_price = avg_price / len(prev_items)
        avg_price_list.append(avg_price)
    return avg_price_list

In [4]:
def get_product_price(candidate_df : pd.DataFrame, product2price_dict):
    price_list = []
    for candidate in tqdm(candidate_df.itertuples(index=False), total=candidate_df.shape[0]):
        product = candidate.product
        locale = candidate.sess_locale
        if product+'_'+locale not in product2price_dict:
            if locale == 'DE': price_list.append(36.7616060903638)
            elif locale == 'JP': price_list.append(4201.2729840839065)
            elif locale == 'UK': price_list.append(22.097065056579634)
        else:
            price_list.append(product2price_dict[product+'_'+locale])
    return price_list

# Merge valid price 

In [5]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_no_hist_feature.parquet'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
product_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/processed_products_train.csv'

In [6]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_product_feature():
    return pd.read_csv(product_path)

In [7]:
valid_sessions = read_valid_sessions()
product_feature = read_product_feature()
merged_candidates_feature = read_merged_candidates_feature()
product_feature['price'] = np.exp(product_feature['price']) - 1

In [8]:
product2price = {}
for i in tqdm(range(product_feature.shape[0])):
    product = product_feature.iloc[i]
    k = product['id']+'_'+product['locale']
    v = product['price']
    product2price[k] = v

100%|██████████| 1551057/1551057 [01:48<00:00, 14235.52it/s]


In [9]:
avg_price_list = get_avg_price(valid_sessions, product2price)
assert len(avg_price_list) == len(valid_sessions)

100%|██████████| 361581/361581 [00:19<00:00, 18648.13it/s]


In [10]:
avg_price_list = np.array(avg_price_list)
sess_index = merged_candidates_feature['sess_id'].to_numpy()
avg_price_list = avg_price_list[sess_index]

In [11]:
merged_candidates_feature['sess_avg_price'] = avg_price_list

In [12]:
product_price_list = get_product_price(merged_candidates_feature, product2price)
assert len(product_price_list) == len(merged_candidates_feature)

100%|██████████| 84407339/84407339 [04:01<00:00, 349625.76it/s]


In [13]:
merged_candidates_feature['product_price'] = product_price_list

In [15]:
cast_dtype(merged_candidates_feature)
merged_candidates_feature.to_parquet(merged_candidates_feature_path, engine='pyarrow')

In [14]:
merged_candidates_feature

Unnamed: 0,sess_id,sess_locale,product,target,sasrec_scores_2,sasrec_normalized_scores_2,product_freq,gru4rec_scores,gru4rec_normalized_scores,sess_avg_price,product_price
0,0,UK,B000OPPVCS,0.0,11.972421,2.286162e-04,104,6.484859,3.816029e-05,7.388571,7.280000
1,0,UK,B000V599Y2,0.0,13.152878,7.443427e-04,37,4.342063,4.477209e-06,7.388571,5.200000
2,0,UK,B0018HH444,0.0,5.606023,3.928400e-07,7,3.220763,1.458925e-06,7.388571,15.800000
3,0,UK,B0079JI4DU,0.0,0.000000,1.443945e-09,67,0.000000,5.824698e-08,7.388571,22.097065
4,0,UK,B0079JI4EY,0.0,0.000000,1.443945e-09,77,0.000000,5.824698e-08,7.388571,22.097065
...,...,...,...,...,...,...,...,...,...,...,...
84407334,361580,DE,B0BB7XV97M,0.0,9.117821,6.077226e-05,56,9.268379,1.396883e-05,32.424000,47.990000
84407335,361580,DE,B0BB7YSRBX,0.0,9.163816,6.363281e-05,58,7.047796,1.516259e-06,32.424000,43.990000
84407336,361580,DE,B0BB7ZMGY8,0.0,11.256460,5.158278e-04,452,9.359167,1.529639e-05,32.424000,41.990000
84407337,361580,DE,B0BD4CP7N3,0.0,-3.778687,1.523433e-10,1,-0.593306,7.282568e-10,32.424000,24.990000


# Merge test price

In [16]:
merged_candidates_feature_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_test_no_hist_feature.parquet'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'
product_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/processed_products_train.csv'

In [17]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature_test():
    return pd.read_parquet(merged_candidates_feature_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_product_feature():
    return pd.read_csv(product_path)

In [18]:
test_sessions = read_test_sessions()
product_feature = read_product_feature()
merged_candidates_feature_test = read_merged_candidates_feature_test()
product_feature['price'] = np.exp(product_feature['price']) - 1

In [19]:
product2price = {}
for i in tqdm(range(product_feature.shape[0])):
    product = product_feature.iloc[i]
    k = product['id']+'_'+product['locale']
    v = product['price']
    product2price[k] = v

100%|██████████| 1551057/1551057 [01:49<00:00, 14131.04it/s]


In [20]:
avg_price_list = get_avg_price(test_sessions, product2price)
assert len(avg_price_list) == len(test_sessions)

100%|██████████| 316971/316971 [00:16<00:00, 19340.64it/s]


In [21]:
avg_price_list = np.array(avg_price_list)
sess_index = merged_candidates_feature_test['sess_id'].to_numpy()
avg_price_list = avg_price_list[sess_index]

In [22]:
merged_candidates_feature_test['sess_avg_price'] = avg_price_list

In [23]:
merged_candidates_feature_test_g = cudf.from_pandas(merged_candidates_feature_test)
product_price_feature_g = cudf.from_pandas(product_feature[['id', 'locale', 'price']])

In [24]:
merged_candidates_feature_test_price_g = merged_candidates_feature_test_g.merge(product_price_feature_g, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates_feature_test_price_g = merged_candidates_feature_test_price_g.sort_values(by=['sess_id', 'product'])
merged_candidates_feature_test_price_g.reset_index(drop=True, inplace=True)
merged_candidates_feature_test_price_g.rename(columns={'price' : 'product_price'}, inplace=True)
merged_candidates_feature_test_price_g.drop(columns=['id', 'locale'], inplace=True)

In [29]:
merged_candidates_feature_test_price = merged_candidates_feature_test_price_g.to_pandas()

In [31]:
# fill nan price 
DE_price_NA = (merged_candidates_feature_test_price['sess_locale'] == 'DE') & merged_candidates_feature_test_price['product_price'].isna()
JP_price_NA = (merged_candidates_feature_test_price['sess_locale'] == 'JP') & merged_candidates_feature_test_price['product_price'].isna()
UK_price_NA = (merged_candidates_feature_test_price['sess_locale'] == 'UK') & merged_candidates_feature_test_price['product_price'].isna()

In [32]:
merged_candidates_feature_test_price.loc[merged_candidates_feature_test_price.index[DE_price_NA], 'product_price'] = 36.7616060903638
merged_candidates_feature_test_price.loc[merged_candidates_feature_test_price.index[JP_price_NA], 'product_price'] = 4201.2729840839065
merged_candidates_feature_test_price.loc[merged_candidates_feature_test_price.index[UK_price_NA], 'product_price'] = 22.097065056579634

In [34]:
cast_dtype(merged_candidates_feature_test_price)
merged_candidates_feature_test_price.to_parquet(merged_candidates_feature_test_path, engine='pyarrow')

In [35]:
merged_candidates_feature_test_price[merged_candidates_feature_test_price['sess_id'] == 3]

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores,product_freq,sess_avg_price,product_price
622,3,DE,1687728321,5.240613,3.914508e-07,5.415880,1.595564e-07,1,13.7225,14.80
623,3,DE,1720106894,1.750717,1.194084e-08,2.297490,7.056911e-09,3,13.7225,14.80
624,3,DE,3473327328,8.976005,1.640353e-05,8.315353,2.898275e-06,23,13.7225,14.99
625,3,DE,3735850669,13.315015,1.257030e-03,13.339020,4.404434e-04,5,13.7225,16.00
626,3,DE,3741524131,17.372890,7.272080e-02,18.678909,9.182791e-02,26,13.7225,8.00
...,...,...,...,...,...,...,...,...,...,...
814,3,DE,B0BHL3M32W,3.161731,4.895877e-08,0.510518,1.181796e-09,58,13.7225,14.95
815,3,DE,B0BHN57LC9,6.468208,1.336027e-06,3.847412,3.324582e-08,5,13.7225,9.95
816,3,DE,B0BJBYY1TX,13.677494,1.806210e-03,13.774452,6.807632e-04,3,13.7225,6.99
817,3,DE,B0BJNDCGQX,14.244232,3.183461e-03,13.746093,6.617283e-04,12,13.7225,7.99


In [None]:
del merged_candidates_feature_test_g
del product_price_feature_g
del merged_candidates_feature_test_price_g

In [33]:
merged_candidates_feature_test_price

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores,product_freq,sess_avg_price,product_price
0,0,DE,4088833651,0.000000,2.975813e-09,0.000000,1.580065e-09,828,25.195268,36.761606
1,0,DE,B000H6W2GW,0.000000,2.975813e-09,0.000000,1.580065e-09,875,25.195268,36.761606
2,0,DE,B000JG2RAG,7.665308,6.347557e-06,8.104032,5.226502e-06,24,25.195268,23.190000
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10,-2.857798,9.068785e-11,5,25.195268,6.900000
4,0,DE,B000UGZVQM,3.977920,1.589257e-07,4.688567,1.717488e-07,4,25.195268,21.990000
...,...,...,...,...,...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04,10.629994,3.818184e-04,74,16.950000,5.800000
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07,6.052083,3.923694e-06,6,16.950000,9.880000
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05,7.667603,1.973815e-05,7,16.950000,22.097065
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14,-6.356799,1.601719e-11,0,16.950000,5.960000


In [25]:
merged_candidates_feature_test

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores,product_freq,sess_avg_price
0,0,DE,4088833651,0.000000,2.975813e-09,0.000000,1.580065e-09,828,25.195268
1,0,DE,B000H6W2GW,0.000000,2.975813e-09,0.000000,1.580065e-09,875,25.195268
2,0,DE,B000JG2RAG,7.665308,6.347557e-06,8.104032,5.226502e-06,24,25.195268
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10,-2.857798,9.068785e-11,5,25.195268
4,0,DE,B000UGZVQM,3.977920,1.589257e-07,4.688567,1.717488e-07,4,25.195268
...,...,...,...,...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04,10.629994,3.818184e-04,74,16.950000
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07,6.052083,3.923694e-06,6,16.950000
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05,7.667603,1.973815e-05,7,16.950000
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14,-6.356799,1.601719e-11,0,16.950000
