In [None]:
import numpy as np
import torch as tc
from torch import nn

import gymnasium as gym

from agents import *
from envs import *
from train import obs_to_state

import mlflow
from tqdm import trange
import matplotlib.pyplot as plt
import toml

DEVICE = "cpu"

In [None]:
# select settings path
path_settings = "settings_pendulum.toml"

# read from settings
settings = toml.load(path_settings)
params = settings['parameters']
path_weight = params['save_path']
env_params = params['env']
Model: type[Learner] = eval(params['model_class'])
discretise: int | bool = params.get('discretise', False)

# select settings
EPOCH = 300
SHOW_WINDOW = True
env_params['max_episode_steps'] = 100
N_EPISODES = 5

In [None]:
# choose render mode
if SHOW_WINDOW:
    env_params['render_mode'] = "human"

# create environment generator
env_generator = (
        (lambda: DiscretiseAction(gym.make(**params['env']), n_actions=discretise))
    if isinstance(discretise, int) else
        (lambda: gym.make(**params['env']))
)

# create environment (using a SyncVectorEnv to be consistent with training; not necessary)
env_vis = gym.vector.SyncVectorEnv([env_generator])

# load agent from checkpoint
with tc.serialization.safe_globals([nn.BatchNorm1d]):
    agent_vis = Model.load(path_weight, EPOCH, DEVICE)
agent_vis.eval()

# create action lookup
n_actions = int(env_vis.action_space.nvec[0]) # type: ignore
actions_onehot = tc.eye(n_actions, dtype=tc.int, device=agent_vis.device)

# keep track of states & actions
states = []
actions = []

# reset environment
observation, info = env_vis.reset() 


tqdm_iter = trange(env_params['max_episode_steps'] * N_EPISODES)
for t in tqdm_iter:
    with tc.no_grad():
        state = obs_to_state(observation, DEVICE)

        action = agent_vis.act(state, actions_onehot)

    # update progress bar title
    tqdm_iter.set_description(f"Chose action {action}")

    # perform step
    observation, reward, terminated, truncated, info = env_vis.step(action)

    # store state & action
    states.append(state)
    actions.append(float(action.item()))

    # If the episode has ended then we can reset to start a new episode
    if terminated or truncated:
        observation, info = env_vis.reset()

env_vis.close()

In [None]:
v = tc.nn.utils.parameters_to_vector(agent_vis.parameters())

In [None]:
tc.nn.utils.vector_to_parameters(v, agent_vis.parameters())

In [None]:
env_vis.close()

In [None]:
plt.hist(actions, bins=50)
plt.xticks()
plt.show()

In [None]:
plt.plot(states)
plt.show()

# Metric retrieval

In [None]:
mlflow.set_tracking_uri("http://10.30.20.11:5000")
mlflow.set_experiment("jheis_SRL_MountainCar-v0")

client = mlflow.tracking.MlflowClient() # type: ignore

In [None]:
run_id = "f70604c2f06147de9d7577b04016208a"

get_history = lambda name: np.array(list(map(lambda x: x.value, client.get_metric_history(run_id, name))))

history_lq1 = get_history("loss_q1")
history_lq2 = get_history("loss_q2")
history_el = get_history("episode_length")
history_q = get_history("q")
history_reward = get_history("mean_reward")

In [None]:
plt.figure(figsize=(7, 5))

xmin = 0
xmax = len(history_q) - 1
plt.hlines([0], xmin, xmax, colors="black")
# plt.plot(history_lq1, label="Q loss")
plt.plot(history_q, label="Q value")
plt.plot(history_reward, label="Reward")
# plt.plot(history_el > 200)
# plt.plot(history_reward)
# plt.semilogy()
# plt.ylim(38, 40)

plt.xlim(xmin, xmax)
plt.grid()
plt.legend()
plt.title("SAC training metrics")
plt.xlabel("Epochs")
plt.tight_layout()
plt.show()