In [1]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 

In [2]:
recstudio_data_dir = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio'
raw_data_dir = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data'

In [3]:
sasrec_session_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/2023-04-01-14-13-51.parquet'
retromae_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/RetroMAE/examples/retriever/kdd_cup/valid_results/valid_prediction.parquet'

In [14]:
@lru_cache(maxsize=1)
def read_sasrec_session_candidates():
    return pd.read_parquet(sasrec_session_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_retromae_session_candidates():
    return pd.read_parquet(retromae_candidates_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(os.path.join(recstudio_data_dir, 'all_task_1_valid_sessions.csv'))

In [15]:
sasrec_session_candidates_df = read_sasrec_session_candidates()
sasrec_session_candidates_df.head(10), len(sasrec_session_candidates_df), len(sasrec_session_candidates_df.iloc[0]['candidates'])

(  locale                                         candidates sess_id
 0     UK  [B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...       0
 1     JP  [B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...       1
 2     UK  [B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...       2
 3     UK  [0241572614, 1406392979, 024157563X, 024147681...       3
 4     JP  [B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...       4
 5     DE  [B0BD48G63Q, B0BD3DGNT9, B0B685KHK6, B0B54PQKG...       5
 6     DE  [B07Q82LRDK, B082XSF8XF, B00H37AVY8, B01EGP3LV...       6
 7     UK  [B013SL2712, B07V5LWSN5, B00CL6353A, B07W6JP97...       7
 8     JP  [B0BFPGHSYX, B09ZTN81QH, B00B57A5IY, B00LE7TO0...       8
 9     DE  [B07MH3K3S8, B07KQHHYQC, B00Z0BTBEA, B07YLZ67Q...       9,
 361581,
 300)

In [16]:
sasrec_session_candidates_df = sasrec_session_candidates_df[sasrec_session_candidates_df['locale'] == 'UK']
len(sasrec_session_candidates_df)

130665

In [19]:
retromae_session_candidates_df = read_retromae_session_candidates()
retromae_session_candidates_df.head(10), len(retromae_session_candidates_df), len(retromae_session_candidates_df.iloc[0]['next_item_prediction'])

(  locale                               next_item_prediction
 0     UK  [B06XGD9VLV, B06XG1LZ6Z, B06XGDZVZR, B0BJQ95XW...
 1     UK  [B00JLH125A, B09GGGFKFR, B015DVC6G6, B017A0OY8...
 2     UK  [1406357383, 1406363073, B0BGY899Q5, B074G22VZ...
 3     UK  [B013SL2712, B07W6JP976, B07W6JJ253, B07W7LHVB...
 4     UK  [B0B7XCS3QT, B077P3D31Q, B098JJLSWX, B08PYWFGB...
 5     UK  [1529176913, 1399702742, B07VP5BKSW, B095J8R9Y...
 6     UK  [B00IG32NC6, B01JAQFAZE, B01JLTVHNY, B01JAQEW9...
 7     UK  [B09QKT69MR, B0B1JHRTQY, B08PPDTH9L, B0B9FHP56...
 8     UK  [B08YL9D3B1, B00MAIO9MA, B008MYJ57K, B00O9RYS7...
 9     UK  [B09PRQS53M, B09S11KJ9W, B09BPWHRYQ, B0918LLKV...,
 130665,
 300)

In [17]:
valid_sessions_df = read_valid_sessions()
valid_sessions_df = valid_sessions_df[valid_sessions_df['locale'] == 'UK']
len(valid_sessions_df)

130665

In [22]:
def co_currence(length):
    num_sessions = len(valid_sessions_df)
    co_currence_rates = []
    for i in tqdm(range(num_sessions)):
        co_graph_set = set(retromae_session_candidates_df.iloc[i]['next_item_prediction'][:length])
        sasrec_set = set(sasrec_session_candidates_df.iloc[i]['candidates'][:length])
        co_currence_set = co_graph_set.intersection(sasrec_set)
        co_currence_rates.append(len(co_currence_set) * 1.0 / length * 1.0)
    return np.array(co_currence_rates).mean()

In [23]:
co_currence_rate_300 = co_currence(300)
co_currence_rate_100 = co_currence(100)
co_currence_rate_10 = co_currence(10)
co_currence_rate_300, co_currence_rate_100, co_currence_rate_10

100%|██████████| 130665/130665 [00:24<00:00, 5361.93it/s]
100%|██████████| 130665/130665 [00:16<00:00, 7798.97it/s]
100%|██████████| 130665/130665 [00:12<00:00, 10662.42it/s]


(0.2960451026161047, 0.326827153407569, 0.2968859296674703)

In [10]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [35]:
# retromae hit ratio and mrr
retromae_hit_ratio, retromae_mrr = cal_hit_and_mrr(valid_sessions_df['next_item'], retromae_session_candidates_df['next_item_prediction'])
retromae_hit_ratio, retromae_mrr

100%|██████████| 130665/130665 [00:03<00:00, 32683.51it/s]


(0.7314047373053227, 0.24190486057082744)

In [18]:
# sasrec hit ratio and mrr
sasrec_hit_ratio, sasrec_mrr = cal_hit_and_mrr(valid_sessions_df['next_item'], sasrec_session_candidates_df['candidates'])
sasrec_hit_ratio, sasrec_mrr

100%|██████████| 130665/130665 [00:04<00:00, 26840.75it/s]


(0.6061454865495733, 0.26247991444563873)

In [28]:
# hit or
def cal_or_hit(ground_truth_list, candidates_list_1, candidates_list_2):
    hits = []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates_1 = candidates_list_1.iloc[i]
        candidates_2 = candidates_list_2.iloc[i]
        hit = 0.0
        if (ground_truth in candidates_1) or (ground_truth in candidates_2):
            hit = 1.0
        hits.append(hit)
    return np.array(hits).mean()

In [29]:
or_hit = cal_or_hit(valid_sessions_df['next_item'], retromae_session_candidates_df['next_item_prediction'], sasrec_session_candidates_df['candidates'])
or_hit

100%|██████████| 130665/130665 [00:06<00:00, 20107.39it/s]


0.7527417441548999

In [30]:
# hit concurrently
def cal_co_hit(ground_truth_list, candidates_list_1, candidates_list_2):
    hits = []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates_1 = candidates_list_1.iloc[i]
        candidates_2 = candidates_list_2.iloc[i]
        hit = 0.0
        if (ground_truth in candidates_1) and (ground_truth in candidates_2):
            hit = 1.0
        hits.append(hit)
    return np.array(hits).mean()

In [32]:
co_hit = cal_co_hit(valid_sessions_df['next_item'], retromae_session_candidates_df['next_item_prediction'], sasrec_session_candidates_df['candidates'])
co_hit

100%|██████████| 130665/130665 [00:08<00:00, 16092.94it/s]


0.5848084796999962