[Hands on RL Policy Gradient](https://github.com/PacktPublishing/Hands-on-Reinforcement-Learning-with-PyTorch/blob/master/Section%204/4.3%20Policy%20Gradients%20REINFORCE.ipynb)

[Policy Gradient Math](https://towardsdatascience.com/policy-gradients-in-reinforcement-learning-explained-ecec7df94245)

A widely used variation of REINFORCE is to subtract a baseline value from the return to reduce the variance of gradient estimation while keeping the bias unchanged (Remember we always want to do this when possible). For example, a common baseline is to subtract state-value from action-value, and if applied, we would use advantage:

$$
A(s,a) = Q(s,a) - V(s)
$$

in the gradient ascent update. This [post](https://danieltakeshi.github.io/2017/03/28/going-deeper-into-reinforcement-learning-fundamentals-of-policy-gradients/) nicely explained why a baseline works for reducing the variance, in addition to a set of fundamentals of policy gradient.

In [None]:
#!pip install swig
#!pip install gymnasium[box2d]

## Actor Critic

![Reinforce_bl](acritic.png) 

In [26]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np
from IPython.display import clear_output
from collections import deque

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
env_id = "LunarLander-v2"
env = gym.make(env_id)#,render_mode="human")

s_size = env.observation_space.shape[0]
a_size = env.action_space.n

print("_____OBSERVATION SPACE_____ \n")
print("The State Space is: ", s_size)
print("Sample observation", env.observation_space.sample()) # Get a random observation

_____OBSERVATION SPACE_____ 

The State Space is:  8
Sample observation [ 0.18763174 -1.4121945   2.0076056  -3.8539116   0.7413541  -1.8592845
  0.39006355  0.3410315 ]


In [30]:
env = gym.make('LunarLander-v2')#, render_mode="human")
#gym.make(env_id,render_mode="human")
env.reset()

#prev_screen = env.render()
#plt.imshow(prev_screen)

for i in range(50):
    env.render()
    action = env.action_space.sample()
    obs, reward, done, info, _ = env.step(action)
    if done:
        env.reset()
        
env.close()

In [31]:
class ActorNet(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(ActorNet, self).__init__()
        self.dense_layer_1 = nn.Linear(state_size, hidden_size)
        self.dense_layer_2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, action_size)
    
    def forward(self, x):
        x = torch.clamp(x,-1.1,1.1)
        x = F.relu(self.dense_layer_1(x))
        x = F.relu(self.dense_layer_2(x))
        return F.softmax(self.output(x),dim=-1) + 1e-8 #-1 to take softmax of last dimension
    
class CriticNet(nn.Module):
    def __init__(self, state_size, hidden_size):
        super(CriticNet, self).__init__()
        self.dense_layer_1 = nn.Linear(state_size, hidden_size)
        self.dense_layer_2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = torch.clamp(x,-1.1,1.1)
        x = F.relu(self.dense_layer_1(x))
        x = F.relu(self.dense_layer_2(x))
        return self.output(x)

In [32]:
class ActorCriticAgent():
    def __init__(self, state_size, action_size, hidden_size, actor_lr, critic_lr, discount ):
        self.action_size = action_size
        self.actor_net = ActorNet(state_size, action_size, hidden_size).to(device)
        self.critic_net = CriticNet(state_size, hidden_size).to(device)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic_net.parameters(), lr=critic_lr)
        self.discount = discount
        
    def select_action(self, state):
        #get action probs then randomly sample from the probabilities
        with torch.no_grad():
            input_state = torch.FloatTensor(state).to(device)
            action_probs = self.actor_net(input_state)
            #detach and turn to numpy to use with np.random.choice()
            action_probs = action_probs.detach().cpu().numpy()
            action = np.random.choice(np.arange(self.action_size), p=action_probs)
        return action

In [33]:
hidden_layer = 64
gamma = 0.995
actor_lr = 0.001
critic_lr = 0.001
episodes = 100_000
avg_win_size = 50
epi_results = deque(maxlen=avg_win_size)

# create agent
agent = ActorCriticAgent(s_size, a_size, hidden_layer, actor_lr, critic_lr, gamma)

In [35]:
s = env.reset()[0]
obs = torch.FloatTensor(np.expand_dims(s,0)).to(device)
p_vals = agent.actor_net(obs)

In [36]:
p_vals.sum()

tensor(1., device='cuda:0', grad_fn=<SumBackward0>)

In [37]:
for epi in range(10):

    s = env.reset()[0]
    done , trunc = False, False
    states, rewards, next_states , actions, dones  = [], [], [], [], []
    win = 0

    while not any([done, trunc]):

        states.append(s)
        obs = torch.FloatTensor(np.expand_dims(s,0)).to(device)

        with torch.no_grad():
            p_vals = agent.actor_net(obs)
            p_vals = torch.squeeze(p_vals)

        p_vals = p_vals.detach().cpu().numpy()
        a = np.random.choice(a_size, p=p_vals)

        s_, r, done ,trunc, _  = env.step(a)
        actions.append(a)
        rewards.append(r)
        next_states.append(s_)
        dones.append(int(done))
        s=np.copy(s_)
        
        
        

  state_t = torch.FloatTensor(states).to(device)


In [41]:
state_t

tensor([[ 6.4694e-03,  1.4216e+00,  6.5528e-01,  4.7632e-01, -7.4898e-03,
         -1.4843e-01,  0.0000e+00,  0.0000e+00],
        [ 1.2873e-02,  1.4318e+00,  6.4615e-01,  4.5112e-01, -1.3170e-02,
         -1.1361e-01,  0.0000e+00,  0.0000e+00],
        [ 1.9342e-02,  1.4413e+00,  6.5419e-01,  4.2409e-01, -2.0457e-02,
         -1.4577e-01,  0.0000e+00,  0.0000e+00],
        [ 2.5810e-02,  1.4503e+00,  6.5421e-01,  3.9742e-01, -2.7743e-02,
         -1.4573e-01,  0.0000e+00,  0.0000e+00],
        [ 3.2364e-02,  1.4586e+00,  6.6488e-01,  3.7056e-01, -3.7164e-02,
         -1.8844e-01,  0.0000e+00,  0.0000e+00],
        [ 3.8918e-02,  1.4664e+00,  6.6491e-01,  3.4389e-01, -4.6583e-02,
         -1.8839e-01,  0.0000e+00,  0.0000e+00],
        [ 4.5531e-02,  1.4735e+00,  6.7241e-01,  3.1668e-01, -5.7506e-02,
         -2.1848e-01,  0.0000e+00,  0.0000e+00],
        [ 5.2069e-02,  1.4800e+00,  6.6287e-01,  2.9029e-01, -6.6506e-02,
         -1.8001e-01,  0.0000e+00,  0.0000e+00],
        [ 5.8522