In [3]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 

In [2]:
recstudio_data_dir = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio'
raw_data_dir = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data'

In [1]:
sasrec_session_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/2023-04-01-14-13-51.parquet'
roberta_candidates_300_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/text_method/valid_results_epoch_4/roberta_prediction_300_filtered.parquet'
roberta_candidates_150_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/text_method/valid_results_epoch_4/roberta_prediction_150_filtered.parquet'

In [4]:
@lru_cache(maxsize=1)
def read_sasrec_session_candidates():
    return pd.read_parquet(sasrec_session_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_roberta_session_candidates_300():
    return pd.read_parquet(roberta_candidates_300_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_roberta_session_candidates_150():
    return pd.read_parquet(roberta_candidates_150_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(os.path.join(recstudio_data_dir, 'all_task_1_valid_sessions.csv'))

In [5]:
sasrec_session_candidates_df = read_sasrec_session_candidates()
sasrec_session_candidates_df.head(10), len(sasrec_session_candidates_df), len(sasrec_session_candidates_df.iloc[0]['candidates'])

(  locale                                         candidates sess_id
 0     UK  [B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...       0
 1     JP  [B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...       1
 2     UK  [B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...       2
 3     UK  [0241572614, 1406392979, 024157563X, 024147681...       3
 4     JP  [B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...       4
 5     DE  [B0BD48G63Q, B0BD3DGNT9, B0B685KHK6, B0B54PQKG...       5
 6     DE  [B07Q82LRDK, B082XSF8XF, B00H37AVY8, B01EGP3LV...       6
 7     UK  [B013SL2712, B07V5LWSN5, B00CL6353A, B07W6JP97...       7
 8     JP  [B0BFPGHSYX, B09ZTN81QH, B00B57A5IY, B00LE7TO0...       8
 9     DE  [B07MH3K3S8, B07KQHHYQC, B00Z0BTBEA, B07YLZ67Q...       9,
 361581,
 300)

In [6]:
roberta_session_candidates_300_df = read_roberta_session_candidates_300()
roberta_session_candidates_300_df.head(10), len(roberta_session_candidates_300_df)

(                                next_item_prediction  \
 0  [B096ZT4DK4, B09XCGNN6F, B09XCHDLYR, B06XGD9VL...   
 1  [B08ZHJKF28, B09LCPT9DQ, B09MRYK5CV, B09G9YRX1...   
 2  [B06ZYHMRST, B01017JLP6, B007CW985C, B0017RTPK...   
 3  [B01M5EBAO4, B07TYFS1K7, B08BX71RLT, B0BGY899Q...   
 4  [B0B6NXJMJJ, B0B6PF619D, B0B6P77ZRN, B0B6NRKVZ...   
 5  [B0BD48G63Q, B0BD3DGNT9, B0BD3FYDPB, B09BJNXGP...   
 6  [B00GYSC172, B088NQ34PW, B06ZZ6ZCD6, B07CNV566...   
 7  [B07V3HG2YD, B0B8NKMT5S, B07XRFMG5B, B019O1R4D...   
 8  [B0BFPGHSYX, B005A0PEQ0, B082FHGR57, B0152C7KP...   
 9  [B096DSVNSN, B096D4C8HT, B07MH3K3S8, B00Z0BTBE...   
 
                                               scores  
 0  [268.6368103027344, 268.5672607421875, 268.567...  
 1  [269.99102783203125, 269.4064636230469, 269.30...  
 2  [264.7541198730469, 264.58624267578125, 264.49...  
 3  [267.71771240234375, 266.8636779785156, 266.69...  
 4  [269.4986572265625, 269.49267578125, 269.47979...  
 5  [268.35125732421875, 267.995635

In [7]:
roberta_session_candidates_150_df = read_roberta_session_candidates_150()
roberta_session_candidates_150_df.head(10), len(roberta_session_candidates_150_df), len(roberta_session_candidates_150_df.iloc[0]['next_item_prediction'])

(                                next_item_prediction  \
 0  [B096ZT4DK4, B09XCGNN6F, B09XCHDLYR, B06XGD9VL...   
 1  [B08ZHJKF28, B09LCPT9DQ, B09MRYK5CV, B09G9YRX1...   
 2  [B06ZYHMRST, B01017JLP6, B007CW985C, B0017RTPK...   
 3  [B01M5EBAO4, B07TYFS1K7, B08BX71RLT, B0BGY899Q...   
 4  [B0B6NXJMJJ, B0B6PF619D, B0B6P77ZRN, B0B6NRKVZ...   
 5  [B0BD48G63Q, B0BD3DGNT9, B0BD3FYDPB, B09BJNXGP...   
 6  [B00GYSC172, B088NQ34PW, B06ZZ6ZCD6, B07CNV566...   
 7  [B07V3HG2YD, B0B8NKMT5S, B07XRFMG5B, B019O1R4D...   
 8  [B0BFPGHSYX, B005A0PEQ0, B082FHGR57, B0152C7KP...   
 9  [B096DSVNSN, B096D4C8HT, B07MH3K3S8, B00Z0BTBE...   
 
                                               scores  
 0  [268.6368103027344, 268.5672607421875, 268.567...  
 1  [269.99102783203125, 269.4064636230469, 269.30...  
 2  [264.7541198730469, 264.58624267578125, 264.49...  
 3  [267.71771240234375, 266.8636779785156, 266.69...  
 4  [269.4986572265625, 269.49267578125, 269.47979...  
 5  [268.35125732421875, 267.995635

In [8]:
valid_sessions_df = read_valid_sessions()
len(valid_sessions_df)

361581

In [10]:
def co_currence(length):
    num_sessions = len(valid_sessions_df)
    co_currence_rates = []
    for i in tqdm(range(num_sessions)):
        co_graph_set = set(roberta_session_candidates_300_df.iloc[i]['next_item_prediction'][:length])
        sasrec_set = set(sasrec_session_candidates_df.iloc[i]['candidates'][:length])
        co_currence_set = co_graph_set.intersection(sasrec_set)
        co_currence_rates.append(len(co_currence_set) * 1.0 / length * 1.0)
    return np.array(co_currence_rates).mean()

In [11]:
co_currence_rate_300 = co_currence(300)
co_currence_rate_100 = co_currence(100)
co_currence_rate_10 = co_currence(10)
co_currence_rate_300, co_currence_rate_100, co_currence_rate_10

100%|██████████| 361581/361581 [01:00<00:00, 6009.64it/s]
100%|██████████| 361581/361581 [00:45<00:00, 7878.77it/s]
100%|██████████| 361581/361581 [00:33<00:00, 10718.18it/s]


(0.16374873126629994, 0.25498864708046054, 0.1887048821702468)

In [9]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [10]:
# retromae hit ratio and mrr
roberta_hit_ratio, roberta_mrr = cal_hit_and_mrr(valid_sessions_df['next_item'], roberta_session_candidates_300_df['next_item_prediction'])
roberta_hit_ratio, roberta_mrr

100%|██████████| 361581/361581 [00:13<00:00, 26763.77it/s]


(0.6903460082249897, 0.1774922584571136)

In [12]:
roberta_hit_ratio, roberta_mrr = cal_hit_and_mrr(valid_sessions_df['next_item'], roberta_session_candidates_150_df['next_item_prediction'])
roberta_hit_ratio, roberta_mrr

100%|██████████| 361581/361581 [00:10<00:00, 36047.67it/s]


(0.6325498297753477, 0.17721058083632896)

In [11]:
# sasrec hit ratio and mrr
sasrec_hit_ratio, sasrec_mrr = cal_hit_and_mrr(valid_sessions_df['next_item'], sasrec_session_candidates_df['candidates'])
sasrec_hit_ratio, sasrec_mrr

100%|██████████| 361581/361581 [00:14<00:00, 24144.01it/s]


(0.6216615364192256, 0.27526101444400336)

In [13]:
# hit or
def cal_or_hit(ground_truth_list, candidates_list_1, candidates_list_2):
    hits = []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates_1 = candidates_list_1.iloc[i]
        candidates_2 = candidates_list_2.iloc[i]
        hit = 0.0
        if (ground_truth in candidates_1) or (ground_truth in candidates_2):
            hit = 1.0
        hits.append(hit)
    return np.array(hits).mean()

In [14]:
or_hit = cal_or_hit(valid_sessions_df['next_item'], roberta_session_candidates_300_df['next_item_prediction'], sasrec_session_candidates_df['candidates'])
or_hit

100%|██████████| 361581/361581 [00:21<00:00, 17173.59it/s]


0.7468423396140838

In [15]:
or_hit = cal_or_hit(valid_sessions_df['next_item'], roberta_session_candidates_150_df['next_item_prediction'], sasrec_session_candidates_df['candidates'])
or_hit

100%|██████████| 361581/361581 [00:17<00:00, 20479.41it/s]


0.7286085275498436

In [17]:
# hit concurrently
def cal_co_hit(ground_truth_list, candidates_list_1, candidates_list_2):
    hits = []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates_1 = candidates_list_1.iloc[i]
        candidates_2 = candidates_list_2.iloc[i]
        hit = 0.0
        if (ground_truth in candidates_1) and (ground_truth in candidates_2):
            hit = 1.0
        hits.append(hit)
    return np.array(hits).mean()

In [18]:
co_hit = cal_co_hit(valid_sessions_df['next_item'], roberta_session_candidates_300_df['next_item_prediction'], sasrec_session_candidates_df['candidates'])
co_hit

100%|██████████| 361581/361581 [00:24<00:00, 14637.85it/s]


0.5651652050301316