In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box
from tqdm import trange
from time import sleep

In [2]:
def get_mlp(sizes, activation=nn.ReLU):
    layers = []
    for i in range(len(sizes) - 1):
        layers += [nn.Linear(sizes[i], sizes[i + 1]), activation()]
    return nn.Sequential(*layers)

In [3]:
def get_action(policy, obs):
    return Categorical(logits=policy(obs)).sample().item()

In [4]:
def get_loss(policy, obs, act, weights):
    log_prob = Categorical(logits=policy(obs)).log_prob(act)
    return -(log_prob * weights).mean()

In [5]:
def get_reward_to_go(rewards):
    rewards_to_go = np.zeros(len(rewards) + 1)
    for i in reversed(range(len(rewards))):
        rewards_to_go[i] = rewards[i] + rewards_to_go[i + 1]
    return list(rewards_to_go[:-1])

In [6]:
def train(hidden_sizes, env_name, epochs, batch_size, lr=1e-2, activation=nn.ReLU, render=True, rewards_to_go=True):
    env = gym.make(env_name)

    assert isinstance(env.observation_space, Box), \
        "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), \
        "This example only works for envs with discrete action spaces."

    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n
    
    logits_net = get_mlp([obs_dim] + hidden_sizes + [n_acts], activation)
    optimizer = Adam(logits_net.parameters(), lr=lr)
    
    pbar = trange(epochs, unit="epochs")
    
    for i in pbar:
        pbar.set_description(f"Epoch {i}")
        
        batch_obs = []
        batch_acts = []
        batch_rets = []
        batch_lens = []
        batch_weights = []
        episode_rew = []
        
        done = False
        finished_rendering_this_epoch= False
        
        obs = env.reset()
        
        while True:
            if (not finished_rendering_this_epoch) and render:
                env.render()

            batch_obs.append(obs.copy())
            
            act = get_action(logits_net, torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _ = env.step(act)
            batch_acts.append(act)
            
            episode_rew.append(rew)
            
            if done:
                episode_rew_sum, episode_len = sum(episode_rew), len(episode_rew)
                
                batch_rets.append(episode_rew_sum)
                batch_lens.append(episode_len)
                batch_weights += get_reward_to_go(episode_rew) if rewards_to_go else [episode_rew_sum] * episode_len
                
                finished_rendering_this_epoch = True
                obs, done, episode_rew = env.reset(), False, []
                
                if len(batch_obs) > batch_size:
                    break

        optimizer.zero_grad()
        batch_loss = get_loss(logits_net, torch.as_tensor(batch_obs, dtype=torch.float32),
                                          torch.as_tensor(batch_acts, dtype=torch.int32),
                                          torch.as_tensor(batch_weights, dtype=torch.float32))
        batch_loss.backward()
        optimizer.step()
        
        pbar.set_postfix(policy_grad=-batch_loss.item(), avg_ep_rew=np.mean(batch_rets), avg_ep_len = np.mean(batch_lens))
        sleep(0.1)
        
    pbar.close()
    return logits_net

In [8]:
logits_net = train([32], 'MountainCar-v0', 100, 5000, render=True, rewards_to_go=True)

Epoch 99: 100%|██████████| 100/100 [08:48<00:00,  5.29s/epochs, avg_ep_len=200, avg_ep_rew=-200, policy_grad=110]


In [None]:
logits_net = train([32], 'Acrobot-v1', 100, 5000, render=True, rewards_to_go=True)