In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import datasets
from datasets import Dataset as TFDataset
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

In [4]:
def get_sessions(df: pd.DataFrame, test=False, list_item=False) -> list:
    
    all_item = []
    if 'next_item' in df and not test:
        if list_item:
            for i in trange(len(df)):
                all_item.append(np.concatenate([np.array(df.loc[i, 'prev_items']), np.array(df.loc[i, 'next_item'])], axis=0))
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items'][:-1]+f" '{df.loc[i, 'next_item']}']").replace(" ", ",")))
    else:
        if list_item:
            all_item = df['prev_items']
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items']).replace(" ", ",")))
    return all_item

In [5]:
def get_co_occurence_dict(sessions: list, bidirection: bool=True, weighted: bool=False, max_dis=None) -> dict:
    res = {}
    for sess in tqdm(sessions):
        for i, id in enumerate(sess):
            if id not in res:
                res[id] = Counter()
            
            if max_dis == None:
                e = len(sess)
            else:
                e = min(i + max_dis + 1, len(sess))

            for j in range(i+1, e):
                if not weighted:
                    res[id][sess[j]] += 1
                else:
                    res[id][sess[j]] += 1 / (j-i)
                if bidirection:
                    if sess[j] not in res:
                        res[sess[j]] = Counter()
                    if not weighted:
                        res[sess[j]][id] += 1
                    else:
                        res[sess[j]][id] += 1 / (j-i)
    return res

In [6]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

# Merge Validation Candidates

In [7]:
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions_phase2.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv'
sasrec_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates_phase2/SASRec/valid_top300_sasrec0531.parquet'
roberta_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates_phase2/roberta/phase2_task1_valid_300_filtered.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/products_train.csv'

In [8]:
@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_valid_candidates():
    return pd.read_parquet(sasrec_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_valid_candidates():
    return pd.read_parquet(roberta_valid_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

In [9]:
train_sessions = read_train_sessions()
valid_sessions = read_valid_sessions()
sasrec_valid_candidates = read_sasrec_valid_candidates()
roberta_valid_candidates = read_roberta_valid_candidates()
product_data = read_product_data()

In [10]:
sasrec_valid_candidates

Unnamed: 0,next_item_prediction,locale
0,"[B09BNV7PBF, B08G56GFCV, B09BNW36Z2, B09Y96L4C...",DE
1,"[B0BGJCMQ5H, B08L9CHW69, B08TLR9773, B0BHJKVMB...",DE
2,"[B0B77KWZ7B, B0B77KRD1Q, B0B77M9DVM, B0B77PGVM...",DE
3,"[B0B9FYSBC9, B093F6V982, B01G7PUCZE, B085L9FT7...",DE
4,"[B0B7XBPC69, B09FT8ZVR7, B07FSQ6BGV, B0B2W3P3L...",DE
...,...,...
261811,"[B0B256HNBW, B0B24SG4YR, B09213J1RM, B091ZZ6VR...",UK
261812,"[B09RN57T97, B09DLCDML9, B07RGVSD61, B078H22NB...",UK
261813,"[B07D37YF4W, B09TPS3DZ2, B07D395945, B082DGC1K...",UK
261814,"[B07C9LQJFF, B073Q4XMXS, B00L2ILKUI, B01M1HF9M...",UK


In [11]:
roberta_valid_candidates['sess_id'] = np.arange(roberta_valid_candidates.shape[0])
roberta_valid_candidates

Unnamed: 0,next_item_prediction,sess_id
0,"[B09BNV7PBF, B09BNXMFCT, B07W6Q951M, B09BNVXRZ...",0
1,"[B0B82B3NRZ, B0B25PXM6S, B00E383IRC, B0B829WLD...",1
2,"[B0B77NDDG1, B0B77KVDYG, B0B77PJJ8Z, B0B77PG2K...",2
3,"[B077PL86PP, B08NPHSB2F, B0B1DBR6T5, B01G7PQ7L...",3
4,"[B07SW3YYTQ, B0B3GGVMR7, B09G6MBTS7, B0B2W3P3L...",4
...,...,...
261811,"[B01DY36JF4, B0835DXMY6, B08CD8KPRG, B094VLF9Q...",261811
261812,"[B09G9GZTHD, B0BB1ZJMY2, B09G9J91CZ, B09WDDYCS...",261812
261813,"[B084LL85L6, B08Z9Y379G, B0B3JK6V8T, B09TPT35W...",261813
261814,"[B073Q4XMXS, B0076126L2, B07B5GXS45, B00AGG3MZ...",261814


## get co-graph candidates

In [None]:
train_sess_item = get_sessions(train_sessions, list_item=False)
valid_sess_item = get_sessions(valid_sessions, test=True, list_item=False)

In [None]:
# valid sessions are included in train sessions 
co_occurence_dict_bi = get_co_occurence_dict(train_sess_item, bidirection=True, weighted=False)

100%|██████████| 3010900/3010900 [01:48<00:00, 27653.33it/s]


In [None]:
valid_co_graph_candidates_dataset = TFDataset.from_dict({'sess_id' : list(range(len(valid_sessions)))})

In [None]:
def get_valid_session_co_graph_candidates(sess_id_example):
    sess = valid_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(100))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [None]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
valid_co_graph_candidates_dataset = valid_co_graph_candidates_dataset.map(get_valid_session_co_graph_candidates, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)



In [None]:
valid_co_graph_candidates = valid_co_graph_candidates_dataset['co_graph_candidates']

In [None]:
co_graph_valid_candidates = pd.DataFrame({'sess_id' : list(range(len(valid_sessions))), 'candidates' : valid_co_graph_candidates})
co_graph_valid_candidates

Unnamed: 0,sess_id,candidates
0,0,"[B09YQSGV9B, B09BNV7PBF, B09QFP8FNZ, B09BNSD7G..."
1,1,"[B07CZ4DLCP, B09C7BRP5Y, B08TLR9773, B09C6RTP2..."
2,2,"[B0B77KWZ7B, B0B77PGVM9, B0B77M9DVM, B0B77KRD1..."
3,3,"[B0B9FYSBC9, B09QLGKG6M, B09HN456ZN, B0B8VM7QK..."
4,4,"[B09ZXYVP4Q, B09WDC1S17, B09WDCD4FT, B09ZYC44S..."
...,...,...
261811,261811,"[B0B256HNBW, B094VJV3PS, B0B24SG4YR, B08CD8KPR..."
261812,261812,"[B09Y52VXK5, B09RN57T97, B0B2KQQWSG, B07RNLDLR..."
261813,261813,"[B07D395945, B07D38HKCS, B0B92VZXP4, B09TPR7B7..."
261814,261814,"[B084V9PP9B, B073Q4XMXS, B075DFHSW6, B07C9LQJF..."


In [None]:
def get_valid_session_co_graph_candidates_150(sess_id_example):
    sess = valid_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(150))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [None]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
valid_co_graph_candidates_150_dataset = valid_co_graph_candidates_dataset.map(get_valid_session_co_graph_candidates_150, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)

In [None]:
valid_co_graph_candidates_150 = valid_co_graph_candidates_150_dataset['co_graph_candidates']

In [None]:
co_graph_valid_candidates_150 = pd.DataFrame({'sess_id' : list(range(len(valid_sessions))), 'candidates' : valid_co_graph_candidates_150})

Unnamed: 0,sess_id,candidates
0,0,"[B09YQSGV9B, B09BNV7PBF, B09QFP8FNZ, B09BNSD7G..."
1,1,"[B07CZ4DLCP, B09C7BRP5Y, B08TLR9773, B09C6RTP2..."
2,2,"[B0B77KWZ7B, B0B77PGVM9, B0B77M9DVM, B0B77KRD1..."
3,3,"[B0B9FYSBC9, B09QLGKG6M, B09HN456ZN, B0B8VM7QK..."
4,4,"[B09ZXYVP4Q, B09WDC1S17, B09WDCD4FT, B09ZYC44S..."
...,...,...
261811,261811,"[B0B256HNBW, B094VJV3PS, B0B24SG4YR, B08CD8KPR..."
261812,261812,"[B09Y52VXK5, B09RN57T97, B0B2KQQWSG, B07RNLDLR..."
261813,261813,"[B07D395945, B07D38HKCS, B0B92VZXP4, B09TPR7B7..."
261814,261814,"[B084V9PP9B, B073Q4XMXS, B075DFHSW6, B07C9LQJF..."


In [None]:
len(co_graph_valid_candidates_150.iloc[60]['candidates'])

150

In [51]:
def get_valid_session_co_graph_candidates_200(sess_id_example):
    sess = valid_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(200))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [52]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
valid_co_graph_candidates_200_dataset = valid_co_graph_candidates_dataset.map(get_valid_session_co_graph_candidates_200, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)

In [53]:
valid_co_graph_candidates_200 = valid_co_graph_candidates_200_dataset['co_graph_candidates']

In [54]:
co_graph_valid_candidates_200 = pd.DataFrame({'sess_id' : list(range(len(valid_sessions))), 'candidates' : valid_co_graph_candidates_200})

In [58]:
len(co_graph_valid_candidates_200.iloc[100]['candidates'])

200

In [19]:
len(valid_co_graph_candidates)

261816

In [23]:
valid_sessions.iloc[1]['prev_items']

"['B0BD5JBWKW' 'B09MSFTHDV' 'B0B8291RBN' 'B092VQCBRV' 'B06Y5G959K']"

In [20]:
valid_co_graph_candidates[1][:10], len(valid_co_graph_candidates[1])

(['B07CZ4DLCP',
  'B09C7BRP5Y',
  'B08TLR9773',
  'B09C6RTP2S',
  'B09MSHV8LT',
  'B0BGJCMQ5H',
  'B0B7JY3GDW',
  'B0B829LLGR',
  'B0B4BJG9L4',
  'B0B8D4CWZ4'],
 100)

In [25]:
co_occurence_dict_bi['B06Y5G959K']['B09C7BRP5Y']

6

## merge candidates 

In [13]:
valid_sessions

Unnamed: 0,prev_items,next_item,locale
0,['B08WX53RJD' 'B09JWGF3X2' 'B09BNTMYR8' 'B0B1M...,B09BNV7PBF,DE
1,['B0BD5JBWKW' 'B09MSFTHDV' 'B0B8291RBN' 'B092V...,B07CHK4N8S,DE
2,['B0B77MCDQB' 'B0B77MCDQB'],B0B7J4842S,DE
3,['B0BDV9S9J7' 'B0B9RQM5WT' 'B0B1VY21F4' 'B0B9R...,B0B2Q5C6ZK,DE
4,['B09YSZPS3H' 'B09YSZPS3H' 'B09WDBDGBZ' 'B09YS...,B09ZXYVP4Q,DE
...,...,...,...
261811,['B08SJWM6L2' 'B08ZKP7Z4T' 'B08SJWM6L2' 'B08CD...,B0B63T163B,UK
261812,['B07Z1V3WGM' 'B07QQC48FX' 'B07W92D6W5' 'B07KP...,B09CYG11VV,UK
261813,['B082DGJQCL' 'B09JV3Q5FX' 'B082DGJQCL' 'B082D...,B09PBFBR4L,UK
261814,['B003KU6GAU' 'B074JDHYCF' 'B003KU6GAU'],B096RVWZCV,UK


In [None]:
# make sure that history items are not in candidates set
assert 'B07VYSSRL7' not in co_graph_valid_candidates.iloc[2]['candidates']
assert 'B01MG55XDR' not in sasrec_valid_candidates.iloc[2]['next_item_prediction']
assert 'B01BM9V6H8' not in roberta_valid_candidates.iloc[2]['next_item_prediction']

In [14]:
sasrec_valid_candidates_100 = copy.deepcopy(sasrec_valid_candidates)
sasrec_valid_candidates_100['candidates'] = sasrec_valid_candidates['next_item_prediction'].apply(lambda x : x[:100])

roberta_valid_candidates_100 = copy.deepcopy(roberta_valid_candidates)
roberta_valid_candidates_100['candidates'] = roberta_valid_candidates['next_item_prediction'].apply(lambda x : x[:100])

In [15]:
merged_candidates_df = merge_candidates(valid_sessions, [sasrec_valid_candidates_100, roberta_valid_candidates_100])

100%|██████████| 261816/261816 [02:51<00:00, 1522.57it/s]


In [16]:
# recall 
(merged_candidates_df['target'] == 1.0).sum() / (merged_candidates_df['sess_id'].max() + 1)

0.7320217251810431

In [32]:
merged_candidates_df = merge_candidates(valid_sessions, [sasrec_valid_candidates_100, roberta_valid_candidates_100, co_graph_valid_candidates])

100%|██████████| 261816/261816 [03:27<00:00, 1261.58it/s]


In [33]:
# recall 
(merged_candidates_df['target'] == 1.0).sum() / (merged_candidates_df['sess_id'].max() + 1)

0.7416697222476853

In [65]:
len(merged_candidates_df) / (merged_candidates_df['sess_id'].max() + 1), len(merged_candidates_df)

(204.81605020319614, 53624119)

In [17]:
sasrec_valid_candidates_150 = copy.deepcopy(sasrec_valid_candidates)
sasrec_valid_candidates_150['candidates'] = sasrec_valid_candidates['next_item_prediction'].apply(lambda x : x[:150])

roberta_valid_candidates_150 = copy.deepcopy(roberta_valid_candidates)
roberta_valid_candidates_150['candidates'] = roberta_valid_candidates['next_item_prediction'].apply(lambda x : x[:150])

In [18]:
merged_candidates_df_150 = merge_candidates(valid_sessions, [sasrec_valid_candidates_150, roberta_valid_candidates_150])
# recall 
(merged_candidates_df_150['target'] == 1.0).sum() / (merged_candidates_df_150['sess_id'].max() + 1)

100%|██████████| 261816/261816 [03:19<00:00, 1309.33it/s]


0.7530517615424573

In [49]:
merged_candidates_df_150 = merge_candidates(valid_sessions, [sasrec_valid_candidates_150, roberta_valid_candidates_150, co_graph_valid_candidates_150])

100%|██████████| 261816/261816 [04:06<00:00, 1063.62it/s]


In [50]:
# recall 
(merged_candidates_df_150['target'] == 1.0).sum() / (merged_candidates_df_150['sess_id'].max() + 1)

0.7613629419134048

In [66]:
len(merged_candidates_df_150) / (merged_candidates_df_150['sess_id'].max() + 1), len(merged_candidates_df_150)

(301.13590842423685, 78842199)

In [20]:
sasrec_valid_candidates_200 = copy.deepcopy(sasrec_valid_candidates)
sasrec_valid_candidates_200['candidates'] = sasrec_valid_candidates['next_item_prediction'].apply(lambda x : x[:200])

roberta_valid_candidates_200 = copy.deepcopy(roberta_valid_candidates)
roberta_valid_candidates_200['candidates'] = roberta_valid_candidates['next_item_prediction'].apply(lambda x : x[:200])

In [21]:
merged_candidates_df_200 = merge_candidates(valid_sessions, [sasrec_valid_candidates_200, roberta_valid_candidates_200])
# recall 
(merged_candidates_df_200['target'] == 1.0).sum() / (merged_candidates_df_200['sess_id'].max() + 1)

100%|██████████| 261816/261816 [03:54<00:00, 1115.78it/s]


0.7662022183518196

In [60]:
merged_candidates_df_200 = merge_candidates(valid_sessions, [sasrec_valid_candidates_200, roberta_valid_candidates_200, co_graph_valid_candidates_200])

100%|██████████| 261816/261816 [09:43<00:00, 448.84it/s]


In [61]:
# recall 
(merged_candidates_df_200['target'] == 1.0).sum() / (merged_candidates_df_200['sess_id'].max() + 1)

0.7733790142695633

In [67]:
len(merged_candidates_df_200) / (merged_candidates_df_200['sess_id'].max() + 1), len(merged_candidates_df_200)

(394.99121902404744, 103415021)

In [68]:
cast_dtype(merged_candidates_df_150)

In [70]:
merged_candidates_df_150.to_parquet('./candidates_phase2/merged_candidates_150.parquet', engine='pyarrow')

In [69]:
merged_candidates_df_150

Unnamed: 0,sess_id,sess_locale,product,target
0,0,DE,B09BNW4F85,0.0
1,0,DE,B01J41G4SC,0.0
2,0,DE,B07FP2KPWV,0.0
3,0,DE,B09QFM2945,0.0
4,0,DE,B09QG4M23V,0.0
...,...,...,...,...
78842194,261815,UK,B07CF56HFY,0.0
78842195,261815,UK,B002SPGQV2,0.0
78842196,261815,UK,B07YQFYH54,0.0
78842197,261815,UK,B09N3QGQ2M,0.0


In [71]:
cast_dtype(merged_candidates_df)
merged_candidates_df.to_parquet('./candidates_phase2/merged_candidates_100.parquet', engine='pyarrow')

In [72]:
merged_candidates_df

Unnamed: 0,sess_id,sess_locale,product,target
0,0,DE,B09L8GZZVX,0.0
1,0,DE,B09BNW4F85,0.0
2,0,DE,B01J41G4SC,0.0
3,0,DE,B07FNWMBML,0.0
4,0,DE,B07FP2KPWV,0.0
...,...,...,...,...
53624114,261815,UK,B0037P6GC6,0.0
53624115,261815,UK,B07YQFYH54,0.0
53624116,261815,UK,B009SKGKDO,0.0
53624117,261815,UK,B07XPHH7S1,0.0


# Merge Test Candidates

In [7]:
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions_phase2.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1_phase2.csv'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates_phase2/SASRec/test_top300_sasrec0531.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates_phase2/roberta/phase2_task1_test_300_filtered.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/products_train.csv'

In [8]:
@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)


In [9]:
train_sessions = read_train_sessions()
valid_sessions = read_valid_sessions()
test_sessions = read_test_sessions()
sasrec_test_candidates = read_sasrec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
product_data = read_product_data()

In [10]:
sasrec_test_candidates

Unnamed: 0,prev_items,locale,next_item_prediction
0,['B087VLP2RT' 'B09BRQSHYH' 'B099KW4ZLV'],DE,"[B07SDFLVKD, B091CK241X, B0BGC82WVW, B07SR4R8K..."
1,['B08XW4W667' 'B096VMCJYF' 'B096VMCJYF'],DE,"[B084CB7GX9, B004P4QFJM, B08HQWQ1SK, B004P4OF1..."
2,['B09Z4T2GJ3' 'B09Z3FBXMB' 'B0936K9LTJ' 'B09Z4...,DE,"[B09Z4PZQBF, B08LLF9M11, B09GPJ15GS, B07HFTJLR..."
3,['B07T6Y2HG7' 'B07T2NBLX9' 'B07Y1G5F3Y'],DE,"[B07Y1KLF25, B07T5XJW9G, B07QQZD49D, B07VSNX4G..."
4,['B0B2DRKZ6X' 'B0B2DRKZ6X' 'B0B2DRKZ6X'],DE,"[B0B2JY9THB, B08P94RML3, B08SXLWXH9, B08SHZHRQ..."
...,...,...,...
316967,['B078RJX3CC' 'B07GKM97YF'],UK,"[B07GKP2LCF, B07GKYSHB4, B006DDGCI2, B081VSV2F..."
316968,['B01LX5Y7RG' 'B00M35Y2J0' 'B0BFR9D1Y2' 'B09BB...,UK,"[B00M35Y326, B085C7TCTC, B08L5Z8GPL, B00NVMIO0..."
316969,['B09HGRXXTM' 'B08VDNCZT9'],UK,"[B08VDHH6QF, B08VDSL596, B08VD5DC5L, B07SWLWPX..."
316970,['B089CVQ2FS' 'B089CVQ2FS'],UK,"[B089CZWB4C, B08W2JJZBM, B08T1ZJYHV, B09WCQYGX..."


In [11]:
roberta_test_candidates['sess_id'] = np.arange(roberta_test_candidates.shape[0])
roberta_test_candidates

Unnamed: 0,next_item_prediction,sess_id
0,"[B07GJXLBJ2, B078Y36QFJ, B08JW5DR79, B01L7PQBL...",0
1,"[B079NW6DQ1, B0B2Z8DJYP, B079NVK7CZ, B079NWL3L...",1
2,"[B09Z4PZQBF, B09Z3D15JK, B08C5F8M7V, B08C5FLXW...",2
3,"[B08FQ2XLFV, B07T5XY2CJ, B07Y1KLF25, B07SZHHK5...",3
4,"[B0B2JY9THB, B08LKGLSP8, B09JC6YKQ5, B09G2GHZV...",4
...,...,...
316967,"[B07GKP2LCF, B00V6FIFZ0, B095HM921Q, B07GKYSHB...",316967
316968,"[B01NCJMULG, B07KTJXWHH, B07R2KJTVY, B00M35Y2M...",316968
316969,"[B08VDSL596, B08VDHH6QF, B09NCC7GH2, B0B3XX8P9...",316969
316970,"[B08W2JJZBM, B089CZWB4C, B089CJ7P7Z, B08T1ZJYH...",316970


## get co-graph candidates

In [12]:
train_sess_item = get_sessions(train_sessions, list_item=False)
valid_sess_item = get_sessions(valid_sessions, test=False, list_item=False)
test_sess_item = get_sessions(test_sessions, test=True, list_item=False)

100%|██████████| 3010900/3010900 [03:36<00:00, 13881.74it/s]
100%|██████████| 261816/261816 [00:23<00:00, 11000.89it/s]
100%|██████████| 316972/316972 [00:16<00:00, 19727.28it/s]


In [13]:
co_occurence_dict_bi = get_co_occurence_dict(train_sess_item, bidirection=True, weighted=False)

100%|██████████| 3010900/3010900 [02:15<00:00, 22220.44it/s]


In [14]:
test_co_graph_candidates_dataset = TFDataset.from_dict({'sess_id' : list(range(len(test_sessions)))})

In [15]:
def get_test_session_co_graph_candidates(sess_id_example):
    sess = test_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(150))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [16]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
test_co_graph_candidates_dataset_150 = test_co_graph_candidates_dataset.map(get_test_session_co_graph_candidates, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)



In [17]:
test_co_graph_candidates_150 = test_co_graph_candidates_dataset_150['co_graph_candidates']

In [18]:
co_graph_test_candidates_150 = pd.DataFrame({'sess_id' : list(range(len(test_sessions))), 'candidates' : test_co_graph_candidates_150})
co_graph_test_candidates_150

Unnamed: 0,sess_id,candidates
0,0,"[B07SDFLVKD, B08SRMPBRF, B091CK241X, B07SR4R8K..."
1,1,"[B004P4OF1C, B09YD8XV6M, B004P4QFJM, B084CB7GX..."
2,2,"[B09Z4PZQBF, B09GPJ15GS, B08LLF9M11, B07XFZWKX..."
3,3,"[B07QQZD49D, B07T5XY2CJ, B07VSNX4GG, B07WV6ZNQ..."
4,4,"[B08P94RML3, B0935DN1BN, B08R9PTZ5G, B07HKQRV8..."
...,...,...
316967,316967,"[B07GKP2LCF, B07GKYSHB4, B07B3PZK8V, B078RK3J1..."
316968,316968,"[B00NVMIO02, B0865RL4PH, B08TV24K42, B0989S5K3..."
316969,316969,"[B08VDSL596, B09HGSCL9Q, B08VDGMBGP, B08VDKWMR..."
316970,316970,"[B089CZWB4C, B09P12CMC2, B09LYL8SV8, B08PF7R4R..."


## merge candidates 

In [19]:
test_sessions

Unnamed: 0,prev_items,locale
0,['B087VLP2RT' 'B09BRQSHYH' 'B099KW4ZLV'],DE
1,['B08XW4W667' 'B096VMCJYF' 'B096VMCJYF'],DE
2,['B09Z4T2GJ3' 'B09Z3FBXMB' 'B0936K9LTJ' 'B09Z4...,DE
3,['B07T6Y2HG7' 'B07T2NBLX9' 'B07Y1G5F3Y'],DE
4,['B0B2DRKZ6X' 'B0B2DRKZ6X' 'B0B2DRKZ6X'],DE
...,...,...
316967,['B078RJX3CC' 'B07GKM97YF'],UK
316968,['B01LX5Y7RG' 'B00M35Y2J0' 'B0BFR9D1Y2' 'B09BB...,UK
316969,['B09HGRXXTM' 'B08VDNCZT9'],UK
316970,['B089CVQ2FS' 'B089CVQ2FS'],UK


In [21]:
# make sure that history items are not in candidates set
assert 'B099KW4ZLV' not in co_graph_test_candidates_150.iloc[0]['candidates']
assert 'B09BRQSHYH' not in sasrec_test_candidates.iloc[0]['next_item_prediction']
assert 'B087VLP2RT' not in roberta_test_candidates.iloc[0]['next_item_prediction']

In [23]:
sasrec_test_candidates_150 = copy.deepcopy(sasrec_test_candidates)
sasrec_test_candidates_150['candidates'] = sasrec_test_candidates['next_item_prediction'].apply(lambda x : x[:150])

roberta_test_candidates_150 = copy.deepcopy(roberta_test_candidates)
roberta_test_candidates_150['candidates'] = roberta_test_candidates['next_item_prediction'].apply(lambda x : x[:150])

In [24]:
merged_candidates_df = merge_candidates(test_sessions, [sasrec_test_candidates_150, roberta_test_candidates_150, co_graph_test_candidates_150], test=True)

100%|██████████| 316972/316972 [04:15<00:00, 1242.83it/s]


In [26]:
cast_dtype(merged_candidates_df)

In [None]:
merged_candidates_df.to_parquet('./candidates_phase2//merged_candidates_150_test.parquet', engine='pyarrow')

In [25]:
merged_candidates_df

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,B08PZ7QKD2
1,0,DE,B01MY13UKE
2,0,DE,B09CHFG3JB
3,0,DE,B09C92RJN7
4,0,DE,B094Z7S1JB
...,...,...,...
96556030,316971,UK,B07FC7VZ8G
96556031,316971,UK,B08P1T8R46
96556032,316971,UK,B07SBL5GZB
96556033,316971,UK,B09B6RB4H9
