In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
import gym

In [3]:
class policynet(nn.Module):
    def __init__(self, s, h, a): # s-state space, hidden layer size, action space
        super(policynet, self).__init__()
        self.hl = nn.Linear(s,h)
        self.out = nn.Linear(h,a)
        
    def forward(self,x):
        x = F.relu(self.hl(x))
        x = self.out(x)
        
        #x = F.softmax(self.out(x),dim=1)
        # we can include softmax in network if we want
        return x
            
        

In [4]:
env_name='CartPole-v0'

# hyperparameters

lr=5e-3
epochs=200
batch_size=1000

In [5]:

# make environment
env = gym.make(env_name)

# get diensions for policy network
obs_dim = env.observation_space.shape[0]
n_acts = env.action_space.n

# make core of policy network
logits_net = policynet(obs_dim, 32, n_acts)

# this function gets action distribution
def get_policy(obs):
    logits = logits_net(obs)
    return Categorical(logits=logits)

# make action selection function (outputs int actions, sampled from policy)
def get_action(obs):
    return get_policy(obs).sample().item()

# compute loss, for VPG is log prob times advantage (value - reward). Instead we just use 
#reward, R(tau) since it works well enough
def compute_loss(obs, act, weights):
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()

# adam optimizer
optimizer = torch.optim.Adam(logits_net.parameters(), lr=lr)

def train_one_epoch():
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # collect experience by acting in the environment with current policy
    while True:
        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _ = env.step(act)

        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)

        # check if episode is done, then record some info
        if done:
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)

            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len

            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
    # typical pytorch optimization loop
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(np.array(batch_obs), dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    batch_loss.backward()
    optimizer.step()
    return batch_loss, batch_rets, batch_lens

for i in range(epochs):
    batch_loss, batch_rets, batch_lens = train_one_epoch()
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
            (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))



epoch:   0 	 loss: 15.844 	 return: 19.385 	 ep_len: 19.385
epoch:   1 	 loss: 18.628 	 return: 21.319 	 ep_len: 21.319
epoch:   2 	 loss: 17.797 	 return: 21.723 	 ep_len: 21.723
epoch:   3 	 loss: 18.224 	 return: 21.146 	 ep_len: 21.146
epoch:   4 	 loss: 21.776 	 return: 24.095 	 ep_len: 24.095
epoch:   5 	 loss: 18.238 	 return: 21.312 	 ep_len: 21.312
epoch:   6 	 loss: 23.565 	 return: 25.425 	 ep_len: 25.425
epoch:   7 	 loss: 23.906 	 return: 28.800 	 ep_len: 28.800
epoch:   8 	 loss: 20.368 	 return: 24.732 	 ep_len: 24.732
epoch:   9 	 loss: 26.749 	 return: 29.647 	 ep_len: 29.647
epoch:  10 	 loss: 22.654 	 return: 28.389 	 ep_len: 28.389
epoch:  11 	 loss: 28.063 	 return: 31.844 	 ep_len: 31.844
epoch:  12 	 loss: 28.426 	 return: 30.441 	 ep_len: 30.441
epoch:  13 	 loss: 35.341 	 return: 40.480 	 ep_len: 40.480
epoch:  14 	 loss: 28.355 	 return: 33.258 	 ep_len: 33.258
epoch:  15 	 loss: 36.083 	 return: 40.360 	 ep_len: 40.360
epoch:  16 	 loss: 29.427 	 return: 36.0

epoch: 136 	 loss: 77.735 	 return: 144.429 	 ep_len: 144.429
epoch: 137 	 loss: 95.010 	 return: 191.333 	 ep_len: 191.333
epoch: 138 	 loss: 87.069 	 return: 176.667 	 ep_len: 176.667
epoch: 139 	 loss: 86.498 	 return: 170.167 	 ep_len: 170.167
epoch: 140 	 loss: 95.855 	 return: 194.167 	 ep_len: 194.167
epoch: 141 	 loss: 90.016 	 return: 178.500 	 ep_len: 178.500
epoch: 142 	 loss: 92.532 	 return: 178.000 	 ep_len: 178.000
epoch: 143 	 loss: 93.873 	 return: 193.833 	 ep_len: 193.833
epoch: 144 	 loss: 92.849 	 return: 184.167 	 ep_len: 184.167
epoch: 145 	 loss: 92.632 	 return: 181.833 	 ep_len: 181.833
epoch: 146 	 loss: 90.206 	 return: 185.833 	 ep_len: 185.833
epoch: 147 	 loss: 88.692 	 return: 182.833 	 ep_len: 182.833
epoch: 148 	 loss: 92.416 	 return: 191.000 	 ep_len: 191.000
epoch: 149 	 loss: 93.158 	 return: 191.333 	 ep_len: 191.333
epoch: 150 	 loss: 93.843 	 return: 181.167 	 ep_len: 181.167
epoch: 151 	 loss: 94.061 	 return: 199.000 	 ep_len: 199.000
epoch: 1

In [6]:
# test model 

# use this if you want to save video output, otherwise comment out
env = gym.wrappers.Monitor(gym.make('CartPole-v0'), './', force=True)

# test our agent
for i_episode in range(1):
    
    observation = env.reset()
    for t in range(500):
        env.render()
        #print(observation)
        action = get_action(torch.as_tensor(observation, dtype=torch.float32))
        observation, reward, done, info = env.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break
env.close()

Episode finished after 200 timesteps
