In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gymnasium as gym
from tqdm import tqdm
import matplotlib.pyplot as plt

from rl_hockey.sac import SAC

In [None]:
env_name = 'Pendulum-v1'
# env_name = 'LunarLanderContinuous-v3'

In [None]:
env = gym.make(env_name)
env = gym.wrappers.RescaleAction(env, min_action=-1.0, max_action=1.0)

o_space = env.observation_space
ac_space = env.action_space

In [None]:
max_episodes = 100
max_episode_steps = 500
updates_per_step = 1

In [None]:
agent = SAC(o_space.shape[0], action_dim=ac_space.shape[0], noise='pink', max_episode_steps=max_episode_steps)

In [None]:
critic_losses = []
actor_losses = []
rewards = []
gradient_steps = 0

In [None]:
pbar = tqdm(range(max_episodes), desc=env_name)
for i in pbar:    
    total_reward = 0
    state, _ = env.reset()

    agent.on_episode_start(i)

    for t in range(max_episode_steps):
        done = False
        action = agent.act(state)
        (next_state, reward, done, trunc, _) = env.step(action)
        agent.store_transition((state, action, reward, next_state, done))            
        state = next_state

        stats = agent.train(updates_per_step)

        gradient_steps += updates_per_step
        total_reward += reward
        critic_losses.extend(stats['critic_loss'])
        actor_losses.extend(stats['actor_loss'])

        if done or trunc:
            break

    agent.on_episode_end(i)

    rewards.append(total_reward)    
    
    pbar.set_postfix({
        'total_reward': total_reward
    })

agent.save(f'../../../models/sac/{env_name}_{gradient_steps//1000}k.pt')

In [None]:
def moving_average(data, window_size):
    return [sum(data[max(0, i - window_size + 1):i + 1]) / (min(i + 1, window_size)) for i in range(len(data))]

In [None]:
plt.plot(moving_average(rewards, 10))
plt.xlabel('Episodes')
plt.ylabel('Total Reward')
plt.title('Total Reward per Episode')
plt.show()

In [None]:
plt.plot(moving_average(critic_losses, 100))
plt.xlabel('Training Steps')
plt.ylabel('Critic Loss')
plt.title('Critic Loss over Time')
plt.show()

In [None]:
plt.plot(moving_average(actor_losses, 100))
plt.xlabel('Training Steps')
plt.ylabel('Actor Loss')
plt.title('Actor Loss over Time')
plt.show()

In [None]:
env = gym.make(env_name, render_mode='human')
env = gym.wrappers.RescaleAction(env, min_action=-1.0, max_action=1.0)

In [None]:
total_reward = 0
state, _ = env.reset()
for t in range(max_episode_steps):
    done = False
    action = agent.act(state, deterministic=True)
    (next_state, reward, done, trunc, _) = env.step(action)
    state = next_state

    total_reward += reward

    if done or trunc:
        break

print(f'total_reward: {total_reward}')

In [None]:
env.close()