# Public functions

In [3]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 
import torch
import copy 
from collections import Counter

In [27]:
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/all_task_1_valid_sessions.csv'
sasrec_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/kdd_cup_2023/2023-04-12-14-21-03.parquet'
sasrec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/three_locale_prediction_0416_2120_with_score.parquet'
roberta_valid_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/roberta/roberta_prediction_300.parquet'
roberta_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/roberta/roberta_test_prediction_300.parquet'
gru4rec_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/GRU4Rec_Next/GRU_three_locale_prediction_0422_1925.parquet'
sasrec_feat_test_candidates_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next_Feat/kdd_cup_2023/SASRec_Feat_three_locale_prediction_0423_2156.parquet'

In [24]:
@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_sasrec_valid_candidates():
    return pd.read_parquet(sasrec_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_sasrec_test_candidates():
    return pd.read_parquet(sasrec_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_sasrec_feat_test_candidates():
    return pd.read_parquet(sasrec_feat_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_valid_candidates():
    return pd.read_parquet(roberta_valid_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_roberta_test_candidates():
    return pd.read_parquet(roberta_test_candidates_path, engine='pyarrow')

lru_cache(maxsize=1)
def read_gru4rec_test_candidates():
    return pd.read_parquet(gru4rec_test_candidates_path, engine='pyarrow')

In [6]:
def cut_candidates(candidates_df, cut_len):
    new_candidates_df = copy.deepcopy(candidates_df)
    new_candidates_df['next_item_prediction'] = candidates_df['next_item_prediction'].apply(lambda x : x[:cut_len])
    new_candidates_df['scores'] = candidates_df['scores'].apply(lambda x : x[:cut_len])
    print(len(new_candidates_df.iloc[0]['next_item_prediction']), len(new_candidates_df.iloc[0]['scores']))
    return new_candidates_df

In [7]:
def normalize_score(candidates_df):
    normalized_score = []
    for i in tqdm(range(len(candidates_df))):
        scores = candidates_df.iloc[i]['scores'] - candidates_df.iloc[i]['scores'].max()
        scores = torch.nn.functional.softmax(torch.from_numpy(scores))
        normalized_score.append(scores.numpy())
    candidates_df['scores'] = normalized_score

In [8]:
def merge_candidates(candidates_df_1, candidates_df_2, lam=5):
    new_candidates_list = []
    locale_list = []
    item_counter = Counter()
    for i in tqdm(range(len(candidates_df_1))):
        item_counter.clear()
        candidates_1 = candidates_df_1.iloc[i]['next_item_prediction']
        scores_1 = candidates_df_1.iloc[i]['scores']
        candidates_2 = candidates_df_2.iloc[i]['next_item_prediction']
        scores_2 = candidates_df_2.iloc[i]['scores']

        for can, score in zip(candidates_1, scores_1):
            item_counter[can] += score
        for can, score in zip(candidates_2, scores_2):
            item_counter[can] += lam * score
        
        new_candidates, _ = zip(*item_counter.most_common(100))

        new_candidates_list.append(list(new_candidates))
        locale_list.append(candidates_df_1.iloc[i]['locale'])

    merged_df = pd.DataFrame({'locale' : locale_list, 'next_item_prediction' : new_candidates_list})
    return merged_df

In [9]:
# merge multi predictions 
def merge_multi_predictions(pred_df_list : list[pd.DataFrame], weights : list):
    new_pred_df = copy.deepcopy(pred_df_list[0])
    new_predictions = []
    score_counter = Counter()
    for i in tqdm(range(pred_df_list[0].shape[0])):
        score_counter.clear()
        for pred_df, w in zip(pred_df_list, weights):
            for item, score in zip(pred_df.iloc[i]['next_item_prediction'], pred_df.iloc[i]['scores']):
                score_counter[item] += w * score 
        new_pred, _ = zip(*score_counter.most_common(100))
        new_predictions.append(list(new_pred))
    new_pred_df['next_item_prediction'] = new_predictions
    
    return new_pred_df

In [10]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [79]:
valid_sessions = read_valid_sessions()
valid_sessions.head(2)

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP


In [18]:
roberta_valid_candidates = read_roberta_valid_candidates()
roberta_valid_candidates.head(2)

Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[266.0794982910156, 265.99432373046875, 265.99..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[268.8392639160156, 268.33990478515625, 268.14..."


In [62]:
roberta_valid_candidates_100 = copy.deepcopy(roberta_valid_candidates)
roberta_valid_candidates_100['next_item_prediction'] = roberta_valid_candidates['next_item_prediction'].apply(lambda x : x[:100])
roberta_valid_candidates_100['scores'] = roberta_valid_candidates['scores'].apply(lambda x : x[:100])
len(roberta_valid_candidates_100.iloc[0]['next_item_prediction']), len(roberta_valid_candidates_100.iloc[0]['scores'])

(100, 100)

In [85]:
for i in tqdm(range(len(roberta_valid_candidates_100))):
    scores = roberta_valid_candidates_100.iloc[i]['scores'] - roberta_valid_candidates_100.iloc[i]['scores'].max()
    scores = torch.nn.functional.softmax(torch.from_numpy(scores))
    roberta_valid_candidates_100.at[i, 'scores'] = scores.numpy()
roberta_valid_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 361581/361581 [00:53<00:00, 6734.73it/s]


Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[0.010051003401355616, 0.010038628362893742, 0..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[0.010772150929588447, 0.010420460807982517, 0..."


In [96]:
for i in tqdm(range(len(sasrec_valid_candidates_100))):
    scores = sasrec_valid_candidates_100.iloc[i]['score'] - sasrec_valid_candidates_100.iloc[i]['score'].max()
    scores = torch.nn.functional.softmax(torch.from_numpy(scores))
    sasrec_valid_candidates_100.at[i, 'score'] = scores.numpy()
sasrec_valid_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 361581/361581 [00:55<00:00, 6470.64it/s]


Unnamed: 0,locale,candidates,sess_id,score
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...",0,"[0.2712954212156488, 0.16770452100600358, 0.16..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...",1,"[0.9104351964115006, 0.019139557601848656, 0.0..."


In [98]:
sasrec_valid_candidates_100.rename(columns={'score' : 'scores', 'candidates' : 'next_item_prediction'}, inplace=True)
sasrec_valid_candidates_100

Unnamed: 0,locale,next_item_prediction,sess_id,scores
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...",0,"[0.2712954212156488, 0.16770452100600358, 0.16..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...",1,"[0.9104351964115006, 0.019139557601848656, 0.0..."
2,UK,"[B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...",2,"[0.5189365862585792, 0.11578885764040067, 0.08..."
3,UK,"[0241572614, 1406392979, 024157563X, 024147681...",3,"[0.17436096931167835, 0.1062810279624873, 0.09..."
4,JP,"[B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...",4,"[0.9418506919421911, 0.03955500319937586, 0.00..."
...,...,...,...,...
361576,UK,"[B0050IG9DE, B0B7LVKNK8, B07ZCWWZSM, B00465F49...",361576,"[0.3216314191614029, 0.03237086799685645, 0.03..."
361577,JP,"[B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL...",361577,"[0.32195132928283726, 0.2734383066946782, 0.10..."
361578,DE,"[B0BC38GHB4, B07KLCY8NF, B00MXZEMBI, B07K6LTLW...",361578,"[0.5633314956366812, 0.27870872436215477, 0.06..."
361579,DE,"[B08RQR2NPB, B08RQDVX71, B08H8SYLMQ, B08H8TLK4...",361579,"[0.4668851399227642, 0.2816610932504468, 0.078..."


In [116]:
merged_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100)
merged_df.head(2)

100%|██████████| 361581/361581 [05:30<00:00, 1092.68it/s]


Unnamed: 0,locale,next_item_prediction
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B01MYUDYP7, B076PN1SK..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W..."


In [125]:
merged_10_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100, lam=10)
merged_10_df.head(2)

100%|██████████| 361581/361581 [05:32<00:00, 1085.94it/s]


Unnamed: 0,locale,next_item_prediction
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B01MYUDYP7, B076PN1SK..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W..."


In [13]:
merged_2_df = merge_candidates(sasrec_valid_candidates_100, roberta_valid_candidates_100, lam=2)
merged_2_df.head(2)

NameError: name 'sasrec_valid_candidates_100' is not defined

In [126]:
len(merged_10_df.iloc[0]['next_item_prediction'])

100

In [118]:
len(merged_df.iloc[0]['next_item_prediction'])

100

In [121]:
cal_hit_and_mrr(valid_sessions['next_item'], merged_df['next_item_prediction'])

100%|██████████| 361581/361581 [00:17<00:00, 21150.92it/s]


(0.6287498513472777, 0.28694266710644833)

In [122]:
cal_hit_and_mrr(valid_sessions['next_item'], sasrec_valid_candidates_100['next_item_prediction'])

100%|██████████| 361581/361581 [00:19<00:00, 18773.39it/s]


(0.5756746067962641, 0.27497283950067)

In [127]:
cal_hit_and_mrr(valid_sessions['next_item'], merged_10_df['next_item_prediction'])

100%|██████████| 361581/361581 [00:16<00:00, 21836.06it/s]


(0.6228231018775876, 0.28704315354228643)

In [120]:
valid_sessions

Unnamed: 0,prev_items,next_item,locale
0,['B09VSN9GLS' 'B09VSG9DCG' 'B0BJ5L1ZPH' 'B09VS...,B06XG1LZ6Z,UK
1,['B00390YWXE' 'B00390YWXE' 'B09WM9W6WQ'],B01MSUI4FE,JP
2,['B01BM9V6H8' 'B01MG55XDR' 'B07VYSSRL7'],B01M6625ME,UK
3,['B092ZG24S7' 'B09BNHWWZM' 'B08CB1WG5M' '17880...,0241558573,UK
4,['B0B6NY5RM8' 'B09BJGBBBR'],B09BJF6N8K,JP
...,...,...,...
361576,['B08HH6L4PB' 'B08L8N8HDR'],B00TS5UXGY,UK
361577,['B08X4L1KLZ' 'B09BBX1T4S' 'B09D76FT9D'],B09BCM5NL1,JP
361578,['B0098G6L3M' 'B00ELRLP3O' 'B00PLXGK82' 'B09GS...,B0BC38GHB4,DE
361579,['B07Q2CNLY3' 'B07Q2CNLY3' 'B07BR7DZWN' 'B07Q2...,B08H8SYLMQ,DE


# Cut and save candidates

In [None]:
sasrec_valid_candidates = read_sasrec_valid_candidates()
sasrec_valid_candidates.head(2), len(sasrec_valid_candidates)

In [None]:
sasrec_valid_candidates.rename(columns={'candidates' : 'next_item_prediction', 'score' : 'scores'}, inplace=True)
sasrec_valid_candidates.head(2)

In [None]:
sasrec_valid_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_valid_300_with_score.parquet')

In [None]:
sasrec_valid_candidates_150 = cut_candidates(sasrec_valid_candidates, cut_len=150)

In [None]:
sasrec_valid_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_valid_150_with_score.parquet')

In [62]:
sasrec_test_candidates = read_sasrec_test_candidates()
sasrec_test_candidates.head(2), len(sasrec_test_candidates)

Unnamed: 0,locale,next_item_prediction,score
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM...","[21.012548446655273, 19.94110107421875, 18.792..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ...","[24.912931442260742, 17.836830139160156, 17.03..."


In [66]:
print(len(sasrec_test_candidates.iloc[0]['next_item_prediction']))
sasrec_test_candidates.rename(columns={'score' : 'scores'}, inplace=True)

In [68]:
sasrec_test_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_test_300_with_score.parquet')

In [67]:
sasrec_test_candidates_150 = cut_candidates(sasrec_test_candidates, 150)
sasrec_test_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_test_150_with_score.parquet')

150 150


In [69]:
roberta_test_candidates = read_roberta_test_candidates()
roberta_test_candidates.head(2), len(roberta_test_candidates)

Unnamed: 0,next_item_prediction,scores
0,"[B08V12CT4C, B099NR3X6D, B099NS1XPG, B08DFMG6D...","[266.46697998046875, 266.4009094238281, 266.39..."
1,"[B00R9R5ND6, B004ZXMV4Q, B00R9RZ9ZS, B09P1XWJP...","[266.80084228515625, 266.1681823730469, 265.14..."


In [73]:
roberta_test_candidates.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_300_with_score.parquet')

In [70]:
roberta_test_candidates_150 = cut_candidates(roberta_test_candidates, 150)

150 150


In [72]:
roberta_test_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_150_with_score.parquet')

In [74]:
roberta_valid_candidates = read_roberta_valid_candidates()
roberta_valid_candidates.head(2), len(roberta_valid_candidates)

Unnamed: 0,next_item_prediction,scores
0,"[B084ZPL21L, B0B2JT2BND, B0B2JSKN7X, B08F7P8R6...","[266.0794982910156, 265.99432373046875, 265.99..."
1,"[B09WM9W6WQ, B09LCPT9DQ, B09MRYK5CV, B09LCRNQT...","[268.8392639160156, 268.33990478515625, 268.14..."


In [78]:
roberta_valid_candidates_150 = cut_candidates(roberta_valid_candidates, 150)

150 150


In [79]:
roberta_valid_candidates_150.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_valid_150_with_score.parquet')

# test dataset

In [19]:
sasrec_test_candidates = read_sasrec_test_candidates()
gru4rec_test_candidates = read_gru4rec_test_candidates()
len(sasrec_test_candidates), len(gru4rec_test_candidates)

(316971, 316971)

In [27]:
sasrec_test_candidates_100 = cut_candidates(sasrec_test_candidates, 100)
gru4rec_test_candidates_100 = cut_candidates(gru4rec_test_candidates, 100)

100 100
100 100


In [28]:
normalize_score(sasrec_test_candidates_100)
sasrec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:49<00:00, 6408.21it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[0.5168281936136279, 0.31300851750540704, 0.14..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[0.9656834652778398, 0.013265541175667032, 0.0..."


In [29]:
normalize_score(gru4rec_test_candidates_100)
gru4rec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:50<00:00, 6271.15it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B099NR3X6D, B07LG5T3V9, B086CJF45...","[0.6378334874849756, 0.20950075503198057, 0.09..."
1,DE,"[B004ZXMV4Q, B094QTYN3Q, B097HPKM63, B095TQTZX...","[0.9604668765992949, 0.00758641279840183, 0.00..."


In [30]:
gru4rec_test_candidates_100.head(10)

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B099NR3X6D, B07LG5T3V9, B086CJF45...","[0.6378334874849756, 0.20950075503198057, 0.09..."
1,DE,"[B004ZXMV4Q, B094QTYN3Q, B097HPKM63, B095TQTZX...","[0.9604668765992949, 0.00758641279840183, 0.00..."
2,DE,"[B0B5QNFWJ1, B099277D7Q, B0BJF4KGCN, B00C3TEAP...","[0.09298466166398507, 0.060870259701001704, 0...."
3,DE,"[395535086X, 3772476953, 3772477917, B0829LZFT...","[0.20584923570534397, 0.19317618014137544, 0.0..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV...","[0.2037403493859897, 0.17484534165480833, 0.07..."
5,DE,"[B09LD73G5Z, B0BHT6ZT5S, B0BHT7PKKX, B085ZSMK1...","[0.07069167096868216, 0.06507349711573505, 0.0..."
6,DE,"[B07BK9F7WB, B08QSHKW7V, B01M36DP4F, B09HCST7T...","[0.21888163977332273, 0.1449445389367416, 0.05..."
7,DE,"[B09T765VQ5, B09SPYK6BY, B0747WXNPT, B09WDNL3C...","[0.11643942978757407, 0.07144751563781769, 0.0..."
8,DE,"[B0B3BW9HV1, B0B3BMTSKT, B09JQWZPTS, B09JQZZRZ...","[0.2747262418880494, 0.10957367003731364, 0.07..."
9,DE,"[B06Y5MSGXC, B0BGHRSXR9, B07H2BZ5TP, B07HF7DLQ...","[0.10832743845124787, 0.08067419460462871, 0.0..."


In [31]:
sasrec_test_candidates_100.head(10)

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[0.5168281936136279, 0.31300851750540704, 0.14..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[0.9656834652778398, 0.013265541175667032, 0.0..."
2,DE,"[B0B5QNFWJ1, B0BJF4KGCN, B099277D7Q, B0B5TFLBC...","[0.09221292135188211, 0.06819531255747927, 0.0..."
3,DE,"[395535086X, 3772476953, 3772477917, B0829LZFT...","[0.3970605289476756, 0.23308572773878866, 0.06..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV...","[0.3992416904645631, 0.22623920832686648, 0.15..."
5,DE,"[B0BK5XDW6P, B0BHT7PKKX, B09LD6XRRJ, B0BHT6ZT5...","[0.5138915207150259, 0.3457729860463977, 0.048..."
6,DE,"[B07BK9F7WB, B09FLLC45X, B08QSHKW7V, B07C1J5WN...","[0.2556027254069821, 0.09568315775433477, 0.09..."
7,DE,"[B09SPYK6BY, B09WDNL3CM, B09T765VQ5, B08NPG71H...","[0.11194026108834021, 0.11026513181299256, 0.0..."
8,DE,"[B0B3BW9HV1, B0B3BMTSKT, B09JR3XTSX, B09JQZZRZ...","[0.21229931443172997, 0.1596694931394764, 0.11..."
9,DE,"[B06Y5MSGXC, B07HF7DLQ2, B07H2BZ5TP, B0BGHRSXR...","[0.20674877019612498, 0.14128828349307815, 0.0..."


In [32]:
merged_0_8_df = merge_candidates(sasrec_test_candidates_100, gru4rec_test_candidates_100, lam=0.8)
merged_0_8_df.head(2)

100%|██████████| 316971/316971 [02:42<00:00, 1948.92it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B094QTYN3..."


In [33]:
merged_0_8_df.head(20)

Unnamed: 0,locale,next_item_prediction
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B094QTYN3..."
2,DE,"[B0B5QNFWJ1, B0BJF4KGCN, B099277D7Q, B07XFMQ6F..."
3,DE,"[395535086X, 3772476953, 3772477917, B0829LZFT..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV..."
5,DE,"[B0BK5XDW6P, B0BHT7PKKX, B09LD6XRRJ, B0BHT6ZT5..."
6,DE,"[B07BK9F7WB, B08QSHKW7V, B09FLLC45X, B01M36DP4..."
7,DE,"[B09T765VQ5, B09SPYK6BY, B09WDNL3CM, B0747WXNP..."
8,DE,"[B0B3BW9HV1, B0B3BMTSKT, B09JR3XTSX, B09JQZZRZ..."
9,DE,"[B06Y5MSGXC, B07HF7DLQ2, B0BGHRSXR9, B07H2BZ5T..."


In [23]:
merged_0_8_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_gru4rec_0.8.parquet')

In [26]:
merged_5_df = merge_candidates(sasrec_test_candidates_100, roberta_test_candidates_100, lam=5)
merged_5_df.head(2)

100%|██████████| 316971/316971 [01:54<00:00, 2773.06it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ..."


In [27]:
merged_5_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_roberta_5.parquet')

In [29]:
merged_10_df = merge_candidates(sasrec_test_candidates_100, roberta_test_candidates_100, lam=10)
merged_10_df.head(2)

100%|██████████| 316971/316971 [01:51<00:00, 2839.11it/s]


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07LG5T3V9, B099NS1XPG, B099NR3X6D, B08QYYBTM..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B08BZCKDKQ, B010MJNUZ..."


In [30]:
merged_10_df.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/merge_sasrec_and_roberta_10.parquet')

In [36]:
roberta_test_candidates_100.iloc[30]['scores']

array([0.01404141, 0.01404141, 0.01304059, 0.01283646, 0.01220759,
       0.01192225, 0.01181792, 0.01180279, 0.01160704, 0.01154874,
       0.01152445, 0.01151777, 0.01136345, 0.01135167, 0.01114162,
       0.01109649, 0.01108871, 0.01108871, 0.01108871, 0.01108871,
       0.01108871, 0.01103065, 0.01098832, 0.01094916, 0.01079389,
       0.01077414, 0.01077349, 0.01076954, 0.01059126, 0.01059126,
       0.01054642, 0.01054642, 0.01035539, 0.01033802, 0.01031061,
       0.01030841, 0.01015664, 0.01005518, 0.01002118, 0.00984566,
       0.00967115, 0.00964521, 0.0095474 , 0.00953837, 0.00953139,
       0.0094922 , 0.00941029, 0.00935446, 0.00931372, 0.00930349,
       0.00927939, 0.00926977, 0.00924717, 0.00924547, 0.00923532,
       0.00923222, 0.00913971, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913887,
       0.00913887, 0.00913887, 0.00913887, 0.00913887, 0.00913

In [37]:
sasrec_valid_candidates['scores']

NameError: name 'sasrec_valid_candidates' is not defined

## merge sasrec, roberta and gru4rec

In [11]:
sasrec_test_candidates = read_sasrec_test_candidates()
gru4rec_test_candidates = read_gru4rec_test_candidates()
roberta_test_candidates = read_roberta_test_candidates()
len(sasrec_test_candidates), len(gru4rec_test_candidates), len(roberta_test_candidates)

(316971, 316971, 316971)

In [12]:
sasrec_test_candidates_100 = cut_candidates(sasrec_test_candidates, 100)
gru4rec_test_candidates_100 = cut_candidates(gru4rec_test_candidates, 100)
roberta_test_candidates_100 = cut_candidates(roberta_test_candidates, 100)

100 100
100 100
100 100


In [13]:
normalize_score(sasrec_test_candidates_100)
sasrec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:49<00:00, 6452.76it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[0.5168281936136279, 0.31300851750540704, 0.14..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[0.9656834652778398, 0.013265541175667032, 0.0..."


In [14]:
normalize_score(gru4rec_test_candidates_100)
gru4rec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:47<00:00, 6672.65it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B099NR3X6D, B07LG5T3V9, B086CJF45...","[0.6378334874849756, 0.20950075503198057, 0.09..."
1,DE,"[B004ZXMV4Q, B094QTYN3Q, B097HPKM63, B095TQTZX...","[0.9604668765992949, 0.00758641279840183, 0.00..."


In [15]:
normalize_score(roberta_test_candidates_100)
roberta_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:48<00:00, 6515.71it/s]


Unnamed: 0,next_item_prediction,scores
0,"[B08V12CT4C, B099NR3X6D, B099NS1XPG, B08DFMG6D...","[0.018416484674576083, 0.01723902334513837, 0...."
1,"[B00R9R5ND6, B004ZXMV4Q, B00R9RZ9ZS, B09P1XWJP...","[0.2799033881807676, 0.14867825211013236, 0.05..."


In [17]:
merged_predictions = merge_multi_predictions([sasrec_test_candidates_100, gru4rec_test_candidates_100, roberta_test_candidates_100], [1.0, 0.8, 5.0])

100%|██████████| 316971/316971 [03:33<00:00, 1485.27it/s]


In [21]:
merged_predictions.drop(columns=['scores'], inplace=True)

In [22]:
merged_predictions.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/sasrec_gru4rec_and_roberta_1_0-8_5.parquet')

In [23]:
merged_predictions['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

## merge sasrec, sasrec_feat and gru4rec

In [29]:
sasrec_test_candidates = read_sasrec_test_candidates()
sasrec_feat_test_candidates = read_sasrec_feat_test_candidates()
gru4rec_test_candidates = read_gru4rec_test_candidates()
len(sasrec_test_candidates), len(gru4rec_test_candidates), len(sasrec_feat_test_candidates)

(316971, 316971, 316971)

In [30]:
sasrec_test_candidates_100 = cut_candidates(sasrec_test_candidates, 100)
gru4rec_test_candidates_100 = cut_candidates(gru4rec_test_candidates, 100)
sasrec_feat_test_candidates_100 = cut_candidates(sasrec_feat_test_candidates, 100)

100 100
100 100
100 100


In [31]:
normalize_score(sasrec_test_candidates_100)
sasrec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:57<00:00, 5498.51it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[0.5168281936136279, 0.31300851750540704, 0.14..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[0.9656834652778398, 0.013265541175667032, 0.0..."


In [32]:
normalize_score(gru4rec_test_candidates_100)
gru4rec_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:59<00:00, 5329.86it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B099NR3X6D, B07LG5T3V9, B086CJF45...","[0.6378334874849756, 0.20950075503198057, 0.09..."
1,DE,"[B004ZXMV4Q, B094QTYN3Q, B097HPKM63, B095TQTZX...","[0.9604668765992949, 0.00758641279840183, 0.00..."


In [33]:
normalize_score(sasrec_feat_test_candidates_100)
sasrec_feat_test_candidates_100.head(2)

  scores = torch.nn.functional.softmax(torch.from_numpy(scores))
100%|██████████| 316971/316971 [00:48<00:00, 6584.76it/s]


Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B089FBHSJ...","[0.594087111717978, 0.285478666638591, 0.09025..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B001LEO21...","[0.9431500041614154, 0.01586097917802657, 0.01..."


In [34]:
sasrec_feat_test_candidates_100

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B089FBHSJ...","[0.594087111717978, 0.285478666638591, 0.09025..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B001LEO21...","[0.9431500041614154, 0.01586097917802657, 0.01..."
2,DE,"[B0B5QNFWJ1, B0BJF4KGCN, B0BF93M2Y8, B0B5TFLBC...","[0.07661601425065313, 0.057622333176242525, 0...."
3,DE,"[395535086X, 3772476953, B0829LZFT1, B09KNCTB5...","[0.3180099396063456, 0.2427390052636888, 0.152..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV...","[0.39815844695421876, 0.19126820612952602, 0.1..."
...,...,...,...
316966,UK,"[B08X9L5RGD, B09MW64JGM, B073FSR7WV, B09Y4HKGK...","[0.9788217617253151, 0.015382279607597867, 0.0..."
316967,UK,"[B0989BHLSY, B09CPNS7XV, B09895QPQF, B09CPP92Q...","[0.31026281623518676, 0.20198314922383795, 0.1..."
316968,UK,"[B09HKZBNZH, B09HZSRJWW, B09HL141QC, B09HSR3RF...","[0.5030406134834551, 0.2835089547346834, 0.069..."
316969,UK,"[B07TR5LQSL, B08FB464L7, B07L5YWCQ8, B0BGDK1J1...","[0.5176141894537712, 0.1690664293391734, 0.073..."


In [35]:
sasrec_test_candidates_100

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[0.5168281936136279, 0.31300851750540704, 0.14..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[0.9656834652778398, 0.013265541175667032, 0.0..."
2,DE,"[B0B5QNFWJ1, B0BJF4KGCN, B099277D7Q, B0B5TFLBC...","[0.09221292135188211, 0.06819531255747927, 0.0..."
3,DE,"[395535086X, 3772476953, 3772477917, B0829LZFT...","[0.3970605289476756, 0.23308572773878866, 0.06..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV...","[0.3992416904645631, 0.22623920832686648, 0.15..."
...,...,...,...
316966,UK,"[B08X9L5RGD, B09G9YY2C9, B09MW64JGM, B07V5FL8G...","[0.9786935497092883, 0.009526700519337919, 0.0..."
316967,UK,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09L14HQF...","[0.31671145283779745, 0.2524957163610817, 0.17..."
316968,UK,"[B09HKZBNZH, B09HZSRJWW, B07PY1NG3X, B09HL141Q...","[0.6364462118668077, 0.22300407368544484, 0.03..."
316969,UK,"[B08FB464L7, B07TR5LQSL, B0BGDK1J1G, B00HEL380...","[0.3396318652839738, 0.26567075176176835, 0.13..."


In [36]:
merged_predictions = merge_multi_predictions([sasrec_test_candidates_100, gru4rec_test_candidates_100, sasrec_feat_test_candidates_100], [1.0, 0.8, 0.8])

100%|██████████| 316971/316971 [03:22<00:00, 1564.89it/s]


In [37]:
merged_predictions.drop(columns=['scores'], inplace=True)

In [38]:
merged_predictions.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/merge_prediction/sasrec_gru4rec_and_sasrec_feat_1_0-8_0-8.parquet')

In [39]:
merged_predictions['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64