In [None]:
import os
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import datasets
from datasets import Dataset as TFDataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [None]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

In [None]:
def get_sessions(df: pd.DataFrame, test=False, list_item=False) -> list:
    
    all_item = []
    if 'next_item' in df and not test:
        if list_item:
            for i in trange(len(df)):
                all_item.append(np.concatenate([np.array(df.loc[i, 'prev_items']), np.array(df.loc[i, 'next_item'])], axis=0))
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items'][:-1]+f" '{df.loc[i, 'next_item']}']").replace(" ", ",")))
    else:
        if list_item:
            all_item = df['prev_items']
        else:
            for i in trange(len(df)):
                all_item.append(eval((df.loc[i, 'prev_items']).replace(" ", ",")))
    return all_item

In [None]:
def get_co_occurence_dict(sessions: list, bidirection: bool=True, weighted: bool=False, max_dis=None) -> dict:
    res = {}
    for sess in tqdm(sessions):
        for i, id in enumerate(sess):
            if id not in res:
                res[id] = Counter()
            
            if max_dis == None:
                e = len(sess)
            else:
                e = min(i + max_dis + 1, len(sess))

            for j in range(i+1, e):
                if not weighted:
                    res[id][sess[j]] += 1
                else:
                    res[id][sess[j]] += 1 / (j-i)
                if bidirection:
                    if sess[j] not in res:
                        res[sess[j]] = Counter()
                    if not weighted:
                        res[sess[j]][id] += 1
                    else:
                        res[sess[j]][id] += 1 / (j-i)
    return res

In [None]:
def merge_candidates(session_df, candidates_df_list, test=False):
    sess_id_list = []
    sess_locale_list = []
    product_list = []
    target_list = []
    for i in tqdm(range(session_df.shape[0])):
        sess_id = i
        sess_locale = session_df.iloc[i]['locale']
        if not test:
            sess_next_item = session_df.iloc[i]['next_item']

        candidates_set_list = [set(x.iloc[i]['candidates']) for x in candidates_df_list]
        candidates_set = candidates_set_list[0]
        for x in candidates_set_list[1:]:
            candidates_set = candidates_set.union(x)
        cur_product_list = list(candidates_set)
        cur_sess_id_list = [sess_id for _ in range(len(cur_product_list))]
        cur_sess_locale_list = [sess_locale for _ in range(len(cur_product_list))]
        if not test:
            cur_target_list = (np.array(cur_product_list) == sess_next_item).astype(np.float32).tolist()

        for j in range(len(cur_product_list)):
            sess_id_list.append(cur_sess_id_list[j])
            sess_locale_list.append(cur_sess_locale_list[j])
            product_list.append(cur_product_list[j])
            if not test:
                target_list.append(cur_target_list[j])

    df_dict = {'sess_id' : sess_id_list, 'sess_locale' : sess_locale_list, 'product' : product_list}
    if not test: 
        df_dict['target'] = target_list
    return pd.DataFrame(df_dict)

# Merge Test Candidates

In [None]:
train_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_train_sessions.csv'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/seperate_locale/SASRec_Next_04_27_20_07_test_100_with_score.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_100_with_score.parquet'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/products_train.csv'

In [None]:
@lru_cache(maxsize=1)
def read_train_sessions():
    return pd.read_csv(train_sessions_path)

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)


In [None]:
train_sessions = read_train_sessions()
valid_sessions = read_valid_sessions()
test_sessions = read_test_sessions()
sasrec_test_candidates = read_sasrec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
product_data = read_product_data()

In [None]:
sasrec_test_candidates

Unnamed: 0,locale,scores,candidates
0,DE,"[19.93616485595703, 19.43468475341797, 18.6949...","[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM..."
1,DE,"[22.813570022583008, 18.525903701782227, 18.01...","[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ..."
2,DE,"[14.396807670593262, 14.095083236694336, 13.71...","[B0B5QNFWJ1, B0BJF4KGCN, B099277D7Q, B0B5TFLBC..."
3,DE,"[18.627986907958984, 18.095304489135742, 16.81...","[395535086X, 3772476953, 3772477917, B0829LZFT..."
4,DE,"[19.423419952392578, 18.855445861816406, 18.45...","[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV..."
...,...,...,...
316966,UK,"[20.52749252319336, 15.89537239074707, 15.7825...","[B08X9L5RGD, B09G9YY2C9, B09MW64JGM, B07V5FL8G..."
316967,UK,"[17.313642501831055, 17.087045669555664, 16.69...","[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09L14HQF..."
316968,UK,"[22.398101806640625, 21.34939193725586, 19.468...","[B09HKZBNZH, B09HZSRJWW, B07PY1NG3X, B09HL141Q..."
316969,UK,"[16.046977996826172, 15.801373481750488, 15.09...","[B08FB464L7, B07TR5LQSL, B0BGDK1J1G, B00HEL380..."


In [None]:
roberta_test_candidates['sess_id'] = np.arange(roberta_test_candidates.shape[0])
roberta_test_candidates

Unnamed: 0,scores,candidates,sess_id
0,"[269.218017578125, 268.97796630859375, 268.910...","[B07TV22X9M, B08Q391KS3, B07TV364MZ, B01H1R0K6...",0
1,"[268.6519775390625, 267.3247985839844, 266.088...","[B004ZXMV4Q, B09P1XWJPS, B00R9RNWF2, B010MJNUZ...",1
2,"[269.5508117675781, 269.29168701171875, 269.25...","[B017MIKTKS, B017KCNNFY, B08G4KS3TJ, B0B8Z2RZ7...",2
3,"[268.4931640625, 268.36383056640625, 268.12060...","[B0BJYJQS7T, B08P27T2D1, B07PLZ3WZQ, B0BJBYY1T...",3
4,"[267.1663513183594, 267.1663513183594, 267.166...","[B09J8TTZ68, B09J8V9RQQ, B09J8TWRV3, B09J8VPTT...",4
...,...,...,...
316966,"[266.6226806640625, 266.5784606933594, 266.571...","[B09MMGQ1C9, B08XX5GR3F, B07KG8T5W8, B09JWJC2V...",316966
316967,"[268.3211364746094, 268.18267822265625, 268.13...","[B09Y532DXH, B09Y55H96M, B0851KN668, B09Y551L4...",316967
316968,"[269.5409851074219, 269.5409851074219, 269.446...","[B07PY1NG3X, B06XPKPJFB, B09HZSRJWW, B09HKZBNZ...",316968
316969,"[269.934814453125, 269.577880859375, 269.47412...","[B07QQZD49B, B07TR5LQSL, B08FB464L7, B08FFH7FD...",316969


## get co-graph candidates

In [None]:
train_sess_item = get_sessions(train_sessions, list_item=False)
valid_sess_item = get_sessions(valid_sessions, test=False, list_item=False)
test_sess_item = get_sessions(test_sessions, test=True, list_item=False)

100%|██████████| 3557898/3557898 [10:07<00:00, 5861.43it/s] 
100%|██████████| 361581/361581 [00:52<00:00, 6872.99it/s] 
100%|██████████| 316971/316971 [00:31<00:00, 9914.32it/s] 


In [None]:
co_occurence_dict_bi = get_co_occurence_dict(train_sess_item + valid_sess_item + test_sess_item, bidirection=True, weighted=False)

100%|██████████| 4236450/4236450 [02:59<00:00, 23616.68it/s]


In [None]:
test_co_graph_candidates_dataset = TFDataset.from_dict({'sess_id' : list(range(len(test_sessions)))})

In [None]:
def get_test_session_co_graph_candidates(sess_id_example):
    sess = test_sess_item[sess_id_example['sess_id']]
    prev_items = set()
    cand_counter = Counter()
    for item in sess:
        if item in co_occurence_dict_bi and item not in prev_items:
            cand_counter = cand_counter + co_occurence_dict_bi[item]
            prev_items.add(item) # one time for every item
    for item in sess:
        if item in cand_counter:
            cand_counter.pop(item) # remove history items 
    if len(cand_counter) > 0:
        candidates, _ = zip(*cand_counter.most_common(100))
    else:
        candidates = []
    return {'co_graph_candidates' : candidates}

In [None]:
# about 1 mins
datasets.set_progress_bar_enabled(False)
test_co_graph_candidates_dataset = test_co_graph_candidates_dataset.map(get_test_session_co_graph_candidates, num_proc=8, batched=False)
datasets.set_progress_bar_enabled(True)



In [None]:
test_co_graph_candidates = test_co_graph_candidates_dataset['co_graph_candidates']

In [None]:
co_graph_test_candidates = pd.DataFrame({'sess_id' : list(range(len(test_sessions))), 'candidates' : test_co_graph_candidates})
co_graph_test_candidates

Unnamed: 0,sess_id,candidates
0,0,"[B07JG9TFSB, B07TV22X9M, B07JDSHD4Z, B07TV364M..."
1,1,"[B004ZXMV4Q, B097HPKM63, B09M455RWT, B08NJP33W..."
2,2,"[B0B5QNFWJ1, B00C3TEAPM, B07YSRH5LG, B000KSLHN..."
3,3,"[B09KF2JR9F, B09KNCTB5F, B0829LZFT1, 377247791..."
4,4,"[B09J8SKX9G, B08HV7LW3Q, B09J8V9RQQ, B09J8VPTT..."
...,...,...
316966,316966,"[B01JPAJST4, B08X9L5RGD, B0B28ZPRQS, B09HZJT52..."
316967,316967,"[B0851KN668, B0989BHLSY, B09895QPQF, B09CPNS7X..."
316968,316968,"[B07PY1NG3X, B09HKZBNZH, B07Q1ZMNZQ, B09HZSRJW..."
316969,316969,"[B07QNRB5S6, B07QQZD49B, B07TR5LQSL, B082FT7L5..."


## merge candidates 

In [None]:
test_sessions

Unnamed: 0,prev_items,locale
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE
...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK
316969,['B01MCQMORK' 'B09JYZ325W'],UK


In [None]:
# make sure that history items are not in candidates set
assert 'B00R9RZ9ZS' not in co_graph_test_candidates.iloc[1]['candidates']
assert 'B00R9RZ9ZS' not in sasrec_test_candidates.iloc[1]['candidates']
assert 'B00R9R5ND6' not in roberta_test_candidates.iloc[2]['candidates']
assert 'B0BG2LZQSL' not in co_graph_test_candidates.iloc[316970]['candidates']
assert 'B09TN4MP6V' not in sasrec_test_candidates.iloc[316970]['candidates']
assert 'B0B8JX92YJ' not in roberta_test_candidates.iloc[316970]['candidates']

In [None]:
merged_candidates_df = merge_candidates(test_sessions, [sasrec_test_candidates, roberta_test_candidates, co_graph_test_candidates], test=True)

In [None]:
cast_dtype(merged_candidates_df)

In [None]:
merged_candidates_df.to_parquet('./candidates/merged_candidates_test_2.parquet', engine='pyarrow')

In [None]:
merged_candidates_df

Unnamed: 0,sess_id,sess_locale,product,target
0,0,UK,B0856JQ3WJ,0.0
1,0,UK,B07L3L4PQH,0.0
2,0,UK,B07L3LGRFH,0.0
3,0,UK,B08147T2RH,0.0
4,0,UK,B09TW5134P,0.0
...,...,...,...,...
87162772,361580,DE,B00816X8TU,0.0
87162773,361580,DE,B0187EEUHO,0.0
87162774,361580,DE,B084GPY76H,0.0
87162775,361580,DE,B07N8J3B19,0.0
