In [1]:
# %pip install -U lightgbm

In [1]:
import pandas as pd
import numpy as np

import pickle
from tqdm import tqdm
from pathlib import Path
import gc

In [2]:
import warnings
import sys
from IPython.core.interactiveshell import InteractiveShell

warnings.filterwarnings("ignore")
sys.path.append("../src/")
InteractiveShell.ast_node_interactivity = "all"

In [3]:
from data import DataHelper, map_at_k, hr_at_k, recall_at_k
from retrieval.rules import OrderHistory, ItemPair, TimeHistory, OutOfStock
from retrieval.collector import RuleCollector

In [5]:
data_dir = Path("../data/")
dh = DataHelper(data_dir)

In [6]:
# data = dh.preprocess_data(save=True) # run only once

In [7]:
data = dh.load_data(name="encoded_full")

## Retrieval

In [8]:
trans = data["inter"]
train, valid = dh.split_data(trans, "2020-09-16", "2020-09-23")
last_week = train.loc[train.t_dat >= "2020-09-09"]
customer_list = valid["customer_id"].values

In [13]:
candidates = RuleCollector().collect(
    customer_list=customer_list,
    rules=[
        OrderHistory(train, 7),
        ItemPair(OrderHistory(train, 7).retrieve()),
        TimeHistory(last_week, 12),
    ],
    filters=[OutOfStock(trans)],
    compress=True,
)

Retrieve items by rules: 100%|██████████| 3/3 [00:45<00:00, 15.02s/it]


0.025353589326755178

In [16]:
candidates.rename(columns={'article_id': 'prediction'}, inplace=True)
valid2 = pd.merge(valid, candidates, on="customer_id", how="left")

In [18]:
map_at_k(valid2["article_id"], valid2["prediction"], k=12)
hr_at_k(valid2["article_id"], valid2["prediction"], k=12)
recall_at_k(valid2["article_id"], valid2["prediction"], k=12)

0.025353589326755178

0.1117795430824539

0.05785989807865099

## Predict

In [19]:
uid2idx = pickle.load(open(data_dir/"index_id_map/user_id2index.pkl", "rb"))
idx2uid = pickle.load(open(data_dir/"index_id_map/user_index2id.pkl", "rb"))
idx2iid = pickle.load(open(data_dir/"index_id_map/item_index2id.pkl", "rb"))

In [30]:
submission = pd.read_csv(data_dir/"raw"/'sample_submission.csv')
submission['customer_id'] = submission['customer_id'].map(uid2idx)

In [27]:
last_week = trans.loc[trans.t_dat >= "2020-09-16"]
candidates = RuleCollector().collect(
    customer_list=submission['customer_id'].values,
    rules=[
        OrderHistory(trans, 7),
        ItemPair(OrderHistory(trans, 7).retrieve()),
        TimeHistory(last_week, 12),
    ],
    filters=[OutOfStock(trans)],
    compress=False,
)

Retrieve items by rules: 100%|██████████| 3/3 [00:47<00:00, 15.81s/it]


In [28]:
candidates['article_id'] = candidates['article_id'].map(idx2iid).apply(lambda x:'0'+str(x))
candidates = candidates.groupby('customer_id')['article_id'].apply(list).reset_index()
candidates['article_id'] = candidates['article_id'].apply(lambda x: ' '.join(x))
candidates.rename(columns={'article_id':'prediction'}, inplace=True)

In [33]:
del submission['prediction']
submission = submission.merge(candidates, on='customer_id', how='left')
submission['customer_id'] = submission['customer_id'].map(idx2uid)

In [None]:
submission.to_csv('submission.csv', index=False)