In [None]:
import tensorflow as tf
import numpy as np
from buffer import Buffer
from ou_noise import OUActionNoise

In [None]:
def train_actor_critic(env, std_dev, actor_lr, critic_lr, gamma, tau, total_episodes):
    ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))

    actor_model = get_actor()
    critic_model = get_critic()

    target_actor = get_actor()
    target_critic = get_critic()

    target_actor.set_weights(actor_model.get_weights())
    target_critic.set_weights(critic_model.get_weights())

    critic_optimizer = tf.keras.optimizers.Adam(critic_lr)
    actor_optimizer = tf.keras.optimizers.Adam(actor_lr)

    buffer = Buffer(50000, 64)

    ep_reward_list = []
    avg_reward_list = []

    for ep in range(total_episodes):
        prev_state = env.reset()
        episodic_reward = 0

        print("Episode: "+ str(ep))
        while True:
            with tf.device('GPU:0'):
                tf_prev_state = tf.expand_dims(tf.convert_to_tensor(prev_state[0], dtype=tf.float32), 0)

                action = policy(tf_prev_state, ou_noise)

                state, reward, terminated, truncated, info = env.step(action)
                state = tf.expand_dims(state, 0)

                buffer.record((prev_state[0], action, reward, state))
                episodic_reward += reward

                buffer.learn()
                update_target(target_actor.variables, actor_model.variables, tau)
                update_target(target_critic.variables, critic_model.variables, tau)

                if terminated or truncated:
                    break

                prev_state = state

        ep_reward_list.append(episodic_reward)

        avg_reward = np.mean(ep_reward_list[-40:])
        print("Episode * {} * Avg Reward is ==> {}".format(ep, avg_reward))
        avg_reward_list.append(avg_reward)

    return ep_reward_list, avg_reward_list


In [None]:
env = gym.make('Pendulum-v0')
ep_reward_list, avg_reward_list = train_actor_critic(env, std_dev=0.2, actor_lr=0.001, critic_lr=0.002, gamma=0.99, tau=0.005, total_episodes=100)
