## PG-based REINFORCE method to solve CartPole

$\text{Policy Gradient}: \min\limits_{\pi} \mathcal{L} := -\mathbb{E}_{\tau}[Q(s,a)\log \pi(a|s)], \tau=\{(s_i,a_i,r_i)\}$

$\text{REINFORCE: the vanilla kind of PG methods when}\;\; Q(s_t,a_t) = \sum\limits_{i=t}^T \gamma^{i} r_i$

In [18]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import gym
from torch.utils.tensorboard import SummaryWriter

### step0. load the env

In [24]:
env_name = "CartPole-v1"

env = gym.make(env_name)
env.observation_space.sample(), env.observation_space.shape, \
env.action_space.sample(), env.action_space.n

(array([-4.3404794e+00,  2.0558559e+38, -2.7077520e-01,  1.4372895e+38],
       dtype=float32),
 (4,),
 0,
 2)

In [8]:
def play(env_name, policy_func=None, max_steps=1000):
    env = gym.make(env_name, render_mode="human")
    obs, _ = env.reset()
    total_rewards = 0
    
    for step in range(max_steps):
        if policy_func:
            action = policy_func(obs)
        else:
            action = env.action_space.sample()
        obs, reward, terminated, truncated, _ = env.step(action)
        total_rewards += reward
        if terminated or truncated:
            print(f"The game is over with {step+1} steps and {total_rewards} epsiode reward")
            break
    else:
        print(f"The step limit has been reached at {max_steps} steps")
    
    env.close()
    return

In [10]:
play(env_name)

The game is over with 16 steps and 16.0 epsiode reward


### step1. build the model

In [13]:
# policy network: observation vector => action logits
class PolicyNet(nn.Module):
    def __init__(self, input_size, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
    
    def forward(self, x):
        logits = self.net(x)
        return logits # no softmax

In [28]:
# REINFORCE way to calculate Q values for each step in a epsiode
def cal_qvals(rewards, gamma=1.0):
    """
    rewards: [ri]
    res: [r_1, r_1+gamma*r_2, r_1+gamma*r_2+gamma^2*r_3,...]
    """
    res = []
    sum_r = 0.0
    for r in reversed(rewards):
        sum_r *= gamma
        sum_r += r
        res.append(sum_r)
    return list(reversed(res))

In [91]:
class REINFORCEAgent:
    """Agent class using REINFORCE algorithm"""
    def __init__(self, policy_net, device="cuda:0"):
        self.policy_net = policy_net.to(device)
        self.device = device
        
    def policy_func(self, obs):
        batch_mode = len(obs.shape) > 1 # there's a batch dim
        
        with torch.no_grad():
            # get tensorized observation
            if batch_mode:
                obs_t = torch.tensor(obs).to(self.device) # shape=(batch_size, obs_dim)
            else: # single observation
                obs_t = torch.tensor([obs]).to(self.device)
            # get logits
            action_logits_t = self.policy_net(obs_t) 
            # get probs
            action_probs_t = F.softmax(action_logits_t, dim=1) # shape=(batch_size, n_actions)
            # sample actions
            actions_t = Categorical(action_probs_t).sample() # shape=(batch_size,)
            if batch_mode:
                actions = actions_t.long().cpu().numpy()
            else:
                actions = actions_t.long().cpu().item()
            
            return actions
        
    def loss_func(self, batch_obs, batch_acts, batch_qvals):
        """loss = - Q(s,a) * logπ(a | s) = -∑γ^iri * logπ(a | s)"""
        # tensorize
        batch_obs_t = torch.tensor(batch_obs).to(self.device) # shape=(batch_size, obs_dim)
        batch_acts_t = torch.tensor(batch_acts, dtype=torch.int64).to(self.device).unsqueeze(1) # shape=(batch_size, 1)
        batch_qvals_t = torch.tensor(batch_qvals).to(self.device) # shape=(batch, )
        # forward to get action logits
        action_logits_t = self.policy_net(batch_obs_t)
        # calculate log pdf for all actions
        log_pdf_t = F.log_softmax(action_logits_t, dim=1) # shape = (batch_size, n_actions)
        #  gather the log probs for the executed actions
        log_probs_t = log_pdf_t.gather(1, batch_acts_t).squeeze(1) # shape = (batch_size, )
        # calculate the loss
        loss_t = -batch_qvals_t * log_probs_t # shape=(batch_size,)
        
        return loss_t.mean()
    
    def save(self, save_path):
        torch.save(self.policy_net.state_dict(), save_path)
    
    def __str__(self):
        return str(self.policy_net)

In [93]:
net = PolicyNet(
    input_size=env.observation_space.shape[0],
    n_actions=env.action_space.n
)
agent = REINFORCEAgent(net)
print(agent)

PolicyNet(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)


### step2. build the dataset generator

In [84]:
def iter_exp(env, policy_func=None):
    """return a 4-elem tuple (obs, action, reward, next_obs) as an experience for each step in one episode"""
    done = True
    
    while True:
        # initial a new episode
        if done:
            obs, _ = env.reset()
            done = False
        # get the action based on current observation
        if policy_func:
            action = policy_func(obs)
        else:
            action = env.action_space.sample()
        # step the env to get the next observation, reward, and game-over flag
        next_obs, reward, terminated, truncated, _ = env.step(action)
        yield obs, action, reward, next_obs, terminated
        
        if terminated:
            done = True
        obs = next_obs

In [101]:
def collect_exp_batch(env, policy_func=None, batch_episodes=16, gamma=1.0):
    """collect an experience batch of episodes consisting of obs, acts, qvals and episode rewards"""
    batch_obs, batch_acts, batch_qvals, batch_rewards = [], [], [], []
    episode_rewards = []
    episode = 0
    
    while episode < batch_episodes:
        for exp in iter_exp(env, policy_func):
            obs, act, reward, next_obs, done = exp
            batch_obs.append(obs)
            batch_acts.append(act)
            batch_rewards.append(reward)
            if done:
                # calculate qvals for this episode
                batch_qvals.extend(cal_qvals(batch_rewards, gamma=gamma))
                # calculate episode reward
                episode_rewards.append(float(np.sum(batch_rewards)))
                batch_rewards.clear()
                # get to next episode if it exists
                episode += 1
                break
    return batch_obs, batch_acts, batch_qvals, episode_rewards

### step3. train

In [102]:
def train(env_name, agent, writer, 
          max_epochs=500, batch_episodes=4, lr=0.01, gamma=0.99, reward_bound=195, recent=50):
    # init
    env = gym.make(env_name)
    optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=lr)
    episode_rewards, mean_reward = [], 0.0
    # loop
    for epoch in range(max_epochs):
        # get the batch data (obs, act, qvals), with epi_rewards for logging
        batch_data = collect_exp_batch(env, agent.policy_func, batch_episodes, gamma)
        ers = batch_data[-1]
        # optimize a step with this batch data
        optimizer.zero_grad()
        loss = agent.loss_func(*batch_data[:-1])
        loss.backward()
        optimizer.step()
        # log for each episode
        for er in ers:
            episode_rewards.append(er)
            episode_idx = len(episode_rewards)
            mean_reward = float(np.mean(episode_rewards[-recent:]))
            # log each episode reward
            writer.add_scalar("episode reward", er, episode_idx)
            # log the mean reward of last recent episodes
            writer.add_scalar("mean reward", mean_reward, episode_idx)
        # print for each epoch
        print(f"epoch {epoch} => loss: {loss.item()} | mean reward: {mean_reward}")
        if mean_reward > reward_bound: # good enough
            break
    # close
    env.close()
    writer.close()

In [87]:
writer = SummaryWriter(comment="-PG_REINFORCE_CartPole")

In [103]:
train(env_name, agent, writer)

epoch 0 => loss: 8.235739707946777 | mean reward: 41.0
epoch 1 => loss: 7.017628192901611 | mean reward: 36.625
epoch 2 => loss: 6.653279781341553 | mean reward: 34.666666666666664
epoch 3 => loss: 5.73574161529541 | mean reward: 32.75
epoch 4 => loss: 5.891140937805176 | mean reward: 31.75
epoch 5 => loss: 6.923009872436523 | mean reward: 31.208333333333332
epoch 6 => loss: 6.854440212249756 | mean reward: 30.607142857142858
epoch 7 => loss: 5.16597843170166 | mean reward: 29.90625
epoch 8 => loss: 4.680619716644287 | mean reward: 29.083333333333332
epoch 9 => loss: 5.615002632141113 | mean reward: 28.925
epoch 10 => loss: 5.906534671783447 | mean reward: 28.886363636363637
epoch 11 => loss: 5.32065486907959 | mean reward: 28.5625
epoch 12 => loss: 5.779853820800781 | mean reward: 28.24
epoch 13 => loss: 6.911384105682373 | mean reward: 27.94
epoch 14 => loss: 5.245478630065918 | mean reward: 27.32
epoch 15 => loss: 8.976016998291016 | mean reward: 28.24
epoch 16 => loss: 8.3384342193

### step4. test

In [104]:
play(env_name, agent.policy_func)

The game is over with 454 steps and 454.0 epsiode reward


In [105]:
# save the net
save_path =  "./ckpt/REINFORCE-CartPolev1-r213.pth"
agent.save(save_path)