In [1]:
import sys
sys.path.append("../src")
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/networks.py
import sys
sys.path.append("../src")
from replay_buffer import *
from config import *
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.initializers import random_uniform
import tensorflow_probability as tfp

In [3]:
%%write_and_run -a ../src/networks.py

class Critic(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
        super(Critic, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.q_value = Dense(1, activation=None)

    def call(self, state, action):
        state_action_value = self.dense_0(tf.concat([state, action], axis=1))
        state_action_value = self.dense_1(state_action_value)

        q_value = self.q_value(state_action_value)

        return q_value
    
class CriticValue(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
        super(CriticValue, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.value = Dense(1, activation=None)

    def call(self, state):
        state_value = self.dense_0(state)
        state_value = self.dense_1(state_value)

        value = self.value(state_value)

        return value

class Actor(tf.keras.Model):
    def __init__(self, name, actions_dim, upper_bound, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1, noise=NOISE, log_std_min=LOG_STD_MIN, log_std_max=LOG_STD_MAX):
        super(Actor, self).__init__()
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1
        self.actions_dim = actions_dim
        self.upper_bound = upper_bound
        self.noise = noise
        
        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.mean = Dense(self.actions_dim, activation=None)
        self.std = Dense(self.actions_dim, activation=None)

    def call(self, state):
        policy = self.dense_0(state)
        policy = self.dense_1(policy)
        mean = self.mean(policy)
        std = self.std(policy)

        return mean, std
    
    def evaluate(self, state, reparameterization=False):
        mean, std = self.call(state)
        normal_mean_std = tfp.distributions.Normal(mean, std)
        
        if reparameterization:
            actions = normal_mean_std.sample()
        else:
            actions = normal_mean_std.sample()
            
        actions = tf.math.tanh(actions)*self.upper_bound
        
        log_probs = normal_mean_std.log_prob(actions) - tf.math.log(1 - tf.math.pow(actions, 2) + self.noise)
        log_probs = tf.math.reduce_sum(log_probs, axis=1, keepdims=True)
        
        return actions, log_probs

In [4]:
import gym
import numpy as np

In [5]:
env = gym.make(ENV_NAME)
upper_bound = env.action_space.high[0]
lower_bound = env.action_space.low[0]
action_dim = env.action_space.shape[0]
actor = Actor("name", action_dim, upper_bound)
critic = Critic("name")
critic_value = CriticValue("name")

In [6]:
env = gym.make(ENV_NAME)

In [7]:
rb = ReplayBuffer(env)

In [8]:
env.reset()

array([ 0.7620553 ,  0.64751194, -0.35071421])

In [9]:
env.action_space

Box(-2.0, 2.0, (1,), float32)

In [10]:
action = np.array([-0.5])

In [11]:
lower_bound

-2.0

In [12]:
state, reward, done, _ = env.step(action)

In [13]:
for i in range(1000):
    rb.add_record(state, action, reward, state, done)

In [14]:
state, action, reward, next_state, done = rb.get_minibatch()

In [15]:
actor.upper_bound

2.0

In [16]:
action, log_probs = actor.evaluate(state, False)

In [17]:
log_probs

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[-497.86215],
       [-610.0933 ],
       [-698.148  ],
       [-579.88385],
       [-716.94037],
       [-524.7273 ],
       [-597.0488 ],
       [-820.9631 ],
       [-504.02206],
       [-548.7138 ],
       [-641.68805],
       [-566.0127 ],
       [-715.3209 ],
       [-596.28595],
       [-605.1497 ],
       [-542.07526],
       [-622.1418 ],
       [-578.0609 ],
       [-686.9125 ],
       [-608.24963],
       [-672.1221 ],
       [-554.35345],
       [-721.1798 ],
       [-471.03934],
       [-491.364  ],
       [-516.0968 ],
       [-410.65094],
       [-624.62555],
       [-622.8024 ],
       [-666.7389 ],
       [-476.46133],
       [-675.2427 ],
       [-559.51245],
       [-651.2096 ],
       [-666.64825],
       [-601.14075],
       [-496.1668 ],
       [-686.37854],
       [-547.9794 ],
       [-525.89655],
       [-606.85156],
       [-571.5088 ],
       [-576.0764 ],
       [-571.0081 ],
       [-535.10815],
      

In [18]:
critic(state, action)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[0.01252018],
       [0.0126493 ],
       [0.01276033],
       [0.01261003],
       [0.01278349],
       [0.01254734],
       [0.01263246],
       [0.01290666],
       [0.01252819],
       [0.01257222],
       [0.01268939],
       [0.01259161],
       [0.01278151],
       [0.01263146],
       [0.01264295],
       [0.01256538],
       [0.01266471],
       [0.01260761],
       [0.01274632],
       [0.01264695],
       [0.01272772],
       [0.012578  ],
       [0.01278869],
       [0.01248466],
       [0.01251164],
       [0.01253935],
       [0.0124006 ],
       [0.01266789],
       [0.01266554],
       [0.01272089],
       [0.01249192],
       [0.01273166],
       [0.01258332],
       [0.01270127],
       [0.01272078],
       [0.01263776],
       [0.01251794],
       [0.01274564],
       [0.01257146],
       [0.01254856],
       [0.01264514],
       [0.01259894],
       [0.01260499],
       [0.01259825],
       [0.01255818],
      

In [19]:
critic_value(state)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
       [0.02027844],
      