In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import hockey.hockey_env as h_env

from rl_hockey.sac import SAC
from rl_hockey.common import utils

In [None]:
env = h_env.HockeyEnv(mode=h_env.Mode.TRAIN_SHOOTING)

o_space = env.observation_space
ac_space = env.action_space

In [None]:
max_episodes = 500
max_episode_steps = 500
updates_per_step = 1
warmup_steps = 10000

In [None]:
agent = SAC(o_space.shape[0], action_dim=ac_space.shape[0], noise='pink', max_episode_steps=max_episode_steps)

In [None]:
critic_losses = []
actor_losses = []
rewards = []
steps = 0
gradient_steps = 0

In [None]:
run_name = 'hockey-shooting'

In [None]:
pbar = tqdm(range(max_episodes), desc=run_name)
for i in pbar:    
    total_reward = 0
    state, _ = env.reset()

    agent.on_episode_start(i)

    for t in range(max_episode_steps):
        done = False
        action = agent.act(state.astype(np.float32))
        (next_state, reward, done, trunc, _) = env.step(action)
        agent.store_transition((state, action, reward, next_state, done))
        agent.store_transition((utils.mirror_state(state), utils.mirror_action(action), reward, utils.mirror_state(next_state), done))          
        state = next_state

        steps += 1
        total_reward += reward

        if steps >= warmup_steps / 2:  # mirroring enables 2 transitions per step
            stats = agent.train(updates_per_step)

            gradient_steps += updates_per_step
            critic_losses.extend(stats['critic_loss'])
            actor_losses.extend(stats['actor_loss'])

        if done or trunc:
            break

    agent.on_episode_end(i)

    rewards.append(total_reward)    
    
    pbar.set_postfix({
        'total_reward': total_reward,
        'episode_length': t,
    })

agent.save(f'../../../models/sac/{run_name}_{gradient_steps//1000}k.pt')

In [None]:
def moving_average(data, window_size):
    moving_averages = []
    for i in range(len(data)):
        window_start = max(0, i - window_size + 1)
        window = data[window_start:i + 1]
        moving_averages.append(sum(window) / len(window))
    
    return moving_averages

In [None]:
plt.plot(moving_average(rewards, 10))
plt.xlabel('Episodes')
plt.ylabel('Total Reward')
plt.title('Total Reward per Episode')
plt.show()

In [None]:
plt.plot(moving_average(critic_losses, 100))
plt.xlabel('Training Steps')
plt.ylabel('Critic Loss')
plt.title('Critic Loss over Time')
plt.show()

In [None]:
plt.plot(moving_average(actor_losses, 100))
plt.xlabel('Training Steps')
plt.ylabel('Actor Loss')
plt.title('Actor Loss over Time')
plt.show()

In [None]:
env = h_env.HockeyEnv(mode=h_env.Mode.TRAIN_SHOOTING)

In [None]:
total_reward = 0
state, _ = env.reset()
for t in range(max_episode_steps):
    env.render(mode="human")

    done = False
    action = agent.act(state.astype(np.float32), deterministic=True)
    (next_state, reward, done, trunc, _) = env.step(action)
    state = next_state

    total_reward += reward

    if done or trunc:
        break

print(f'total_reward: {total_reward}')

In [None]:
env.close()