In [1]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 
from copy import deepcopy

In [3]:
sasrec_UK_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_UK/2023-05-07-03-21-09.parquet'
sasrec_DE_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_DE/2023-05-10-19-55-32.parquet'
sasrec_JP_locale_prediction_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_JP/2023-05-11-04-25-17.parquet'

In [4]:

@lru_cache(maxsize=1)
def read_all_valid_sessions():
    return pd.read_csv('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv')

@lru_cache(maxsize=1)
def read_sasrec_UK_locale_prediction():
    return pd.read_parquet(sasrec_UK_locale_prediction_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_DE_locale_prediction():
    return pd.read_parquet(sasrec_DE_locale_prediction_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_JP_locale_prediction():
    return pd.read_parquet(sasrec_JP_locale_prediction_path, engine='pyarrow')

# task1 validation

In [17]:
sasrec_valid_sessions = read_all_valid_sessions()

In [13]:
sasrec_UK_locale_prediction = read_sasrec_UK_locale_prediction()
len(sasrec_UK_locale_prediction), len(sasrec_UK_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [14]:
sasrec_DE_locale_prediction = read_sasrec_DE_locale_prediction()
len(sasrec_DE_locale_prediction), len(sasrec_DE_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [15]:
sasrec_JP_locale_prediction = read_sasrec_JP_locale_prediction()
len(sasrec_JP_locale_prediction), len(sasrec_JP_locale_prediction.iloc[0]['candidates'])

(361581, 300)

In [16]:
sasrec_JP_locale_prediction

Unnamed: 0,locale,candidates,sess_id,scores
0,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",0,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B0797D6F3...",1,"[17.74812889099121, 14.109373092651367, 12.546..."
2,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",2,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
3,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",3,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P2PCM...",4,"[21.915809631347656, 20.597021102905273, 19.71..."
...,...,...,...,...
361576,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361576,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
361577,JP,"[B09BCM5NL1, B09B9V4PXC, B09XH1YGLL, B09XGRXXG...",361577,"[13.340217590332031, 12.540390014648438, 12.30..."
361578,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361578,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
361579,[PAD],"[B07BP8ZBJK, B01CODB5T0, B01CODB64E, B019BWJT0...",361579,"[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."


In [24]:
sasrec_merged_valid_prediction = deepcopy(sasrec_DE_locale_prediction)

In [25]:
sasrec_merged_valid_prediction[sasrec_valid_sessions['locale'] == 'JP'] = sasrec_JP_locale_prediction[sasrec_valid_sessions['locale'] == 'JP']

In [27]:
sasrec_merged_valid_prediction[sasrec_valid_sessions['locale'] == 'UK'] = sasrec_UK_locale_prediction[sasrec_valid_sessions['locale'] == 'UK']

In [29]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [31]:
cal_hit_and_mrr(sasrec_valid_sessions['next_item'], sasrec_merged_valid_prediction['candidates'])

100%|██████████| 361581/361581 [00:09<00:00, 36467.17it/s]


(0.6773447719874661, 0.30861870792730345)

In [34]:
sasrec_merged_valid_prediction['candidates'] = sasrec_merged_valid_prediction['candidates'].apply(lambda x : x[:100])
sasrec_merged_valid_prediction['scores'] = sasrec_merged_valid_prediction['scores'].apply(lambda x : x[:100])

In [36]:
sasrec_merged_valid_prediction['scores'].apply(len).describe()

count    361581.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: scores, dtype: float64

In [None]:
sasrec_merged_valid_prediction.to_parquet('../candidates/SASRec_Next/seperate_locale/SASRec_Next_04_26_15_26_valid_with_score.parquet')

# task 1 

In [5]:
sasrec_UK_locale_prediction = read_sasrec_UK_locale_prediction()
len(sasrec_UK_locale_prediction), len(sasrec_UK_locale_prediction.iloc[0]['next_item_prediction'])

(316971, 100)

In [6]:
sasrec_DE_locale_prediction = read_sasrec_DE_locale_prediction()
len(sasrec_DE_locale_prediction), len(sasrec_DE_locale_prediction.iloc[0]['next_item_prediction'])

(316971, 100)

In [7]:
sasrec_JP_locale_prediction = read_sasrec_JP_locale_prediction()
len(sasrec_JP_locale_prediction), len(sasrec_JP_locale_prediction.iloc[0]['next_item_prediction'])

(316971, 100)

In [8]:
len(sasrec_DE_locale_prediction[sasrec_DE_locale_prediction['locale'] == 'DE'])

104568

In [9]:
len(sasrec_JP_locale_prediction[sasrec_JP_locale_prediction['locale'] == 'JP'])

96467

In [10]:
len(sasrec_UK_locale_prediction[sasrec_UK_locale_prediction['locale'] == 'UK'])

115936

In [11]:
new_test_prediction = pd.concat([sasrec_DE_locale_prediction.iloc[ : 104568], sasrec_JP_locale_prediction.iloc[104568 : 104568 + 96467], sasrec_UK_locale_prediction.iloc[104568 + 96467: ]], axis=0)

In [12]:
def new_pred_validation(new_pred):
    assert len(new_pred) == len(sasrec_DE_locale_prediction)
    assert new_pred.iloc[104567]['locale'] == 'DE'
    assert new_pred.iloc[104568]['locale'] == 'JP'
    assert new_pred.iloc[104568 + 96467 - 1]['locale'] == 'JP'
    assert new_pred.iloc[104568 + 96467]['locale'] == 'UK'
    print("Nice!!!")

In [13]:
new_pred_validation(new_test_prediction)

Nice!!!


In [14]:
new_test_prediction['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [15]:
new_test_prediction['next_item_prediction'] = new_test_prediction['next_item_prediction'].apply(lambda x : x[:100])

In [19]:
new_test_prediction.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/three_locale_prediction_three_layers_0511_2055.parquet', engine='pyarrow')

# task 3

In [3]:
import pandas as pd 
import numpy as np
from functools import lru_cache
import os
from tqdm import tqdm 

In [27]:
taske_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task3.csv'
sasrec_UK_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_UK/2023-04-21-17-06-44.parquet'
sasrec_DE_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_DE/2023-04-21-17-04-13.parquet'
sasrec_JP_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_JP/2023-04-21-17-05-37.parquet'
sasrec_IT_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_IT/2023-04-21-17-07-42.parquet'
sasrec_ES_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_ES/2023-04-21-17-08-33.parquet'
sasrec_FR_locale_prediction_task3_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023_FR/2023-04-21-17-09-29.parquet'

In [31]:
@lru_cache(maxsize=1)
def read_task3_test():
    return pd.read_csv(taske_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_UK_locale_prediction():
    return pd.read_parquet(sasrec_UK_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_DE_locale_prediction():
    return pd.read_parquet(sasrec_DE_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_JP_locale_prediction():
    return pd.read_parquet(sasrec_JP_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_IT_locale_prediction():
    return pd.read_parquet(sasrec_IT_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_ES_locale_prediction():
    return pd.read_parquet(sasrec_ES_locale_prediction_task3_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_FR_locale_prediction():
    return pd.read_parquet(sasrec_FR_locale_prediction_task3_path, engine='pyarrow')

In [32]:
task3_DE_prediction = read_sasrec_DE_locale_prediction()
task3_DE_prediction = task3_DE_prediction[task3_DE_prediction['locale'] == 'DE']

task3_JP_prediction = read_sasrec_JP_locale_prediction()
task3_JP_prediction = task3_JP_prediction[task3_JP_prediction['locale'] == 'JP']

task3_UK_prediction = read_sasrec_UK_locale_prediction()
task3_UK_prediction = task3_UK_prediction[task3_UK_prediction['locale'] == 'UK']

task3_IT_prediction = read_sasrec_IT_locale_prediction()
task3_IT_prediction = task3_IT_prediction[task3_IT_prediction['locale'] == 'IT']

task3_ES_prediction = read_sasrec_ES_locale_prediction()
task3_ES_prediction = task3_ES_prediction[task3_ES_prediction['locale'] == 'ES']

task3_FR_prediction = read_sasrec_FR_locale_prediction()
task3_FR_prediction = task3_FR_prediction[task3_FR_prediction['locale'] == 'FR']

In [33]:
task3_test_sessions = read_task3_test()

In [34]:
task3_test_sessions

Unnamed: 0,prev_items,locale
0,['B082DLM3NZ' 'B089X86H73'],ES
1,['B071WPLND2' 'B08TMJ9SDZ' 'B07XRCLVYG'],ES
2,['B094V8G54H' 'B094V97YV8'],ES
3,['B0B3DQXY57' 'B0B6W3GGTM'],ES
4,['B0765BPD7T' 'B00V4PQY3C' 'B09HWV4MBK'],ES
...,...,...
56416,['B08GNG5FMW' 'B08Q7MJW8W'],UK
56417,['B09YH16XH1' 'B09YGY96ZM'],UK
56418,['B00EXKSNNE' 'B005DBORH8' 'B005DBORCS' 'B005D...,UK
56419,['B007CJVZ1A' 'B07GCSPHNK' 'B07GCVF3N3'],UK


In [43]:
# ES : [0, 6420], DE : [6421, 16420], FR : [16421, 26420], IT : [26421, 36420], JP : [36421, 46420], UK : [46421, 56420]
new_test_prediction_task3 = pd.concat([task3_ES_prediction, task3_DE_prediction, task3_FR_prediction, task3_IT_prediction, task3_JP_prediction, task3_UK_prediction], axis=0)

In [52]:
new_test_prediction_task3.to_parquet('/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/task3_six_locale_prediction_0421_1731.parquet', engine='pyarrow')