In [61]:
import os
import torch
import gymnasium

import bright_rl as brl

In [62]:
from torch.utils import tensorboard

experiment_name = "reinforce"
log_dir = os.path.join("../../tensorboard", experiment_name)
writer = tensorboard.SummaryWriter(log_dir=log_dir)

In [63]:
env = gymnasium.make("CartPole-v1")
env = gymnasium.wrappers.NormalizeObservation(env)

In [64]:
from torch import nn

policy = nn.Sequential(
    nn.Linear(env.observation_space.shape[0], 24),
    nn.ReLU(inplace=True),
    nn.Linear(24, 2),
)

In [65]:
agent = brl.reinforce.Reinforce(policy=policy, nb_episodes=1, optimizer=torch.optim.Adagrad, optimizer_parameters={"lr": 3e-2, "lr_decay": 1e-4})

def training(env, agent, nb_epochs):
    for epoch in range(nb_epochs):
        run_env(env, agent, epoch)

def run_env(env, agent, epoch):
    sum_reward = 0

    terminated = False
    obs, info = env.reset()

    while not terminated:
        action = agent.act(obs)

        new_obs, reward, terminated, truncated, info = env.step(action)

        terminated = terminated or truncated

        agent.observe(obs, action, reward, new_obs, terminated)

        obs = new_obs

        sum_reward += reward

    writer.add_scalar("train/reward", sum_reward, epoch)
    writer.add_scalar("train/loss", agent.loss, epoch)

In [66]:
training(env, agent, nb_epochs=300)