In [1]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2

In [2]:
import sys
sys.path.append("..")

from typing import Optional,List
from tqdm.notebook import tqdm
import datetime
import os
import copy
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pickle

from recs.dataset import session_parallel_dataset
from recs.evaluator import metrics

import tensorflow as tf
from tensorflow import keras as tfk
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd


In [14]:
class GRUKNN(tfk.Model):
    
    def __init__(
        self,
        num_items:int,
        seq_len:Optional[int]=3,
        embed_dim:Optional[int]=100,
        dropout_rate:Optional[float]=0.5,
        k:Optional[int]=20,
        name="GRUKNN"
    ):
        super(GRUKNN, self).__init__(name=name)
        self._topk = k
        self._num_items = num_items
        self._embedding = tfk.layers.Embedding(num_items, embed_dim, mask_zero=True)
        self._gru = tfk.layers.GRU(
            embed_dim, 
            dropout=dropout_rate)
        
        self._recall_tracker = tfk.metrics.Recall(name="recall")
        
    def call(
        self, 
        item_seqs:tf.Tensor,
        training:Optional[bool]=False
    ):
        
        x = self._embedding(item_seqs)
        x = self._gru(x, training=training) # (batch_size, embed_dim)
        
        items = self._embedding.weights[0] # (num_items, embed_dim)
        out = x @ tf.transpose(items) # (batch_size, num_items)
        return out
    
    def test_step(self, data):
        state, target = data
        target = tf.one_hot(target, depth=self._num_items)
        target = tf.cast(target, dtype=tf.int32)

        qvalue = self(state)
        topkitem = tf.math.top_k(qvalue, k=self._topk)[1]
        topkitem = tf.reduce_sum(tf.one_hot(topkitem, depth=self._num_items), axis=1)
        topkitem = tf.cast(topkitem, dtype=tf.int32)
        
        self._recall_tracker.update_state(target, topkitem)
        
        return {"recall":self._recall_tracker.result()}

In [15]:
dataname="diginetica"
modelname = "GRUKNN"
default_logdir = "/home/inoue/work/recs/"
log_dir =  os.path.join(default_logdir, "logs/%s/%s/"%(dataname, modelname)+datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
train = pickle.load(open(
    "/home/inoue/work/dataset/%s/derived/mdp_train.df"%dataname, "rb"
))

split_ind = int(len(train[0])*0.9)
data = pd.read_pickle("~/work/dataset/%s/derived/train.df"%dataname)
testdata = pd.read_pickle("~/work/dataset/%s/derived/test.df"%dataname)

num_items = max(data.itemId.max(), testdata.itemId.max())+1
emb_dim = 64
hidden_dim = 64
seq_len = train[1].shape[1]
batch_size=500

train_data = tf.data.Dataset.from_tensor_slices(
    (train[1][:split_ind, :],
     train[2][:split_ind]-1)).shuffle(len(train[0][:split_ind])).batch(batch_size)
valid_data = tf.data.Dataset.from_tensor_slices(
    (train[1][split_ind:, :],
     train[2][split_ind:]-1)
).shuffle(len(train[0][split_ind:])).batch(batch_size)

In [16]:
model = GRUKNN(num_items, seq_len, emb_dim, dropout_rate=0.1)
model.compile(
    loss=tfk.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tfk.optimizers.Adam(learning_rate=0.01))

# model.build(input_shape=(1,seq_len))

In [17]:
model.fit(
    train_data, 
    epochs=100, 
    validation_data=valid_data,
    validation_freq=1,
    callbacks=[
        tfk.callbacks.TensorBoard(log_dir=log_dir), 
        tfk.callbacks.ModelCheckpoint(
            filepath=os.path.join(default_logdir, "params/%s/checkpoint"%modelname),
            save_weights_only=True,
            monitor="val_recall",
            mode="max",
            save_best_only=True
        ),
        tfk.callbacks.EarlyStopping(
            monitor="val_recall",
            min_delta=1e-4,
            patience=3,
            mode="max",
            verbose=1
        )
    ]
)

Epoch 1/100
Epoch 2/100

KeyboardInterrupt: 