In [1]:
import numpy as np
import pandas as pd
from functools import lru_cache

In [4]:
task = 'task2'

In [5]:
# Cache loading of data for multiple calls

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv('../data/raw-data/products_train.csv')

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv('../data/raw-data/sessions_train.csv')

@lru_cache(maxsize=3)
def read_test_data(task):
    return pd.read_csv(f'../data/raw-data/sessions_test_{task}.csv')

In [6]:
train_sessions = read_train_data()
train_sessions.sample(5)

Unnamed: 0,prev_items,next_item,locale
1628637,['B09T3VFBHG' 'B09T3TQ6G2' 'B09K7PLYVQ'],B0991KLH15,JP
1155112,['4800562112' '4800562082' '4800562104' 'B0B4D...,4800562767,JP
2619786,['B093P2S78K' 'B093P2S78K' 'B086WLRYCZ' 'B086W...,B085W84NHT,UK
2483712,['1408363968' '1408363968'],1408357488,UK
1150745,['B00CP1L14G' 'B08SR95LZX' 'B0B87M86GQ'],B09VT29L69,JP


In [7]:
test_sessions = read_test_data(task)
test_sessions.sample(5)

Unnamed: 0,prev_items,locale
30896,['B094R3R9XH' 'B07TV364MZ'],IT
6666,['B09K3ZSTY8' 'B0BF5BGL7Q' 'B00HCYTIYG'],ES
10791,['B08XXZJG53' 'B09LTXK21P' 'B08XXZJG53' 'B08TJ...,FR
10701,['B002TMCCA8' 'B002TGVA70'],FR
33271,['8858038630' 'B09JPDLS27' '8858043022'],IT


In [11]:
predictions = pd.read_csv('result.csv')

In [12]:
def check_predictions(predictions, check_products=False):
    """
    These tests need to pass as they will also be applied on the evaluator
    """
    test_locale_names = test_sessions['locale'].unique()
    for locale in test_locale_names:
        sess_test = test_sessions.query(f'locale == "{locale}"')
        preds_locale =  predictions[predictions['locale'] == sess_test['locale'].iloc[0]]
        assert sorted(preds_locale.index.values) == sorted(sess_test.index.values), f"Session ids of {locale} doesn't match"

        if check_products:
            # This check is not done on the evaluator
            # but you can run it to verify there is no mixing of products between locales
            # Since the ground truth next item will always belong to the same locale
            # Warning - This can be slow to run
            products = read_product_data().query(f'locale == "{locale}"')
            predicted_products = np.unique( np.array(list(preds_locale["next_item_prediction"].values)) )
            assert np.all( np.isin(predicted_products, products['id']) ), f"Invalid products in {locale} predictions"

In [13]:
check_predictions(predictions)

In [14]:
# Its important that the parquet file you submit is saved with pyarrow backend
predictions.to_parquet(f'submission_{task}.parquet', engine='pyarrow')