In [1]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import datasets
from datasets import Dataset as TFDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [3]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

In [4]:
def get_sessions(df: pd.DataFrame, test=False, list_item=False) -> list:
    
    all_item = []
    if 'next_item' in df and not test:
        if list_item:
            for i in trange(len(df)):
                all_item.append(np.concatenate([np.array(df.loc[i, 'prev_items']), np.array(df.loc[i, 'next_item'])], axis=0))
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items'][:-1]+f" '{df.loc[i, 'next_item']}']").replace(" ", ",")))
    else:
        if list_item:
            all_item = df['prev_items']
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items']).replace(" ", ",")))
    return all_item

In [5]:
def get_co_occurence_dict(sessions: list, bidirection: bool=True, weighted: bool=False, max_dis=None) -> dict:
    res = {}
    for sess in tqdm(sessions):
        for i, id in enumerate(sess):
            if id not in res:
                res[id] = Counter()
            
            if max_dis == None:
                e = len(sess)
            else:
                e = min(i + max_dis + 1, len(sess))

            for j in range(i+1, e):
                if not weighted:
                    res[id][sess[j]] += 1
                else:
                    res[id][sess[j]] += 1 / (j-i)
                if bidirection:
                    if sess[j] not in res:
                        res[sess[j]] = Counter()
                    if not weighted:
                        res[sess[j]][id] += 1
                    else:
                        res[sess[j]][id] += 1 / (j-i)
    return res

In [6]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

# Merge Test Candidates

In [7]:
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data//sessions_test_task1.csv'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_three_layers_05_23_23_13_test_100_with_score.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/phase2_roberta_task1_test_100_with_score.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/products_train.csv'

In [8]:
@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)


In [9]:
train_sessions = read_train_sessions()
valid_sessions = read_valid_sessions()
test_sessions = read_test_sessions()
sasrec_test_candidates = read_sasrec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
product_data = read_product_data()

In [10]:
sasrec_test_candidates

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B091CK241X, B07SDFLVKD, B093X59B31, B0BGC82WV...","[19.219451904296875, 18.955141067504883, 18.11..."
1,DE,"[B004P4OF1C, B084CB7GX9, B08HQWQ1SK, B09YD8XV6...","[15.721356391906738, 15.346590995788574, 15.30..."
2,DE,"[B09Z4PZQBF, B09GPJ15GS, B07HFTJLR8, B08LLF9M1...","[16.841047286987305, 13.765119552612305, 13.53..."
3,DE,"[B07Y1KLF25, B07QQZD49D, B07VSNX4GG, B09GKJ9RR...","[17.946931838989258, 17.033920288085938, 15.95..."
4,DE,"[B0B2JY9THB, B08SXLWXH9, B0B93LNSDL, B08SHZHRQ...","[16.5403995513916, 15.69337272644043, 14.64745..."
...,...,...,...
316967,UK,"[B07GKP2LCF, B07GKYSHB4, B006DDGCI2, B016RAAUE...","[17.98622703552246, 16.594831466674805, 13.547..."
316968,UK,"[B00M35Y326, B08L5Z8GPL, B085C7TCTC, B00NVMIO0...","[20.247331619262695, 16.64974594116211, 16.398..."
316969,UK,"[B08VDHH6QF, B08VD5DC5L, B08VDSL596, B09HGSCL9...","[15.693093299865723, 14.321159362792969, 13.85..."
316970,UK,"[B089CZWB4C, B08T1ZJYHV, B09WCQYGX8, B08W2JJZB...","[18.1264591217041, 16.0655460357666, 15.565340..."


In [11]:
roberta_test_candidates['sess_id'] = np.arange(roberta_test_candidates.shape[0])
roberta_test_candidates

Unnamed: 0,next_item_prediction,scores,sess_id
0,"[B01L7PQBL8, B01MY13UKE, B07SMRJG6L, B078Y36QF...","[270.6937255859375, 270.25555419921875, 269.76...",0
1,"[B079NW6DQ1, B0064LE9EC, B09TLDV2K8, B09MSBWMC...","[269.1681213378906, 268.6677551269531, 268.475...",1
2,"[B09Z4PZQBF, B0BDMLKLPP, B0BDML16KP, B0BDML947...","[268.52374267578125, 267.9551086425781, 267.95...",2
3,"[B09F7Z5GP1, B0B5R1D2M6, B07VSNX4GG, B0B254VD5...","[267.2171325683594, 267.0771179199219, 267.024...",3
4,"[B09JC6YKQ5, B09G2GHZVD, B09HKQTVK8, B09B2NYJ3...","[267.6343994140625, 267.4753112792969, 267.285...",4
...,...,...,...
316967,"[B09SZGKTS8, B07GKP2LCF, B07GKYSHB4, B09ZPL7WX...","[269.01617431640625, 268.9991455078125, 268.93...",316967
316968,"[B07W5FJJKZ, B0865SPP6Q, B07F8T4SB8, B07KTJXWH...","[268.2914733886719, 268.2543640136719, 267.986...",316968
316969,"[B08VDHH6QF, B09HGVNLFK, B08VDGZ8K5, B09NCC7GH...","[269.6565246582031, 269.5868835449219, 269.309...",316969
316970,"[B09WCQYGX8, B089CJ7P7Z, B08W2JJZBM, B089CZWB4...","[268.09698486328125, 268.09698486328125, 268.0...",316970


## get co-graph candidates

In [12]:
train_sess_item = get_sessions(train_sessions, list_item=False)
valid_sess_item = get_sessions(valid_sessions, test=False, list_item=False)
test_sess_item = get_sessions(test_sessions, test=True, list_item=False)

100%|██████████| 3557898/3557898 [02:19<00:00, 25469.93it/s]
100%|██████████| 361581/361581 [00:15<00:00, 23776.35it/s]
100%|██████████| 316972/316972 [00:07<00:00, 41556.83it/s]


In [13]:
co_occurence_dict_bi = get_co_occurence_dict(train_sess_item + valid_sess_item + test_sess_item, bidirection=True, weighted=False)

100%|██████████| 4236451/4236451 [01:41<00:00, 41754.21it/s]


In [14]:
test_co_graph_candidates_dataset = TFDataset.from_dict({'sess_id' : list(range(len(test_sessions)))})

In [15]:
def get_test_session_co_graph_candidates(sess_id_example):
    sess = test_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(100))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [16]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
test_co_graph_candidates_dataset = test_co_graph_candidates_dataset.map(get_test_session_co_graph_candidates, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)



In [17]:
test_co_graph_candidates = test_co_graph_candidates_dataset['co_graph_candidates']

In [18]:
co_graph_test_candidates = pd.DataFrame({'sess_id' : list(range(len(test_sessions))), 'candidates' : test_co_graph_candidates})
co_graph_test_candidates

Unnamed: 0,sess_id,candidates
0,0,"[B07SDFLVKD, B091CK241X, B08SRMPBRF, B07SR4R8K..."
1,1,"[B004P4OF1C, B09YD8XV6M, B004P4QFJM, B084CB7GX..."
2,2,"[B09Z4PZQBF, B0BDML9477, B08GWS298V, B0B6RSH9R..."
3,3,"[B07Y1KLF25, B07QQZD49D, B07T5XY2CJ, B07QQZFZK..."
4,4,"[B08P94RML3, B08SXLWXH9, B08SHZHRQ7, B0935DN1B..."
...,...,...
316967,316967,"[B07GKP2LCF, B07GKYSHB4, B078RK3J12, B07B3PZK8..."
316968,316968,"[B00NVMIO02, B0865RL4PH, B08TV24K42, B0989S5K3..."
316969,316969,"[B09HGSCL9Q, B08VDSL596, B08VDGMBGP, B08VDKWMR..."
316970,316970,"[B089CZWB4C, B098T5NZT9, B09PQJQYVC, B09J8T6TT..."


## merge candidates 

In [19]:
test_sessions

Unnamed: 0,prev_items,locale
0,['B087VLP2RT' 'B09BRQSHYH' 'B099KW4ZLV'],DE
1,['B08XW4W667' 'B096VMCJYF' 'B096VMCJYF'],DE
2,['B09Z4T2GJ3' 'B09Z3FBXMB' 'B0936K9LTJ' 'B09Z4...,DE
3,['B07T6Y2HG7' 'B07T2NBLX9' 'B07Y1G5F3Y'],DE
4,['B0B2DRKZ6X' 'B0B2DRKZ6X' 'B0B2DRKZ6X'],DE
...,...,...
316967,['B078RJX3CC' 'B07GKM97YF'],UK
316968,['B01LX5Y7RG' 'B00M35Y2J0' 'B0BFR9D1Y2' 'B09BB...,UK
316969,['B09HGRXXTM' 'B08VDNCZT9'],UK
316970,['B089CVQ2FS' 'B089CVQ2FS'],UK


In [25]:
sasrec_test_candidates.rename(columns={'next_item_prediction' : 'candidates'}, inplace=True)
roberta_test_candidates.rename(columns={'next_item_prediction' : 'candidates'}, inplace=True)

In [26]:
roberta_test_candidates

Unnamed: 0,candidates,scores,sess_id
0,"[B01L7PQBL8, B01MY13UKE, B07SMRJG6L, B078Y36QF...","[270.6937255859375, 270.25555419921875, 269.76...",0
1,"[B079NW6DQ1, B0064LE9EC, B09TLDV2K8, B09MSBWMC...","[269.1681213378906, 268.6677551269531, 268.475...",1
2,"[B09Z4PZQBF, B0BDMLKLPP, B0BDML16KP, B0BDML947...","[268.52374267578125, 267.9551086425781, 267.95...",2
3,"[B09F7Z5GP1, B0B5R1D2M6, B07VSNX4GG, B0B254VD5...","[267.2171325683594, 267.0771179199219, 267.024...",3
4,"[B09JC6YKQ5, B09G2GHZVD, B09HKQTVK8, B09B2NYJ3...","[267.6343994140625, 267.4753112792969, 267.285...",4
...,...,...,...
316967,"[B09SZGKTS8, B07GKP2LCF, B07GKYSHB4, B09ZPL7WX...","[269.01617431640625, 268.9991455078125, 268.93...",316967
316968,"[B07W5FJJKZ, B0865SPP6Q, B07F8T4SB8, B07KTJXWH...","[268.2914733886719, 268.2543640136719, 267.986...",316968
316969,"[B08VDHH6QF, B09HGVNLFK, B08VDGZ8K5, B09NCC7GH...","[269.6565246582031, 269.5868835449219, 269.309...",316969
316970,"[B09WCQYGX8, B089CJ7P7Z, B08W2JJZBM, B089CZWB4...","[268.09698486328125, 268.09698486328125, 268.0...",316970


In [27]:
merged_candidates_df = merge_candidates(test_sessions, [sasrec_test_candidates, roberta_test_candidates, co_graph_test_candidates], test=True)

100%|██████████| 316972/316972 [02:15<00:00, 2346.40it/s]


In [31]:
cast_dtype(merged_candidates_df)

In [33]:
merged_candidates_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_test.parquet', engine='pyarrow')

In [32]:
merged_candidates_df

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,B0B51MHS12
1,0,DE,B09CHFG3JB
2,0,DE,B077YLDC5N
3,0,DE,B08SRMPBRF
4,0,DE,B08JW624NN
...,...,...,...
65960104,316971,UK,B07H4BPZZG
65960105,316971,UK,B08P517NW5
65960106,316971,UK,B08L7RSTS8
65960107,316971,UK,B07FC7VZ8G


In [29]:
merged_candidates_df.query("sess_id==0&product=='B099KW4ZLV'")

Unnamed: 0,sess_id,sess_locale,product


In [30]:
merged_candidates_df.query("sess_id==316967&product=='B07GKM97YF'")

Unnamed: 0,sess_id,sess_locale,product


In [34]:
len(merged_candidates_df) / 316972

208.09443420869982