In [None]:
import os
import torch
import gymnasium

import brl

In [None]:
experiment_name = "reinforce_4"
log_dir = os.path.join("../tensorboard", experiment_name)
writer = torch.utils.tensorboard.SummaryWriter(log_dir=log_dir)

class MyPolicy(torch.nn.Module):

    def __init__(self, n_observations, n_actions):
        super(MyPolicy, self).__init__()

        self.layer1 = torch.nn.Linear(n_observations, 64)
        self.layer2 = torch.nn.Linear(64, n_actions)

    def forward(self, x):
        x = torch.nn.functional.relu(self.layer1(x))
        x = self.layer2(x)
        return x

    def act(self, state):
        state = torch.from_numpy(state)
        probabilities = torch.nn.functional.softmax(self.forward(state), dim=0)
        action = torch.multinomial(probabilities, num_samples=1).item()
        return action

def _init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.uniform_(m.weight)
        m.bias.data.fill_(0.01)

env = gymnasium.make("CartPole-v1")

policy = MyPolicy(n_observations=env.observation_space.shape[0], n_actions=2)
policy.apply(_init_weights)

agent = brl.reinforce.Reinforce(policy=policy)

def training(env, agent, nb_epochs):
    for epoch in range(nb_epochs):
        run_env(env, agent, epoch)

def run_env(env, agent, epoch):
    sum_reward = 0

    terminated = False
    obs, info = env.reset()

    while not terminated:
        action = agent.act(obs)

        new_obs, reward, terminated, truncated, info = env.step(action)

        terminated = terminated or truncated

        agent.observe(obs, action, reward, new_obs)

        obs = new_obs

        sum_reward += reward

    agent.optimize()

    writer.add_scalar("train/reward", sum_reward, epoch)
    writer.add_scalar("train/loss", agent.loss, epoch)

In [4]:
training(env, agent, nb_epochs=1000)

In [None]:
env = gymnasium.make("CartPole-v1")