In [380]:
import tensorflow as tf
from tensorflow.keras.layers import Concatenate,Dense
import gym
import numpy as np
from typing import List
import tqdm
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tensorflow.keras.optimizers import Adam

In [381]:
env = gym.make('BipedalWalker-v3')
print(env.action_space.high)
print(env.action_space.low)
state_dim = env.observation_space.shape
action_dim = env.action_space.shape
print(state_dim[0],action_dim[0])

[1. 1. 1. 1.]
[-1. -1. -1. -1.]
24 4


In [382]:
def env_step(action: np.ndarray) -> List[np.ndarray]:
    state, reward, done, _ = env.step(action)
    return (
        state.astype(np.float64),
        np.array(reward, np.float32),
        np.array(done, np.int32)
    )


def tf_env_step(action: tf.Tensor):
    return tf.numpy_function(
        env_step, [action], [tf.float64, tf.float32, tf.int32]
    )


In [383]:
last_weight_init = tf.keras.initializers.RandomUniform(minval=-.003, maxval=.003)

In [384]:
bound = tf.constant(env.action_space.high)

In [426]:
# func approx for deterministic policy 
class Actor(tf.keras.Model):
    def __init__(self,action_shape):
        super(Actor,self).__init__()
        self.fc1 = Dense(400,activation='relu')
        self.fc2 = Dense(300,activation='relu')
        #fed through tanh to bound actions between (-1,1) 
        self.fc3 = Dense(action_shape,activation='tanh',kernel_initializer=last_weight_init)
    def call(self,inputs):
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        return x * bound
    
# func approx for Q(s,a)
class Critic(tf.keras.Model):
    def __init__(self,action_shape):
        super(Critic,self).__init__()
        self.state_fc1 = Dense(400,activation='relu')
        self.state_fc2 = Dense(300,activation='relu')
        
        self.action_fc1 = Dense(400,activation='relu')
        
        self.concat = Concatenate()

        self.out = Dense(1,kernel_initializer=last_weight_init)

    def call(self,inputs,training):
        [state,action] = inputs
        state_x = self.state_fc1(state)
        state_x = self.state_fc2(state_x)
        
        action_x = self.action_fc1(action)
        
        concat_x = self.concat([state_x,action_x])
        
        return self.out(concat_x)
        

In [499]:

data_spec =  (
    tf.TensorSpec([state_dim[0]], tf.float64, 'state'),
    tf.TensorSpec([action_dim[0]], tf.float32, 'action'),
    tf.TensorSpec([1], tf.float32, 'reward'),
    tf.TensorSpec([state_dim[0]], tf.float64, 'next_state'),
    tf.TensorSpec([1], tf.int32, 'done'),
)

batch_size = 1
max_length = int(1e6)

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec,
    batch_size=batch_size,
    max_length=max_length)
replay_buffer.num_frames()

<tf.Tensor: shape=(), dtype=int64, numpy=0>

In [495]:
actor = Actor(env.action_space.shape[0])
critic = Critic(env.action_space.shape[0])

actor_t_net = Actor(env.action_space.shape[0])
critic_t_net = Critic(env.action_space.shape[0])

actor_t_net.set_weights(actor.get_weights())
critic_t_net.set_weights(critic.get_weights())

max_episodes = 100
gamma = .99
max_steps = 10000
actor_lr = .0001
critic_lr = .0001

actor_optim = Adam(learning_rate=actor_lr)
critic_optim = Adam(learning_rate=critic_lr)

In [519]:
state = tf.constant(env.reset(),dtype=tf.float64)
transition
while True:
    action = actor(tf.expand_dims(state,0))
    next_state,reward,done = tf_env_step(tf.squeeze(action))
    transition = (state,tf.squeeze(action),tf.expand_dims(reward,0),next_state,tf.expand_dims(done,0))
    transition = tf.nest.map_structure(lambda t: tf.stack([t] * 1),
                                       transition)
    replay_buffer.add_batch(transition)
#     env.render()
    if tf.cast(done,tf.bool):
        break
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)
replay_buffer.add_batch(transition)


In [526]:
sample = replay_buffer.as_dataset()

In [516]:
@tf.function
def train_step(initial_state: tf.Tensor, gamma: float, max_steps: int, batch_size: int) -> tf.Tensor:
    
    state = initial_state
    state_shape = initial_state.shape
    
    reward_shape = (1,)
    episode_reward = tf.constant([0],dtype=tf.float32)
    episode_reward.set_shape(reward_shape)
    print(episode_reward.shape)
    for t in tf.range(max_steps):
        action = actor(tf.expand_dims(state,0))
        next_state,reward,done = tf_env_step(tf.squeeze(action))
        
        next_state.set_shape(state_shape)
        reward.set_shape(reward_shape)
        done.set_shape(reward_shape)

        transition = (state,tf.squeeze(action),reward,next_state,done)
        transition = tf.nest.map_structure(lambda t: tf.stack([t] * 1),
                                       transition)
        
        transitions_stored += 1
        replay_buffer.add_batch(transition)
        
        if transitions_stored >= tf.constant(batch_size):
            sample = replay_buffer.as_dataset(sample_batch_size=1,num_steps=1)
            iterator = iter(sample)
            (states,actions,rewards,next_states,dones),_ = iterator.next()
            with tf.GradientTape() as tape:
                td_target = rewards + gamma * critic_t_net([next_states,actions])
                td_pred = critic([states,actions])
                critic_loss = tf.reduce_mean(tf.math.square(td_pred-td_target))
            critic_grads = tape.gradient(critic_loss,critic.trainable_variables)
            critic_optim.apply_gradients(zip(critic_grads,critic.trainable_variables))
            
            with tf.GradientTape() as tape:
                actions = actor(states)
                action_values = critic([states,actions])
                actor_loss = -tf.math.reduce_mean(action_values)
            actor_grads = tape.gradient(actor_loss,actor.trainable_variables)
            actor_optim.apply_gradients(zip(actor_grads,actor.trainable_variables))
        
        state = next_state
        
        episode_reward += reward 
        episode_reward.set_shape(reward_shape)
    
        if tf.cast(done,tf.bool):
            break

    return episode_reward


In [517]:
running_reward = 0
reward_threshold = 195

with tqdm.trange(max_episodes) as t:
    for i in t:
        state = tf.constant(env.reset(),dtype=tf.float64)
        episode_reward = float(train_step(state,gamma,max_steps,64))
                
        running_reward = .99 * running_reward + .01 * episode_reward
    
        t.set_description(f"Episode {i}")
        t.set_postfix(episode_reward=episode_reward,running_reward=running_reward)
        
        if running_reward > reward_threshold:
            break

  0%|          | 0/100 [00:00<?, ?it/s]

(1,)
(1,)


  0%|          | 0/100 [00:01<?, ?it/s]


InvalidArgumentError:  Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [1], indices.shape [1], params.shape [1000000,1]
	 [[{{node while/body/_1/while/TFUniformReplayBuffer/ResourceScatterUpdate_3}}]] [Op:__inference_train_step_1110255]

Function call stack:
train_step
