In [10]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box
from collections import OrderedDict

In [19]:
# Hyperparameter
hidden_size=32
lr=1e-2
epochs=2
batch_size=5000
render=True
use_cuda=False

output_path_model='./models'
save_epochs = 10

In [20]:
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

In [21]:
# Environment
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
n_acts = env.action_space.n

In [23]:
# Init action network
actor = Actor(obs_dim,hidden_size,n_acts).to(device)

# make optimizer
optimizer = Adam(actor.logits_net.parameters(), lr=lr)

In [24]:
# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    logp = actor._log_prob_from_distribution(obs,act)
    return -(logp * weights).mean()

In [25]:
# One epoch training 
def train_one_epoch():
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths

    # reset episode-specific variables
    obs = env.reset()       # first obs comes from starting distribution
    done = False            # signal from environment that episode is over
    ep_rews = []            # list for rewards accrued throughout ep

    # render first episode of each epoch
    finished_rendering_this_epoch = False

    # collect experience by acting in the environment with current policy
    while True:
        # rendering
        if (not finished_rendering_this_epoch) and render:
            env.render()

        # save obs
        batch_obs.append(obs.copy())

        # act in the environment
        obs_cuda= torch.as_tensor(obs, dtype=torch.float32,device=device)
        act = actor.step(obs_cuda)
        # Step then get state, reward and if done
        obs, rew, done, _ = env.step(act)

        # save action, reward
        batch_acts.append(act.item())
        ep_rews.append(rew)

        if done:
            # if episode is over, record info about episode
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)

            # the weight for each logprob(a|s) == R(episode)
            batch_weights += [ep_ret] * ep_len

            # reset episode-specific variables
            obs, done, ep_rews = env.reset(), False, []

            # won't render again this epoch
            finished_rendering_this_epoch = True

            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
        env.close()

    # take a single policy gradient update step
    optimizer.zero_grad()
    batch_obs = torch.as_tensor(batch_obs, dtype=torch.float32,device =device)
    batch_acts = torch.as_tensor(batch_acts, dtype=torch.float32,device=device)
    weights = torch.as_tensor(batch_weights, dtype=torch.float32,device =device)

    batch_loss = compute_loss(obs=batch_obs,
                              act=batch_acts,
                              weights=weights
                              )
    batch_loss.backward()
    optimizer.step()
    
    # 
    # Save model ever N epochs
    if (epoch % save_epochs == 0):
        torch.save(
                {
                    "actor": actor.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                os.path.join(output_path_model, "simple_pg_model.pth"))
            
    return batch_loss, batch_rets, batch_lens

In [18]:
# training loop
batch_avg_len=[]
batch_avg_return=[]

for i in range(epochs):
    batch_loss, batch_rets, batch_lens = train_one_epoch()
    batch_avg_len.append(np.mean(batch_lens))
    batch_avg_return.append(np.mean(batch_avg_return))
    print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
            (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))

epoch:   0 	 loss: 17.328 	 return: 19.952 	 ep_len: 19.952
epoch:   1 	 loss: 19.206 	 return: 21.765 	 ep_len: 21.765
