In [1]:
import sys
sys.path.append("../src")
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/networks.py
import sys
sys.path.append("../src")
from replay_buffer import *
from config import *
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.initializers import random_uniform

In [3]:
%%write_and_run -a ../src/networks.py

class Critic(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
        super(Critic, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.q_value = Dense(1, activation=None)

    def call(self, state, action):
        state_action_value = self.dense_0(tf.concat([state, action], axis=1))
        state_action_value = self.dense_1(state_action_value)

        q_value = self.q_value(state_action_value)

        return q_value

class Actor(tf.keras.Model):
    def __init__(self, name, actions_dim, upper_bound, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1, init_minval=INIT_MINVAL, init_maxval=INIT_MAXVAL):
        super(Actor, self).__init__()
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1
        self.actions_dim = actions_dim
        self.init_minval = init_minval
        self.init_maxval = init_maxval
        self.upper_bound = upper_bound
        
        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.policy = Dense(self.actions_dim, kernel_initializer=random_uniform(minval=self.init_minval, maxval=self.init_maxval), activation='tanh')

    def call(self, state):
        x = self.dense_0(state)
        policy = self.dense_1(x)
        policy = self.policy(policy)

        return policy * self.upper_bound

In [4]:
import gym
import numpy as np

In [20]:
env = gym.make(ENV_NAME)
upper_bound = env.action_space.high[0]
lower_bound = env.action_space.low[0]
action_dim = env.action_space.shape[0]
actor = Actor("name", action_dim, upper_bound)
critic = Critic("name")

In [21]:
env = gym.make(ENV_NAME)

In [22]:
rb = ReplayBuffer(env)

In [23]:
env.reset()

array([ 0.00653095,  1.4056934 ,  0.6614974 , -0.23232347, -0.00756093,
       -0.14983907,  0.        ,  0.        ], dtype=float32)

In [27]:
env.action_space

Box(-1.0, 1.0, (2,), float32)

In [28]:
action = np.array([-0.5, 0.5])

In [29]:
lower_bound

-1.0

In [30]:
state, reward, done, _ = env.step(action)

In [31]:
for i in range(1000):
    rb.add_record(state, action, reward, state, done)

In [32]:
state, action, reward, next_state, done = rb.get_minibatch()

In [35]:
actor.upper_bound

1.0

In [33]:
actor(state)

<tf.Tensor: shape=(64, 2), dtype=float32, numpy=
array([[ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
       [ 0.03692073, -0.04923367],
      

In [34]:
critic(state, action)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
       [0.08359692],
      