In [1]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 
from copy import deepcopy

In [2]:
sasrec_DE_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next_Text/kdd_cup_2023_DE/2023-05-26-14-58-46.parquet'
sasrec_JP_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next_Text/kdd_cup_2023_JP/2023-05-26-17-56-57.parquet'
sasrec_UK_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next_Text/kdd_cup_2023_UK/2023-05-27-18-21-16.parquet'

In [3]:
@lru_cache(maxsize=1)
def read_all_valid_sessions():
    return pd.read_csv('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv')

@lru_cache(maxsize=1)
def read_sasrec_UK_locale_prediction():
    return pd.read_parquet(sasrec_UK_locale_prediction_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_DE_locale_prediction():
    return pd.read_parquet(sasrec_DE_locale_prediction_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_JP_locale_prediction():
    return pd.read_parquet(sasrec_JP_locale_prediction_path, engine='pyarrow')

# task1 validation

In [17]:
sasrec_valid_sessions = read_all_valid_sessions()

In [13]:
sasrec_UK_locale_prediction = read_sasrec_UK_locale_prediction()
len(sasrec_UK_locale_prediction), len(sasrec_UK_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [14]:
sasrec_DE_locale_prediction = read_sasrec_DE_locale_prediction()
len(sasrec_DE_locale_prediction), len(sasrec_DE_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [15]:
sasrec_JP_locale_prediction = read_sasrec_JP_locale_prediction()
len(sasrec_JP_locale_prediction), len(sasrec_JP_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [16]:
sasrec_JP_locale_prediction

Unnamed: 0,locale,candidates,sess_id,scores
0,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",0,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B0797D6F3...",1,"[17.74812889099121, 14.109373092651367, 12.546..."
2,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",2,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
3,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",3,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P2PCM...",4,"[21.915809631347656, 20.597021102905273, 19.71..."
...,...,...,...,...
361576,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361576,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
361577,JP,"[B09BCM5NL1, B09B9V4PXC, B09XH1YGLL, B09XGRXXG...",361577,"[13.340217590332031, 12.540390014648438, 12.30..."
361578,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361578,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
361579,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361579,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."


In [24]:
sasrec_merged_valid_prediction = deepcopy(sasrec_DE_locale_prediction)

In [25]:
sasrec_merged_valid_prediction[sasrec_valid_sessions['locale'] == 'JP'] = sasrec_JP_locale_prediction[sasrec_valid_sessions['locale'] == 'JP']

In [27]:
sasrec_merged_valid_prediction[sasrec_valid_sessions['locale'] == 'UK'] = sasrec_UK_locale_prediction[sasrec_valid_sessions['locale'] == 'UK']

In [29]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [31]:
cal_hit_and_mrr(sasrec_valid_sessions['next_item'], sasrec_merged_valid_prediction['candidates'])

100%|██████████| 361581/361581 [00:09<00:00, 36467.17it/s]


(0.6773447719874661, 0.30861870792730345)

In [34]:
sasrec_merged_valid_prediction['candidates'] = sasrec_merged_valid_prediction['candidates'].apply(lambda x : x[:100])
sasrec_merged_valid_prediction['scores'] = sasrec_merged_valid_prediction['scores'].apply(lambda x : x[:100])

In [36]:
sasrec_merged_valid_prediction['scores'].apply(len).describe()

count    361581.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: scores, dtype: float64

In [None]:
sasrec_merged_valid_prediction.to_parquet('../candidates/SASRec_Next/seperate_locale/SASRec_Next_04_26_15_26_valid_with_score.parquet')

# task 1 

In [4]:
sasrec_UK_locale_prediction = read_sasrec_UK_locale_prediction()
len(sasrec_UK_locale_prediction), len(sasrec_UK_locale_prediction.iloc[0]['next_item_prediction'])

(316972, 100)

In [5]:
sasrec_DE_locale_prediction = read_sasrec_DE_locale_prediction()
len(sasrec_DE_locale_prediction), len(sasrec_DE_locale_prediction.iloc[0]['next_item_prediction'])

(316972, 100)

In [6]:
sasrec_JP_locale_prediction = read_sasrec_JP_locale_prediction()
len(sasrec_JP_locale_prediction), len(sasrec_JP_locale_prediction.iloc[0]['next_item_prediction'])

(316972, 100)

In [7]:
len(sasrec_DE_locale_prediction[sasrec_DE_locale_prediction['locale'] == 'DE'])

104568

In [8]:
len(sasrec_JP_locale_prediction[sasrec_JP_locale_prediction['locale'] == 'JP'])

96467

In [9]:
len(sasrec_UK_locale_prediction[sasrec_UK_locale_prediction['locale'] == 'UK'])

115937

In [11]:
new_test_prediction = pd.concat([sasrec_DE_locale_prediction.iloc[ : 104568], sasrec_JP_locale_prediction.iloc[104568 : 104568 + 96467], sasrec_UK_locale_prediction.iloc[104568 + 96467: ]], axis=0)

In [12]:
def new_pred_validation(new_pred):
    assert len(new_pred) == len(sasrec_DE_locale_prediction)
    assert new_pred.iloc[104567]['locale'] == 'DE'
    assert new_pred.iloc[104568]['locale'] == 'JP'
    assert new_pred.iloc[104568 + 96467 - 1]['locale'] == 'JP'
    assert new_pred.iloc[104568 + 96467]['locale'] == 'UK'
    print("Nice!!!")

In [13]:
new_pred_validation(new_test_prediction)

Nice!!!


In [14]:
new_test_prediction['next_item_prediction'].apply(len).describe()

count    316972.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [15]:
new_test_prediction.drop(columns=['scores'], inplace=True)

KeyError: "['scores'] not found in axis"

In [16]:
new_test_prediction

Unnamed: 0,locale,next_item_prediction
0,DE,"[B091CK241X, B07SDFLVKD, B0BGC82WVW, B093X59B3..."
1,DE,"[B004P4OF1C, B084CB7GX9, B09YD8XV6M, B004P4QFJ..."
2,DE,"[B09Z4PZQBF, B09GPJ15GS, B09JP5LQ2W, B07HFTJLR..."
3,DE,"[B07Y1KLF25, B07QMSVYL8, B07YCFFC44, B07T5XY2C..."
4,DE,"[B0B2JY9THB, B08SXLWXH9, B08SHZHRQ7, B01MRXVY2..."
...,...,...
316967,UK,"[B07GKP2LCF, B07GKYSHB4, B016RAAUEM, B006DDGCI..."
316968,UK,"[B00M35Y326, B08L5Z8GPL, B06X92Z7R3, B085C7TCT..."
316969,UK,"[B08VDHH6QF, B08VDSL596, B08VD5DC5L, B07QK2SPP..."
316970,UK,"[B089CZWB4C, B08W2JJZBM, B09WCQYGX8, B08T1ZJYH..."


In [17]:
new_test_prediction['next_item_prediction'] = new_test_prediction['next_item_prediction'].apply(lambda x : x[:100])

In [None]:
new_test_prediction.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/seperate_locale/SASRec_Text_three_layers_05_27_20_52.parquet', engine='pyarrow')

# task 3

In [1]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 

In [11]:
taske_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/phase_2/sessions_test_task3.csv'
sasrec_DE_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_DE/2023-05-23-19-11-33.parquet'
sasrec_JP_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_JP/2023-05-23-19-12-49.parquet'
sasrec_UK_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_UK/2023-05-23-19-13-41.parquet'
sasrec_IT_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_IT/2023-05-23-19-41-17.parquet'
sasrec_ES_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_ES/2023-05-23-19-42-14.parquet'
sasrec_FR_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_FR/2023-05-23-19-43-18.parquet'

In [12]:
@lru_cache(maxsize=1)
def read_task3_test():
    return pd.read_csv(taske_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_UK_locale_prediction():
    return pd.read_parquet(sasrec_UK_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_DE_locale_prediction():
    return pd.read_parquet(sasrec_DE_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_JP_locale_prediction():
    return pd.read_parquet(sasrec_JP_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_IT_locale_prediction():
    return pd.read_parquet(sasrec_IT_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_ES_locale_prediction():
    return pd.read_parquet(sasrec_ES_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_FR_locale_prediction():
    return pd.read_parquet(sasrec_FR_locale_prediction_task3_path, engine='pyarrow')

In [13]:
task3_DE_prediction = read_sasrec_DE_locale_prediction()
task3_DE_prediction = task3_DE_prediction[task3_DE_prediction['locale'] == 'DE']

task3_JP_prediction = read_sasrec_JP_locale_prediction()
task3_JP_prediction = task3_JP_prediction[task3_JP_prediction['locale'] == 'JP']

task3_UK_prediction = read_sasrec_UK_locale_prediction()
task3_UK_prediction = task3_UK_prediction[task3_UK_prediction['locale'] == 'UK']

task3_IT_prediction = read_sasrec_IT_locale_prediction()
task3_IT_prediction = task3_IT_prediction[task3_IT_prediction['locale'] == 'IT']

task3_ES_prediction = read_sasrec_ES_locale_prediction()
task3_ES_prediction = task3_ES_prediction[task3_ES_prediction['locale'] == 'ES']

task3_FR_prediction = read_sasrec_FR_locale_prediction()
task3_FR_prediction = task3_FR_prediction[task3_FR_prediction['locale'] == 'FR']

In [14]:
task3_test_sessions = read_task3_test()

In [15]:
task3_test_sessions

Unnamed: 0,prev_items,locale
0,['B0BF9JMVDG' 'B01ET9V90M'],ES
1,['B09QQG85HM' 'B09J4T4JF5'],ES
2,['B09NSKDG4K' 'B09YY6J1ZM'],ES
3,['B09B7NYDJ7' 'B09B7NYDJ7'],ES
4,['B0B6J17LK4' 'B0B6R7X6GY' 'B07HXY5SGH'],ES
...,...,...
56417,['B08S8N3YG4' 'B00HRRANWY'],UK
56418,['B09F5L55QV' 'B0B8BRLGY4' 'B00IKGH2TI' 'B00IK...,UK
56419,['B09FY2WLZ3' 'B09FY4F65Y'],UK
56420,['B001G4L9KO' 'B0BG373773' 'B001G4L9KO' 'B07Z2...,UK


In [16]:
(task3_test_sessions['locale'] == 'ES').sum(), (task3_test_sessions['locale'] == 'DE').sum(), (task3_test_sessions['locale'] == 'FR').sum()

(6422, 10000, 10000)

In [17]:
(task3_test_sessions['locale'] == 'IT').sum(), (task3_test_sessions['locale'] == 'JP').sum(), (task3_test_sessions['locale'] == 'UK').sum()

(10000, 10000, 10000)

In [18]:
# ES : [0, 6420], DE : [6421, 16420], FR : [16421, 26420], IT : [26421, 36420], JP : [36421, 46420], UK : [46421, 56420]
new_test_prediction_task3 = pd.concat([task3_ES_prediction, task3_DE_prediction, task3_FR_prediction, task3_IT_prediction, task3_JP_prediction, task3_UK_prediction], axis=0)

In [19]:
new_test_prediction_task3

Unnamed: 0,locale,next_item_prediction,scores
0,ES,"[B00XA0GJSE, B09HSK3MR5, B00EU8F8M8, B098Q9L54...","[12.420239448547363, 10.279866218566895, 9.359..."
1,ES,"[B08BRGH7S3, B09K7TDY1H, B09NT33LZN, B08C7G859...","[14.474514961242676, 14.411263465881348, 13.38..."
2,ES,"[B09XM6Z7VY, B09XM6G37B, B09CCQ9987, B07MC3GR8...","[15.859247207641602, 15.796967506408691, 14.34..."
3,ES,"[B014EWSGX2, B00MFEHQMO, B00MFEHQJ2, B013C053D...","[18.539295196533203, 17.649274826049805, 16.58..."
4,ES,"[B09DXT3NGZ, B087HJ5ZZS, B07S9HX4SX, B07GVM3ZL...","[16.20880699157715, 15.716256141662598, 15.214..."
...,...,...,...
56417,UK,"[B00HRRAK7W, B017UDNQAU, B0012NIL28, B01BREENB...","[16.211650848388672, 15.547916412353516, 15.51..."
56418,UK,"[B0BB7H3RL8, B09N2TT7WT, B08KT4N4RN, B0B93KGY2...","[26.357576370239258, 15.044347763061523, 14.95..."
56419,UK,"[B08L6FLJHB, B09FXW93LB, B08CYY2KR1, B09YDJWLK...","[24.867746353149414, 22.156007766723633, 18.20..."
56420,UK,"[B07Z27R3QH, B09PBNQ79F, B09K41DGTZ, B08LGD255...","[17.918062210083008, 10.027246475219727, 8.897..."


In [52]:
new_test_prediction_task3.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/task3_six_locale_prediction_0523_1953.parquet', engine='pyarrow')