In [30]:
import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch.optim import AdamW, RMSprop
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

from stable_baselines3.common.vec_env import VecEnv


In [31]:
class PolicyNet(nn.Module):
    def __init__(self, nvec_s: int, nvec_u: int):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(nvec_s, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, nvec_u)


    def forward(self, x, deterministic = False):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = self.fc3(x)
        dist = torch.distributions.Categorical(logits=x)
        if deterministic:
            action = torch.argmax(x)
            return action
        action = dist.sample()
        entropy = dist.entropy()
        log_prob = dist.log_prob(action)
        return action, log_prob, entropy
    
    def evaluate_actions(self,states, actions):
        x = F.tanh(self.fc1(states))
        x = F.tanh(self.fc2(x))
        x = self.fc3(x)
        dist = torch.distributions.Categorical(logits=x)
        log_prob = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_prob, entropy
    


In [32]:
class ValueNet(nn.Module):
    def __init__(self, n_features, n_hidden):
        super(ValueNet, self).__init__()
        self.fc1 = nn.Linear(n_features, n_hidden)
        self.fc2 = nn.Linear(n_hidden,n_hidden)
        self.fc3 = nn.Linear(n_hidden, 1)

    def forward(self, x) -> torch.Tensor:
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        return self.fc3(x)
    

In [33]:
class RolloutBuffer:

    def __init__(self, env:VecEnv, capacity, gamma, gae_lambda, device:torch.device, num_envs):
        self.env = env
        self.gamma = gamma
        self.gae_lambda = gae_lambda

        self.position = 0
        self.size = 0
        self.capacity = capacity

        self.n_actions = env.action_space.n 
        self.n_states = env.observation_space.shape[0]

        self.states = np.zeros((capacity,num_envs, self.n_states), dtype=np.float32)
        self.actions = np.zeros((capacity,num_envs), dtype=np.float32)
        self.rewards = np.zeros((capacity,num_envs), dtype=np.float32)
        self.dones = np.zeros((capacity,num_envs), dtype=np.float32)
        self.values = np.zeros((capacity,num_envs), dtype=np.float32)
        self.log_probs = np.zeros((capacity,num_envs), dtype=np.float32)
        self.entropy = np.zeros((capacity,num_envs), dtype=np.float32)
        self.advantages = np.zeros((capacity,num_envs), dtype=np.float32)
        self.retuns = np.zeros((capacity,num_envs), dtype=np.float32) 

        self.reset()

    def push(self, state:np.ndarray, action, reward, done, value:torch.Tensor, log_prob:torch.Tensor, entropy):
        self.states[self.position] = state
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.dones[self.position] = done
        self.values[self.position] = value.clone().cpu().numpy().flatten()
        self.log_probs[self.position] = log_prob.clone().cpu().numpy().flatten()
        self.entropy[self.position] = entropy

        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def reset(self):
        self.position = 0
        self.size = 0

    def compute_advantages_and_returns(self, last_value:torch.Tensor, dones):
        last_value = last_value.cpu().numpy().flatten()
        last_gae = 0
        for step in reversed(range(len(self.rewards))):
            if step == len(self.rewards) - 1:
                next_value = last_value
                next_non_terminal = 1 - dones
            else:
                next_value = self.values[step + 1]
                next_non_terminal = 1 - self.dones[step + 1]   
            
            delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
            last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
            self.advantages[step]= last_gae

        self.returns = self.advantages + self.values


In [34]:
class A2C:
    def __init__(
        self,
        env: VecEnv,
        lr,  # learning rate
        gamma,  # discount factor
        gae_lambda,  # Generalized Advantage Estimation lambda
        max_steps,  # max steps for training
        n_steps,  # number of steps to run before updating
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        self.env = env
        self.num_envs = env.num_envs
        self.lr = lr
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.max_steps = max_steps
        self.n_steps = n_steps
        self.device = device

        self.rollout_buffer = RolloutBuffer(
            env, n_steps, gamma, gae_lambda, device, num_envs=self.num_envs
        )
        self.policy_net = PolicyNet(env.observation_space.shape[0], env.action_space.n)
        self.value_net = ValueNet(env.observation_space.shape[0], 128)
        self.optimizer_policy = RMSprop(self.policy_net.parameters(), lr=self.lr, eps=1e-5)
        self.optimizer_value = RMSprop(self.value_net.parameters(), lr=self.lr, eps=1e-5)

        self.total_steps = 0

        self.last_state = self.env.reset()

        # stats
        self.episodes = 0
        self.total_rewards = np.zeros(self.env.num_envs)
        self.mean_episode_reward = 0

    def rollout(self):

        for _ in range(self.n_steps):

            with torch.no_grad():
                action, log_prob, entropy = self.policy_net(
                    torch.from_numpy(self.last_state).float().to(self.device)
                )
                value = self.value_net(
                    torch.tensor(self.last_state).float().to(self.device)
                )

            next_state, rewards, dones, infos = self.env.step(action.detach().numpy())

            # check if done is because of truncation
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = infos[idx]["terminal_observation"]
                    with torch.no_grad():
                        terminal_value = self.value_net(torch.tensor(terminal_obs).float().to(self.device))  # type: ignore[arg-type]
                    rewards[idx] += self.gamma * terminal_value

            self.rollout_buffer.push(
                self.last_state, action, rewards, done, value, log_prob, entropy
            )

            self.total_rewards += rewards
            self.total_steps += 1
            self.pbar.update(1)

            if dones[0]:
                self.last_state = self.env.reset()
                self.episodes += 1

                if self.episodes % 100 == 0:

                    self.mean_episode_reward = (
                        self.total_rewards[0]
                    ) / self.episodes
                    self.pbar.set_description(
                        f"Reward: {self.mean_episode_reward :.3f}"
                    )
                    # self.writer.add_scalar("reward", self.total_rewards, self.episodes)
                    self.episodes = 0
                    self.total_rewards = np.zeros(self.env.num_envs)

            self.last_state = next_state

        with torch.no_grad():
            last_value = self.value_net(
                torch.tensor(self.last_state).float().to(self.device)
            )

        self.rollout_buffer.compute_advantages_and_returns(last_value, dones)

    def learn(self):

        log_prob, entropy = self.policy_net.evaluate_actions(
            torch.tensor(self.rollout_buffer.states).float().to(self.device),
            torch.tensor(self.rollout_buffer.actions).to(self.device),
        )

        values = self.value_net(
            torch.tensor(self.rollout_buffer.states).float().to(self.device)
        ).squeeze()

        # advantages = self.rollout_buffer.advantages
        
        advantages = (self.rollout_buffer.advantages - self.rollout_buffer.advantages.mean())/ (self.rollout_buffer.advantages.std() + 1e-5)

        self.policy_loss = -torch.mean(
            torch.from_numpy(advantages) * log_prob
        )

        self.entropy_loss = -torch.mean(entropy)
        self.policy_loss = self.policy_loss + 0.0 * self.entropy_loss

        self.value_loss = F.mse_loss(
            torch.from_numpy(self.rollout_buffer.returns), values
        )

        self.optimizer_policy.zero_grad()
        self.policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
        self.optimizer_policy.step()

        self.optimizer_value.zero_grad()
        self.value_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
        self.optimizer_value.step()

    def train(self):
        self.pbar = tqdm(total=self.max_steps, position=0, leave=True)
        # self.writer = SummaryWriter(log_dir="runs/reinforce_logs/REINFORCE_BASELINE_2")

        while self.total_steps < self.max_steps:
            self.rollout_buffer.reset()
            self.rollout()
            self.learn()

In [35]:
from stable_baselines3.common.env_util import make_vec_env

def make_env():
    env = gym.make("CartPole-v1")
    # env = gym.wrappers.time_limit.TimeLimit(env, max_episode_steps=500)

    return env

In [36]:
env = make_vec_env(make_env, n_envs=8)

In [37]:
agent = A2C(
    env=env,
    lr=0.0001,
    gamma=0.99,
    gae_lambda=0.9,
    max_steps=500000,
    n_steps=8,
)

In [38]:
agent.train()

Reward: 9.430:  17%|█▋        | 84190/500000 [1:03:01<5:11:15, 22.27it/s]
Reward: 64.770:  29%|██▉       | 144667/500000 [57:40<2:21:39, 41.80it/s]
Reward: 55.200:  36%|███▌      | 178688/500000 [47:22<1:25:11, 62.86it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

env = gym.make("CartPole-v1",
               render_mode='human'
               )

n_episodes = 100
for _ in range(n_episodes):
    obs, info = env.reset()
    terminated = False
    truncated = False
    while not terminated and not truncated:
        with torch.no_grad():
            action = agent.policy_net(torch.from_numpy(obs).float().to(agent.device), deterministic=True).item()
            obs, reward, terminated,  truncated, info = env.step(action)
            env.render()


KeyboardInterrupt: 

Reward: 55.200:  36%|███▌      | 178688/500000 [10:02<18:36, 287.69it/s]