In [1]:
import pickle
import pandas as pd
import numpy as np
from functools import lru_cache

In [2]:
!aicrowd login

'aicrowd' is not recognized as an internal or external command,
operable program or batch file.


In [3]:
task = 'task2'

In [4]:
# Cache loading of data for multiple calls

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv('../data/raw-data/products_train.csv')

@lru_cache(maxsize=1)
def read_train_data():
    return pd.read_csv('../data/raw-data/sessions_train.csv')

@lru_cache(maxsize=3)
def read_test_data(task):
    return pd.read_csv(f'../data/raw-data/sessions_test_{task}.csv')

In [5]:
train_sessions = read_train_data()
train_sessions.sample(5)

Unnamed: 0,prev_items,next_item,locale
2587339,['B07WJHY2R7' 'B01N903LMM'],B072WGPK53,UK
2956781,['B09QLFV81Y' 'B08L5DX25C' 'B082WDM878'],B08X1T3W89,UK
159736,['B01FDHKJ96' 'B01FDHKJ96' 'B00AQJ6V4O' 'B07D9...,B07D9C8NP2,DE
1774349,['B00267T6L4' 'B00267T6LE' 'B00267T6LO' 'B00SU...,B07XV9CGJC,JP
2358753,['B0BJQ91MQF' 'B0BJ93KSC8' 'B07S5B9DVH' 'B0BJQ...,B08B1D8Z7Q,UK


In [6]:
test_sessions = read_test_data(task)
test_sessions.sample(5)

Unnamed: 0,prev_items,locale
20915,['B00NOP8NUW' 'B07549RGZH'],IT
26621,['B098TVZ6TS' 'B095Y1JNHG'],IT
26858,['B00RTCUIX6' 'B00D3HXQ6C'],IT
2083,['B07MBTDL1H' 'B07KF9TC3J'],ES
17882,['B0868F8JTM' 'B09MTKBB6B'],FR


In [7]:
locale = []
result = []

In [8]:
LOCALES = ["ES", "FR", "IT"] #change locale name in order for another locale recommendations

for LOCALE in LOCALES:
    with open(f"../SR-GNN/test-result/{LOCALE}-recs.pkl", "rb") as f:
        recs = pickle.load(f)

    for l in recs:
        result.append(l)


    for i in range(len(recs)):
        locale.append(LOCALE)

In [9]:
predictions = pd.DataFrame(list(zip(locale, result)),
               columns =['locale', 'next_item_prediction'])
predictions

Unnamed: 0,locale,next_item_prediction
0,ES,"[B08GY8NHF2, B073JYVKNX, B08GYBBBBH, B074RNRM2..."
1,ES,"[B085NGXGWM, B06XGHV89F, B08NYF9MBQ, B01M3RT7M..."
2,ES,"[B0B1DG29F4, B091FL1QFK, B09BD6SY6B, B01CQOIL3..."
3,ES,"[B07JCNMR9S, B07SPSZCF2, B07F38XNL3, B0859T1YC..."
4,ES,"[B099WT8P4D, B09V1NBZ47, B09V1LJLGH, B00Y4SN79..."
...,...,...
34683,IT,"[8894957314, B07WKNQ8JT, B09JQ5B3S5, B09JSZ2DH..."
34684,IT,"[8894957314, B07WKNQ8JT, B09JQ5B3S5, B09JSZ2DH..."
34685,IT,"[8894957314, B07WKNQ8JT, B09JQ5B3S5, B09JSZ2DH..."
34686,IT,"[8894957314, B07WKNQ8JT, B09JQ5B3S5, B09JSZ2DH..."


In [10]:
def check_predictions(predictions, check_products=False):
    """
    These tests need to pass as they will also be applied on the evaluator
    """
    test_locale_names = test_sessions['locale'].unique()
    for locale in test_locale_names:
        sess_test = test_sessions.query(f'locale == "{locale}"')
        preds_locale =  predictions[predictions['locale'] == sess_test['locale'].iloc[0]]
        assert sorted(preds_locale.index.values) == sorted(sess_test.index.values), f"Session ids of {locale} doesn't match"

        if check_products:
            # This check is not done on the evaluator
            # but you can run it to verify there is no mixing of products between locales
            # Since the ground truth next item will always belong to the same locale
            # Warning - This can be slow to run
            products = read_product_data().query(f'locale == "{locale}"')
            predicted_products = np.unique( np.array(list(preds_locale["next_item_prediction"].values)) )
            assert np.all( np.isin(predicted_products, products['id']) ), f"Invalid products in {locale} predictions"

In [11]:
#check_predictions(predictions, True)

In [12]:
# Its important that the parquet file you submit is saved with pyarrow backend
predictions.to_parquet(f'submission_{task}.parquet', engine='pyarrow')

## Submit to AIcrowd 🚀

In [13]:
# You can submit with aicrowd-cli, or upload manually on the challenge page.
!aicrowd submission create -c task-2-next-product-recommendation-for-underrepresented-languages -f "submission_task2.parquet"

'aicrowd' is not recognized as an internal or external command,
operable program or batch file.
