In [49]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import datasets
from datasets import Dataset as TFDataset

In [50]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [51]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

In [52]:
def get_sessions(df: pd.DataFrame, test=False, list_item=False) -> list:
    
    all_item = []
    if 'next_item' in df and not test:
        if list_item:
            for i in trange(len(df)):
                all_item.append(np.concatenate([np.array(df.loc[i, 'prev_items']), np.array(df.loc[i, 'next_item'])], axis=0))
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items'][:-1]+f" '{df.loc[i, 'next_item']}']").replace(" ", ",")))
    else:
        if list_item:
            all_item = df['prev_items']
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items']).replace(" ", ",")))
    return all_item

In [53]:
def get_co_occurence_dict(sessions: list, bidirection: bool=True, weighted: bool=False, max_dis=None) -> dict:
    res = {}
    for sess in tqdm(sessions):
        for i, id in enumerate(sess):
            if id not in res:
                res[id] = Counter()
            
            if max_dis == None:
                e = len(sess)
            else:
                e = min(i + max_dis + 1, len(sess))

            for j in range(i+1, e):
                if not weighted:
                    res[id][sess[j]] += 1
                else:
                    res[id][sess[j]] += 1 / (j-i)
                if bidirection:
                    if sess[j] not in res:
                        res[sess[j]] = Counter()
                    if not weighted:
                        res[sess[j]][id] += 1
                    else:
                        res[sess[j]][id] += 1 / (j-i)
    return res

In [54]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

# Merge Validation Candidates

In [55]:
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
sasrec_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_Next_04_26_15_26_valid_100_with_score.parquet'
roberta_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_valid_100_with_score.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/products_train.csv'

In [56]:
@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_valid_candidates():
    return pd.read_parquet(sasrec_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_valid_candidates():
    return pd.read_parquet(roberta_valid_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

In [57]:
train_sessions = read_train_sessions()
valid_sessions = read_valid_sessions()
sasrec_valid_candidates = read_sasrec_valid_candidates()
roberta_valid_candidates = read_roberta_valid_candidates()
product_data = read_product_data()

In [58]:
sasrec_valid_candidates

Unnamed: 0,locale,candidates,sess_id,scores
0,UK,"[B06XGDZVZR, B06XG1LZ6Z, B06XGD9VLV, B01MYUDYP...",0,"[18.85309600830078, 18.762922286987305, 17.422..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B0797D6F3...",1,"[17.74812889099121, 14.109373092651367, 12.546..."
2,UK,"[B00L529BAC, B0B1VNZW6H, B09XBS6WCX, B07CFSC84...",2,"[16.962995529174805, 16.4238338470459, 16.2612..."
3,UK,"[024157563X, 1406392979, 0241457920, 024156343...",3,"[14.681435585021973, 14.633809089660645, 14.45..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P2PCM...",4,"[21.915809631347656, 20.597021102905273, 19.71..."
...,...,...,...,...
361576,UK,"[B08F5D8T22, B0050IG9DE, B00465F49W, B01A955L8...",361576,"[23.4932804107666, 20.638296127319336, 19.6963..."
361577,JP,"[B09BCM5NL1, B09B9V4PXC, B09XH1YGLL, B09XGRXXG...",361577,"[13.340217590332031, 12.540390014648438, 12.30..."
361578,DE,"[B0BC38GHB4, B00MXZEMBI, B078X1F2JD, B07KLCY8N...",361578,"[22.47096824645996, 19.17393684387207, 17.3051..."
361579,DE,"[B08RQR2NPB, B08RQDVX71, B07PY86YPJ, B08H8SYLM...",361579,"[19.952720642089844, 19.861900329589844, 19.27..."


In [59]:
roberta_valid_candidates['sess_id'] = np.arange(roberta_valid_candidates.shape[0])
roberta_valid_candidates

Unnamed: 0,candidates,scores,sess_id
0,"[B096ZT4DK4, B09XCGNN6F, B09XCHDLYR, B06XGD9VL...","[268.6368103027344, 268.5672607421875, 268.567...",0
1,"[B08ZHJKF28, B09LCPT9DQ, B09MRYK5CV, B09G9YRX1...","[269.99102783203125, 269.4064636230469, 269.30...",1
2,"[B06ZYHMRST, B01017JLP6, B007CW985C, B0017RTPK...","[264.7541198730469, 264.58624267578125, 264.49...",2
3,"[B01M5EBAO4, B07TYFS1K7, B08BX71RLT, B0BGY899Q...","[267.71771240234375, 266.8636779785156, 266.69...",3
4,"[B0B6NXJMJJ, B0B6PF619D, B0B6P77ZRN, B0B6NRKVZ...","[269.4986572265625, 269.49267578125, 269.47979...",4
...,...,...,...
361576,"[B08F5D8T22, B079FV57RR, B085WCD1CF, B09C3Q4V1...","[269.1492919921875, 269.11639404296875, 266.56...",361576
361577,"[B0B96YTD1B, B0B96YKVJ6, B09XGPX16G, B0B973WT3...","[269.2782897949219, 269.2471923828125, 269.101...",361577
361578,"[B09SXRYYMM, B0B6D8KGPC, B0B6D9V5CG, B0BC38GHB...","[266.98553466796875, 266.8766784667969, 266.80...",361578
361579,"[B08RQDVX71, B08RQBT2D8, B08RQR2NPB, B09D3VG5N...","[268.53466796875, 267.6655578613281, 267.64810...",361579


## get co-graph candidates

In [60]:
train_sess_item = get_sessions(train_sessions, list_item=False)
valid_sess_item = get_sessions(valid_sessions, test=True, list_item=False)

100%|██████████| 3557898/3557898 [01:49<00:00, 32539.04it/s]
100%|██████████| 361581/361581 [00:14<00:00, 25734.17it/s]


In [61]:
co_occurence_dict_bi = get_co_occurence_dict(train_sess_item + valid_sess_item, bidirection=True, weighted=False)

100%|██████████| 3919479/3919479 [01:20<00:00, 48638.98it/s]


In [87]:
valid_co_graph_candidates_dataset = TFDataset.from_dict({'sess_id' : list(range(len(valid_sessions)))})

In [88]:
def get_valid_session_co_graph_candidates(sess_id_example):
    sess = valid_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(100))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [89]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
valid_co_graph_candidates_dataset = valid_co_graph_candidates_dataset.map(get_valid_session_co_graph_candidates, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)

In [90]:
valid_co_graph_candidates = valid_co_graph_candidates_dataset['co_graph_candidates']

In [91]:
co_graph_valid_candidates = pd.DataFrame({'sess_id' : list(range(len(valid_sessions))), 'candidates' : valid_co_graph_candidates})
co_graph_valid_candidates

Unnamed: 0,sess_id,candidates
0,0,"[B06XG1LZ6Z, B06XGDZVZR, B06XGD9VLV, B076PN1SK..."
1,1,"[B09LCPT9DQ, B0BB5VQ1L8, B092D5HM5S, B01BEPNPF..."
2,2,"[B0B1VNZW6H, B08M1863YR, B003BQFX4S, B0B1VLKR1..."
3,3,"[B07G4MNXPZ, B09ZTGMS34, B09B7DGSDT, B097JK5BS..."
4,4,"[B0B6PF619D, B0B6P77ZRN, B0B6NRKVZF, B0B6NY4PN..."
...,...,...
361576,361576,"[B0050IG9DE, B00465F49W, B08F5D8T22, B09C23S3F..."
361577,361577,"[B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL..."
361578,361578,"[B00E4JO8MO, B00ELRLN3G, B01N9BZMWS, B079L5WQR..."
361579,361579,"[B08RQDVX71, B08H8TLK4F, B08RQR2NPB, B08H8SYLM..."


In [92]:
len(valid_co_graph_candidates)

361581

In [93]:
valid_co_graph_candidates[1][:10], len(valid_co_graph_candidates[1])

(['B09LCPT9DQ',
  'B0BB5VQ1L8',
  'B092D5HM5S',
  'B01BEPNPF6',
  'B0BFD9P93X',
  'B084T9C6WD',
  'B01H6MF6Z8',
  'B09MRYK5CV',
  'B08GZR18S3',
  'B08RYKYFKD'],
 100)

In [94]:
co_occurence_dict_bi['B09WM9W6WQ']['B09LCPT9DQ']

692

## merge candidates 

In [95]:
valid_sessions

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP
2,['B01BM9V6H8' 'B01MG55XDR' 'B07VYSSRL7'],B01M6625ME,UK
3,['B092ZG24S7' 'B09BNHWWZM' 'B08CB1WG5M' '17880...,0241558573,UK
4,['B0B6NY5RM8' 'B09BJGBBBR'],B09BJF6N8K,JP
...,...,...,...
361576,['B08HH6L4PB' 'B08L8N8HDR'],B00TS5UXGY,UK
361577,['B08X4L1KLZ' 'B09BBX1T4S' 'B09D76FT9D'],B09BCM5NL1,JP
361578,['B0098G6L3M' 'B00ELRLP3O' 'B00PLXGK82' 'B09GS...,B0BC38GHB4,DE
361579,['B07Q2CNLY3' 'B07Q2CNLY3' 'B07BR7DZWN' 'B07Q2...,B08H8SYLMQ,DE


In [96]:
# make sure that history items are not in candidates set
assert 'B07VYSSRL7' not in co_graph_valid_candidates.iloc[3]['candidates']
assert 'B07VYSSRL7' not in sasrec_valid_candidates.iloc[3]['candidates']
assert 'B07VYSSRL7' not in roberta_valid_candidates.iloc[3]['candidates']
assert 'B08L8N8HDR' not in co_graph_valid_candidates.iloc[361576]['candidates']
assert 'B08L8N8HDR' not in sasrec_valid_candidates.iloc[361576]['candidates']
assert 'B08L8N8HDR' not in roberta_valid_candidates.iloc[361576]['candidates']

In [97]:
merged_candidates_df = merge_candidates(valid_sessions, [sasrec_valid_candidates, roberta_valid_candidates, co_graph_valid_candidates])

100%|██████████| 361581/361581 [02:32<00:00, 2374.23it/s]


In [98]:
# recall 
(merged_candidates_df['target'] == 1.0).sum() / (merged_candidates_df['sess_id'].max() + 1)

0.7176898122412405

In [19]:
cast_dtype(merged_candidates_df)

In [48]:
merged_candidates_df.to_parquet('./candidates/merged_candidates_all_items.parquet', engine='pyarrow')

In [74]:
merged_candidates_df

Unnamed: 0,sess_id,sess_locale,product,target
0,0,UK,B0856JQ3WJ,0.0
1,0,UK,B07L3L4PQH,0.0
2,0,UK,B07L3LGRFH,0.0
3,0,UK,B08147T2RH,0.0
4,0,UK,B09TW5134P,0.0
...,...,...,...,...
87162772,361580,DE,B00816X8TU,0.0
87162773,361580,DE,B0187EEUHO,0.0
87162774,361580,DE,B084GPY76H,0.0
87162775,361580,DE,B07N8J3B19,0.0


# Merge Test Candidates

In [26]:
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_Next_04_27_20_07_test_100_with_score.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_100_with_score.parquet'
co_graph_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/co_graph/co_graph_test_100_with_normalized_score.parquet'

In [27]:
@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_co_graph_test_candidates():
    return pd.read_parquet(co_graph_test_candidates_path, engine='pyarrow')

In [36]:
test_sessions = read_test_sessions()
sasrec_test_candidates = read_sasrec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
co_graph_test_candidates = read_co_graph_test_candidates()

In [39]:
merged_candidates_df = merge_candidates(test_sessions, [sasrec_test_candidates, roberta_test_candidates, co_graph_test_candidates], test=True)

100%|██████████| 316971/316971 [01:36<00:00, 3298.02it/s]


In [41]:
cast_dtype(merged_candidates_df)

In [42]:
merged_candidates_df.to_parquet('./candidates/merged_candidates_test.parquet', engine='pyarrow')