In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache
from tqdm import tqdm, trange
from collections import Counter, defaultdict

In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def cal_item_freq(item_counter:Counter, session_df:pd.DataFrame, test=False):
    for i in tqdm(range(session_df.shape[0])):
        sess = session_df.iloc[i]
        prev_items = eval(sess['prev_items'].replace(' ', ','))
        for item in prev_items:
            item_counter[item] += 1
        if not test:
            next_item = sess['next_item'] 
            item_counter[next_item] += 1 

# Merge valid item frequency

In [4]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_feature.parquet'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv'
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions_phase2.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1_phase2.csv'

In [5]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

In [6]:
valid_sessions_df = read_valid_sessions()
train_sessions_df = read_train_sessions()

In [7]:
merged_candidates_feature = read_merged_candidates_feature()

In [8]:
item_counter = Counter()
cal_item_freq(item_counter, train_sessions_df, test=False)
# cal_item_freq(item_counter, valid_sessions_df, test=True)

100%|██████████| 3966659/3966659 [05:02<00:00, 13129.68it/s]


In [9]:
item_counter.most_common(10)

[('B0BD5MFPMF', 3161),
 ('B07QPV9Z7X', 2919),
 ('B01MXLEVR7', 2818),
 ('B08CN3G4N9', 2661),
 ('B0BDML9477', 2659),
 ('B08GWS298V', 2655),
 ('B07N8QY3YH', 2472),
 ('B0BD88WWQ8', 2412),
 ('B09VKWKZ16', 2407),
 ('B07RHT52HX', 2326)]

In [10]:
item_counter.most_common(10)

[('B0BD5MFPMF', 3161),
 ('B07QPV9Z7X', 2919),
 ('B01MXLEVR7', 2818),
 ('B08CN3G4N9', 2661),
 ('B0BDML9477', 2659),
 ('B08GWS298V', 2655),
 ('B07N8QY3YH', 2472),
 ('B0BD88WWQ8', 2412),
 ('B09VKWKZ16', 2407),
 ('B07RHT52HX', 2326)]

In [11]:
products, counts = zip(*item_counter.items())
item_freq_df = pd.DataFrame({'product' : products, 'product_freq' : counts})

In [12]:
# item_freq_df_g = cudf.from_pandas(item_freq_df)
# merged_candidates_feature_g = cudf.from_pandas(merged_candidates_feature)

In [13]:
# merged_candidates_freq_g = merged_candidates_feature_g.merge(item_freq_df_g, how='left', left_on=['product'], right_on=['product'])
# merged_candidates_freq_g = merged_candidates_freq_g.sort_values(by=['sess_id', 'product']).reset_index(drop=True)
# merged_candidates_freq_g['product_freq'] = merged_candidates_freq_g['product_freq'].fillna(0)
# cast_dtype(merged_candidates_freq_g)

In [14]:
merged_candidates_freq = merged_candidates_feature.merge(item_freq_df, how='left', left_on=['product'], right_on=['product'])
merged_candidates_freq = merged_candidates_freq.sort_values(by=['sess_id', 'product']).reset_index(drop=True)
merged_candidates_freq['product_freq'] = merged_candidates_freq['product_freq'].fillna(0)

In [15]:
merged_candidates_freq['product_freq']

0            63.0
1           128.0
2             7.0
3           197.0
4           639.0
            ...  
78842194      7.0
78842195     63.0
78842196      7.0
78842197     61.0
78842198     31.0
Name: product_freq, Length: 78842199, dtype: float64

In [16]:
merged_candidates_freq.columns

Index(['sess_id', 'sess_locale', 'product', 'target', 'sess_avg_price',
       'product_price', 'sasrec_scores_3', 'normalized_sasrec_scores_3',
       'sasrec_scores_2', 'normalized_sasrec_scores_2', 'seqmlp_scores',
       'normalized_seqmlp_scores', 'narm_scores', 'normalized_narm_scores',
       'gru4rec_scores_2', 'normalized_gru4rec_scores_2', 'title_BM25_scores',
       'desc_BM25_scores', 'normalized_all_items_co_graph_count_0',
       'all_items_co_graph_count_0', 'normalized_all_items_co_graph_count_1',
       'all_items_co_graph_count_1', 'normalized_all_items_co_graph_count_2',
       'all_items_co_graph_count_2', 'co_graph_counts_0',
       'normalized_co_graph_counts_0', 'co_graph_counts_1',
       'normalized_co_graph_counts_1', 'co_graph_counts_2',
       'normalized_co_graph_counts_2', 'cos_text_bert_scores',
       'text_bert_scores', 'normalized_text_bert_scores', 'roberta_scores',
       'normalized_roberta_scores', 'product_freq'],
      dtype='object')

In [17]:
# merged_candidates_freq = merged_candidates_freq_g.to_pandas()
cast_dtype(merged_candidates_freq)
merged_candidates_freq.to_parquet(merged_candidates_feature_path, engine='pyarrow')

In [18]:
merged_candidates_freq

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,normalized_sasrec_scores_2,...,co_graph_counts_1,normalized_co_graph_counts_1,co_graph_counts_2,normalized_co_graph_counts_2,cos_text_bert_scores,text_bert_scores,normalized_text_bert_scores,roberta_scores,normalized_roberta_scores,product_freq
0,0,DE,355165591X,0.0,43.256542,8.990000,2.230508,7.658405e-09,0.512931,1.377575e-09,...,0.000000,0.000000,0.0,0.000000,0.903757,378.286041,1.296655e-08,276.525787,7.975509e-07,63.0
1,0,DE,3833237058,0.0,43.256542,22.000000,9.605231,1.221631e-05,9.325538,9.255110e-06,...,0.090909,0.000610,0.0,0.000000,0.921604,387.624756,1.474268e-04,284.460052,2.226209e-03,128.0
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,0.714114,1.681035e-09,-0.115904,7.345399e-10,...,0.000000,0.000000,0.0,0.000000,0.901061,374.802551,3.980740e-10,278.039612,3.624132e-06,7.0
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,8.750996,5.199363e-06,8.507557,4.084482e-06,...,0.000000,0.000000,0.0,0.000000,0.927298,385.701782,2.154962e-05,285.239197,4.852260e-03,197.0
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,8.056712,2.596729e-06,5.898870,3.007453e-07,...,0.000000,0.000000,0.0,0.000000,0.930655,385.398499,1.591202e-05,284.763611,3.015780e-03,639.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,6.813615,1.076201e-03,7.203015,4.597607e-04,...,0.000000,0.000000,0.0,0.000000,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04,7.0
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,9.030836,9.881445e-03,10.123234,8.526421e-03,...,4.750000,0.013025,2.0,0.009091,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04,63.0
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,0.796892,2.623396e-06,1.711608,1.895152e-06,...,0.000000,0.000000,0.0,0.000000,0.953467,430.164368,2.486932e-08,283.306732,6.012531e-04,7.0
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,4.250781,8.296004e-05,6.447586,2.159998e-04,...,0.000000,0.000000,0.0,0.000000,0.961829,434.029083,1.186011e-06,273.954742,5.218429e-08,61.0


In [15]:
# del item_freq_df_g
# del merged_candidates_feature_g
# del merged_candidates_freq_g

In [19]:
merged_candidates_feature

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,normalized_sasrec_scores_2,...,normalized_co_graph_counts_0,co_graph_counts_1,normalized_co_graph_counts_1,co_graph_counts_2,normalized_co_graph_counts_2,cos_text_bert_scores,text_bert_scores,normalized_text_bert_scores,roberta_scores,normalized_roberta_scores
0,0,DE,355165591X,0.0,43.256542,8.990000,2.230508,7.658405e-09,0.512931,1.377575e-09,...,0.000000,0.000000,0.000000,0.0,0.000000,0.903757,378.286041,1.296655e-08,276.525787,7.975509e-07
1,0,DE,3833237058,0.0,43.256542,22.000000,9.605231,1.221631e-05,9.325538,9.255110e-06,...,0.003110,0.090909,0.000610,0.0,0.000000,0.921604,387.624756,1.474268e-04,284.460052,2.226209e-03
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,0.714114,1.681035e-09,-0.115904,7.345399e-10,...,0.000000,0.000000,0.000000,0.0,0.000000,0.901061,374.802551,3.980740e-10,278.039612,3.624132e-06
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,8.750996,5.199363e-06,8.507557,4.084482e-06,...,0.000000,0.000000,0.000000,0.0,0.000000,0.927298,385.701782,2.154962e-05,285.239197,4.852260e-03
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,8.056712,2.596729e-06,5.898870,3.007453e-07,...,0.000000,0.000000,0.000000,0.0,0.000000,0.930655,385.398499,1.591202e-05,284.763611,3.015780e-03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,6.813615,1.076201e-03,7.203015,4.597607e-04,...,0.000000,0.000000,0.000000,0.0,0.000000,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,9.030836,9.881445e-03,10.123234,8.526421e-03,...,0.016563,4.750000,0.013025,2.0,0.009091,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,0.796892,2.623396e-06,1.711608,1.895152e-06,...,0.000000,0.000000,0.000000,0.0,0.000000,0.953467,430.164368,2.486932e-08,283.306732,6.012531e-04
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,4.250781,8.296004e-05,6.447586,2.159998e-04,...,0.000000,0.000000,0.000000,0.0,0.000000,0.961829,434.029083,1.186011e-06,273.954742,5.218429e-08


In [17]:
merged_candidates_freq

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,product_freq
0,0,DE,355165591X,0.0,43.256542,8.990000,51.0
1,0,DE,3833237058,0.0,43.256542,22.000000,84.0
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,7.0
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,166.0
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,502.0
...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,7.0
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,53.0
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,7.0
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,38.0
