In [None]:
import numpy as np
import torch as tc
from torch import nn

import gymnasium as gym

from scaled_reward_learner import ScaledRewardLearner, _obs_to_state
from replay_buffer import ReplayBuffer

from tqdm import tqdm, trange

import mlflow
from mlflow.models import infer_signature
import mlflow.pytorch as mlflow_tc
import matplotlib.pyplot as plt

from pathlib import Path

DEVICE = "cpu"

In [None]:
# hyperparameters mostly taken from the original SAC paper
params = {
    # environment
    'env': {
        'id': "MountainCar-v0",
        'max_episode_steps': 1000,
    },

    # model architecture (excluding input & output size since these are determined by the environment)
    'architecture_policy_hidden': [256, nn.BatchNorm1d(256), 256],
}

EPOCH = 0
SHOW_WINDOW = True

In [None]:
# this cell is adapted from https://gymnasium.farama.org/
if SHOW_WINDOW:
    env_vis = gym.make(**params['env'], render_mode="human")
else:
    env_vis = gym.make(**params['env'])

# load agent from checkpoint
state_size = env_vis.observation_space.shape[-1] # type: ignore
action_size = int(env_vis.action_space.n) # type: ignore

agent_vis = ScaledRewardLearner.load("weights/srl_mountaincar", EPOCH, DEVICE)
agent_vis.eval()

# keep track of actions
actions = []

# Reset the environment to generate the first observation
observation, info = env_vis.reset(seed=42) 
for t in trange(1000):
    # this is where you would insert your policy
    with tc.no_grad():
        state = _obs_to_state(observation, DEVICE).unsqueeze(0)

        action = agent_vis.act(state).squeeze()
        # action = np.array([-2])

    # step (transition) through the environment with the action
    # receiving the next observation, reward and if the episode has terminated or truncated
    observation, reward, terminated, truncated, info = env_vis.step(action)

    # store action
    actions.append(float(action))

    # If the episode has ended then we can reset to start a new episode
    if terminated or truncated:
        observation, info = env_vis.reset()

env_vis.close()

In [None]:
env_vis.close()

In [None]:
plt.hist(actions, bins=50)
plt.show()

# Metric retrieval

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("MountainCar-v0")

client = mlflow.tracking.MlflowClient() # type: ignore

In [None]:
run_id = ""

get_history = lambda name: np.array(list(map(lambda x: x.value, client.get_metric_history(run_id, name))))

history_lq1 = get_history("loss_q1")
history_lq2 = get_history("loss_q2")
history_el = get_history("episode_length")
history_q = get_history("q")
history_reward = get_history("mean_reward")

In [None]:
plt.figure(figsize=(7, 5))

xmin = 0
xmax = len(history_q) - 1
plt.hlines([0], xmin, xmax, colors="black")
plt.plot(history_lq1, label="Q loss")
plt.plot(history_q, label="Q value")
plt.plot(history_reward, label="Reward")
# plt.plot(history_el > 200)
# plt.plot(history_reward)
# plt.semilogy()
# plt.ylim(38, 40)

plt.xlim(xmin, xmax)
plt.grid()
plt.legend()
plt.title("SAC training metrics")
plt.xlabel("Epochs")
plt.tight_layout()
plt.show()