In [1]:
from typing import Optional,List
from tqdm.notebook import tqdm
import copy

import gym
from collections import deque

import tensorflow as tf
from tensorflow import keras as tfk
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pickle

In [2]:
class QNet(tfk.Model):
    
    def __init__(
        self,
        num_actions:int,
        emb_dim:Optional[int]=100,
        hidden_units:Optional[int]=200,
        seq_len:Optional[int]=10,
        batch_size:Optional[int]=32
    ):
        super().__init__()
        
        self._num_actions = num_actions
        self._emb_dim = emb_dim
        self._hidden_units = hidden_units
        self._seq_len = seq_len
        
        self._dropout = tfk.layers.Dropout(0.5)
        self._emb_layer = tfk.layers.Embedding(self._num_actions, self._emb_dim)
        self._gru_layer = tfk.layers.GRU(self._hidden_units, input_shape=(self._seq_len, self._emb_dim))
        self._linear1 = tfk.layers.Dense(self._num_actions)
        #self._linear2 = tfk.layers.Dense(self._num_actions)
        
        
        dummy_state = tf.zeros((batch_size, seq_len), dtype=tf.int32)
        self(dummy_state)
    def call(self, state):
        x = self._emb_layer(state)
        x = self._gru_layer(x)
        x = self._dropout(x)
        score = self._linear1(x)
        #q = self._linear2(x)

        #return score, q
        return score

In [None]:
class G

In [3]:
class DQN(object):
    
    def __init__(
        self,
        num_actions:int,
        emb_dim:Optional[int]=100,
        hidden_units:Optional[int]=200,
        seq_len:Optional[int]=10,
        gamma:Optional[float]=1.,
        lr:Optional[float]=0.01,
        batch_size:Optional[int]=256
    ):
        self._num_actions = num_actions
        self._emb_dim = emb_dim
        self._hidden_units = hidden_units
        self._seq_len = seq_len
        #self._gamma = gamma
        
        self._qnet = QNet(num_actions, emb_dim, hidden_units, seq_len, batch_size)
        #self._target_qnet = copy.deepcopy(self._qnet)
        
        #self._tdloss = tfk.losses.Huber()
        self._loss = tfk.losses.SparseCategoricalCrossentropy()
        self._optim = tfk.optimizers.Adam(learning_rate=lr)
        

    def compute_target_sa_values(self, state, action):
        _, q_values = self._target_qnet(state)
        action_one_hot = tf.one_hot(action, depth=self._num_actions)
        return tf.reduce_sum(q_values*action_one_hot, axis=1)

    def compute_predict_sa_values(self, state, action):
        _, q_values = self._qnet(state)
        action_one_hot = tf.one_hot(action, depth=self._num_actions)
        return tf.reduce_sum(q_values*action_one_hot, axis=1)
    
    def update_params(self, tau=0.9):
        for param, tar_param in zip(self._qnet.trainable_variables, self._target_qnet.trainable_variables):
            tar_param.assign(param*tau + (1-tau)*tar_param)


    @tf.function
    def _train_step(self, batch):
        #_, q_target = self._target_qnet(batch["n_state"])
        #q_target = batch["reward"] + self._gamma * self.compute_target_sa_values(batch["n_state"], tf.argmax(q_target, axis=1))
        #q_target = tf.stop_gradient(q_target)
        
        with tf.GradientTape() as tape:
            #q_pred = self.compute_predict_sa_values(batch["state"], batch["action"])
            #score, _ = self._qnet(batch["state"])
            #td_loss = self._tdloss(q_target, q_pred)
            score = self._qnet(batch["state"])
            loss = self._loss(batch["action"], score)
            #loss += td_loss
        grad = tape.gradient(loss, self._qnet.trainable_variables)
        self._optim.apply_gradients(zip(grad, self._qnet.trainable_variables))
        return loss
        
        
    
    def fit(
        self,
        train_data:tf.data.Dataset,
        n_epochs=10,
        update_iter=10,
        tau=0.9
    ):
        losses = []
        best_loss = np.Inf
        stop_count = 0
        for epoch in range(n_epochs):
            batch_loss = 0.
            with tqdm(train_data, desc="[Epoch%d]"%(epoch+1)) as ts:
                for i, batch in enumerate(ts):
                    loss = self._train_step(batch)
                    batch_loss += loss
                    ts.set_postfix_str("Loss=%4f"%(batch_loss / (i+1)))
                    
                    #if (i+1)%update_iter == 0:
                        #self.update_params(tau)
                batch_loss /= (i+1)

            if batch_loss >= best_loss:
                stop_count += 1
            else:
                best_loss = batch_loss
                stop_count = 0
                
            if stop_count > 3:
                break 
            losses += [batch_loss.numpy()]
        return losses

In [4]:
train_mdp = pickle.load(open("/home/inoue/work/dataset/diginetica2/derived/mdp_train.df", "rb"))

In [5]:
train_data = tf.data.Dataset.from_tensor_slices(
    {
        "sess" : train_mdp[0].astype(np.int32),
        "state":train_mdp[1].astype(np.int32),
        "action":train_mdp[2].astype(np.int32),
        "reward":train_mdp[3].astype(np.float32),
        "n_state":train_mdp[4].astype(np.int32),
        "done":train_mdp[5].astype(np.float32)
    }
)

In [6]:
model = DQN(42171, seq_len=3, batch_size=500)

In [7]:
losses = model.fit(train_data.batch(500),n_epochs=100)

[Epoch1]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch2]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch3]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch4]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch5]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch6]:   0%|          | 0/718 [00:00<?, ?it/s]

[Epoch7]:   0%|          | 0/718 [00:00<?, ?it/s]

In [8]:
losses

[<tf.Tensor: shape=(), dtype=float32, numpy=10.697142>,
 <tf.Tensor: shape=(), dtype=float32, numpy=10.6494465>,
 <tf.Tensor: shape=(), dtype=float32, numpy=10.649439>,
 <tf.Tensor: shape=(), dtype=float32, numpy=10.649439>,
 <tf.Tensor: shape=(), dtype=float32, numpy=10.649439>,
 <tf.Tensor: shape=(), dtype=float32, numpy=10.649439>]