In [1]:
import sys
sys.path.append("../src/")

In [2]:
from plugin_write_and_run import *

In [3]:
%%write_and_run ../src/networks.py
import numpy as np
import json
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.initializers import random_uniform
import sys
sys.path.append("../src")
from config import *

In [4]:
%%write_and_run -a ../src/networks.py

class Critic(tf.keras.Model):
    def __init__(self, name, hidden_0=CRITIC_HIDDEN_0, hidden_1=CRITIC_HIDDEN_1):
            
        super(Critic, self).__init__()
        
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1

        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.q_value = Dense(1, activation=None)
    
    def call(self, state, actors_actions):
        state_action_value = self.dense_0(tf.concat([state, actors_actions], axis=1)) # multiple actions
        state_action_value = self.dense_1(state_action_value)

        q_value = self.q_value(state_action_value)

        return q_value

In [5]:
%%write_and_run -a ../src/networks.py

class Actor(tf.keras.Model):
    def __init__(self, name, actions_dim, hidden_0=ACTOR_HIDDEN_0, hidden_1=ACTOR_HIDDEN_1):
        super(Actor, self).__init__()
        self.hidden_0 = hidden_0
        self.hidden_1 = hidden_1
        self.actions_dim = actions_dim
        
        self.net_name = name

        self.dense_0 = Dense(self.hidden_0, activation='relu')
        self.dense_1 = Dense(self.hidden_1, activation='relu')
        self.policy = Dense(self.actions_dim, activation='sigmoid') # we want something beetween zero and one

    def call(self, state):
        x = self.dense_0(state)
        policy = self.dense_1(x)
        policy = self.policy(policy)
        return policy

In [6]:
from make_env import *
from replay_buffer import *

In [7]:
env = make_env(ENV_NAME)

In [8]:
rb = ReplayBuffer(env)

In [9]:
for i in range(100):
    actors_state, reward, done, info = env.step([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
    state = np.concatenate(actors_state)
    rb.add_record(actors_state, actors_state, [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], state, state, reward, done)

In [10]:
state, reward, next_state, done, actors_state, actors_new_state, actors_action = rb.get_minibatch()

In [11]:
state.shape, reward.shape, next_state.shape, done.shape

((64, 28), (64, 3), (64, 28), (64, 3))

In [12]:
env.observation_space

[Box(-inf, inf, (8,), float32),
 Box(-inf, inf, (10,), float32),
 Box(-inf, inf, (10,), float32)]

In [13]:
critic = Critic("critic")

In [14]:
np.array(actors_action).shape

(3, 64, 5)

In [15]:
np.concatenate(actors_action, axis=1).shape

(64, 15)

In [16]:
critic(state, np.concatenate(actors_action, axis=1))

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[-4.7636404 ],
       [-6.28298   ],
       [-6.66727   ],
       [-4.9551125 ],
       [-3.250637  ],
       [-8.474953  ],
       [-3.1774144 ],
       [-3.3704672 ],
       [-1.4903333 ],
       [-8.756027  ],
       [-3.508485  ],
       [-6.5711966 ],
       [-3.5904212 ],
       [-8.568591  ],
       [-0.98046714],
       [-1.7612805 ],
       [-2.4080536 ],
       [-8.002979  ],
       [-2.4860454 ],
       [-5.808099  ],
       [-7.3356767 ],
       [-4.280472  ],
       [-3.0281017 ],
       [-2.278727  ],
       [-8.381312  ],
       [-3.9159412 ],
       [-2.7959354 ],
       [-1.0576344 ],
       [-0.9235202 ],
       [-1.2391566 ],
       [-1.9515522 ],
       [-3.8342366 ],
       [-2.210513  ],
       [-7.4310045 ],
       [-0.8963884 ],
       [-2.3456807 ],
       [-2.5644724 ],
       [-4.003454  ],
       [-1.6751487 ],
       [-5.5269833 ],
       [-8.943575  ],
       [-3.7526863 ],
       [-4.5711875 ],
     

In [17]:
env.action_space[0].n

5

In [18]:
actor = Actor("actor", env.action_space[0].n)

In [19]:
actor(actors_state[0])

<tf.Tensor: shape=(64, 5), dtype=float32, numpy=
array([[0.47622827, 0.11387476, 0.5037373 , 0.26783657, 0.68668276],
       [0.47678116, 0.06080547, 0.5016193 , 0.2108672 , 0.7398893 ],
       [0.47717306, 0.05175623, 0.5005966 , 0.19808608, 0.75218403],
       [0.4758508 , 0.10545298, 0.50425446, 0.26023132, 0.69354343],
       [0.48342675, 0.22470403, 0.49762863, 0.34619972, 0.61740595],
       [0.47997385, 0.02366257, 0.4958104 , 0.14463973, 0.8049794 ],
       [0.48375463, 0.23252124, 0.4968544 , 0.35085416, 0.6135457 ],
       [0.4827053 , 0.20905048, 0.49864933, 0.33722255, 0.6257168 ],
       [0.52314854, 0.41540718, 0.49016726, 0.4589412 , 0.5287641 ],
       [0.4806851 , 0.0208607 , 0.49505585, 0.13721904, 0.81244516],
       [0.48199058, 0.19410524, 0.49940678, 0.3284037 , 0.6342562 ],
       [0.47707513, 0.05389142, 0.5008522 , 0.20122573, 0.7491473 ],
       [0.4816332 , 0.18694258, 0.49978557, 0.3240387 , 0.6384948 ],
       [0.480211  , 0.02269006, 0.49555916, 0.14212975