In [1]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2

In [2]:
import sys
sys.path.append("..")

In [3]:
from typing import Optional,List
from tqdm.notebook import tqdm
import datetime
import os
import copy
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pickle

from recs.dataset import session_parallel_dataset
from recs.evaluator import metrics

import tensorflow as tf
from tensorflow import keras as tfk
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd


In [4]:
class QNet(tfk.layers.Layer):
    def __init__(
        self,
        num_items:int,
        seq_len:Optional[int]=3,
        hidden_dim:Optional[int]=100,
        embed_dim:Optional[int]=100,
        dropout_rate:Optional[int]=0.5,
        activation=None
        name="QNet"
    ):
        super(QNet, self).__init__(name=name)
        
        self._embedding = tfk.layers.Embedding(num_items, embed_dim, mask_zero=True)
        self._gru = tfk.layers.GRU(
            hidden_dim, 
            dropout=dropout_rate)

        self._qvalue_dense = tfk.layers.Dense(num_items, activation=activation)
    
    def call(
        self, 
        item_seqs:tf.Tensor, # (batch_size, seq_len)
        training:Optional[bool]=False,
    ):
        x = self._embedding(item_seqs)
        x = self._gru(x, training=training)
        out = self._qvalue_dense(x)
        return out

In [None]:
class ActorCritic(tfk.Model):
    
    def __init__(
        self,
        num_items:int,
        seq_len:Optional[int]=3,
        hidden_dim:Optional[int]=100,
        embed_dim:Optional[int]=100,
        dropout_rate:Optional[int]=0.5,
        name="ActorCritic"
    ):
        self._num_items = num_items
        self._seq_len = seq_len
        self._loss_tracker = tfk.metrics.Mean(name="loss")
        
        self._actor = QNet(
            num_items,
            seq_len,
            hidden_dim,
            embed_dim,
            dropout_rate, 
            activation="softmax", 
            name="Actor")
        self._critic = QNet(
            1,
            seq_len,
            hidden_dim,
            embed_dim,
            dropout_rate, 
            name="Critic")
        
        dummy_state = tf.zeros((1, seq_len), dtype=tf.int32)
        self._actor(dummy_state)
        self._critic(dummy_state)
        self._target_actor(dummy_state)
        self._target_critic(dummy_state)
    
    def call(self, state):
        return self._actor(state)
    
        
    def train_step(self, data):
        state, action, reward, n_state, done = data
        onehot_act = tf.one_hot(action-1, depth=self._num_items)
        
        with tf.GradientTape() as tape: