In [1]:
import sys
sys.path.append("../src")
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/networks.py
import sys
sys.path.append("../src")
from replay_buffer import *
from config import *
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.initializers import random_uniform
import tensorflow_probability as tfp

In [3]:
standard_normal = tfp.distributions.Normal(0, 1)

In [4]:
%%write_and_run -a ../src/networks.py

class Critic(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
        super(Critic, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.q_value = Dense(1, activation=None)

    def call(self, state, action):
        state_action_value = self.dense_0(tf.concat([state, action], axis=1))
        state_action_value = self.dense_1(state_action_value)

        q_value = self.q_value(state_action_value)

        return q_value
    
class CriticValue(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
        super(CriticValue, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.value = Dense(1, activation=None)

    def call(self, state):
        state_value = self.dense_0(state)
        state_value = self.dense_1(state_value)

        value = self.value(state_value)

        return value

class Actor(tf.keras.Model):
    def __init__(self, name, actions_dim, upper_bound, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1, noise=NOISE, log_std_min=LOG_STD_MIN, log_std_max=LOG_STD_MAX):
        super(Actor, self).__init__()
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1
        self.actions_dim = actions_dim
        self.upper_bound = upper_bound
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.noise = noise
        
        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.mean = Dense(self.actions_dim, activation=None)
        self.log_std = Dense(self.actions_dim, activation=None)

    def call(self, state):
        x = self.dense_0(state)
        policy = self.dense_1(x)
        mean = self.mean(policy)
        log_std = self.log_std(policy)
        log_std = tf.clip_by_value(log_std, self.log_std_min, self.log_std_max)

        return mean, log_std
    
    def evaluate(self, state, reparameterization=False):
        mean, log_std = self.call(state)
        std = tf.exp(log_std)
        standard_normal = tfp.distributions.Normal(0, 1)
        normal_mean_std = tfp.distributions.Normal(mean, std)
        
        if reparameterization:
            action = tf.math.tanh(mean + std * standard_normal.sample(sample_shape=mean.shape))
        else:
            action = tf.math.tanh(normal_mean_std.sample())
        
        log_probs = normal_mean_std.log_prob(action) - tf.math.log(1 - tf.math.pow(action, 2) + self.noise)
        log_probs = tf.math.reduce_sum(log_probs, axis=1, keepdims=True)
        
        return action * self.upper_bound, log_probs

In [5]:
import gym
import numpy as np

In [6]:
env = gym.make(ENV_NAME)
upper_bound = env.action_space.high[0]
lower_bound = env.action_space.low[0]
action_dim = env.action_space.shape[0]
actor = Actor("name", action_dim, upper_bound)
critic = Critic("name")
critic_value = CriticValue("name")

In [7]:
env = gym.make(ENV_NAME)

In [8]:
rb = ReplayBuffer(env)

In [9]:
env.reset()

array([-0.63902324, -0.76918743,  0.86220174])

In [10]:
env.action_space

Box(-2.0, 2.0, (1,), float32)

In [11]:
action = np.array([-0.5])

In [12]:
lower_bound

-2.0

In [13]:
state, reward, done, _ = env.step(action)

In [14]:
for i in range(1000):
    rb.add_record(state, action, reward, state, done)

In [15]:
state, action, reward, next_state, done = rb.get_minibatch()

In [16]:
actor.upper_bound

2.0

In [17]:
action, log_probs = actor.evaluate(state, False)

In [18]:
log_probs

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[-0.93016297],
       [-0.87399185],
       [ 0.5622529 ],
       [-0.07912588],
       [-0.73927826],
       [-0.9283955 ],
       [-0.8197799 ],
       [ 1.3096818 ],
       [ 0.12989092],
       [-0.7514427 ],
       [ 0.21941304],
       [-0.89378655],
       [-0.85712874],
       [-0.18797326],
       [-0.9382204 ],
       [ 0.7751658 ],
       [ 0.4623152 ],
       [-0.37615317],
       [-0.82767   ],
       [-0.28302056],
       [-0.9386636 ],
       [ 0.38335252],
       [-0.6808816 ],
       [ 1.9553834 ],
       [-0.09767842],
       [-0.89835256],
       [-0.82583797],
       [-0.28303808],
       [ 0.06315625],
       [-0.87687254],
       [ 0.05845022],
       [-0.62154895],
       [-0.92982876],
       [ 0.7052827 ],
       [-0.8759372 ],
       [-0.17257178],
       [-0.08468437],
       [-0.9077453 ],
       [-0.7891208 ],
       [ 0.10489798],
       [-0.6896508 ],
       [-0.8994697 ],
       [-0.38904023],
     

In [19]:
critic(state, action)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[ 0.00319149],
       [-0.07036109],
       [-0.05290864],
       [-0.14233482],
       [-0.09982853],
       [ 0.00455765],
       [ 0.00796022],
       [-0.05818208],
       [-0.04648896],
       [-0.0982159 ],
       [-0.048218  ],
       [-0.06142227],
       [-0.07640227],
       [-0.1388068 ],
       [-0.00869247],
       [-0.15915501],
       [-0.05177684],
       [-0.13101643],
       [ 0.0082804 ],
       [-0.1351302 ],
       [-0.01527549],
       [-0.15331878],
       [-0.10726266],
       [-0.16626868],
       [-0.1417462 ],
       [ 0.00995772],
       [-0.08535442],
       [-0.13512947],
       [-0.14643598],
       [-0.06925236],
       [-0.04485992],
       [-0.01261769],
       [-0.02714475],
       [-0.05431669],
       [-0.06961582],
       [-0.13934085],
       [-0.14215294],
       [-0.05197422],
       [-0.09252449],
       [-0.14750066],
       [-0.10625383],
       [-0.05807236],
       [-0.13036594],
     

In [20]:
critic_value(state)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
       [0.00791488],
      