In [1]:
import sys
sys.path.append("../src")
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/networks.py
import sys
sys.path.append("../src")
from replay_buffer import *
from config import *
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Concatenate
from tensorflow.keras.initializers import random_normal

In [3]:
%%write_and_run -a ../src/networks.py

import os
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Dense

class Critic(keras.Model):
    def __init__(self, fc1_dims=512, fc2_dims=512,
            name='critic'):
        super(Critic, self).__init__()
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims

        self.model_name = name

        self.fc1 = Dense(self.fc1_dims, activation='relu')
        self.fc2 = Dense(self.fc2_dims, activation='relu')
        self.q = Dense(1, activation=None)

    def call(self, state, action):
        action_value = self.fc1(tf.concat([state, action], axis=1))
        action_value = self.fc2(action_value)

        q = self.q(action_value)

        return q

class Actor(keras.Model):
    def __init__(self, fc1_dims=512, fc2_dims=512, n_actions=2, name='actor'):
        super(Actor, self).__init__()
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions

        self.model_name = name

        self.fc1 = Dense(self.fc1_dims, activation='relu')
        self.fc2 = Dense(self.fc2_dims, activation='relu')
        self.mu = Dense(self.n_actions, activation='tanh')

    def call(self, state):
        prob = self.fc1(state)
        prob = self.fc2(prob)

        mu = self.mu(prob)

        return mu

In [4]:
import gym
import numpy as np

In [5]:
env = gym.make(ENV_NAME)
upper_bound = env.action_space.high[0]
action_dim = env.action_space.shape[0]
actor = Actor(action_dim, upper_bound)

In [6]:
upper_bound

2.0

In [7]:
a = Actor()
c = Critic()

In [8]:
env = gym.make(ENV_NAME)

In [9]:
rb = ReplayBuffer(env)

In [10]:
env.reset()

array([0.72441326, 0.68936596, 0.16871488])

In [11]:
action = np.array([-0.5])

In [12]:
state, reward, done, _ = env.step(action)

In [13]:
for i in range(1000):
    rb.add_record(state, action, reward, state, done)

In [14]:
state, action, reward, next_state, done = rb.get_minibatch()

In [15]:
a(state)

<tf.Tensor: shape=(64, 2), dtype=float32, numpy=
array([[-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
       [-0.00344271,  0.02177251],
      

In [16]:
c(state, action)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
       [-0.03718539],
     