In [16]:
import sys
sys.path.append("../../")
import os
import numpy as np
import zipfile
from tqdm import tqdm
from tempfile import TemporaryDirectory
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # only show error messages

from iterator import MINDIterator
from utils import prepare_hparams
from utils import get_mind_data_set
from model import NRMSModel
from model import LSTURModel
from layer import cal_metric
iterator = MINDIterator
seed=40

In [18]:
hparams_nrms = prepare_hparams('data/utils/nrms.yaml', 
                          wordEmb_file='data/utils/embedding.npy',
                          wordDict_file='data/utils/word_dict.pkl', 
                          userDict_file='data/utils/uid2index.pkl',
                          batch_size=32,
                          show_step=10)

In [19]:
hparams_lstur = prepare_hparams('data/utils/lstur.yaml', 
                          wordEmb_file='data/utils/embedding_all.npy',
                          wordDict_file='data/utils/word_dict_all.pkl', 
                          userDict_file='data/utils/uid2index.pkl',
                          batch_size=32,
                          show_step=10)

In [35]:
model_nrms = NRMSModel(hparams_nrms, iterator, seed=seed)


Tensor("conv1d_1/Relu:0", shape=(?, 30, 400), dtype=float32)
Tensor("att_layer2_5/Sum_1:0", shape=(?, 400), dtype=float32)


In [None]:
model_nrms.model.summary()

In [None]:
model_nrms.scorer.summary()

In [21]:

pretrianed_metric,pretrianed_labels, pretrianed_preds=model_nrms.run_eval('data/valid/news.tsv', 'data/valid/behaviors.tsv')
print(pretrianed_metric)

586it [00:02, 222.40it/s]
236it [00:40,  5.80it/s]
7538it [00:01, 4032.97it/s]


{'group_auc': 0.517, 'mean_mrr': 0.2221, 'ndcg@5': 0.2296, 'ndcg@10': 0.2914}


In [22]:
status_nrms=model_nrms.scorer.load_weights(tf.train.latest_checkpoint("./data/nrms_3e-4/")).assert_existing_objects_matched()
status_nrms


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc1386fc358>

In [23]:
trained_metrices_nrms, group_labels_nrms, group_preds_nrms=model_nrms.run_eval('data/valid/news.tsv', 'data/valid/behaviors.tsv')
print(trained_metrices_nrms)

586it [00:02, 290.14it/s]
236it [00:38,  6.11it/s]
7538it [00:01, 4088.95it/s]


{'group_auc': 0.6005, 'mean_mrr': 0.2733, 'ndcg@5': 0.2977, 'ndcg@10': 0.3595}


In [24]:
model_lstur = LSTURModel(hparams_lstur, iterator, seed=seed)

pretrianed_metric_lstur,pretrianed_labels_lstur, pretrianed_preds_lstur, =model_lstur.run_eval('data/valid/news.tsv', 'data/valid/behaviors.tsv')
print(pretrianed_metric_lstur)

Tensor("conv1d/Relu:0", shape=(?, 30, 400), dtype=float32)
Tensor("att_layer2_2/Sum_1:0", shape=(?, 400), dtype=float32)


586it [00:02, 214.75it/s]
236it [00:27,  8.54it/s]
7538it [00:01, 3852.70it/s]


{'group_auc': 0.519, 'mean_mrr': 0.2209, 'ndcg@5': 0.2286, 'ndcg@10': 0.2908}


In [25]:
status_lstur=model_lstur.model.load_weights(tf.train.latest_checkpoint("./data/lstur_3e-4/")).assert_existing_objects_matched()
status_lstur


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc0f47425f8>

In [26]:
trained_metrices_lstur, group_labels_lstur, group_preds_lstur=model_lstur.run_eval('data/valid/news.tsv', 'data/valid/behaviors.tsv')
print(trained_metrices_lstur)

586it [00:01, 334.21it/s]
236it [00:26,  8.79it/s]
7538it [00:01, 4078.20it/s]


{'group_auc': 0.6108, 'mean_mrr': 0.2798, 'ndcg@5': 0.3026, 'ndcg@10': 0.3654}
