In [1]:
import warnings
warnings.simplefilter('ignore')

import gc
import re
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
from tqdm.auto import tqdm

In [2]:
df_sess = pd.read_csv('./raw_data/sessions_train.csv')
df_sess

Unnamed: 0,prev_items,next_item,locale
0,['B09W9FND7K' 'B09JSPLN1M'],B09M7GY217,DE
1,['B076THCGSG' 'B007MO8IME' 'B08MF65MLV' 'B001B...,B001B4THSA,DE
2,['B0B1LGXWDS' 'B00AZYORS2' 'B0B1LGXWDS' 'B00AZ...,B0767DTG2Q,DE
3,['B09XMTWDVT' 'B0B4MZZ8MB' 'B0B7HZ2GWX' 'B09XM...,B0B4R9NN4B,DE
4,['B09Y5CSL3T' 'B09Y5DPTXN' 'B09FKD61R8'],B0BGVBKWGZ,DE
...,...,...,...
3606244,['B086CYFSKW' 'B0874F9859' 'B086CYFSKW'],B07B5TYD76,IT
3606245,['B09NRZKZ7V' 'B08WJTPV93'],B08L1P4C3D,IT
3606246,['B085JFX7MP' 'B085JGHW8R'],B01MPWVD44,IT
3606247,['B00B0UING2' 'B00B0UING2'],B00D3HYEZ4,IT


In [3]:
df_test = pd.read_csv('./raw_data/sessions_test_task3.csv')
df_test

Unnamed: 0,prev_items,locale
0,['B082DLM3NZ' 'B089X86H73'],ES
1,['B071WPLND2' 'B08TMJ9SDZ' 'B07XRCLVYG'],ES
2,['B094V8G54H' 'B094V97YV8'],ES
3,['B0B3DQXY57' 'B0B6W3GGTM'],ES
4,['B0765BPD7T' 'B00V4PQY3C' 'B09HWV4MBK'],ES
...,...,...
56416,['B08GNG5FMW' 'B08Q7MJW8W'],UK
56417,['B09YH16XH1' 'B09YGY96ZM'],UK
56418,['B00EXKSNNE' 'B005DBORH8' 'B005DBORCS' 'B005D...,UK
56419,['B007CJVZ1A' 'B07GCSPHNK' 'B07GCVF3N3'],UK


In [4]:
def str2list(x):
    x = x.replace('[', '').replace(']', '').replace("'", '').replace('\n', ' ').replace('\r', ' ')
    l = [i for i in x.split() if i]
    return l

In [5]:
next_item_dict = defaultdict(list)

for _, row in tqdm(df_sess.iterrows(), total=len(df_sess)):
    prev_items = str2list(row['prev_items'])
    next_item = row['next_item']
    prev_items_length = len(prev_items)
    if prev_items_length <= 1:
        next_item_dict[prev_items[0]].append(next_item)
    else:
        for i, item in enumerate(prev_items[:-1]):
            next_item_dict[item].append(prev_items[i+1])
        next_item_dict[prev_items[-1]].append(next_item)

100%|██████████| 3606249/3606249 [02:45<00:00, 21802.13it/s]


In [7]:
next_item_map = {}

for item in tqdm(next_item_dict):
    counter = Counter(next_item_dict[item])
    next_item_map[item] = [i[0] for i in counter.most_common(100)]

100%|██████████| 1323630/1323630 [01:11<00:00, 18581.58it/s]


In [8]:
next_item_map['B09W9FND7K']

['B09JSPLN1M',
 'B09W9FND7K',
 'B078WW2WN5',
 'B08PPNZXBC',
 'B08LYTWDT6',
 'B07ZF23YBP']

In [9]:
k = []
v = []

for item in next_item_dict:
    k.append(item)
    v.append(next_item_dict[item])
    
df_next = pd.DataFrame({'item': k, 'next_item': v})
df_next = df_next.explode('next_item').reset_index(drop=True)
df_next

Unnamed: 0,item,next_item
0,B09W9FND7K,B09JSPLN1M
1,B09W9FND7K,B09JSPLN1M
2,B09W9FND7K,B09JSPLN1M
3,B09W9FND7K,B09JSPLN1M
4,B09W9FND7K,B078WW2WN5
...,...,...
15306178,B09V7T4HXD,B09QHXKZXC
15306179,B0B14HBDHX,B092471PYL
15306180,B0B14HBDHX,B09X9DSQ7V
15306181,B07P6QPKNL,B07P6QGKTV


In [None]:
top200 = df_next['next_item'].value_counts().index.tolist()[:200]

In [9]:
df_test['last_item'] = df_test['prev_items'].apply(lambda x: str2list(x)[-1])
df_test['next_item_prediction'] = df_test['last_item'].map(next_item_map)
df_test

Unnamed: 0,prev_items,locale,last_item,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,B099NQFMG7,"[B099NS1XPG, B08496TCCQ, B099NR3X6D, B09WDSH4C..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,B00R9RZ9ZS,"[B004ZXMV4Q, B00R9R5ND6, B095TQTZXY, B086J6RTT..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,B07G7Q5N6G,"[B08C9Q7QVK, B07G7Q5N6G, B07YSRXJD3, B07JD9P44..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,3955350843,"[395535086X, 3772476953, 3955350843, 395535087..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,B09J945WQR,"[B09J8V18FL, B09J8T6TTH, B09J8SKX9G, B09J8V9RQ..."
...,...,...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK,B0B14M3VZX,"[B08X9L5RGD, B07V5FL8G6, B09Y4HKGKT, B09MW64JGM]"
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK,B081YDH55K,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09CPP92Q..."
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK,B09HL11V5B,"[B09HKZBNZH, B09HZSRJWW, B09HX9VGW4, B09HL141Q..."
316969,['B01MCQMORK' 'B09JYZ325W'],UK,B09JYZ325W,"[B07TR5LQSL, B08FB464L7, B08JG8TSCP, B09C8RQ8N..."


In [10]:
preds = []

for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
    pred_orig = row['next_item_prediction']
    pred = pred_orig
    prev_items = str2list(row['prev_items'])
    if type(pred) == float:
        pred = top200[:100]
    else:
        if len(pred_orig) < 100:
            for i in top200:
                if i not in pred_orig and i not in prev_items:
                    pred.append(i)
                if len(pred) >= 100:
                    break
        else:
            pred = pred[:100]
    preds.append(pred)

100%|██████████| 316971/316971 [00:29<00:00, 10927.84it/s]


In [11]:
df_test['next_item_prediction'] = preds
df_test

Unnamed: 0,prev_items,locale,last_item,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,B099NQFMG7,"[B099NS1XPG, B08496TCCQ, B099NR3X6D, B09WDSH4C..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,B00R9RZ9ZS,"[B004ZXMV4Q, B00R9R5ND6, B095TQTZXY, B086J6RTT..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,B07G7Q5N6G,"[B08C9Q7QVK, B07G7Q5N6G, B07YSRXJD3, B07JD9P44..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,3955350843,"[395535086X, 3772476953, 3955350843, 395535087..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,B09J945WQR,"[B09J8V18FL, B09J8T6TTH, B09J8SKX9G, B09J8V9RQ..."
...,...,...,...,...
316966,['B077SZ2C3Y' 'B0B14M3VZX'],UK,B0B14M3VZX,"[B08X9L5RGD, B07V5FL8G6, B09Y4HKGKT, B09MW64JG..."
316967,['B08KFHDPY9' 'B0851KTSRZ' 'B08KFHDPY9' 'B0851...,UK,B081YDH55K,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09CPP92Q..."
316968,['B07PY1N81F' 'B07Q1Z8SQN' 'B07PY1N81F' 'B07Q1...,UK,B09HL11V5B,"[B09HKZBNZH, B09HZSRJWW, B09HX9VGW4, B09HL141Q..."
316969,['B01MCQMORK' 'B09JYZ325W'],UK,B09JYZ325W,"[B07TR5LQSL, B08FB464L7, B08JG8TSCP, B09C8RQ8N..."


In [12]:
df_test['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [13]:
df_test[['locale', 'next_item_prediction']].to_parquet('submission_task1.parquet', engine='pyarrow')

In [14]:
df_test.head(5)

Unnamed: 0,prev_items,locale,last_item,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,B099NQFMG7,"[B099NS1XPG, B08496TCCQ, B099NR3X6D, B09WDSH4C..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,B00R9RZ9ZS,"[B004ZXMV4Q, B00R9R5ND6, B095TQTZXY, B086J6RTT..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,B07G7Q5N6G,"[B08C9Q7QVK, B07G7Q5N6G, B07YSRXJD3, B07JD9P44..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,3955350843,"[395535086X, 3772476953, 3955350843, 395535087..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,B09J945WQR,"[B09J8V18FL, B09J8T6TTH, B09J8SKX9G, B09J8V9RQ..."
