In [None]:
from utilities.plot_utils import *
def learn(agent):
    agent.train(1000)
    plot_episode_rewards(agent.episode_rewards, 'BipedalWalker: Episode Rewards - Training')
    plt.show()
    plot_actor_critic_losses(agent.actor_losses, agent.critic_losses)
    plt.show()
    agent.save_model_weights("bipedalwalker", "td3")

def test(agent):
    agent.test(10)
    plot_episode_rewards(agent.test_rewards, 'BipedalWalker: Episode Rewards - Testing')
    plt.show()
    # agent.save_plot_data("highway", "ddpg", "test_rewards", agent.test_rewards)

def visualize(agent, env):
    agent.change_environment(env)
    agent.visualize(1)

In [None]:
import gymnasium as gym
from td3.td3 import td3

params = {
    'gamma': 0.99,              # discount factor
    'alpha_critic': 0.0001,     # learning rate for critic net
    'alpha_actor': 0.001,       # learning rate for actor network
    'buffer_size': 10000,       # Size of Replay Buffer
    'batch_size' : 100,         # Mini Batch Size for Back Prop
    'tau':         0.001,       # Percentage of new values to update target network with
    'update_rate': 100,         # Update network every n steps
    'noise_scale' : 0.1,        # Standard Deviation of Gausian Noise
    "actor_update_frequency": 2,# How often to update the actor network, every n updates
}

env = gym.make("BipedalWalker-v3")
agent = td3(env, params)

In [None]:
# Fill buffer
state, info = env.reset()
for _ in range(10000):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, info = env.step(action)
    agent.buffer.add(state, action, reward, next_state, terminated)
    state = next_state

In [None]:
learn(agent)

In [None]:
plt.plot(agent.noise_arr)
plt.xlabel('Time Step')
plt.ylabel('Noise Value')
plt.title('Noise for TD3 Algorithm')
plt.show()

In [None]:
test(agent)

In [None]:
from utilities.plot_utils import *
# plot_data = load_plot_data("pendulum", "td3", "plot_data")
plot_dict = {
        "episode_rewards" : agent.episode_rewards,
        "actor_losses"    : agent.actor_losses,
        "critic_losses"   : agent.critic_losses,
        "test_rewards"    : agent.test_rewards
    }
save_plot_data("ant", "td3", "plot_data", plot_dict)

In [None]:
env =  gym.make("BipedalWalker-v3", render_mode='human')
visualize(agent, env)