In [117]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.optim import AdamW
import numpy as np
import random
from collections import deque
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader
import warnings
import gym
from gym.spaces import Discrete, Box
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt

In [118]:
env_name = 'LunarLander-v2'

In [119]:
class Data(IterableDataset):
    def __init__(self, env, policy, steps, gamma):
        self.env = env
        self.policy = policy
        self.steps = steps
        self.gamma = gamma
        self.obs, self.info = env.reset()


    def __iter__(self):
        transitions = []

        for step in range(self.steps):
            with torch.no_grad():
                action = self.policy(torch.as_tensor(self.obs, dtype=torch.float32))
            action = action.multinomial(1).cpu().numpy()
            next_obs, reward, terminate, truncate, info = self.env.step(action.flatten())
            transitions.append((self.obs, action, reward, terminate))
            self.obs = next_obs

        obs_b, action_b, reward_b, done_b = map(np.stack, zip(*transitions))

        running_return = np.zeros(self.env.num_envs, dtype=np.float32)
        return_b = np.zeros_like(reward_b)

        for row in range(self.steps-1,-1,-1):
            running_return = reward_b[row] + (1-done_b[row]) * self.gamma * running_return
            return_b[row] = running_return

        num_samples = self.env.num_envs * self.steps
        obs_b = obs_b.reshape(num_samples, -1)
        action_b = action_b.reshape(num_samples, -1)
        return_b = return_b.reshape(num_samples, -1)

        return_b = (return_b - np.mean(return_b)) / np.std(return_b + 1e-06)

        idx = list(range(num_samples))
        random.shuffle(idx)

        for i in idx:
            yield obs_b[i], action_b[i], return_b[i]

In [120]:
class PolicyNet(nn.Module):
    def __init__(self, input_size, hidden_units, output_size):
        super(PolicyNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, output_size),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        probs = self.model(x)
        return probs

    def __call__(self, x):
        out = self.forward(x)
        return out

In [121]:
def initialize_weights(m):
  if isinstance(m, nn.Conv2d):
      nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.BatchNorm2d):
      nn.init.constant_(m.weight.data, 1)
      nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.Linear):
      nn.init.kaiming_uniform_(m.weight.data)
      nn.init.constant_(m.bias.data, 0)

In [122]:
env = gym.vector.make(env_name, num_envs=5, asynchronous=False)

obs_dim = env.single_observation_space.shape[0]
n_acts = env.single_action_space.n

# make core of policy network
hidden_sizes = 32
logits_net = PolicyNet(obs_dim, hidden_sizes, n_acts)
logits_net.apply(initialize_weights)
# make optimizer
optimizer = AdamW(logits_net.parameters(), lr=0.0003)
gamma = 0.99

In [123]:
# make function to compute action distribution
def get_policy(obs):
    probs = logits_net(obs)
    return Categorical(probs=probs)


# make action selection function (outputs int actions, sampled from policy)
def get_action(obs):
    return get_policy(obs).sample().item()


# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    probs = logits_net(obs)
    log_probs = torch.log(probs + 1e-6)
    action_log_prob = log_probs.gather(1, act)
    return -(action_log_prob * weights).mean()

In [124]:
data = Data(env, logits_net, 256 , 0.99)
loader = DataLoader(data, batch_size=10)

In [125]:
def train_one_epoch():
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient

    # collect experience by acting in the environment with current policy
    for batch in loader:
        with torch.no_grad():
            batch_obs, batch_acts, batch_weights = batch

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=batch_obs, act=batch_acts, weights=batch_weights)
        batch_loss.backward()
        optimizer.step()
    return 0

In [126]:
final_scores=[]
scores = []
def run_test(trajectories, policy):
    scores = deque(maxlen=50)
    env2 = gym.make(env_name)
    for trajectory in range(trajectories):
        trajectory_return = 0
        obs, info = env2.reset()
        truncate, terminate = False, False
        while not terminate and not truncate:
            action = get_action(torch.as_tensor(obs, dtype=torch.float32))
            next_obs, reward, terminate, truncate, _ = env2.step(action)
            obs = next_obs
            trajectory_return += reward
        scores.append(trajectory_return)
    final_scores.append(np.mean(scores))
    del env2
    return np.mean(scores)

In [127]:
for epoch in range(1000):
    train_one_epoch()
    score = run_test(5, logits_net)
    if epoch>1 and epoch % 50 == 0:
        print(f'Episode: {epoch}  Average Score: {score}' )
    if score >= 195:
        print(f'Solved! Episode: {epoch} Average Score: {score}')
        break

Episode: 50  Average Score: -108.02601847984656
Episode: 100  Average Score: 5.604392287029273
Episode: 150  Average Score: -43.639536650226596
Episode: 200  Average Score: 127.23838553842316
Episode: 250  Average Score: 95.11367991144292
Solved! Episode: 275 Average Score: 197.3313038308102


In [1]:
fig = plt.figure(figsize=(9, 9))
ax = fig.add_subplot(111)
plt.plot(np.arange(len(final_scores)),final_scores, marker='.')
plt.title('Policy Gradient Neural Net (PGNN) Parallel')
plt.ylabel('Rewards')
plt.xlabel('Episodes')
plt.show()

NameError: name 'plt' is not defined