In [7]:
import os
import torch
import gymnasium

import bright_rl as brl

In [8]:
from torch.utils import tensorboard

experiment_name = "cartpole/dqn-2"
log_dir = os.path.join("../../tensorboard", experiment_name)
writer = tensorboard.SummaryWriter(log_dir=log_dir)

In [9]:
env = gymnasium.make("CartPole-v1")
env = gymnasium.wrappers.NormalizeObservation(env)

In [10]:
from torch import nn

policy = nn.Sequential(
    nn.Linear(env.observation_space.shape[0], 64),
    nn.ELU(inplace=True),
    nn.Linear(64, 64),
    nn.ELU(inplace=True),
    nn.Linear(64, 2),
)

In [11]:
import queue
import numpy as np

q = queue.Queue()

def training(env, agent, nb_epochs):
    rewards = queue.Queue(maxsize = 0)

    for epoch in range(nb_epochs):
        tt_reward = run_env(env, agent, epoch)

        rewards.put(tt_reward)

        writer.add_scalar("train/mean_reward", np.mean(list(rewards.queue)), epoch)

def run_env(env, agent, epoch):
    sum_reward = 0

    terminated = False
    obs, info = env.reset()

    while not terminated:
        action = agent.act(obs)

        new_obs, reward, terminated, truncated, info = env.step(action)

        terminated = terminated or truncated

        agent.observe(obs, action, reward, new_obs, terminated)

        obs = new_obs

        sum_reward += reward

    writer.add_scalar("train/reward", sum_reward, epoch)
    writer.add_scalar("train/epsilon", agent.epsilon, epoch)
    writer.add_scalar("train/loss", agent.loss, epoch)

    return sum_reward

In [None]:
agent = brl.dqn.DQN(policy=policy, learning_rate=3e-5)

training(env, agent, nb_epochs=20000)