In [1]:
import numpy as np 
import pandas as pd 
import os 
from functools import lru_cache

In [2]:
prediction_df = pd.read_parquet('./predictions/SASRec2/kdd_cup_2023/2023-03-22-16-16-53.parquet', engine='pyarrow')

In [5]:
prediction_df.head(100)


Unnamed: 0,locale,next_item_prediction
0,DE,"[B07JG9TFSB, B099NS1XPG, B07LG5T3V9, B07TV364M..."
1,DE,"[B08NP5GN8R, B07VY9P2GR, B08TZ2RWF6, B09M7ZNQS..."
2,DE,"[B0BJF4KGCN, B000KSLHNQ, B09CTYS57Z, B0B5TFLBC..."
3,DE,"[B001RYNMDA, B01MXLEVR7, B0BD88WWQ8, B009Q8492..."
4,DE,"[B09J8SKX9G, B09MTBHT8P, B0BDDPDQSM, B0B51L89C..."
...,...,...
95,DE,"[B07SD6621Q, B07J12LPMD, B00HYAE0LO, B09LHT9T7..."
96,DE,"[B0B87R6VZ4, B0B87V78G4, B09N3NTVJT, B09LD6XRR..."
97,DE,"[B0B244R4KB, B0B8J8JCZ1, B0BDJM32H4, B0B5H56N3..."
98,DE,"[B003F0WKDM, B09CCVCM32, B07F38XNL3, B0924MT4H..."


In [6]:
(prediction_df['locale'] == 'UK').sum()

115936

In [7]:
prediction_df.iloc[0]['next_item_prediction']

array(['B07JG9TFSB', 'B099NS1XPG', 'B07LG5T3V9', 'B07TV364MZ',
       'B07TV22X9M', 'B08CRV3XXV', 'B0BGL7KC2D', 'B089FBHSJ8',
       'B0B53KBXR8', 'B08QYYBTMC', 'B07XG2PTH8', 'B096BGC3XF',
       'B01H1R0K68', 'B094R3R9XH', 'B089FBVXRZ', 'B07JDSHD4Z',
       'B0B1DJ7LMB', 'B07YWR9S66', 'B086CJF45F', 'B0B3DKVCC6',
       'B06XZYP5FW', 'B093PT1NL1', 'B07GPT8HPY', 'B07JW7K2M5',
       'B09NKFSLGB', 'B0B9HC9P9H', 'B0B68FM5ZL', 'B0BFJJG7RT',
       'B07ZTGVKWN', 'B09W5988V1', 'B07TW3NL2M', 'B07TS22D8Z',
       'B09129FV4D', 'B08Q391KS3', 'B095C1CHMQ', 'B09T2ZXL4V',
       'B0BB2B69CV', 'B091YCWH9S', 'B079JXY1RC', 'B09P3K5778',
       'B0BHHZ9LPT', 'B09Z67V8GM', 'B08RJ6QGFV', 'B0B96Y2N4D',
       'B0B5ND342Y', 'B082N332HB', 'B07JG9QZ2B', 'B0876NB9KT',
       'B07QH1MYPY', 'B08FDKF3MW', 'B07T14HSNQ', 'B0B7S7LBMB',
       'B07G2YM494', 'B09QFN2DYJ', 'B09N3CHFZK', 'B084RPXG3B',
       'B09PH91HX2', 'B09KG6VDF2', 'B09X1K2Y1K', 'B00AUMR30S',
       'B0BB6T261Z', 'B09QFPYX34', 'B09JRJ3NYV', 'B0B7W

In [8]:
test_sessions = pd.read_csv('./raw_data/sessions_test_task1.csv')
test_sessions.head(5)

Unnamed: 0,prev_items,locale
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE


In [9]:
train_data_dir = './raw_data/'
test_data_dir = './raw_data/'
@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(os.path.join(train_data_dir, 'products_train.csv'))

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv(os.path.join(train_data_dir, 'sessions_train.csv'))

@lru_cache(maxsize=3)
def read_test_data(task):
    return pd.read_csv(os.path.join(test_data_dir, f'sessions_test_{task}.csv'))

In [10]:
def check_predictions(predictions, check_products=False):
    """
    These tests need to pass as they will also be applied on the evaluator
    """
    test_locale_names = test_sessions['locale'].unique()
    for locale in test_locale_names:
        sess_test = test_sessions.query(f'locale == "{locale}"')
        preds_locale =  predictions[predictions['locale'] == sess_test['locale'].iloc[0]]
        assert sorted(preds_locale.index.values) == sorted(sess_test.index.values), f"Session ids of {locale} doesn't match"

        if check_products:
            # This check is not done on the evaluator
            # but you can run it to verify there is no mixing of products between locales
            # Since the ground truth next item will always belong to the same locale
            # Warning - This can be slow to run
            products = read_product_data().query(f'locale == "{locale}"')
            predicted_products = np.unique( np.array(list(preds_locale["next_item_prediction"].values)) )
            assert np.all( np.isin(predicted_products, products['id']) ), f"Invalid products in {locale} predictions"

In [11]:
check_predictions(prediction_df)