In [None]:
import os
import torch
import gymnasium
from queue import deque

from torch.utils.tensorboard import SummaryWriter

import brl

experiment_name = "reinforce_1"
log_dir = os.path.join("../tensorboard", experiment_name)
writer = SummaryWriter(log_dir=log_dir)

class MyPolicy(torch.nn.Module):

    def __init__(self, n_observations, n_actions, epsilon_start=0.9, epsilon_end=0.05, epsilon_decay=1-1e-6):
        super(MyPolicy, self).__init__()
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        self.n_actions = n_actions

        self.layer1 = torch.nn.Linear(n_observations, 12)
        self.layer2 = torch.nn.Linear(12, n_actions)

    def forward(self, x):
        x = torch.nn.functional.relu(self.layer1(x))
        x = self.layer2(x)
        return x

    """def act(self, state):
        state = torch.from_numpy(state)
        probabilities = torch.nn.functional.softmax(self.forward(state), dim=0)
        action = torch.multinomial(probabilities, num_samples=1).item()
        return action, probabilities # m.log_prob(action)"""

    @torch.no_grad()
    def act(self, state):
        state = torch.from_numpy(state)
        probabilities = torch.nn.functional.softmax(self.forward(state), dim=0)
        action = torch.multinomial(probabilities, num_samples=1).item()
        return action

    """@torch.no_grad()
    def act(self, state):
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_end)
        random_sample = torch.rand(1).item()

        if random_sample < self.epsilon:
            action = torch.randint(0, self.n_actions, (1,))
        else:
            state = torch.from_numpy(state)
            q_values = torch.nn.functional.softmax(self.forward(state))
            action = torch.argmax(q_values, dim=-1)

        return action.long().squeeze().item()"""

def _init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.uniform_(m.weight)
        m.bias.data.fill_(0.01)

env = gymnasium.make("CartPole-v1")

policy = MyPolicy(n_observations=env.observation_space.shape[0], n_actions=2)
policy.apply(_init_weights)
# policy.compile(fullgraph=True)

agent = brl.reinforce.Reinforce(policy=policy, learning_rate=6e-3)

def training(env, agent, nb_epochs):
    q = deque(maxlen = 1000)

    for epoch in range(nb_epochs):
        reward = run_env(env, agent, epoch)
        q.append(reward)

        writer.add_scalar("train/mean_reward", sum(q) / len(q), epoch)
        writer.add_scalar("train/epoch", epoch, epoch)

def run_env(env, agent, epoch):
    sum_reward = 0

    terminated = False
    obs, info = env.reset()

    while not terminated:
        action = agent.act(obs)

        new_obs, reward, terminated, truncated, info = env.step(action)

        terminated = terminated or truncated

        agent.observe(obs, action, reward, new_obs)

        obs = new_obs

        sum_reward += reward

    agent.optimize()

    writer.add_scalar("train/loss", agent.loss, epoch)
    writer.add_scalar("train/epsilon", agent.policy.epsilon, epoch)

    return sum_reward

sum_reward = training(env, agent, nb_epochs=10000)
