In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache
from tqdm import tqdm, trange
from collections import Counter, defaultdict

In [2]:
def cast_dtype(df : pd.DataFrame, columns=None):
    if columns is None:
        columns = df.columns
    for k in columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def cal_next_item_freq(item_counter:Counter, session_df:pd.DataFrame, test=False):
    for sess in tqdm(session_df.itertuples(), total=session_df.shape[0]):
        next_item = sess.next_item
        item_counter[next_item] += 1 

# Merge valid next item frequency

In [4]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_feature.parquet'
raw_train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data_split/task13_4_task1_raw_train_sessions_phase2.csv'
raw_valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data_split/task13_4_task1_raw_valid_sessions_phase2.csv'
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions_phase2.csv'

In [5]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_raw_train_sessions():
    return pd.read_csv(raw_train_sessions_path)


@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

In [6]:
merged_train_sessions_df = read_train_sessions()

In [7]:
train_sessions_df = read_raw_train_sessions()

In [8]:
train_sessions_df

Unnamed: 0,prev_items,next_item,locale
0,['3949568239' 'B09CLBRV16' 'B0B7237CF5' 'B09FJ...,B0B1MWMKVZ,DE
1,['B09QZWPX6T' 'B09CLH9TWB'],B00175X9QE,DE
2,['B08LMNLBCS' 'B08LMNLBCS'],B0045XA8J6,DE
3,['B0B6HFFLHB' 'B0B6C8MZDV'],B07L1BH4Q9,DE
4,['B089W594LC' 'B09W2G8NLH'],B093GB3CCW,DE
...,...,...,...
3010895,['B07NBVMD8B' 'B09S192YFG' 'B0BFPLN6NL' 'B09R8...,B098T4S62X,UK
3010896,['B08DMVBX4C' 'B08DMV8XPK' 'B08L1X3RJ1' 'B09W8...,B01EG533I0,UK
3010897,['B00UVOE51A' 'B00UVOE51A' 'B00UVOE6TG'],B082Q3LXS2,UK
3010898,['B075D87KJL' 'B08DTWJ8BX' 'B00KDL108M' 'B01CX...,B07V59FPRV,UK


In [8]:
merged_candidates_feature = read_merged_candidates_feature()

In [9]:
merged_candidates = merged_candidates_feature[['sess_id', 'sess_locale', 'product']]

In [8]:
item_counter = Counter()
cal_next_item_freq(item_counter, merged_train_sessions_df, test=False)

100%|██████████| 3966659/3966659 [00:08<00:00, 446806.03it/s]


In [9]:
item_counter.most_common(10)

[('B09QFPZ9B7', 224),
 ('B07QPV9Z7X', 212),
 ('B00NTCH52W', 204),
 ('B00CWNMV4G', 203),
 ('B07N8QY3YH', 183),
 ('B014I8SSD0', 177),
 ('B09QFPYX34', 175),
 ('B019GNUT0C', 167),
 ('B099DP3617', 166),
 ('B09QFJNDQX', 164)]

In [25]:
len(item_counter.values())

920149

In [23]:
sum(item_counter.values())

3966659

In [9]:
# item_counter = Counter()
# cal_next_item_freq(item_counter, train_sessions_df, test=False)
# cal_item_freq(item_counter, valid_sessions_df, test=True)

100%|██████████| 3010900/3010900 [00:07<00:00, 388906.73it/s]


In [42]:
# item_counter.values()

dict_values([6, 7, 2, 19, 5, 27, 17, 7, 5, 1, 6, 9, 14, 7, 9, 2, 10, 16, 8, 3, 31, 26, 8, 1, 12, 2, 11, 3, 2, 4, 9, 8, 2, 6, 5, 10, 2, 6, 8, 1, 10, 2, 6, 7, 1, 11, 37, 15, 7, 18, 28, 3, 2, 18, 4, 2, 29, 3, 28, 1, 2, 6, 1, 27, 28, 4, 30, 3, 32, 4, 2, 3, 7, 17, 16, 3, 15, 6, 15, 3, 14, 2, 5, 20, 7, 7, 5, 2, 27, 18, 20, 23, 23, 10, 15, 5, 22, 2, 25, 4, 2, 2, 1, 7, 8, 26, 16, 1, 9, 27, 2, 9, 29, 5, 5, 28, 7, 63, 2, 8, 9, 2, 6, 2, 37, 1, 13, 5, 2, 16, 4, 39, 14, 26, 13, 2, 1, 21, 1, 3, 3, 4, 2, 2, 1, 13, 2, 6, 4, 3, 2, 1, 1, 54, 1, 4, 2, 9, 16, 53, 8, 24, 29, 3, 21, 25, 6, 4, 5, 17, 12, 2, 4, 4, 5, 10, 6, 8, 2, 1, 2, 2, 11, 6, 1, 14, 30, 7, 3, 9, 1, 12, 2, 3, 1, 8, 10, 6, 3, 12, 32, 1, 14, 9, 4, 1, 3, 1, 10, 6, 3, 1, 43, 6, 4, 5, 6, 26, 9, 25, 6, 28, 4, 2, 2, 2, 30, 13, 3, 20, 15, 4, 5, 8, 31, 2, 11, 4, 44, 7, 5, 36, 4, 4, 25, 2, 1, 3, 3, 4, 3, 3, 1, 1, 1, 76, 5, 2, 11, 4, 10, 22, 1, 6, 3, 2, 9, 2, 3, 2, 5, 43, 5, 28, 9, 32, 11, 1, 26, 15, 6, 27, 4, 29, 28, 2, 1, 28, 6, 16, 5, 5, 21, 26, 2,

In [24]:
# sum(item_counter.values())

3010900

In [43]:
# item_counter.most_common(10)

[('B014I8T0YQ', 88),
 ('B07QS4NMW6', 87),
 ('B01N75EALQ', 85),
 ('B09QFPZ9B7', 85),
 ('B003JKFEL8', 84),
 ('B00CWNMV4G', 84),
 ('B09YCMWPF5', 83),
 ('B07KXQX3S3', 83),
 ('B00MNV8E0C', 83),
 ('B014I8SSD0', 82)]

In [22]:
products, counts = zip(*item_counter.items())
next_freq_df = pd.DataFrame({'product' : products, 'next_freq' : counts})

In [12]:
# item_freq_df_g = cudf.from_pandas(item_freq_df)
# merged_candidates_feature_g = cudf.from_pandas(merged_candidates_feature)

In [13]:
# merged_candidates_freq_g = merged_candidates_feature_g.merge(item_freq_df_g, how='left', left_on=['product'], right_on=['product'])
# merged_candidates_freq_g = merged_candidates_freq_g.sort_values(by=['sess_id', 'product']).reset_index(drop=True)
# merged_candidates_freq_g['product_freq'] = merged_candidates_freq_g['product_freq'].fillna(0)
# cast_dtype(merged_candidates_freq_g)

In [23]:
merged_candidates_next_freq = merged_candidates.merge(next_freq_df, how='left', left_on=['product'], right_on=['product'])
merged_candidates_next_freq = merged_candidates_next_freq.sort_values(by=['sess_id', 'product']).reset_index(drop=True)
merged_candidates_next_freq['next_freq'] = merged_candidates_next_freq['next_freq'].fillna(0)

In [24]:
merged_candidates_next_freq['next_freq']

0            7.0
1           18.0
2            1.0
3           40.0
4           50.0
            ... 
78842194     5.0
78842195    18.0
78842196     0.0
78842197     7.0
78842198    10.0
Name: next_freq, Length: 78842199, dtype: float64

In [15]:
merged_candidates_next_freq['next_freq_']

0            9.0
1           26.0
2            1.0
3           43.0
4           79.0
            ... 
78842194     5.0
78842195    21.0
78842196     0.0
78842197    12.0
78842198    11.0
Name: next_freq_, Length: 78842199, dtype: float64

In [25]:
merged_candidates_next_freq.columns

Index(['sess_id', 'sess_locale', 'product', 'next_freq'], dtype='object')

In [26]:
merged_candidates_feature['next_freq'] = merged_candidates_next_freq['next_freq']

In [27]:
# merged_candidates_freq = merged_candidates_freq_g.to_pandas()
cast_dtype(merged_candidates_feature, ['next_freq'])
merged_candidates_feature.to_parquet(merged_candidates_feature_path, engine='pyarrow')

In [None]:
merged_candidates_next_freq

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,normalized_sasrec_scores_2,...,normalized_gru4rec_feat_scores_2,sasrec_duorec_score,normalized_sasrec_duorec_score,w2v_l1_score,w2v_l2_score,w2v_l3_score,normalized_w2v_l1_score,normalized_w2v_l2_score,normalized_w2v_l3_score,next_freq
0,0,DE,355165591X,0.0,43.256542,8.990000,2.230508,7.658405e-09,0.512931,1.377575e-09,...,2.858669e-12,6.290242,9.662357e-07,25.177464,23.846624,22.635153,7.111217e-16,1.469563e-07,8.380214e-08,9.0
1,0,DE,3833237058,0.0,43.256542,22.000000,9.605231,1.221631e-05,9.325538,9.255110e-06,...,4.320219e-09,11.603280,1.961129e-04,32.207531,24.611195,25.308212,8.036434e-13,3.156726e-07,1.213808e-06,18.0
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,0.714114,1.681035e-09,-0.115904,7.345399e-10,...,2.063835e-11,4.110237,1.092242e-07,19.747381,20.370945,19.463253,3.116658e-18,4.546946e-09,3.513310e-09,1.0
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,8.750996,5.199363e-06,8.507557,4.084482e-06,...,1.319138e-10,11.578343,1.912829e-04,25.640152,22.638163,20.257607,1.129502e-15,4.388943e-08,7.774991e-09,47.0
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,8.056712,2.596729e-06,5.898870,3.007453e-07,...,6.495671e-08,10.304344,5.350390e-05,33.229935,25.067163,20.753508,2.234024e-12,4.980370e-07,1.276636e-08,69.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,6.813615,1.076201e-03,7.203015,4.597607e-04,...,3.847656e-04,8.870004,9.947462e-04,25.790541,25.790541,25.963753,1.723611e-07,1.723611e-07,3.863373e-07,5.0
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,9.030836,9.881445e-03,10.123234,8.526421e-03,...,8.951382e-03,11.360728,1.200661e-02,37.195541,37.195541,37.079201,1.547277e-02,1.547277e-02,2.596238e-02,19.0
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,0.796892,2.623396e-06,1.711608,1.895152e-06,...,3.386215e-05,6.270709,7.393550e-05,18.659113,18.659113,17.611990,1.378160e-10,1.378160e-10,9.116796e-11,0.0
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,4.250781,8.296004e-05,6.447586,2.159998e-04,...,3.735876e-04,7.936618,3.911544e-04,28.308519,28.308519,26.345455,2.137881e-06,2.137881e-06,5.658977e-07,8.0


In [19]:
# del item_freq_df_g
# del merged_candidates_feature_g
# del merged_candidates_freq_g

In [28]:
merged_candidates_feature['next_freq']

0            7.0
1           18.0
2            1.0
3           40.0
4           50.0
            ... 
78842194     5.0
78842195    18.0
78842196     0.0
78842197     7.0
78842198    10.0
Name: next_freq, Length: 78842199, dtype: float32

In [23]:
merged_candidates_feature['next_freq_']

0            9.0
1           26.0
2            1.0
3           43.0
4           79.0
            ... 
78842194     5.0
78842195    21.0
78842196     0.0
78842197    12.0
78842198    11.0
Name: next_freq_, Length: 78842199, dtype: float32

In [None]:
merged_candidates_next_freq

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,normalized_sasrec_scores_2,...,normalized_gru4rec_feat_scores_2,sasrec_duorec_score,normalized_sasrec_duorec_score,w2v_l1_score,w2v_l2_score,w2v_l3_score,normalized_w2v_l1_score,normalized_w2v_l2_score,normalized_w2v_l3_score,next_freq
0,0,DE,355165591X,0.0,43.256542,8.990000,2.230508,7.658405e-09,0.512931,1.377575e-09,...,2.858669e-12,6.290242,9.662357e-07,25.177464,23.846624,22.635153,7.111217e-16,1.469563e-07,8.380214e-08,9.0
1,0,DE,3833237058,0.0,43.256542,22.000000,9.605231,1.221631e-05,9.325538,9.255110e-06,...,4.320219e-09,11.603280,1.961129e-04,32.207531,24.611195,25.308212,8.036434e-13,3.156726e-07,1.213808e-06,18.0
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,0.714114,1.681035e-09,-0.115904,7.345399e-10,...,2.063835e-11,4.110237,1.092242e-07,19.747381,20.370945,19.463253,3.116658e-18,4.546946e-09,3.513310e-09,1.0
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,8.750996,5.199363e-06,8.507557,4.084482e-06,...,1.319138e-10,11.578343,1.912829e-04,25.640152,22.638163,20.257607,1.129502e-15,4.388943e-08,7.774991e-09,47.0
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,8.056712,2.596729e-06,5.898870,3.007453e-07,...,6.495671e-08,10.304344,5.350390e-05,33.229935,25.067163,20.753508,2.234024e-12,4.980370e-07,1.276636e-08,69.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,6.813615,1.076201e-03,7.203015,4.597607e-04,...,3.847656e-04,8.870004,9.947462e-04,25.790541,25.790541,25.963753,1.723611e-07,1.723611e-07,3.863373e-07,5.0
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,9.030836,9.881445e-03,10.123234,8.526421e-03,...,8.951382e-03,11.360728,1.200661e-02,37.195541,37.195541,37.079201,1.547277e-02,1.547277e-02,2.596238e-02,19.0
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,0.796892,2.623396e-06,1.711608,1.895152e-06,...,3.386215e-05,6.270709,7.393550e-05,18.659113,18.659113,17.611990,1.378160e-10,1.378160e-10,9.116796e-11,0.0
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,4.250781,8.296004e-05,6.447586,2.159998e-04,...,3.735876e-04,7.936618,3.911544e-04,28.308519,28.308519,26.345455,2.137881e-06,2.137881e-06,5.658977e-07,8.0


In [None]:
merged_candidates_next_freq['next_freq']

0            9.0
1           18.0
2            1.0
3           47.0
4           69.0
            ... 
78842194     5.0
78842195    19.0
78842196     0.0
78842197     8.0
78842198    11.0
Name: next_freq, Length: 78842199, dtype: float32