In [None]:
import gc
from pathlib import Path

import torch 
import transformers4rec.torch as tr
from transformers4rec.torch import Trainer
from transformers4rec.config.trainer import T4RecTrainingArguments
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt
from transformers4rec.torch.utils.examples_utils import wipe_memory
from transformers4rec.torch.utils.data_utils import MerlinDataLoader

from merlin.io import Dataset
from merlin.schema import Schema

In [None]:
data_path = Path.cwd() / 'data'

## LB

In [None]:
lb_in = data_path / 'lb'
lb_out = data_path / 'lb/model'
if not lb_out.is_dir():
    lb_out.mkdir()

In [None]:
schema_path = lb_in / 'processed_nvt/part_0.parquet'

x_cat_names = ['aid-list']
x_cont_names = ['product_recency_day_log_norm-list', 't_dow_sin-list', 't_dow_cos-list']

train = Dataset(schema_path.as_posix())
schema = train.schema
schema = schema.select_by_name(x_cat_names + x_cont_names)

In [None]:
sequence_length = 20
d_model = 192

inputs = tr.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=sequence_length,
    aggregation='concat',
    masking='mlm'
)

trainsformer_config = tr.XLNetConfig.build(
    d_model=d_model, n_head=4, n_layer=2, total_sequence_length=sequence_length
)

body = tr.SequentialBlock(
    inputs, tr.MLPBlock([d_model]), tr.TransformerBlock(trainsformer_config, masking=inputs.masking)
)

head = tr.Head(
    body,
    tr.NextItemPredictionTask(
        weight_tying=True,
        metrics=[
            NDCGAt(top_ks=[20, 40], labels_onehot=True),
            RecallAt(top_ks=[20, 40], labels_onehot=True)
        ]
    )
)

model = tr.Model(head)

In [None]:
training_args = T4RecTrainingArguments(
    output_dir=lb_out.as_posix(),
    max_sequence_length=20,
    data_loader_engine='merlin',
    num_train_epochs=10, 
    dataloader_drop_last=False,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=0.000666,
    report_to=[],
    logging_steps=200,
)

trainer = Trainer(
    model=model,
    args=training_args,
    schema=schema,
    compute_metrics=True,
)

In [None]:
%%time

sessions_path = lb_in / 'sessions_by_week'
start_window_index = 1
end_window_index = 4

for time_index in range(start_window_index, end_window_index):
    time_index_train = time_index
    time_index_eval = time_index + 1
    
    train_path = (session_path / f'{time_index_train}/train.parquet').as_posix()
    eval_path = (session_path / f'{time_index_eval}/valid.parquet').as_posix()

    print('*'*20)
    print("Launch training for day %s are:" %time_index)
    print('*'*20 + '\n')

    trainer.train_dataset_or_path = train_path
    trainer.reset_lr_scheduler()
    trainer.train()
    trainer.state.global_step +=1
    
    trainer.eval_dataset_or_path = eval_path
    train_metrics = trainer.evaluate(metric_key_prefix='eval')
    print('*'*20)
    print("Eval results for day %s are:\t" %time_index_eval)
    print('\n' + '*'*20 + '\n')
    for key in sorted(train_metrics.keys()):
        print(" %s = %s" % (key, str(train_metrics[key]))) 
    wipe_memory()