In [2]:
import warnings
warnings.simplefilter('ignore')

import gc
import re
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
from tqdm.auto import tqdm

In [3]:
df_sess = pd.read_csv('./raw_data/sessions_train.csv')
df_sess

Unnamed: 0,prev_items,next_item,locale
0,['B09W9FND7K' 'B09JSPLN1M'],B09M7GY217,DE
1,['B076THCGSG' 'B007MO8IME' 'B08MF65MLV' 'B001B...,B001B4THSA,DE
2,['B0B1LGXWDS' 'B00AZYORS2' 'B0B1LGXWDS' 'B00AZ...,B0767DTG2Q,DE
3,['B09XMTWDVT' 'B0B4MZZ8MB' 'B0B7HZ2GWX' 'B09XM...,B0B4R9NN4B,DE
4,['B09Y5CSL3T' 'B09Y5DPTXN' 'B09FKD61R8'],B0BGVBKWGZ,DE
...,...,...,...
3606244,['B086CYFSKW' 'B0874F9859' 'B086CYFSKW'],B07B5TYD76,IT
3606245,['B09NRZKZ7V' 'B08WJTPV93'],B08L1P4C3D,IT
3606246,['B085JFX7MP' 'B085JGHW8R'],B01MPWVD44,IT
3606247,['B00B0UING2' 'B00B0UING2'],B00D3HYEZ4,IT


In [4]:
idx_permu = np.random.permutation(len(df_sess))
train_num, valid_num = int(len(df_sess) * 0.8), int(len(df_sess) * 0.1)
test_num = len(df_sess) - train_num - valid_num 
print(train_num, valid_num, test_num)
train_idx, valid_idx, test_idx = np.split(idx_permu, [train_num, train_num + valid_num])
train_idx, valid_idx, test_idx

2884999 360624 360626


(array([1070973, 2705391, 1965079, ...,  944527,  723883,  163305]),
 array([3275574,  575404, 1544217, ..., 3279956, 1506532, 1392533]),
 array([2170579, 2845910, 2934909, ..., 2359272, 1718538, 1953586]))

In [5]:
df_train_sess = df_sess.loc[train_idx]
df_valid_sess = df_sess.loc[valid_idx]
df_test_sess = df_sess.loc[test_idx]
df_train_sess = df_train_sess.reset_index(drop=True)
df_valid_sess = df_valid_sess.reset_index(drop=True)
df_test_sess = df_test_sess.reset_index(drop=True)

In [6]:
df_pred = pd.read_csv('./raw_data/sessions_test_task1.csv')
df_pred

Unnamed: 0,prev_items,locale
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE
...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK
316969,['B01MCQMORK' 'B09JYZ325W'],UK


In [7]:
def str2list(x):
    x = x.replace('[', '').replace(']', '').replace("'", '').replace('\n', ' ').replace('\r', ' ')
    l = [i for i in x.split() if i]
    return l

In [8]:
next_item_dict = defaultdict(list)

for _, row in tqdm(df_train_sess.iterrows(), total=len(df_train_sess)):
    prev_items = str2list(row['prev_items'])
    next_item = row['next_item']
    prev_items_length = len(prev_items)
    if prev_items_length <= 1:
        next_item_dict[prev_items[0]].append(next_item)
    else:
        for i, item in enumerate(prev_items[:-1]):
            next_item_dict[item].append(prev_items[i+1])
        next_item_dict[prev_items[-1]].append(next_item)

100%|██████████| 2884999/2884999 [03:03<00:00, 15724.33it/s]


In [9]:
next_item_map = {}

for item in tqdm(next_item_dict):
    counter = Counter(next_item_dict[item])
    next_item_map[item] = [i[0] for i in counter.most_common(100)]

100%|██████████| 1269983/1269983 [00:18<00:00, 70325.41it/s] 


In [10]:
next_item_map

{'B00X81MA0Q': ['B00X81MA0Q',
  'B093STBTW6',
  'B00X81M4FW',
  'B01N24HPR2',
  'B00JWKQ9JK',
  'B09MM529WV',
  'B077SMRXB8',
  'B00X81GP9I',
  'B093SVT4FX',
  'B00X9X4JFW',
  'B07Y1RZMHJ',
  'B07G77V6GG',
  'B0035FZIBI',
  'B012NPT51Y',
  'B00NXXNRBU',
  'B0B5DN36ZZ',
  'B01ID7YJ1M',
  'B0BDMGX45M',
  'B07ZY6W39F',
  'B07NYFG5DG',
  'B08ZJ76RH2',
  'B0BBQZLGCQ',
  'B0B3DMNRKL',
  'B092ZM6JQK',
  'B00X81GTCQ',
  'B08P71LNNY',
  'B0B554CWZX',
  'B07G2H9NVT',
  'B07MMSHC4R',
  'B0888KRK3Z',
  'B08QR6ZXQW',
  'B01MR866BB',
  'B007H4B768',
  'B09PZ4H9YM',
  'B0963CHQ6R',
  'B012SRRMLM',
  'B081TRYVS4',
  'B017DO98JY',
  'B00JWKQ7VA'],
 'B07Y1RZMHJ': ['B086V6B85B', 'B07PZK2JSV', 'B09Y91XW98', 'B01G93MLGC'],
 'B086V6B85B': ['B07Y1RZMHJ', 'B07S9NV1LC', 'B091YBGGBY', 'B0B98ZN7Z3'],
 'B07S9NV1LC': ['B0BD9DWXBY',
  'B09D178XW4',
  'B09JCJT88S',
  'B08D688LHQ',
  'B00K6O5TM4',
  'B09CP8FHT2',
  'B07ZNB7VY4',
  'B09MWH1PDT'],
 'B0BD9DWXBY': ['B0BD9DWXBY',
  'B09SW762K6',
  'B08ZS9WTJ7',
  'B09SP5P

In [11]:
k = []
v = []

for item in next_item_dict:
    k.append(item)
    v.append(next_item_dict[item])
    
df_next = pd.DataFrame({'item': k, 'next_item': v})
df_next = df_next.explode('next_item').reset_index(drop=True)
df_next

Unnamed: 0,item,next_item
0,B00X81MA0Q,B07Y1RZMHJ
1,B00X81MA0Q,B07G77V6GG
2,B00X81MA0Q,B00X81MA0Q
3,B00X81MA0Q,B0035FZIBI
4,B00X81MA0Q,B00X81MA0Q
...,...,...
12245636,B08F5L1V72,B08F5L1V72
12245637,B08F5L1V72,B0995MT4H4
12245638,B08PHT7SC7,B08685TJPN
12245639,B08LVHF6GG,B08LVPH7CV


In [12]:
next_item_map['B09W9FND7K']

['B09JSPLN1M', 'B078WW2WN5', 'B07ZF23YBP', 'B09W9FND7K', 'B08PPNZXBC']

In [13]:
top200 = df_next['next_item'].value_counts().index.tolist()[:200]
top200

['B07QPV9Z7X',
 'B09NQGVSPD',
 'B0BD5MFPMF',
 'B01NCX3W3O',
 'B00NTCH52W',
 'B01N40PO2M',
 'B0B1MPZWJG',
 'B08GWS298V',
 'B09VKWKZ16',
 'B07XKBLL8F',
 'B01MXLEVR7',
 'B07RHT52HX',
 'B0BD3XQ9RQ',
 'B0BDML9477',
 'B07N8QY3YH',
 'B0BD88WWQ8',
 'B09VKWL89R',
 'B088FSHMQ3',
 'B08CN3G4N9',
 'B091V2X68B',
 'B0922JX27X',
 'B0BDZ16ZG4',
 'B0B19YG4MB',
 'B019GNUT0C',
 'B07S9FXKZQ',
 'B09D76FT9D',
 'B001RYNMDA',
 'B08YKJ17NF',
 'B00NTCHCU2',
 'B07WD58H6R',
 'B09BBX1T4S',
 'B00CWNMV4G',
 'B0935DN1BN',
 'B09QFPZ9B7',
 'B09HGGV5R5',
 'B004BIG55Q',
 'B014I8SSD0',
 'B0B18DBBFB',
 'B098B8PFXY',
 'B0991FVFQJ',
 'B08GYKNCCP',
 'B09BTKNGN5',
 'B08H93ZRK9',
 'B09QFPYX34',
 'B09SWS16W6',
 'B01B8R6PF2',
 'B08F2NDB39',
 'B07H27J698',
 'B0B18CK9DV',
 'B014I8SIJY',
 'B0BCQLVFBM',
 'B0931VRJT5',
 'B01KNWYGMW',
 'B07MLFBJG3',
 'B0B1Q3QML1',
 'B00HZV9WTM',
 'B08ZMQT158',
 'B077B9X3SB',
 'B0B9B1Y8R5',
 'B08LSNJQ1N',
 'B07XXZP2CK',
 'B07PNL5STG',
 'B009Q8492W',
 'B09C7BRP5Y',
 'B0BFH812PF',
 'B013ICNQLQ',
 'B0BFXGM3

In [14]:
df_valid_sess['last_item'] = df_valid_sess['prev_items'].apply(lambda x: str2list(x)[-1])
df_valid_sess['next_item_prediction'] = df_valid_sess['last_item'].map(next_item_map)
df_valid_sess

Unnamed: 0,prev_items,next_item,locale,last_item,next_item_prediction
0,['B002IJM4DW' 'B006ZZ7OYO' 'B00B81RHGK' 'B002I...,B09J286JSD,ES,B0B65L184L,[B09J286JSD]
1,['B0BGG3G8KB' 'B0BGG42C62'],B0BGG553DF,DE,B0BGG42C62,"[B0BGG3G8KB, B0BGG553DF, B0BGG3BS14, B0B24QYW2..."
2,['B09D3NXCBL' 'B09D3NXCBL' 'B0B7WBCT28'],B0B7VKTZ61,JP,B0B7WBCT28,"[B0B7VKTZ61, B0B48YPVS6, B09NBM7H1L]"
3,['B08VJJWCR4' 'B00M8VCL2Y'],B07SY9MRWJ,JP,B00M8VCL2Y,"[B079QB3DS5, B079PTGNTW, B07J266Z6N, B0BFGDZ1B..."
4,['B07DXVRY1C' 'B093K2BHBS' 'B08156Z37H'],B09QXG9HSL,UK,B08156Z37H,"[B07DXVRY1C, B093K2BHBS, B07DXVSR99, B07KGZGWF..."
...,...,...,...,...,...
360619,['B07TJTQN2F' 'B07TQ8KP4Y'],B07TJTQWR6,DE,B07TQ8KP4Y,"[B00CHU8T22, B07TJTQWR6, B07TP46QQY, B07TQ8KP4..."
360620,['B08P2N83KV' 'B08QMRWQMF'],B08P2QM1DN,FR,B08QMRWQMF,"[B08P2N83KV, B08P2QM1DN, B08CZ1F3RH, B01F8TY7I..."
360621,['B08V4Q8VJ1' 'B09M7LCX8P' 'B079QF67S8' 'B007Z...,B01G7G6XDI,ES,B007ZY4COE,"[B07GTNW9FV, B007ZY4D0W, B0173NA0AG, B01G70GFA..."
360622,['B09JFJNPRZ' 'B00C6VIA6C'],B07PSL8RBN,JP,B00C6VIA6C,"[B01966X1CC, B00C6VIA6C, B06XNVVD51, B09H2PV7L..."


In [15]:
df_test_sess['last_item'] = df_test_sess['prev_items'].apply(lambda x: str2list(x)[-1])
df_test_sess['next_item_prediction'] = df_test_sess['last_item'].map(next_item_map)
df_test_sess

Unnamed: 0,prev_items,next_item,locale,last_item,next_item_prediction
0,['B08GYKF2BH' 'B08H1H4B7P'],B08C7YCWQK,UK,B08H1H4B7P,"[B08C7YCWQK, B08GYKF2BH, B09CNP8LSK, B000TRQGN4]"
1,['B01M62IRCV' 'B01M62IRCV' 'B08HLSXQS9' 'B08DS...,B07259BPG2,UK,B0B1LX67RZ,"[B0B5QGZHR8, B09XC9DV7G, B09SYM81WM, B07259BPG..."
2,['B0B51PL552' 'B07L559LLY' 'B004X56S5Y'],B004X5AZJ4,UK,B004X56S5Y,"[B004X5AZJ4, B09Z3X7DN4]"
3,['B08Z4BGRFV' 'B09T77T7N9'],B09FPQ6NXB,UK,B09T77T7N9,"[B08GPVCWSK, B098KYLL2B, B09F6DQTH8, B0B6GRK25W]"
4,['B07ZYWFFFP' 'B07N9JFGZD' 'B07LCBKPQ2' 'B07LC...,B00B5AN3M6,UK,B09FY639HT,"[B086C6WYTB, B09QG2SDDL]"
...,...,...,...,...,...
360621,['B014R0RUEM' 'B002SWIJ7Y' 'B00QK0KOY6' 'B002S...,B001FMHUNW,JP,B002SWIJ7Y,"[B00QK0KOY6, B014R0RUEM, B00CW9QXG2]"
360622,['B0B9JRSDSX' 'B0B9JQBN63'],B07BB5QRTS,JP,B0B9JQBN63,"[B0B9JRSDSX, B0B9JRMZCY, 484707257X, B0BDKLK1Z..."
360623,['B0B8ND2ZCV' 'B0BHZ6NGR3' 'B0BHZ2GNNR' 'B0B8N...,B09JGFGV3N,UK,B0B8NGDRKZ,"[B0BC888VLM, B0BC8BKBWJ]"
360624,['B09FDV3QDW' 'B09FDV3QDW'],B09TKR9BGM,JP,B09FDV3QDW,"[B0B1J5CX4W, B09FDV3QDW, B0B7S525JP, B07Y3GXQJ..."


In [16]:
preds = []

for _, row in tqdm(df_test_sess.iterrows(), total=len(df_test_sess)):
    pred_orig = row['next_item_prediction']
    pred = pred_orig
    prev_items = str2list(row['prev_items'])
    if type(pred) == float:
        pred = top200[:100]
    else:
        if len(pred_orig) < 100:
            for i in top200:
                if i not in pred_orig and i not in prev_items:
                    pred.append(i)
                if len(pred) >= 100:
                    break
        else:
            pred = pred[:100]
    preds.append(pred)

100%|██████████| 360626/360626 [00:41<00:00, 8708.18it/s] 


In [17]:
df_test_sess['next_item_prediction'] = preds
df_test_sess

Unnamed: 0,prev_items,next_item,locale,last_item,next_item_prediction
0,['B08GYKF2BH' 'B08H1H4B7P'],B08C7YCWQK,UK,B08H1H4B7P,"[B08C7YCWQK, B08GYKF2BH, B09CNP8LSK, B000TRQGN..."
1,['B01M62IRCV' 'B01M62IRCV' 'B08HLSXQS9' 'B08DS...,B07259BPG2,UK,B0B1LX67RZ,"[B0B5QGZHR8, B09XC9DV7G, B09SYM81WM, B07259BPG..."
2,['B0B51PL552' 'B07L559LLY' 'B004X56S5Y'],B004X5AZJ4,UK,B004X56S5Y,"[B004X5AZJ4, B09Z3X7DN4, B07QPV9Z7X, B09NQGVSP..."
3,['B08Z4BGRFV' 'B09T77T7N9'],B09FPQ6NXB,UK,B09T77T7N9,"[B08GPVCWSK, B098KYLL2B, B09F6DQTH8, B0B6GRK25..."
4,['B07ZYWFFFP' 'B07N9JFGZD' 'B07LCBKPQ2' 'B07LC...,B00B5AN3M6,UK,B09FY639HT,"[B086C6WYTB, B09QG2SDDL, B07QPV9Z7X, B09NQGVSP..."
...,...,...,...,...,...
360621,['B014R0RUEM' 'B002SWIJ7Y' 'B00QK0KOY6' 'B002S...,B001FMHUNW,JP,B002SWIJ7Y,"[B00QK0KOY6, B014R0RUEM, B00CW9QXG2, B07QPV9Z7..."
360622,['B0B9JRSDSX' 'B0B9JQBN63'],B07BB5QRTS,JP,B0B9JQBN63,"[B0B9JRSDSX, B0B9JRMZCY, 484707257X, B0BDKLK1Z..."
360623,['B0B8ND2ZCV' 'B0BHZ6NGR3' 'B0BHZ2GNNR' 'B0B8N...,B09JGFGV3N,UK,B0B8NGDRKZ,"[B0BC888VLM, B0BC8BKBWJ, B07QPV9Z7X, B09NQGVSP..."
360624,['B09FDV3QDW' 'B09FDV3QDW'],B09TKR9BGM,JP,B09FDV3QDW,"[B0B1J5CX4W, B09FDV3QDW, B0B7S525JP, B07Y3GXQJ..."


In [18]:
preds = []

for _, row in tqdm(df_valid_sess.iterrows(), total=len(df_valid_sess)):
    pred_orig = row['next_item_prediction']
    pred = pred_orig
    prev_items = str2list(row['prev_items'])
    if type(pred) == float:
        pred = top200[:100]
    else:
        if len(pred_orig) < 100:
            for i in top200:
                if i not in pred_orig and i not in prev_items:
                    pred.append(i)
                if len(pred) >= 100:
                    break
        else:
            pred = pred[:100]
    preds.append(pred)

100%|██████████| 360624/360624 [00:33<00:00, 10893.38it/s]


In [19]:
df_valid_sess['next_item_prediction'] = preds
df_valid_sess

Unnamed: 0,prev_items,next_item,locale,last_item,next_item_prediction
0,['B002IJM4DW' 'B006ZZ7OYO' 'B00B81RHGK' 'B002I...,B09J286JSD,ES,B0B65L184L,"[B09J286JSD, B07QPV9Z7X, B09NQGVSPD, B0BD5MFPM..."
1,['B0BGG3G8KB' 'B0BGG42C62'],B0BGG553DF,DE,B0BGG42C62,"[B0BGG3G8KB, B0BGG553DF, B0BGG3BS14, B0B24QYW2..."
2,['B09D3NXCBL' 'B09D3NXCBL' 'B0B7WBCT28'],B0B7VKTZ61,JP,B0B7WBCT28,"[B0B7VKTZ61, B0B48YPVS6, B09NBM7H1L, B07QPV9Z7..."
3,['B08VJJWCR4' 'B00M8VCL2Y'],B07SY9MRWJ,JP,B00M8VCL2Y,"[B079QB3DS5, B079PTGNTW, B07J266Z6N, B0BFGDZ1B..."
4,['B07DXVRY1C' 'B093K2BHBS' 'B08156Z37H'],B09QXG9HSL,UK,B08156Z37H,"[B07DXVRY1C, B093K2BHBS, B07DXVSR99, B07KGZGWF..."
...,...,...,...,...,...
360619,['B07TJTQN2F' 'B07TQ8KP4Y'],B07TJTQWR6,DE,B07TQ8KP4Y,"[B00CHU8T22, B07TJTQWR6, B07TP46QQY, B07TQ8KP4..."
360620,['B08P2N83KV' 'B08QMRWQMF'],B08P2QM1DN,FR,B08QMRWQMF,"[B08P2N83KV, B08P2QM1DN, B08CZ1F3RH, B01F8TY7I..."
360621,['B08V4Q8VJ1' 'B09M7LCX8P' 'B079QF67S8' 'B007Z...,B01G7G6XDI,ES,B007ZY4COE,"[B07GTNW9FV, B007ZY4D0W, B0173NA0AG, B01G70GFA..."
360622,['B09JFJNPRZ' 'B00C6VIA6C'],B07PSL8RBN,JP,B00C6VIA6C,"[B01966X1CC, B00C6VIA6C, B06XNVVD51, B09H2PV7L..."


In [20]:
df_valid_sess.iloc[14]

prev_items                                    ['B015SXO11M' 'B013QNUN3Q']
next_item                                                      B07C6MQJN2
locale                                                                 DE
last_item                                                      B013QNUN3Q
next_item_prediction    [B015SXO11M, B07C6MQJN2, B088VSJMFJ, B07QPV9Z7...
Name: 14, dtype: object

In [28]:
# mrr for valid
mrr_valid = []
for i in tqdm(range(len(df_valid_sess))):
    mrr = 0.0
    ground_truth = df_valid_sess.iloc[i]['next_item']
    for i, item in enumerate(df_valid_sess.iloc[i]['next_item_prediction']):
        if item == ground_truth:
            mrr = 1.0 / (i + 1)
            break
    mrr_valid.append(mrr)
mrr_valid = np.array(mrr_valid).mean()
mrr_valid

100%|██████████| 360624/360624 [00:31<00:00, 11363.23it/s]


0.24021385273916246

In [29]:
# mrr for test
mrr_test = []
for i in tqdm(range(len(df_test_sess))):
    mrr = 0.0
    ground_truth = df_test_sess.iloc[i]['next_item']
    for i, item in enumerate(df_test_sess.iloc[i]['next_item_prediction']):
        if item == ground_truth:
            mrr = 1.0 / (i + 1)
            break
    mrr_test.append(mrr)
mrr_test = np.array(mrr_test).mean()
mrr_test

100%|██████████| 360626/360626 [00:33<00:00, 10914.50it/s]


0.24074462120638251

In [30]:
df_pred['last_item'] = df_pred['prev_items'].apply(lambda x: str2list(x)[-1])
df_pred['next_item_prediction'] = df_pred['last_item'].map(next_item_map)
df_pred

Unnamed: 0,prev_items,locale,last_item,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,B099NQFMG7,"[B099NS1XPG, B099NR3X6D, B08496TCCQ, B09WDSH4C..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,B00R9RZ9ZS,"[B004ZXMV4Q, B00R9R5ND6, B095TQTZXY, B08WC5CF5..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,B07G7Q5N6G,"[B08C9Q7QVK, B07G7Q5N6G, B07YSRXJD3, B0B5QNFWJ..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,3955350843,"[395535086X, 3955350843, 3772476953, 395535087..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,B09J945WQR,"[B09J8V18FL, B09J8T6TTH, B09J8SKX9G, B09J8V9RQ..."
...,...,...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK,B0B14M3VZX,"[B08X9L5RGD, B09Y4HKGKT, B07V5FL8G6]"
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK,B081YDH55K,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09CPP92Q..."
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK,B09HL11V5B,"[B09HKZBNZH, B09HZSRJWW, B09HL141QC, B07PY1NG3..."
316969,['B01MCQMORK' 'B09JYZ325W'],UK,B09JYZ325W,"[B08FB464L7, B09C8RQ8NT, B08QTGSCL3, B07L5YWCQ..."


In [31]:
preds = []

for _, row in tqdm(df_pred.iterrows(), total=len(df_pred)):
    pred_orig = row['next_item_prediction']
    pred = pred_orig
    prev_items = str2list(row['prev_items'])
    if type(pred) == float:
        pred = top200[:100]
    else:
        if len(pred_orig) < 100:
            for i in top200:
                if i not in pred_orig and i not in prev_items:
                    pred.append(i)
                if len(pred) >= 100:
                    break
        else:
            pred = pred[:100]
    preds.append(pred)

100%|██████████| 316971/316971 [00:31<00:00, 10034.69it/s]


In [32]:
df_pred['next_item_prediction'] = preds
df_pred

Unnamed: 0,prev_items,locale,last_item,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,B099NQFMG7,"[B099NS1XPG, B099NR3X6D, B08496TCCQ, B09WDSH4C..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,B00R9RZ9ZS,"[B004ZXMV4Q, B00R9R5ND6, B095TQTZXY, B08WC5CF5..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,B07G7Q5N6G,"[B08C9Q7QVK, B07G7Q5N6G, B07YSRXJD3, B0B5QNFWJ..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,3955350843,"[395535086X, 3955350843, 3772476953, 395535087..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,B09J945WQR,"[B09J8V18FL, B09J8T6TTH, B09J8SKX9G, B09J8V9RQ..."
...,...,...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK,B0B14M3VZX,"[B08X9L5RGD, B09Y4HKGKT, B07V5FL8G6, B07QPV9Z7..."
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK,B081YDH55K,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09CPP92Q..."
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK,B09HL11V5B,"[B09HKZBNZH, B09HZSRJWW, B09HL141QC, B07PY1NG3..."
316969,['B01MCQMORK' 'B09JYZ325W'],UK,B09JYZ325W,"[B08FB464L7, B09C8RQ8NT, B08QTGSCL3, B07L5YWCQ..."


In [33]:
df_pred['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [35]:
df_pred[['locale', 'next_item_prediction']].to_parquet('submission_task1_2.parquet', engine='pyarrow')

In [3]:
df_my_pred = pd.read_parquet("/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/2023-03-28-01-44-20.parquet")
df_my_pred.head(5)

Unnamed: 0,locale,next_item_prediction
0,DE,"[B0B6BHYW9N, B092D6TVGH, B0B139RW78, B09Z16NCJ..."
1,DE,"[B094P14PNJ, B08R5KGSXJ, B08Z3DG88N, B08PVJNH6..."
2,DE,"[B07Q329STR, B07V6QXGHX, B07JN5LBS7, B09QFWT3G..."
3,DE,"[B077PBQF2X, B07RKWV54M, B07BR4X8L6, B07X43718..."
4,DE,"[B0BGMHNNGQ, B09CXDNFRC, B08WZ2F7K9, B08PT4XZF..."


In [5]:
df_my_pred2 = pd.read_parquet("/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec2/kdd_cup_2023/2023-03-27-20-06-42.parquet")
df_my_pred2.head(5)

Unnamed: 0,locale,next_item_prediction
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B086CJF45..."
1,DE,"[B004ZXMV4Q, B095TQTZXY, B07WDK8PPN, B08WC5CF5..."
2,DE,"[B000KSLHNQ, B0BJ7H1822, B0B6MZZ674, B0B5QNFWJ..."
3,DE,"[395535086X, B089K81TDV, B09KNCTB5F, 377247695..."
4,DE,"[B09J8TWRV3, B09J8SKX9G, B089CVQ2FS, B07JGZH51..."


In [13]:
df_my_pred.iloc[123456]

locale                                                                 JP
next_item_prediction    [B09HKQQBCH, B096ZQNMHH, B07Q73VM9Y, B099K6PP3...
Name: 123456, dtype: object

In [12]:
df_my_pred2.iloc[123456]

locale                                                                 JP
next_item_prediction    [B00A6LHPZ6, B000MU34Y2, B000MU4YKA, B00AXUC3M...
Name: 123456, dtype: object

In [4]:
df_my_pred['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [41]:
df_pred[['locale', 'next_item_prediction']].iloc[0]['next_item_prediction']

['B099NS1XPG',
 'B099NR3X6D',
 'B08496TCCQ',
 'B09WDSH4CD',
 'B07QPV9Z7X',
 'B09NQGVSPD',
 'B0BD5MFPMF',
 'B01NCX3W3O',
 'B00NTCH52W',
 'B01N40PO2M',
 'B0B1MPZWJG',
 'B08GWS298V',
 'B09VKWKZ16',
 'B07XKBLL8F',
 'B01MXLEVR7',
 'B07RHT52HX',
 'B0BD3XQ9RQ',
 'B0BDML9477',
 'B07N8QY3YH',
 'B0BD88WWQ8',
 'B09VKWL89R',
 'B088FSHMQ3',
 'B08CN3G4N9',
 'B091V2X68B',
 'B0922JX27X',
 'B0BDZ16ZG4',
 'B0B19YG4MB',
 'B019GNUT0C',
 'B07S9FXKZQ',
 'B09D76FT9D',
 'B001RYNMDA',
 'B08YKJ17NF',
 'B00NTCHCU2',
 'B07WD58H6R',
 'B09BBX1T4S',
 'B00CWNMV4G',
 'B0935DN1BN',
 'B09QFPZ9B7',
 'B09HGGV5R5',
 'B004BIG55Q',
 'B014I8SSD0',
 'B0B18DBBFB',
 'B098B8PFXY',
 'B0991FVFQJ',
 'B08GYKNCCP',
 'B09BTKNGN5',
 'B08H93ZRK9',
 'B09QFPYX34',
 'B09SWS16W6',
 'B01B8R6PF2',
 'B08F2NDB39',
 'B07H27J698',
 'B0B18CK9DV',
 'B014I8SIJY',
 'B0BCQLVFBM',
 'B0931VRJT5',
 'B01KNWYGMW',
 'B07MLFBJG3',
 'B0B1Q3QML1',
 'B00HZV9WTM',
 'B08ZMQT158',
 'B077B9X3SB',
 'B0B9B1Y8R5',
 'B08LSNJQ1N',
 'B07XXZP2CK',
 'B07PNL5STG',
 'B009Q849

In [42]:
df_my_pred['next_item_prediction'].iloc[0]

array(['B08L9CZ7BW', 'B095C349SJ', 'B07PHQ22FY', 'B099XL3VS4',
       'B08X12TS8R', 'B09Z16NCJT', 'B0B2X3LC93', 'B08DY7S8DR',
       'B08HDCY9VR', 'B0BCGXM4WH', 'B0BGJDSD38', 'B085QYQ112',
       'B0BDFTFPZ3', 'B088GNP6TB', 'B09F3P3DQD', 'B0851LN8QG',
       'B0BH8NYPV6', 'B0BL768TBK', 'B08FCZPSH2', 'B07Y22TWNF',
       '3735880266', 'B08CY2WJFF', 'B0B7R1SVSG', 'B09YPR8BBJ',
       'B08Q7SX3V6', 'B0BDM9JYGJ', 'B096HSYFL2', 'B0B63D692W',
       'B07XZF3JFZ', 'B08X18GCYD', 'B0B7CBD6M6', 'B08CS1WFN8',
       'B08QW43S2L', 'B07CZ4DLCP', 'B09P4B2VS2', 'B08151FJRJ',
       'B0851MXXR3', 'B09FPV7NF8', 'B00F5UX7EQ', 'B095CSPRRF',
       'B093SYL61B', 'B09DDDR9WV', 'B0B3264C64', 'B07H2BZ5TP',
       'B0B9XTSD8Z', 'B08X23BH9M', 'B085QZ2FJT', 'B085QZ6FLD',
       'B08Q7HL9QV', 'B0B5CYGH1B', 'B0BD5JBWKW', 'B0B6ZPHNHQ',
       'B07GQNC5KQ', 'B099MWSH7B', 'B00FFTDZTY', '3473491764',
       'B09NLCC7TD', 'B0BCGV6372', 'B0BHJKVMB4', 'B09P4BKWPL',
       'B09JCH4D5B', 'B08X16LL57', 'B09JRGQ336', 'B095C