In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.distributions.normal import Normal
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box

In [2]:
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    # Build a feedforward neural network.
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

In [19]:
env_name='CartPole-v0'
hidden_sizes=[32]
lr=1e-2
epochs=500
batch_size=5000
render=False

# make environment, check spaces, get obs / act dims
env = gym.make(env_name)
assert isinstance(env.observation_space, Box), \
    "This example only works for envs with continuous state spaces."
assert isinstance(env.action_space, Discrete), \
    "This example only works for envs with discrete action spaces."

obs_dim = env.observation_space.shape[0]
n_acts = env.action_space.n

# make core of policy network
logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])

# make function to compute action distribution
def get_policy(obs):
    logits = logits_net(obs)
    return Categorical(logits=logits)

# make action selection function (outputs int actions, sampled from policy)
def get_action(obs):
    return get_policy(obs).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    logp = get_policy(obs).log_prob(act)
    return -(logp * weights).mean()

# make optimizer
optimizer = Adam(logits_net.parameters(), lr=lr)


In [20]:
# for training policy
def train_one_epoch():
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # render first episode of each epoch
    finished_rendering_this_epoch = False

    # collect experience by acting in the environment with current policy
    while True:

        # rendering
        if (not finished_rendering_this_epoch) and render:
            env.render()

        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _ = env.step(act)

        # save action, reward
        batch_acts.append(act)
        ep_rews.append(rew)
        
        if done:
            
            #max_x = max(list(zip(*batch_obs))[0])
            #print (batch_obs[0], len(batch_obs))
            #print (max_x)
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            
            #if max_x > -0.2:
            #    ep_ret = 1
            
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)

            # the weight for each logprob(a|s) is R(tau)
            batch_weights += [ep_ret] * ep_len
            #batch_weights += 
            
            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # won't render again this epoch
            finished_rendering_this_epoch = True

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break

    #print (len(batch_obs) , len(batch_acts))
    #print (batch_obs[0], batch_acts[0])
    # take a single policy gradient update step
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                              )
    batch_loss.backward()
    optimizer.step()
    return batch_loss, batch_rets, batch_lens, batch_obs

In [21]:
batch_loss, batch_rets, batch_lens, batch_obs = train_one_epoch()

In [22]:
# training loop
logs = []
for i in range(epochs):
    batch_loss, batch_rets, batch_lens,_ = train_one_epoch()
    logs.append((batch_loss, batch_rets, batch_lens))
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
            (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))

epoch:   0 	 loss: 17.527 	 return: 20.008 	 ep_len: 20.008
epoch:   1 	 loss: 19.639 	 return: 22.326 	 ep_len: 22.326
epoch:   2 	 loss: 21.569 	 return: 24.757 	 ep_len: 24.757
epoch:   3 	 loss: 24.799 	 return: 29.512 	 ep_len: 29.512
epoch:   4 	 loss: 28.229 	 return: 33.480 	 ep_len: 33.480
epoch:   5 	 loss: 30.426 	 return: 36.667 	 ep_len: 36.667
epoch:   6 	 loss: 35.599 	 return: 42.613 	 ep_len: 42.613
epoch:   7 	 loss: 30.436 	 return: 38.114 	 ep_len: 38.114
epoch:   8 	 loss: 34.329 	 return: 44.009 	 ep_len: 44.009
epoch:   9 	 loss: 44.061 	 return: 52.811 	 ep_len: 52.811
epoch:  10 	 loss: 42.245 	 return: 55.022 	 ep_len: 55.022
epoch:  11 	 loss: 44.553 	 return: 56.920 	 ep_len: 56.920
epoch:  12 	 loss: 42.712 	 return: 59.318 	 ep_len: 59.318
epoch:  13 	 loss: 49.204 	 return: 67.573 	 ep_len: 67.573
epoch:  14 	 loss: 44.268 	 return: 64.205 	 ep_len: 64.205
epoch:  15 	 loss: 46.768 	 return: 67.770 	 ep_len: 67.770
epoch:  16 	 loss: 48.568 	 return: 70.8

epoch: 133 	 loss: 95.401 	 return: 200.000 	 ep_len: 200.000
epoch: 134 	 loss: 97.923 	 return: 200.000 	 ep_len: 200.000
epoch: 135 	 loss: 96.441 	 return: 200.000 	 ep_len: 200.000
epoch: 136 	 loss: 95.791 	 return: 200.000 	 ep_len: 200.000
epoch: 137 	 loss: 97.099 	 return: 200.000 	 ep_len: 200.000
epoch: 138 	 loss: 96.288 	 return: 200.000 	 ep_len: 200.000
epoch: 139 	 loss: 98.436 	 return: 200.000 	 ep_len: 200.000
epoch: 140 	 loss: 96.416 	 return: 200.000 	 ep_len: 200.000
epoch: 141 	 loss: 96.994 	 return: 200.000 	 ep_len: 200.000
epoch: 142 	 loss: 98.508 	 return: 200.000 	 ep_len: 200.000
epoch: 143 	 loss: 97.955 	 return: 200.000 	 ep_len: 200.000
epoch: 144 	 loss: 97.948 	 return: 200.000 	 ep_len: 200.000
epoch: 145 	 loss: 98.229 	 return: 200.000 	 ep_len: 200.000
epoch: 146 	 loss: 99.813 	 return: 200.000 	 ep_len: 200.000
epoch: 147 	 loss: 98.970 	 return: 200.000 	 ep_len: 200.000
epoch: 148 	 loss: 101.252 	 return: 200.000 	 ep_len: 200.000
epoch: 

epoch: 264 	 loss: 90.768 	 return: 200.000 	 ep_len: 200.000
epoch: 265 	 loss: 90.539 	 return: 200.000 	 ep_len: 200.000
epoch: 266 	 loss: 90.336 	 return: 200.000 	 ep_len: 200.000
epoch: 267 	 loss: 89.556 	 return: 200.000 	 ep_len: 200.000
epoch: 268 	 loss: 90.468 	 return: 200.000 	 ep_len: 200.000
epoch: 269 	 loss: 89.792 	 return: 200.000 	 ep_len: 200.000
epoch: 270 	 loss: 88.002 	 return: 200.000 	 ep_len: 200.000
epoch: 271 	 loss: 87.937 	 return: 200.000 	 ep_len: 200.000
epoch: 272 	 loss: 88.799 	 return: 200.000 	 ep_len: 200.000
epoch: 273 	 loss: 88.190 	 return: 200.000 	 ep_len: 200.000
epoch: 274 	 loss: 88.062 	 return: 200.000 	 ep_len: 200.000
epoch: 275 	 loss: 88.551 	 return: 200.000 	 ep_len: 200.000
epoch: 276 	 loss: 86.425 	 return: 200.000 	 ep_len: 200.000
epoch: 277 	 loss: 89.508 	 return: 200.000 	 ep_len: 200.000
epoch: 278 	 loss: 86.008 	 return: 200.000 	 ep_len: 200.000
epoch: 279 	 loss: 88.257 	 return: 200.000 	 ep_len: 200.000
epoch: 2

epoch: 397 	 loss: 68.654 	 return: 200.000 	 ep_len: 200.000
epoch: 398 	 loss: 68.755 	 return: 200.000 	 ep_len: 200.000
epoch: 399 	 loss: 68.802 	 return: 200.000 	 ep_len: 200.000
epoch: 400 	 loss: 69.787 	 return: 200.000 	 ep_len: 200.000
epoch: 401 	 loss: 69.084 	 return: 200.000 	 ep_len: 200.000
epoch: 402 	 loss: 71.099 	 return: 200.000 	 ep_len: 200.000
epoch: 403 	 loss: 70.223 	 return: 200.000 	 ep_len: 200.000
epoch: 404 	 loss: 70.069 	 return: 200.000 	 ep_len: 200.000
epoch: 405 	 loss: 70.192 	 return: 199.962 	 ep_len: 199.962
epoch: 406 	 loss: 70.917 	 return: 199.000 	 ep_len: 199.000
epoch: 407 	 loss: 68.713 	 return: 199.115 	 ep_len: 199.115
epoch: 408 	 loss: 68.126 	 return: 200.000 	 ep_len: 200.000
epoch: 409 	 loss: 68.849 	 return: 199.769 	 ep_len: 199.769
epoch: 410 	 loss: 69.096 	 return: 200.000 	 ep_len: 200.000
epoch: 411 	 loss: 70.108 	 return: 200.000 	 ep_len: 200.000
epoch: 412 	 loss: 70.508 	 return: 199.962 	 ep_len: 199.962
epoch: 4

In [None]:
print (len(logs), logs[0][0], len(logs[0][1]), len(logs[0][2]))

In [None]:
returns = list(zip(*logs))[1]

In [None]:
import matplotlib.pyplot as plt

In [None]:
mean_returns = [np.mean(x) for x in returns]

In [None]:
print (mean_returns)

In [None]:
plt.plot(mean_returns)