<i>Copyright (c) Microsoft Corporation. All rights reserved.</i>

<i>Licensed under the MIT License.</i>

In [2]:
import sys
import os
import logging
import papermill as pm
import scrapbook as sb
from tempfile import TemporaryDirectory
import pandas as pd
import numpy as np
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # only show error messages

from recommenders.utils.timer import Timer
from recommenders.utils.constants import SEED
from recommenders.models.deeprec.deeprec_utils import prepare_hparams

from recommenders.datasets.amazon_reviews import download_and_extract, data_preprocessing
from recommenders.datasets.download_utils import maybe_download

from recommenders.models.deeprec.models.sequential.sli_rec import SLI_RECModel as SeqModel
from recommenders.models.deeprec.io.sequential_iterator import SequentialIterator

print("System version: {}".format(sys.version))
print("Tensorflow version: {}".format(tf.__version__))

System version: 3.7.0 (default, Jun 28 2018, 08:04:48) [MSC v.1912 64 bit (AMD64)]
Tensorflow version: 1.15.0


# 1. Input Data Format
The input data contains 8 columns, i.e.,   `<label> <user_id> <item_id> <category_id> <timestamp> <history_item_ids> <history_cateory_ids> <hitory_timestamp>`  columns are seperated by `"\t"`.  item_id and category_id denote the target item and category, which means that for this instance, we want to guess whether user user_id will interact with item_id at timestamp. `<history_*>` columns record the user behavior list up to `<timestamp>`, elements are separated by commas.  `<label>` is a binary value with 1 for positive instances and 0 for negative instances.  One example for an instance is: 

`1	523120	14414	3	1612282953577	14330,14135,5877	3,0,8	1611878652251,1611935700202,1612166343656`

In data preprocessing stage, we have a script to generate some ID mapping dictionaries, so user_id, item_id and category_id will be mapped into interager index starting from 1. And you need to tell the input iterator where is the ID mapping files are. (For example, in the next section, we have some mapping files like user_vocab, item_vocab, and cate_vocab).  The data preprocessing script is at [recommenders/dataset/amazon_reviews.py](../nhat_hoang/recommenders/datasets/amazon_reviews.py). Note that ID vocabulary only creates from the train_file, so the new IDs in valid_file or test_file will be regarded as unknown IDs and assigned with a defualt 0 index.

We use Softmax to the loss function. In training and evalution stage, we group 1 positive instance with num_ngs negative instances. Pair-wise ranking can be regarded as a special case of Softmax ranking, where num_ngs is set to 1. 

More specifically,  for training and evalation, you need to organize the data file such that each one positive instance is followd by num_ngs negative instances. Our program will take 1+num_ngs lines as a unit for Softmax calculation. num_ngs is a parameter you need to pass to the `prepare_hparams`, `fit` and `run_eval` function. `train_num_ngs` in `prepare_hparams` denotes the number of negative instances for training, where a recommended number is 4. `valid_num_ngs` and `num_ngs` in `fit` and `run_eval` denote the number in evalution. In evaluation, the model calculates metrics among the 1+num_ngs instances. For the `predict` function, since we only need to calcuate a socre for each individual instance, there is no need for num_ngs setting.  More details and examples will be provided in the following sections.

In [3]:
yaml_file = '../nhat_hoang/recommenders/models/deeprec/config/sli_rec.yaml'

In [4]:
EPOCHS = 10
BATCH_SIZE = 400
RANDOM_SEED = SEED  # Set None for non-deterministic result

data_path = os.path.join("..", "nhat_hoang", "test_slirec", "whole_dataset2")

In [5]:
# for test
train_file = os.path.join(data_path, r'train_data')
valid_file = os.path.join(data_path, r'valid_data')
test_file = os.path.join(data_path, r'test_data')
user_vocab = os.path.join(data_path, r'user_vocab.pkl')
item_vocab = os.path.join(data_path, r'item_vocab.pkl')
cate_vocab = os.path.join(data_path, r'category_vocab.pkl')
output_file = os.path.join(data_path, r'output.txt')

train_num_ngs = 4 # number of negative instances with a positive instance for training
valid_num_ngs = 4 # number of negative instances with a positive instance for validation
test_num_ngs = 9 # number of negative instances with a positive instance for testing
sample_rate = 1 # sample a small item set for training and testing here for fast example

if not os.path.exists(train_file):
    data_preprocessing(train_file, valid_file, test_file, user_vocab, item_vocab, cate_vocab,
                       sample_rate=sample_rate, valid_num_ngs=valid_num_ngs, test_num_ngs=test_num_ngs)

### 1.1 Prepare hyper-parameters

`prepare_hparams()` will create a full set of hyper-parameters for model training, such as learning rate, feature number, and dropout ratio. We can put those parameters in a yaml file (a complete list of parameters can be found under our config folder) , or pass parameters as the function's parameters (which will overwrite yaml settings).

Parameters hints:
need_sample controls whether to perform dynamic negative sampling in mini-batch. train_num_ngs indicates how many negative instances followed by one positive instances.
Examples:
- `need_sample=True` and `train_num_ngs=4`: There are only positive instances in your training file. Our model will dynamically sample 4 negative instances for each positive instances in mini-batch. Note that if `need_sample is` set to True, `train_num_ngs` should be greater than zero.
- `need_sample=False` and `train_num_ngs=4`: In your training file, each one positive line is followed by 4 negative lines. Note that if `need_sample` is set to False, you must provide a traiing file with negative instances, and `train_num_ngs` should match the number of negative number in training file.

In [7]:
hparams = prepare_hparams(yaml_file, 
                          embed_l2=0., 
                          layer_l2=0., 
                          learning_rate=0.001,  # set to 0.01 if batch normalization is disable
                          epochs=EPOCHS,
                          batch_size=BATCH_SIZE,
                          show_step=20,
                          MODEL_DIR=os.path.join(data_path, "model/"),
                          SUMMARIES_DIR=os.path.join(data_path, "summary/"),
                          user_vocab=user_vocab,
                          item_vocab=item_vocab,
                          cate_vocab=cate_vocab,
                          need_sample=True,
                          train_num_ngs=train_num_ngs, # provides the number of negative instances for each positive instance for loss computation.
)

### 1.2 Create data loader
Designate a data iterator for the model. All our sequential models use SequentialIterator. data format is introduced aboved.


Validation and testing data are files after negative sampling offline with the number of <num_ngs> and <test_num_ngs>.

In [8]:
input_creator = SequentialIterator

# 2. Create Model

In [9]:
model = SeqModel(hparams, input_creator, seed=RANDOM_SEED)

# model.load_model(os.path.join(hparams.MODEL_DIR, "best_model"))

Performance before training

In [10]:
model.run_eval(test_file, num_ngs=test_num_ngs)

{'auc': 0.5268,
 'logloss': 0.6931,
 'mean_mrr': 0.283,
 'group_auc': 0.526,
 'ndcg@1': 0.0829,
 'ndcg@3': 0.2088,
 'ndcg@5': 0.2865,
 'hit@1': 0.0829,
 'hit@3': 0.3059,
 'hit@5': 0.4956}

AUC around 0.5 is a state of random guess. We can see that before training, the model behaves like random guessing.

### Train model

In [11]:
with Timer() as train_time:
    model = model.fit(train_file, valid_file, valid_num_ngs=valid_num_ngs) 

# valid_num_ngs is the number of negative lines after each positive line in your valid_file 
# we will evaluate the performance of model on valid_file every epoch
print('Time cost for training is {0:.2f} mins'.format(train_time.interval/60.0))

step 20 , total_loss: 1.3830, data_loss: 1.3830
step 40 , total_loss: 1.1761, data_loss: 1.1761
step 60 , total_loss: 1.1281, data_loss: 1.1281
step 80 , total_loss: 1.0655, data_loss: 1.0655
step 100 , total_loss: 1.0706, data_loss: 1.0706
step 120 , total_loss: 1.0018, data_loss: 1.0018
step 140 , total_loss: 1.0814, data_loss: 1.0814
step 160 , total_loss: 1.0410, data_loss: 1.0410
step 180 , total_loss: 1.0311, data_loss: 1.0311
step 200 , total_loss: 0.9807, data_loss: 0.9807
step 220 , total_loss: 1.0163, data_loss: 1.0163
step 240 , total_loss: 1.1000, data_loss: 1.1000
step 260 , total_loss: 1.0024, data_loss: 1.0024
step 280 , total_loss: 1.0444, data_loss: 1.0444
step 300 , total_loss: 0.9967, data_loss: 0.9967
step 320 , total_loss: 1.0112, data_loss: 1.0112
step 340 , total_loss: 0.9862, data_loss: 0.9862
step 360 , total_loss: 1.0074, data_loss: 1.0074
step 380 , total_loss: 0.9702, data_loss: 0.9702
step 400 , total_loss: 1.0179, data_loss: 1.0179
step 420 , total_loss: 0

step 3320 , total_loss: 0.9628, data_loss: 0.9628
step 3340 , total_loss: 0.8218, data_loss: 0.8218
step 3360 , total_loss: 0.8483, data_loss: 0.8483
step 3380 , total_loss: 0.8602, data_loss: 0.8602
step 3400 , total_loss: 0.9133, data_loss: 0.9133
step 3420 , total_loss: 0.8697, data_loss: 0.8697
step 3440 , total_loss: 0.8781, data_loss: 0.8781
step 3460 , total_loss: 0.9064, data_loss: 0.9064
step 3480 , total_loss: 0.8939, data_loss: 0.8939
step 3500 , total_loss: 0.8418, data_loss: 0.8418
step 3520 , total_loss: 0.9043, data_loss: 0.9043
step 3540 , total_loss: 0.8915, data_loss: 0.8915
step 3560 , total_loss: 0.9823, data_loss: 0.9823
step 3580 , total_loss: 0.8851, data_loss: 0.8851
step 3600 , total_loss: 0.9728, data_loss: 0.9728
step 3620 , total_loss: 0.9125, data_loss: 0.9125
step 3640 , total_loss: 0.9006, data_loss: 0.9006
step 3660 , total_loss: 0.8952, data_loss: 0.8952
step 3680 , total_loss: 0.8508, data_loss: 0.8508
step 3700 , total_loss: 0.8621, data_loss: 0.8621


step 6600 , total_loss: 0.9079, data_loss: 0.9079
step 6620 , total_loss: 0.8478, data_loss: 0.8478
step 6640 , total_loss: 0.8078, data_loss: 0.8078
step 6660 , total_loss: 0.7753, data_loss: 0.7753
step 6680 , total_loss: 0.8543, data_loss: 0.8543
step 6700 , total_loss: 0.8305, data_loss: 0.8305
step 6720 , total_loss: 0.8191, data_loss: 0.8191
step 6740 , total_loss: 0.8543, data_loss: 0.8543
step 6760 , total_loss: 0.7993, data_loss: 0.7993
step 6780 , total_loss: 0.7937, data_loss: 0.7937
step 6800 , total_loss: 0.7861, data_loss: 0.7861
step 6820 , total_loss: 0.8581, data_loss: 0.8581
step 6840 , total_loss: 0.8703, data_loss: 0.8703
step 6860 , total_loss: 0.8369, data_loss: 0.8369
step 6880 , total_loss: 0.8527, data_loss: 0.8527
step 6900 , total_loss: 0.8610, data_loss: 0.8610
step 6920 , total_loss: 0.8407, data_loss: 0.8407
step 6940 , total_loss: 0.8254, data_loss: 0.8254
step 6960 , total_loss: 0.8752, data_loss: 0.8752
step 6980 , total_loss: 0.7968, data_loss: 0.7968


step 2000 , total_loss: 0.7822, data_loss: 0.7822
step 2020 , total_loss: 0.7548, data_loss: 0.7548
step 2040 , total_loss: 0.8345, data_loss: 0.8345
step 2060 , total_loss: 0.6943, data_loss: 0.6943
step 2080 , total_loss: 0.7346, data_loss: 0.7346
step 2100 , total_loss: 0.8672, data_loss: 0.8672
step 2120 , total_loss: 0.7716, data_loss: 0.7716
step 2140 , total_loss: 0.7223, data_loss: 0.7223
step 2160 , total_loss: 0.7588, data_loss: 0.7588
step 2180 , total_loss: 0.7200, data_loss: 0.7200
step 2200 , total_loss: 0.8247, data_loss: 0.8247
step 2220 , total_loss: 0.7805, data_loss: 0.7805
step 2240 , total_loss: 0.8345, data_loss: 0.8345
step 2260 , total_loss: 0.8818, data_loss: 0.8818
step 2280 , total_loss: 0.7672, data_loss: 0.7672
step 2300 , total_loss: 0.7762, data_loss: 0.7762
step 2320 , total_loss: 0.7559, data_loss: 0.7559
step 2340 , total_loss: 0.7824, data_loss: 0.7824
step 2360 , total_loss: 0.8188, data_loss: 0.8188
step 2380 , total_loss: 0.7511, data_loss: 0.7511


step 5280 , total_loss: 0.6931, data_loss: 0.6931
step 5300 , total_loss: 0.7506, data_loss: 0.7506
step 5320 , total_loss: 0.7681, data_loss: 0.7681
step 5340 , total_loss: 0.7621, data_loss: 0.7621
step 5360 , total_loss: 0.8249, data_loss: 0.8249
step 5380 , total_loss: 0.8009, data_loss: 0.8009
step 5400 , total_loss: 0.7016, data_loss: 0.7016
step 5420 , total_loss: 0.8263, data_loss: 0.8263
step 5440 , total_loss: 0.7796, data_loss: 0.7796
step 5460 , total_loss: 0.7309, data_loss: 0.7309
step 5480 , total_loss: 0.7897, data_loss: 0.7897
step 5500 , total_loss: 0.8172, data_loss: 0.8172
step 5520 , total_loss: 0.7351, data_loss: 0.7351
step 5540 , total_loss: 0.7691, data_loss: 0.7691
step 5560 , total_loss: 0.7952, data_loss: 0.7952
step 5580 , total_loss: 0.7376, data_loss: 0.7376
step 5600 , total_loss: 0.7751, data_loss: 0.7751
step 5620 , total_loss: 0.7854, data_loss: 0.7854
step 5640 , total_loss: 0.7377, data_loss: 0.7377
step 5660 , total_loss: 0.7804, data_loss: 0.7804


step 680 , total_loss: 0.7468, data_loss: 0.7468
step 700 , total_loss: 0.7773, data_loss: 0.7773
step 720 , total_loss: 0.7204, data_loss: 0.7204
step 740 , total_loss: 0.7346, data_loss: 0.7346
step 760 , total_loss: 0.7641, data_loss: 0.7641
step 780 , total_loss: 0.8092, data_loss: 0.8092
step 800 , total_loss: 0.8741, data_loss: 0.8741
step 820 , total_loss: 0.7603, data_loss: 0.7603
step 840 , total_loss: 0.7729, data_loss: 0.7729
step 860 , total_loss: 0.7322, data_loss: 0.7322
step 880 , total_loss: 0.7096, data_loss: 0.7096
step 900 , total_loss: 0.7365, data_loss: 0.7365
step 920 , total_loss: 0.7877, data_loss: 0.7877
step 940 , total_loss: 0.7619, data_loss: 0.7619
step 960 , total_loss: 0.7841, data_loss: 0.7841
step 980 , total_loss: 0.7758, data_loss: 0.7758
step 1000 , total_loss: 0.7549, data_loss: 0.7549
step 1020 , total_loss: 0.7264, data_loss: 0.7264
step 1040 , total_loss: 0.7514, data_loss: 0.7514
step 1060 , total_loss: 0.7620, data_loss: 0.7620
step 1080 , tota

step 3980 , total_loss: 0.7732, data_loss: 0.7732
step 4000 , total_loss: 0.8071, data_loss: 0.8071
step 4020 , total_loss: 0.7033, data_loss: 0.7033
step 4040 , total_loss: 0.6941, data_loss: 0.6941
step 4060 , total_loss: 0.7393, data_loss: 0.7393
step 4080 , total_loss: 0.7626, data_loss: 0.7626
step 4100 , total_loss: 0.7461, data_loss: 0.7461
step 4120 , total_loss: 0.7046, data_loss: 0.7046
step 4140 , total_loss: 0.8072, data_loss: 0.8072
step 4160 , total_loss: 0.8397, data_loss: 0.8397
step 4180 , total_loss: 0.7513, data_loss: 0.7513
step 4200 , total_loss: 0.7593, data_loss: 0.7593
step 4220 , total_loss: 0.7634, data_loss: 0.7634
step 4240 , total_loss: 0.7079, data_loss: 0.7079
step 4260 , total_loss: 0.7360, data_loss: 0.7360
step 4280 , total_loss: 0.7395, data_loss: 0.7395
step 4300 , total_loss: 0.6931, data_loss: 0.6931
step 4320 , total_loss: 0.6998, data_loss: 0.6998
step 4340 , total_loss: 0.7369, data_loss: 0.7369
step 4360 , total_loss: 0.7029, data_loss: 0.7029


step 7260 , total_loss: 0.8260, data_loss: 0.8260
step 7280 , total_loss: 0.6796, data_loss: 0.6796
step 7300 , total_loss: 0.7883, data_loss: 0.7883
step 7320 , total_loss: 0.7960, data_loss: 0.7960
step 7340 , total_loss: 0.7596, data_loss: 0.7596
step 7360 , total_loss: 0.6771, data_loss: 0.6771
step 7380 , total_loss: 0.7299, data_loss: 0.7299
step 7400 , total_loss: 0.7430, data_loss: 0.7430
step 7420 , total_loss: 0.7250, data_loss: 0.7250
step 7440 , total_loss: 0.8082, data_loss: 0.8082
step 7460 , total_loss: 0.8504, data_loss: 0.8504
step 7480 , total_loss: 0.7445, data_loss: 0.7445
step 7500 , total_loss: 0.9196, data_loss: 0.9196
step 7520 , total_loss: 0.6923, data_loss: 0.6923
step 7540 , total_loss: 0.7600, data_loss: 0.7600
step 7560 , total_loss: 0.7489, data_loss: 0.7489
step 7580 , total_loss: 0.7472, data_loss: 0.7472
step 7600 , total_loss: 0.7784, data_loss: 0.7784
step 7620 , total_loss: 0.7390, data_loss: 0.7390
step 7640 , total_loss: 0.8423, data_loss: 0.8423


step 2660 , total_loss: 0.7377, data_loss: 0.7377
step 2680 , total_loss: 0.7144, data_loss: 0.7144
step 2700 , total_loss: 0.6978, data_loss: 0.6978
step 2720 , total_loss: 0.7442, data_loss: 0.7442
step 2740 , total_loss: 0.7500, data_loss: 0.7500
step 2760 , total_loss: 0.7169, data_loss: 0.7169
step 2780 , total_loss: 0.7687, data_loss: 0.7687
step 2800 , total_loss: 0.8277, data_loss: 0.8277
step 2820 , total_loss: 0.7594, data_loss: 0.7594
step 2840 , total_loss: 0.7376, data_loss: 0.7376
step 2860 , total_loss: 0.7695, data_loss: 0.7695
step 2880 , total_loss: 0.7530, data_loss: 0.7530
step 2900 , total_loss: 0.7710, data_loss: 0.7710
step 2920 , total_loss: 0.7362, data_loss: 0.7362
step 2940 , total_loss: 0.7246, data_loss: 0.7246
step 2960 , total_loss: 0.7351, data_loss: 0.7351
step 2980 , total_loss: 0.7960, data_loss: 0.7960
step 3000 , total_loss: 0.6911, data_loss: 0.6911
step 3020 , total_loss: 0.7217, data_loss: 0.7217
step 3040 , total_loss: 0.7775, data_loss: 0.7775


step 5940 , total_loss: 0.7860, data_loss: 0.7860
step 5960 , total_loss: 0.7748, data_loss: 0.7748
step 5980 , total_loss: 0.7805, data_loss: 0.7805
step 6000 , total_loss: 0.7597, data_loss: 0.7597
step 6020 , total_loss: 0.7648, data_loss: 0.7648
step 6040 , total_loss: 0.7254, data_loss: 0.7254
step 6060 , total_loss: 0.7215, data_loss: 0.7215
step 6080 , total_loss: 0.7726, data_loss: 0.7726
step 6100 , total_loss: 0.7278, data_loss: 0.7278
step 6120 , total_loss: 0.8334, data_loss: 0.8334
step 6140 , total_loss: 0.7320, data_loss: 0.7320
step 6160 , total_loss: 0.7119, data_loss: 0.7119
step 6180 , total_loss: 0.7145, data_loss: 0.7145
step 6200 , total_loss: 0.6533, data_loss: 0.6533
step 6220 , total_loss: 0.7112, data_loss: 0.7112
step 6240 , total_loss: 0.7818, data_loss: 0.7818
step 6260 , total_loss: 0.7945, data_loss: 0.7945
step 6280 , total_loss: 0.8001, data_loss: 0.8001
step 6300 , total_loss: 0.6834, data_loss: 0.6834
step 6320 , total_loss: 0.7685, data_loss: 0.7685


step 1340 , total_loss: 0.7862, data_loss: 0.7862
step 1360 , total_loss: 0.7359, data_loss: 0.7359
step 1380 , total_loss: 0.7265, data_loss: 0.7265
step 1400 , total_loss: 0.7735, data_loss: 0.7735
step 1420 , total_loss: 0.7199, data_loss: 0.7199
step 1440 , total_loss: 0.8501, data_loss: 0.8501
step 1460 , total_loss: 0.8033, data_loss: 0.8033
step 1480 , total_loss: 0.6798, data_loss: 0.6798
step 1500 , total_loss: 0.7673, data_loss: 0.7673
step 1520 , total_loss: 0.7026, data_loss: 0.7026
step 1540 , total_loss: 0.7573, data_loss: 0.7573
step 1560 , total_loss: 0.8079, data_loss: 0.8079
step 1580 , total_loss: 0.7605, data_loss: 0.7605
step 1600 , total_loss: 0.7167, data_loss: 0.7167
step 1620 , total_loss: 0.7672, data_loss: 0.7672
step 1640 , total_loss: 0.7206, data_loss: 0.7206
step 1660 , total_loss: 0.8060, data_loss: 0.8060
step 1680 , total_loss: 0.8129, data_loss: 0.8129
step 1700 , total_loss: 0.6538, data_loss: 0.6538
step 1720 , total_loss: 0.7837, data_loss: 0.7837


step 4620 , total_loss: 0.7508, data_loss: 0.7508
step 4640 , total_loss: 0.7728, data_loss: 0.7728
step 4660 , total_loss: 0.7668, data_loss: 0.7668
step 4680 , total_loss: 0.7853, data_loss: 0.7853
step 4700 , total_loss: 0.8450, data_loss: 0.8450
step 4720 , total_loss: 0.7273, data_loss: 0.7273
step 4740 , total_loss: 0.7610, data_loss: 0.7610
step 4760 , total_loss: 0.7285, data_loss: 0.7285
step 4780 , total_loss: 0.6970, data_loss: 0.6970
step 4800 , total_loss: 0.7425, data_loss: 0.7425
step 4820 , total_loss: 0.8274, data_loss: 0.8274
step 4840 , total_loss: 0.7418, data_loss: 0.7418
step 4860 , total_loss: 0.7718, data_loss: 0.7718
step 4880 , total_loss: 0.6615, data_loss: 0.6615
step 4900 , total_loss: 0.7447, data_loss: 0.7447
step 4920 , total_loss: 0.8301, data_loss: 0.8301
step 4940 , total_loss: 0.7034, data_loss: 0.7034
step 4960 , total_loss: 0.7066, data_loss: 0.7066
step 4980 , total_loss: 0.7664, data_loss: 0.7664
step 5000 , total_loss: 0.7375, data_loss: 0.7375


step 20 , total_loss: 0.6958, data_loss: 0.6958
step 40 , total_loss: 0.8007, data_loss: 0.8007
step 60 , total_loss: 0.7361, data_loss: 0.7361
step 80 , total_loss: 0.7344, data_loss: 0.7344
step 100 , total_loss: 0.7521, data_loss: 0.7521
step 120 , total_loss: 0.7465, data_loss: 0.7465
step 140 , total_loss: 0.7350, data_loss: 0.7350
step 160 , total_loss: 0.8370, data_loss: 0.8370
step 180 , total_loss: 0.8129, data_loss: 0.8129
step 200 , total_loss: 0.7260, data_loss: 0.7260
step 220 , total_loss: 0.7655, data_loss: 0.7655
step 240 , total_loss: 0.7267, data_loss: 0.7267
step 260 , total_loss: 0.7121, data_loss: 0.7121
step 280 , total_loss: 0.7445, data_loss: 0.7445
step 300 , total_loss: 0.7917, data_loss: 0.7917
step 320 , total_loss: 0.7275, data_loss: 0.7275
step 340 , total_loss: 0.7547, data_loss: 0.7547
step 360 , total_loss: 0.8077, data_loss: 0.8077
step 380 , total_loss: 0.7930, data_loss: 0.7930
step 400 , total_loss: 0.8003, data_loss: 0.8003
step 420 , total_loss: 0

step 3320 , total_loss: 0.6872, data_loss: 0.6872
step 3340 , total_loss: 0.7411, data_loss: 0.7411
step 3360 , total_loss: 0.7549, data_loss: 0.7549
step 3380 , total_loss: 0.7331, data_loss: 0.7331
step 3400 , total_loss: 0.7101, data_loss: 0.7101
step 3420 , total_loss: 0.7028, data_loss: 0.7028
step 3440 , total_loss: 0.7211, data_loss: 0.7211
step 3460 , total_loss: 0.7099, data_loss: 0.7099
step 3480 , total_loss: 0.6952, data_loss: 0.6952
step 3500 , total_loss: 0.7193, data_loss: 0.7193
step 3520 , total_loss: 0.7014, data_loss: 0.7014
step 3540 , total_loss: 0.6357, data_loss: 0.6357
step 3560 , total_loss: 0.7261, data_loss: 0.7261
step 3580 , total_loss: 0.7826, data_loss: 0.7826
step 3600 , total_loss: 0.7277, data_loss: 0.7277
step 3620 , total_loss: 0.7018, data_loss: 0.7018
step 3640 , total_loss: 0.7845, data_loss: 0.7845
step 3660 , total_loss: 0.7525, data_loss: 0.7525
step 3680 , total_loss: 0.7404, data_loss: 0.7404
step 3700 , total_loss: 0.7515, data_loss: 0.7515


step 6600 , total_loss: 0.8049, data_loss: 0.8049
step 6620 , total_loss: 0.6776, data_loss: 0.6776
step 6640 , total_loss: 0.8184, data_loss: 0.8184
step 6660 , total_loss: 0.7743, data_loss: 0.7743
step 6680 , total_loss: 0.7470, data_loss: 0.7470
step 6700 , total_loss: 0.7047, data_loss: 0.7047
step 6720 , total_loss: 0.7775, data_loss: 0.7775
step 6740 , total_loss: 0.7277, data_loss: 0.7277
step 6760 , total_loss: 0.7608, data_loss: 0.7608
step 6780 , total_loss: 0.6788, data_loss: 0.6788
step 6800 , total_loss: 0.7491, data_loss: 0.7491
step 6820 , total_loss: 0.7511, data_loss: 0.7511
step 6840 , total_loss: 0.7622, data_loss: 0.7622
step 6860 , total_loss: 0.7070, data_loss: 0.7070
step 6880 , total_loss: 0.7486, data_loss: 0.7486
step 6900 , total_loss: 0.7503, data_loss: 0.7503
step 6920 , total_loss: 0.7030, data_loss: 0.7030
step 6940 , total_loss: 0.7518, data_loss: 0.7518
step 6960 , total_loss: 0.7193, data_loss: 0.7193
step 6980 , total_loss: 0.7131, data_loss: 0.7131


step 2000 , total_loss: 0.7357, data_loss: 0.7357
step 2020 , total_loss: 0.7200, data_loss: 0.7200
step 2040 , total_loss: 0.7021, data_loss: 0.7021
step 2060 , total_loss: 0.7985, data_loss: 0.7985
step 2080 , total_loss: 0.7528, data_loss: 0.7528
step 2100 , total_loss: 0.6968, data_loss: 0.6968
step 2120 , total_loss: 0.7408, data_loss: 0.7408
step 2140 , total_loss: 0.7970, data_loss: 0.7970
step 2160 , total_loss: 0.7276, data_loss: 0.7276
step 2180 , total_loss: 0.7091, data_loss: 0.7091
step 2200 , total_loss: 0.7452, data_loss: 0.7452
step 2220 , total_loss: 0.8022, data_loss: 0.8022
step 2240 , total_loss: 0.7082, data_loss: 0.7082
step 2260 , total_loss: 0.7326, data_loss: 0.7326
step 2280 , total_loss: 0.6621, data_loss: 0.6621
step 2300 , total_loss: 0.7746, data_loss: 0.7746
step 2320 , total_loss: 0.6976, data_loss: 0.6976
step 2340 , total_loss: 0.7930, data_loss: 0.7930
step 2360 , total_loss: 0.6976, data_loss: 0.6976
step 2380 , total_loss: 0.8194, data_loss: 0.8194


step 5280 , total_loss: 0.6326, data_loss: 0.6326
step 5300 , total_loss: 0.6976, data_loss: 0.6976
step 5320 , total_loss: 0.6466, data_loss: 0.6466
step 5340 , total_loss: 0.6833, data_loss: 0.6833
step 5360 , total_loss: 0.7408, data_loss: 0.7408
step 5380 , total_loss: 0.8009, data_loss: 0.8009
step 5400 , total_loss: 0.8004, data_loss: 0.8004
step 5420 , total_loss: 0.7994, data_loss: 0.7994
step 5440 , total_loss: 0.6477, data_loss: 0.6477
step 5460 , total_loss: 0.6913, data_loss: 0.6913
step 5480 , total_loss: 0.6893, data_loss: 0.6893
step 5500 , total_loss: 0.7227, data_loss: 0.7227
step 5520 , total_loss: 0.7533, data_loss: 0.7533
step 5540 , total_loss: 0.7144, data_loss: 0.7144
step 5560 , total_loss: 0.7602, data_loss: 0.7602
step 5580 , total_loss: 0.7109, data_loss: 0.7109
step 5600 , total_loss: 0.7828, data_loss: 0.7828
step 5620 , total_loss: 0.7348, data_loss: 0.7348
step 5640 , total_loss: 0.7579, data_loss: 0.7579
step 5660 , total_loss: 0.7421, data_loss: 0.7421


step 680 , total_loss: 0.7127, data_loss: 0.7127
step 700 , total_loss: 0.7685, data_loss: 0.7685
step 720 , total_loss: 0.7478, data_loss: 0.7478
step 740 , total_loss: 0.7717, data_loss: 0.7717
step 760 , total_loss: 0.6687, data_loss: 0.6687
step 780 , total_loss: 0.6911, data_loss: 0.6911
step 800 , total_loss: 0.7693, data_loss: 0.7693
step 820 , total_loss: 0.7396, data_loss: 0.7396
step 840 , total_loss: 0.7606, data_loss: 0.7606
step 860 , total_loss: 0.7249, data_loss: 0.7249
step 880 , total_loss: 0.7725, data_loss: 0.7725
step 900 , total_loss: 0.7716, data_loss: 0.7716
step 920 , total_loss: 0.6925, data_loss: 0.6925
step 940 , total_loss: 0.7255, data_loss: 0.7255
step 960 , total_loss: 0.7717, data_loss: 0.7717
step 980 , total_loss: 0.7985, data_loss: 0.7985
step 1000 , total_loss: 0.6603, data_loss: 0.6603
step 1020 , total_loss: 0.7437, data_loss: 0.7437
step 1040 , total_loss: 0.6931, data_loss: 0.6931
step 1060 , total_loss: 0.6982, data_loss: 0.6982
step 1080 , tota

step 3980 , total_loss: 0.7167, data_loss: 0.7167
step 4000 , total_loss: 0.7948, data_loss: 0.7948
step 4020 , total_loss: 0.7237, data_loss: 0.7237
step 4040 , total_loss: 0.7902, data_loss: 0.7902
step 4060 , total_loss: 0.7031, data_loss: 0.7031
step 4080 , total_loss: 0.7055, data_loss: 0.7055
step 4100 , total_loss: 0.8386, data_loss: 0.8386
step 4120 , total_loss: 0.7812, data_loss: 0.7812
step 4140 , total_loss: 0.7997, data_loss: 0.7997
step 4160 , total_loss: 0.6990, data_loss: 0.6990
step 4180 , total_loss: 0.6967, data_loss: 0.6967
step 4200 , total_loss: 0.7517, data_loss: 0.7517
step 4220 , total_loss: 0.7147, data_loss: 0.7147
step 4240 , total_loss: 0.7338, data_loss: 0.7338
step 4260 , total_loss: 0.6725, data_loss: 0.6725
step 4280 , total_loss: 0.7220, data_loss: 0.7220
step 4300 , total_loss: 0.7603, data_loss: 0.7603
step 4320 , total_loss: 0.7409, data_loss: 0.7409
step 4340 , total_loss: 0.7611, data_loss: 0.7611
step 4360 , total_loss: 0.7769, data_loss: 0.7769


step 7260 , total_loss: 0.6927, data_loss: 0.6927
step 7280 , total_loss: 0.7610, data_loss: 0.7610
step 7300 , total_loss: 0.7160, data_loss: 0.7160
step 7320 , total_loss: 0.6736, data_loss: 0.6736
step 7340 , total_loss: 0.7250, data_loss: 0.7250
step 7360 , total_loss: 0.6634, data_loss: 0.6634
step 7380 , total_loss: 0.7682, data_loss: 0.7682
step 7400 , total_loss: 0.7618, data_loss: 0.7618
step 7420 , total_loss: 0.6993, data_loss: 0.6993
step 7440 , total_loss: 0.7814, data_loss: 0.7814
step 7460 , total_loss: 0.7028, data_loss: 0.7028
step 7480 , total_loss: 0.7425, data_loss: 0.7425
step 7500 , total_loss: 0.7436, data_loss: 0.7436
step 7520 , total_loss: 0.8060, data_loss: 0.8060
step 7540 , total_loss: 0.7304, data_loss: 0.7304
step 7560 , total_loss: 0.7392, data_loss: 0.7392
step 7580 , total_loss: 0.6730, data_loss: 0.6730
step 7600 , total_loss: 0.8157, data_loss: 0.8157
step 7620 , total_loss: 0.7136, data_loss: 0.7136
step 7640 , total_loss: 0.7590, data_loss: 0.7590


step 2660 , total_loss: 0.7204, data_loss: 0.7204
step 2680 , total_loss: 0.7621, data_loss: 0.7621
step 2700 , total_loss: 0.7473, data_loss: 0.7473
step 2720 , total_loss: 0.6815, data_loss: 0.6815
step 2740 , total_loss: 0.7304, data_loss: 0.7304
step 2760 , total_loss: 0.7403, data_loss: 0.7403
step 2780 , total_loss: 0.6658, data_loss: 0.6658
step 2800 , total_loss: 0.6986, data_loss: 0.6986
step 2820 , total_loss: 0.6805, data_loss: 0.6805
step 2840 , total_loss: 0.7905, data_loss: 0.7905
step 2860 , total_loss: 0.7200, data_loss: 0.7200
step 2880 , total_loss: 0.7194, data_loss: 0.7194
step 2900 , total_loss: 0.7234, data_loss: 0.7234
step 2920 , total_loss: 0.7218, data_loss: 0.7218
step 2940 , total_loss: 0.7273, data_loss: 0.7273
step 2960 , total_loss: 0.6762, data_loss: 0.6762
step 2980 , total_loss: 0.7421, data_loss: 0.7421
step 3000 , total_loss: 0.7678, data_loss: 0.7678
step 3020 , total_loss: 0.7559, data_loss: 0.7559
step 3040 , total_loss: 0.6873, data_loss: 0.6873


step 5940 , total_loss: 0.7256, data_loss: 0.7256
step 5960 , total_loss: 0.7189, data_loss: 0.7189
step 5980 , total_loss: 0.7114, data_loss: 0.7114
step 6000 , total_loss: 0.7430, data_loss: 0.7430
step 6020 , total_loss: 0.7731, data_loss: 0.7731
step 6040 , total_loss: 0.8101, data_loss: 0.8101
step 6060 , total_loss: 0.7935, data_loss: 0.7935
step 6080 , total_loss: 0.6630, data_loss: 0.6630
step 6100 , total_loss: 0.6714, data_loss: 0.6714
step 6120 , total_loss: 0.6444, data_loss: 0.6444
step 6140 , total_loss: 0.7444, data_loss: 0.7444
step 6160 , total_loss: 0.7666, data_loss: 0.7666
step 6180 , total_loss: 0.7545, data_loss: 0.7545
step 6200 , total_loss: 0.6957, data_loss: 0.6957
step 6220 , total_loss: 0.7217, data_loss: 0.7217
step 6240 , total_loss: 0.6943, data_loss: 0.6943
step 6260 , total_loss: 0.6813, data_loss: 0.6813
step 6280 , total_loss: 0.7276, data_loss: 0.7276
step 6300 , total_loss: 0.7135, data_loss: 0.7135
step 6320 , total_loss: 0.7767, data_loss: 0.7767


step 1340 , total_loss: 0.6918, data_loss: 0.6918
step 1360 , total_loss: 0.6968, data_loss: 0.6968
step 1380 , total_loss: 0.7668, data_loss: 0.7668
step 1400 , total_loss: 0.6626, data_loss: 0.6626
step 1420 , total_loss: 0.6775, data_loss: 0.6775
step 1440 , total_loss: 0.6931, data_loss: 0.6931
step 1460 , total_loss: 0.7642, data_loss: 0.7642
step 1480 , total_loss: 0.7178, data_loss: 0.7178
step 1500 , total_loss: 0.7585, data_loss: 0.7585
step 1520 , total_loss: 0.7060, data_loss: 0.7060
step 1540 , total_loss: 0.7125, data_loss: 0.7125
step 1560 , total_loss: 0.7017, data_loss: 0.7017
step 1580 , total_loss: 0.7433, data_loss: 0.7433
step 1600 , total_loss: 0.7332, data_loss: 0.7332
step 1620 , total_loss: 0.7079, data_loss: 0.7079
step 1640 , total_loss: 0.7602, data_loss: 0.7602
step 1660 , total_loss: 0.7502, data_loss: 0.7502
step 1680 , total_loss: 0.6621, data_loss: 0.6621
step 1700 , total_loss: 0.7462, data_loss: 0.7462
step 1720 , total_loss: 0.7111, data_loss: 0.7111


step 4620 , total_loss: 0.7641, data_loss: 0.7641
step 4640 , total_loss: 0.7353, data_loss: 0.7353
step 4660 , total_loss: 0.6632, data_loss: 0.6632
step 4680 , total_loss: 0.7638, data_loss: 0.7638
step 4700 , total_loss: 0.6947, data_loss: 0.6947
step 4720 , total_loss: 0.6969, data_loss: 0.6969
step 4740 , total_loss: 0.6935, data_loss: 0.6935
step 4760 , total_loss: 0.7950, data_loss: 0.7950
step 4780 , total_loss: 0.7760, data_loss: 0.7760
step 4800 , total_loss: 0.7151, data_loss: 0.7151
step 4820 , total_loss: 0.7991, data_loss: 0.7991
step 4840 , total_loss: 0.7564, data_loss: 0.7564
step 4860 , total_loss: 0.7063, data_loss: 0.7063
step 4880 , total_loss: 0.7511, data_loss: 0.7511
step 4900 , total_loss: 0.7720, data_loss: 0.7720
step 4920 , total_loss: 0.7264, data_loss: 0.7264
step 4940 , total_loss: 0.7086, data_loss: 0.7086
step 4960 , total_loss: 0.7330, data_loss: 0.7330
step 4980 , total_loss: 0.7477, data_loss: 0.7477
step 5000 , total_loss: 0.6629, data_loss: 0.6629


[(1, {'auc': 0.8593, 'logloss': 0.5418, 'mean_mrr': 0.7493, 'group_auc': 0.8489, 'ndcg@1': 0.559, 'ndcg@3': 0.7915, 'ndcg@5': 0.8133, 'hit@1': 0.559, 'hit@3': 0.9479, 'hit@5': 1.0}), (2, {'auc': 0.8684, 'logloss': 0.6311, 'mean_mrr': 0.7586, 'group_auc': 0.8566, 'ndcg@1': 0.5733, 'ndcg@3': 0.8005, 'ndcg@5': 0.8202, 'hit@1': 0.5733, 'hit@3': 0.9529, 'hit@5': 1.0}), (3, {'auc': 0.8708, 'logloss': 0.7495, 'mean_mrr': 0.7611, 'group_auc': 0.8586, 'ndcg@1': 0.5772, 'ndcg@3': 0.8028, 'ndcg@5': 0.8221, 'hit@1': 0.5772, 'hit@3': 0.9539, 'hit@5': 1.0}), (4, {'auc': 0.8722, 'logloss': 0.868, 'mean_mrr': 0.7634, 'group_auc': 0.8604, 'ndcg@1': 0.5809, 'ndcg@3': 0.8049, 'ndcg@5': 0.8239, 'hit@1': 0.5809, 'hit@3': 0.9547, 'hit@5': 1.0}), (5, {'auc': 0.8734, 'logloss': 1.0367, 'mean_mrr': 0.7645, 'group_auc': 0.8609, 'ndcg@1': 0.5831, 'ndcg@3': 0.8056, 'ndcg@5': 0.8246, 'hit@1': 0.5831, 'hit@3': 0.9545, 'hit@5': 1.0}), (6, {'auc': 0.8739, 'logloss': 1.1299, 'mean_mrr': 0.7657, 'group_auc': 0.8617, 'n

### Evaluate the performance after training

In [12]:
model.run_eval(test_file, num_ngs=test_num_ngs)

{'auc': 0.8731,
 'logloss': 2.2132,
 'mean_mrr': 0.6169,
 'group_auc': 0.8618,
 'ndcg@1': 0.3867,
 'ndcg@3': 0.6416,
 'ndcg@5': 0.6917,
 'hit@1': 0.3867,
 'hit@3': 0.8189,
 'hit@5': 0.9398}