# Public functions

In [38]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 
import torch
import copy 
from collections import Counter

In [60]:
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/all_task_1_valid_sessions.csv'
sasrec_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/kdd_cup_2023/2023-04-12-14-21-03.parquet'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/2023-04-12-23-41-31.parquet'
roberta_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/roberta/roberta_prediction_300.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/roberta/roberta_test_prediction_300.parquet'

In [61]:
@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_valid_candidates():
    return pd.read_parquet(sasrec_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_valid_candidates():
    return pd.read_parquet(roberta_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

In [43]:
def cut_candidates(candidates_df, cut_len):
    new_candidates_df = copy.deepcopy(candidates_df)
    new_candidates_df['next_item_prediction'] = candidates_df['next_item_prediction'].apply(lambda x : x[:cut_len])
    new_candidates_df['scores'] = candidates_df['scores'].apply(lambda x : x[:cut_len])
    print(len(new_candidates_df.iloc[0]['next_item_prediction']), len(new_candidates_df.iloc[0]['scores']))
    return new_candidates_df

In [79]:
valid_sessions = read_valid_sessions()
valid_sessions.head(2)

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP


In [18]:
roberta_valid_candidates = read_roberta_valid_candidates()
roberta_valid_candidates.head(2)

Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[266.0794982910156, 265.99432373046875, 265.99..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[268.8392639160156, 268.33990478515625, 268.14..."


In [62]:
roberta_valid_candidates_100 = copy.deepcopy(roberta_valid_candidates)
roberta_valid_candidates_100['next_item_prediction'] = roberta_valid_candidates['next_item_prediction'].apply(lambda x : x[:100])
roberta_valid_candidates_100['scores'] = roberta_valid_candidates['scores'].apply(lambda x : x[:100])
len(roberta_valid_candidates_100.iloc[0]['next_item_prediction']), len(roberta_valid_candidates_100.iloc[0]['scores'])

(100, 100)

In [5]:
def normalize_score(candidates_df):
    for i in tqdm(range(len(candidates_df))):
        scores = candidates_df.iloc[i]['scores'] - candidates_df.iloc[i]['scores'].max()
        scores = torch.nn.functional.softmax(torch.from_numpy(scores))
        candidates_df.at[i, 'scores'] = scores.numpy()

In [85]:
for i in tqdm(range(len(roberta_valid_candidates_100))):
    scores = roberta_valid_candidates_100.iloc[i]['scores'] - roberta_valid_candidates_100.iloc[i]['scores'].max()
    scores = torch.nn.functional.softmax(torch.from_numpy(scores))
    roberta_valid_candidates_100.at[i, 'scores'] = scores.numpy()
roberta_valid_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 361581/361581 [00:53<00:00, 6734.73it/s]


Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[0.010051003401355616, 0.010038628362893742, 0..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[0.010772150929588447, 0.010420460807982517, 0..."


In [96]:
for i in tqdm(range(len(sasrec_valid_candidates_100))):
    scores = sasrec_valid_candidates_100.iloc[i]['score'] - sasrec_valid_candidates_100.iloc[i]['score'].max()
    scores = torch.nn.functional.softmax(torch.from_numpy(scores))
    sasrec_valid_candidates_100.at[i, 'score'] = scores.numpy()
sasrec_valid_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 361581/361581 [00:55<00:00, 6470.64it/s]


Unnamed: 0,locale,candidates,sess_id,score
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...",0,"[0.2712954212156488, 0.16770452100600358, 0.16..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...",1,"[0.9104351964115006, 0.019139557601848656, 0.0..."


In [98]:
sasrec_valid_candidates_100.rename(columns={'score' : 'scores', 'candidates' : 'next_item_prediction'}, inplace=True)
sasrec_valid_candidates_100

Unnamed: 0,locale,next_item_prediction,sess_id,scores
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...",0,"[0.2712954212156488, 0.16770452100600358, 0.16..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...",1,"[0.9104351964115006, 0.019139557601848656, 0.0..."
2,UK,"[B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...",2,"[0.5189365862585792, 0.11578885764040067, 0.08..."
3,UK,"[0241572614, 1406392979, 024157563X, 024147681...",3,"[0.17436096931167835, 0.1062810279624873, 0.09..."
4,JP,"[B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...",4,"[0.9418506919421911, 0.03955500319937586, 0.00..."
...,...,...,...,...
361576,UK,"[B0050IG9DE, B0B7LVKNK8, B07ZCWWZSM, B00465F49...",361576,"[0.3216314191614029, 0.03237086799685645, 0.03..."
361577,JP,"[B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL...",361577,"[0.32195132928283726, 0.2734383066946782, 0.10..."
361578,DE,"[B0BC38GHB4, B07KLCY8NF, B00MXZEMBI, B07K6LTLW...",361578,"[0.5633314956366812, 0.27870872436215477, 0.06..."
361579,DE,"[B08RQR2NPB, B08RQDVX71, B08H8SYLMQ, B08H8TLK4...",361579,"[0.4668851399227642, 0.2816610932504468, 0.078..."


In [6]:
def merge_candidates(candidates_df_1, candidates_df_2, lam=5):
    new_candidates_list = []
    locale_list = []
    item_counter = Counter()
    for i in tqdm(range(len(candidates_df_1))):
        item_counter.clear()
        candidates_1 = candidates_df_1.iloc[i]['next_item_prediction']
        scores_1 = candidates_df_1.iloc[i]['scores']
        candidates_2 = candidates_df_2.iloc[i]['next_item_prediction']
        scores_2 = candidates_df_2.iloc[i]['scores']

        for can, score in zip(candidates_1, scores_1):
            item_counter[can] += score
        for can, score in zip(candidates_2, scores_2):
            item_counter[can] += lam * score
        
        new_candidates, _ = zip(*item_counter.most_common(100))

        new_candidates_list.append(list(new_candidates))
        locale_list.append(candidates_df_1.iloc[i]['locale'])

    merged_df = pd.DataFrame({'locale' : locale_list, 'next_item_prediction' : new_candidates_list})
    return merged_df

In [116]:
merged_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100)
merged_df.head(2)

100%|██████████| 361581/361581 [05:30<00:00, 1092.68it/s]


Unnamed: 0,locale,next_item_prediction
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B01MYUDYP7, B076PN1SK..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W..."


In [125]:
merged_10_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100, lam=10)
merged_10_df.head(2)

100%|██████████| 361581/361581 [05:32<00:00, 1085.94it/s]


Unnamed: 0,locale,next_item_prediction
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B01MYUDYP7, B076PN1SK..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W..."


In [13]:
merged_2_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100, lam=2)
merged_2_df.head(2)

NameError: name 'sasrec_valid_candidates_100' is not defined

In [126]:
len(merged_10_df.iloc[0]['next_item_prediction'])

100

In [118]:
len(merged_df.iloc[0]['next_item_prediction'])

100

In [117]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [121]:
cal_hit_and_mrr(valid_sessions['next_item'], merged_df['next_item_prediction'])

100%|██████████| 361581/361581 [00:17<00:00, 21150.92it/s]


(0.6287498513472777, 0.28694266710644833)

In [122]:
cal_hit_and_mrr(valid_sessions['next_item'], sasrec_valid_candidates_100['next_item_prediction'])

100%|██████████| 361581/361581 [00:19<00:00, 18773.39it/s]


(0.5756746067962641, 0.27497283950067)

In [127]:
cal_hit_and_mrr(valid_sessions['next_item'], merged_10_df['next_item_prediction'])

100%|██████████| 361581/361581 [00:16<00:00, 21836.06it/s]


(0.6228231018775876, 0.28704315354228643)

In [120]:
valid_sessions

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP
2,['B01BM9V6H8' 'B01MG55XDR' 'B07VYSSRL7'],B01M6625ME,UK
3,['B092ZG24S7' 'B09BNHWWZM' 'B08CB1WG5M' '17880...,0241558573,UK
4,['B0B6NY5RM8' 'B09BJGBBBR'],B09BJF6N8K,JP
...,...,...,...
361576,['B08HH6L4PB' 'B08L8N8HDR'],B00TS5UXGY,UK
361577,['B08X4L1KLZ' 'B09BBX1T4S' 'B09D76FT9D'],B09BCM5NL1,JP
361578,['B0098G6L3M' 'B00ELRLP3O' 'B00PLXGK82' 'B09GS...,B0BC38GHB4,DE
361579,['B07Q2CNLY3' 'B07Q2CNLY3' 'B07BR7DZWN' 'B07Q2...,B08H8SYLMQ,DE


# Cut and save candidates

In [None]:
sasrec_valid_candidates = read_sasrec_valid_candidates()
sasrec_valid_candidates.head(2), len(sasrec_valid_candidates)

In [None]:
sasrec_valid_candidates.rename(columns={'candidates' : 'next_item_prediction', 'score' : 'scores'}, inplace=True)
sasrec_valid_candidates.head(2)

In [None]:
sasrec_valid_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_valid_300_with_score.parquet')

In [None]:
sasrec_valid_candidates_150 = cut_candidates(sasrec_valid_candidates, cut_len=150)

In [None]:
sasrec_valid_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_valid_150_with_score.parquet')

In [62]:
sasrec_test_candidates = read_sasrec_test_candidates()
sasrec_test_candidates.head(2), len(sasrec_test_candidates)

Unnamed: 0,locale,next_item_prediction,score
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM...","[21.012548446655273, 19.94110107421875, 18.792..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ...","[24.912931442260742, 17.836830139160156, 17.03..."


In [66]:
print(len(sasrec_test_candidates.iloc[0]['next_item_prediction']))
sasrec_test_candidates.rename(columns={'score' : 'scores'}, inplace=True)

In [68]:
sasrec_test_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_test_300_with_score.parquet')

In [67]:
sasrec_test_candidates_150 = cut_candidates(sasrec_test_candidates, 150)
sasrec_test_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_test_150_with_score.parquet')

150 150


In [69]:
roberta_test_candidates = read_roberta_test_candidates()
roberta_test_candidates.head(2), len(roberta_test_candidates)

Unnamed: 0,next_item_prediction,scores
0,"[B08V12CT4C, B099NR3X6D, B099NS1XPG, B08DFMG6D...","[266.46697998046875, 266.4009094238281, 266.39..."
1,"[B00R9R5ND6, B004ZXMV4Q, B00R9RZ9ZS, B09P1XWJP...","[266.80084228515625, 266.1681823730469, 265.14..."


In [73]:
roberta_test_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_300_with_score.parquet')

In [70]:
roberta_test_candidates_150 = cut_candidates(roberta_test_candidates, 150)

150 150


In [72]:
roberta_test_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_150_with_score.parquet')

In [74]:
roberta_valid_candidates = read_roberta_valid_candidates()
roberta_valid_candidates.head(2), len(roberta_valid_candidates)

Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[266.0794982910156, 265.99432373046875, 265.99..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[268.8392639160156, 268.33990478515625, 268.14..."


In [78]:
roberta_valid_candidates_150 = cut_candidates(roberta_valid_candidates, 150)

150 150


In [79]:
roberta_valid_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_valid_150_with_score.parquet')

# test dataset

In [11]:
normalize_score(sasrec_test_candidates_100)
sasrec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:44<00:00, 7117.53it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM...","[0.6562827152580006, 0.2247849780236378, 0.071..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ...","[0.9985161592585081, 0.0008438074359781079, 0...."


In [20]:
normalize_score(roberta_test_candidates_100)
roberta_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:43<00:00, 7220.07it/s]


Unnamed: 0,next_item_prediction,scores
0,"[B08V12CT4C, B099NR3X6D, B099NS1XPG, B08DFMG6D...","[0.018416484674576083, 0.01723902334513837, 0...."
1,"[B00R9R5ND6, B004ZXMV4Q, B00R9RZ9ZS, B09P1XWJP...","[0.2799033881807676, 0.14867825211013236, 0.05..."


In [22]:
merged_2_5_df = merge_candidates(sasrec_test_candidates_100, roberta_test_candidates_100, lam=2.5)
merged_2_5_df.head(2)

100%|██████████| 316971/316971 [01:57<00:00, 2692.34it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ..."


In [23]:
merged_2_5_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_roberta.parquet')

In [26]:
merged_5_df = merge_candidates(sasrec_test_candidates_100, roberta_test_candidates_100, lam=5)
merged_5_df.head(2)

100%|██████████| 316971/316971 [01:54<00:00, 2773.06it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ..."


In [27]:
merged_5_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_roberta_5.parquet')

In [29]:
merged_10_df = merge_candidates(sasrec_test_candidates_100, roberta_test_candidates_100, lam=10)
merged_10_df.head(2)

100%|██████████| 316971/316971 [01:51<00:00, 2839.11it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ..."


In [30]:
merged_10_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_roberta_10.parquet')

In [36]:
roberta_test_candidates_100.iloc[30]['scores']

array([0.01404141, 0.01404141, 0.01304059, 0.01283646, 0.01220759,
       0.01192225, 0.01181792, 0.01180279, 0.01160704, 0.01154874,
       0.01152445, 0.01151777, 0.01136345, 0.01135167, 0.01114162,
       0.01109649, 0.01108871, 0.01108871, 0.01108871, 0.01108871,
       0.01108871, 0.01103065, 0.01098832, 0.01094916, 0.01079389,
       0.01077414, 0.01077349, 0.01076954, 0.01059126, 0.01059126,
       0.01054642, 0.01054642, 0.01035539, 0.01033802, 0.01031061,
       0.01030841, 0.01015664, 0.01005518, 0.01002118, 0.00984566,
       0.00967115, 0.00964521, 0.0095474 , 0.00953837, 0.00953139,
       0.0094922 , 0.00941029, 0.00935446, 0.00931372, 0.00930349,
       0.00927939, 0.00926977, 0.00924717, 0.00924547, 0.00923532,
       0.00923222, 0.00913971, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913

In [37]:
sasrec_valid_candidates['scores']

NameError: name 'sasrec_valid_candidates' is not defined