In [None]:
import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
from torch.optim import RMSprop, AdamW
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

from stable_baselines3.common.vec_env import VecEnv


In [None]:
from functools import partial


class PolicyNet(nn.Module):
    def __init__(self, nvec_s: int, nvec_u: int, hidden_dim):
        super(PolicyNet, self).__init__()

        self.policy_feature_extractor = nn.Sequential(
            nn.Linear(nvec_s, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )
        self.value_feature_extractor = nn.Sequential(
            nn.Linear(nvec_s, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )
        self.action_net = nn.Linear(hidden_dim, nvec_u)
        self.value_net = nn.Linear(hidden_dim, 1)

        module_gains = {
                self.policy_feature_extractor: np.sqrt(2),
                self.value_feature_extractor: np.sqrt(2),
                self.action_net: 0.01,
                self.value_net: 1,
            }
        def init_weights(module: nn.Module, gain: float = 1) -> None:
            """
            Orthogonal initialization (used in PPO and A2C)
            """
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                nn.init.orthogonal_(module.weight, gain=gain)
                if module.bias is not None:
                    module.bias.data.fill_(0.0)

        for module, gain in module_gains.items():
                module.apply(partial(init_weights, gain=gain))


    def forward(self, x, deterministic = False):
        action_logits = self.action_net(self.policy_feature_extractor(x))
        value = self.value_net(self.value_feature_extractor(x))
        dist = torch.distributions.Categorical(logits=action_logits)
        if deterministic:
            action = torch.argmax(action_logits)
            return action
        action = dist.sample()
        entropy = dist.entropy()
        log_prob = dist.log_prob(action)
        return value, action, log_prob, entropy
    
    def get_value(self, x):
       return self.value_net(self.value_feature_extractor(x))
    
    def evaluate_actions(self, states, actions):
        action_logits = self.action_net(self.policy_feature_extractor(states))
        dist = torch.distributions.Categorical(logits=action_logits)
        log_prob = dist.log_prob(actions)
        entropy = dist.entropy()
        values = self.value_net(self.value_feature_extractor(states))
        return values, log_prob, entropy


In [None]:
class RolloutBuffer:

    def __init__(self, env:VecEnv, capacity, gamma, gae_lambda, device:torch.device, num_envs):
        self.env = env
        self.gamma = gamma
        self.gae_lambda = gae_lambda

        self.position = 0
        self.size = 0
        self.capacity = capacity

        self.n_actions = env.action_space.n 

        if isinstance(env.observation_space, gym.spaces.Box):
         self.n_states = env.observation_space.shape[0]
        elif isinstance(env.observation_space, gym.spaces.Discrete):
            self.n_states = 1

        self.states = np.zeros((capacity,num_envs, self.n_states), dtype=np.float32)
        self.actions = np.zeros((capacity,num_envs), dtype=np.float32)
        self.rewards = np.zeros((capacity,num_envs), dtype=np.float32)
        self.dones = np.zeros((capacity,num_envs), dtype=np.float32)
        self.values = np.zeros((capacity,num_envs), dtype=np.float32)
        self.log_probs = np.zeros((capacity,num_envs), dtype=np.float32)
        self.advantages = np.zeros((capacity,num_envs), dtype=np.float32)
        self.retuns = np.zeros((capacity,num_envs), dtype=np.float32) 

        self.reset()

    def push(self, state:np.ndarray, action, reward, done, value:torch.Tensor, log_prob:torch.Tensor):
        self.states[self.position] = state
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.dones[self.position] = done
        self.values[self.position] = value.clone().cpu().numpy().flatten()
        self.log_probs[self.position] = log_prob.clone().cpu().numpy().flatten()

        self.position = (self.position + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def reset(self):
        self.position = 0
        self.size = 0

    def compute_advantages_and_returns(self, last_value:torch.Tensor, dones):
        last_value = last_value.clone().cpu().numpy().flatten()
        last_gae = 0
        for step in reversed(range(len(self.rewards))):
            if step == len(self.rewards) - 1:
                next_value = last_value
                next_non_terminal = 1 - dones
            else:
                next_value = self.values[step + 1]
                next_non_terminal = 1 - self.dones[step + 1]   
            
            delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
            last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
            self.advantages[step]= last_gae

        self.returns = self.advantages + self.values


In [None]:
class A2C:
    def __init__(
        self,
        env: VecEnv,
        lr,  # learning rate
        gamma,  # discount factor
        gae_lambda,  # Generalized Advantage Estimation lambda
        vf_coef,  # value function coefficient
        ent_coef,  # entropy coefficient
        max_steps,  # max steps for training
        n_steps,  # number of steps to run before updating
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        self.env = env
        self.num_envs = env.num_envs
        self.lr = lr
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef
        self.max_steps = max_steps
        self.n_steps = n_steps
        self.device = device

        self.rollout_buffer = RolloutBuffer(
            env, n_steps, gamma, gae_lambda, device, num_envs=self.num_envs
        )

        if isinstance(env.observation_space, gym.spaces.Box):
            self.n_states = env.observation_space.shape[0]
        elif isinstance(env.observation_space, gym.spaces.Discrete):
            self.n_states = 1
        self.policy_net = PolicyNet(self.n_states, env.action_space.n, 64)
        self.optimizer_policy = RMSprop(
            self.policy_net.parameters(),
            lr=self.lr,
            eps=1e-5,
            weight_decay=0,
            alpha=0.99,
        )

        self.total_steps = 0

        self.last_state = self.env.reset()
        self.last_episode_starts = None

        # stats
        self.episodes = 0
        self.total_rewards = np.zeros(self.env.num_envs)
        self.mean_episode_reward = 0

    def rollout(self):
        self.policy_net.eval()
        for _ in range(self.n_steps):

            with torch.no_grad():
                value, action, log_prob, entropy = self.policy_net(
                    torch.from_numpy(self.last_state).float().to(self.device)
                )

            next_state, rewards, dones, infos = self.env.step(action.detach().numpy())

            # check if done is because of truncation
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = infos[idx]["terminal_observation"]
                    with torch.no_grad():
                        terminal_value = self.policy_net.get_value(torch.tensor(terminal_obs).float().to(self.device))  
                    rewards[idx] += self.gamma * terminal_value

            self.rollout_buffer.push(
                self.last_state, action, rewards, self.last_episode_starts, value, log_prob
            )

            self.total_rewards += rewards
            self.total_steps += 1
            self.pbar.update(1)

            if dones[0]:
                self.episodes += 1
                if self.episodes % 10 == 0:

                    self.mean_episode_reward = (self.total_rewards[0]) / self.episodes
                    self.pbar.set_description(
                        f"Reward: {self.mean_episode_reward :.3f} || Entropy: {entropy.mean():.3f}"
                    )
                    self.writer.add_scalar("Reward", self.mean_episode_reward, self.total_steps)
                    self.episodes = 0
                    self.total_rewards = np.zeros(self.env.num_envs)

            self.last_state = next_state
            self.last_episode_starts = dones

        with torch.no_grad():
            last_value = self.policy_net.get_value(
                torch.tensor(self.last_state).float().to(self.device)
            )

        self.rollout_buffer.compute_advantages_and_returns(last_value, dones)

    def learn(self):

        self.policy_net.train()
        values, log_prob, entropy = self.policy_net.evaluate_actions(
            torch.tensor(self.rollout_buffer.states).float().to(self.device),
            torch.tensor(self.rollout_buffer.actions),
        )

        advantages = self.rollout_buffer.advantages

        values = values.squeeze()

        self.policy_loss = -torch.mean(torch.tensor(advantages) * log_prob)

        self.entropy_loss = -torch.mean(entropy)
        self.value_loss = F.mse_loss(
            torch.tensor(self.rollout_buffer.returns), values
        )

        self.tot_policy_loss = self.policy_loss + self.vf_coef * self.value_loss + self.ent_coef * self.entropy_loss
        
        
        if self.total_steps % 100 == 0:
            self.pbar.set_postfix(
                policy_loss=self.policy_loss.item(),
                value_loss=self.value_loss.item(),
                entropy_loss=self.entropy_loss.item(),
                lr=self.optimizer_policy.param_groups[0]["lr"],
            )

        self.optimizer_policy.zero_grad()
        self.tot_policy_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
        self.optimizer_policy.step()

    def train(self):
        self.pbar = tqdm(total=self.max_steps, position=0, leave=True)
        self.writer = SummaryWriter(log_dir="runs/a2c_logs/A2C")

        while self.total_steps < self.max_steps:
            self.rollout_buffer.reset()
            self.rollout()
            self.learn()

In [None]:
from stable_baselines3.common.env_util import make_vec_env

def make_env():
    env = gym.make("CartPole-v1")
    return env

In [None]:
env = make_vec_env(make_env, n_envs=8)

In [None]:
agent = A2C(
    env=env,
    lr=7e-4,
    gamma=0.99,
    gae_lambda=1.0,
    vf_coef=1.0,
    ent_coef=0.001,
    max_steps=25000,
    n_steps=5,
)

In [None]:
agent.train()

In [None]:
env = gym.make("CartPole-v1",
               render_mode='human'
               )

n_episodes = 10
for _ in range(n_episodes):
    obs, info = env.reset()
    terminated = False
    truncated = False
    tot_reward=0
    while not terminated and not truncated:
        with torch.no_grad():
            action = agent.policy_net(torch.from_numpy(obs).float().to(agent.device), deterministic=True).item()
            obs, reward, terminated,  truncated, info = env.step(action)
            tot_reward += reward
            env.render()
    print(f"Episode reward: {tot_reward}", end="\r")

In [None]:
# torch.save(agent.policy_net.state_dict(), "../models/a2c/a2c_cartpole.pth")