In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache
from tqdm import tqdm, trange
from collections import Counter, defaultdict



In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

# Merge Validation Candidates

In [11]:
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
sasrec_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_Next_04_26_15_26_valid_100_with_score.parquet'
roberta_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_valid_100_with_score.parquet'
co_graph_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/co_graph/co_graph_valid_100_with_normalized_score_2.parquet'

In [9]:
@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_valid_candidates():
    return pd.read_parquet(sasrec_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_valid_candidates():
    return pd.read_parquet(roberta_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_co_graph_valid_candidates():
    return pd.read_parquet(co_graph_valid_candidates_path, engine='pyarrow')

In [12]:
valid_sessions = read_valid_sessions()
sasrec_valid_candidates = read_sasrec_valid_candidates()
roberta_valid_candidates = read_roberta_valid_candidates()
co_graph_valid_candidates = read_co_graph_valid_candidates()

In [13]:
roberta_valid_candidates['sess_id'] = np.arange(roberta_valid_candidates.shape[0])
roberta_valid_candidates

Unnamed: 0,candidates,scores,sess_id
0,"[B096ZT4DK4, B09XCGNN6F, B09XCHDLYR, B06XGD9VL...","[268.6368103027344, 268.5672607421875, 268.567...",0
1,"[B08ZHJKF28, B09LCPT9DQ, B09MRYK5CV, B09G9YRX1...","[269.99102783203125, 269.4064636230469, 269.30...",1
2,"[B06ZYHMRST, B01017JLP6, B007CW985C, B0017RTPK...","[264.7541198730469, 264.58624267578125, 264.49...",2
3,"[B01M5EBAO4, B07TYFS1K7, B08BX71RLT, B0BGY899Q...","[267.71771240234375, 266.8636779785156, 266.69...",3
4,"[B0B6NXJMJJ, B0B6PF619D, B0B6P77ZRN, B0B6NRKVZ...","[269.4986572265625, 269.49267578125, 269.47979...",4
...,...,...,...
361576,"[B08F5D8T22, B079FV57RR, B085WCD1CF, B09C3Q4V1...","[269.1492919921875, 269.11639404296875, 266.56...",361576
361577,"[B0B96YTD1B, B0B96YKVJ6, B09XGPX16G, B0B973WT3...","[269.2782897949219, 269.2471923828125, 269.101...",361577
361578,"[B09SXRYYMM, B0B6D8KGPC, B0B6D9V5CG, B0BC38GHB...","[266.98553466796875, 266.8766784667969, 266.80...",361578
361579,"[B08RQDVX71, B08RQBT2D8, B08RQR2NPB, B09D3VG5N...","[268.53466796875, 267.6655578613281, 267.64810...",361579


## merge candidates 

In [17]:
valid_sessions

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP
2,['B01BM9V6H8' 'B01MG55XDR' 'B07VYSSRL7'],B01M6625ME,UK
3,['B092ZG24S7' 'B09BNHWWZM' 'B08CB1WG5M' '17880...,0241558573,UK
4,['B0B6NY5RM8' 'B09BJGBBBR'],B09BJF6N8K,JP
...,...,...,...
361576,['B08HH6L4PB' 'B08L8N8HDR'],B00TS5UXGY,UK
361577,['B08X4L1KLZ' 'B09BBX1T4S' 'B09D76FT9D'],B09BCM5NL1,JP
361578,['B0098G6L3M' 'B00ELRLP3O' 'B00PLXGK82' 'B09GS...,B0BC38GHB4,DE
361579,['B07Q2CNLY3' 'B07Q2CNLY3' 'B07BR7DZWN' 'B07Q2...,B08H8SYLMQ,DE


In [18]:
merged_candidates_df = merge_candidates(valid_sessions, [sasrec_valid_candidates, roberta_valid_candidates, co_graph_valid_candidates])

100%|██████████| 361581/361581 [02:34<00:00, 2346.95it/s]


In [19]:
cast_dtype(merged_candidates_df)

In [None]:
merged_candidates_df.to_parquet('./candidates/merged_candidates_2.parquet', engine='pyarrow')

# Merge Test Candidates

In [26]:
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_Next_04_27_20_07_test_100_with_score.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_100_with_score.parquet'
co_graph_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/co_graph/co_graph_test_100_with_normalized_score.parquet'

In [27]:
@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_co_graph_test_candidates():
    return pd.read_parquet(co_graph_test_candidates_path, engine='pyarrow')

In [36]:
test_sessions = read_test_sessions()
sasrec_test_candidates = read_sasrec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
co_graph_test_candidates = read_co_graph_test_candidates()

In [39]:
merged_candidates_df = merge_candidates(test_sessions, [sasrec_test_candidates, roberta_test_candidates, co_graph_test_candidates], test=True)

100%|██████████| 316971/316971 [01:36<00:00, 3298.02it/s]


In [41]:
cast_dtype(merged_candidates_df)

In [42]:
merged_candidates_df.to_parquet('./candidates/merged_candidates_test.parquet', engine='pyarrow')