In [12]:
# constants

import torch
import torch.nn as nn
import numpy as np
from collections import namedtuple
import gym
from tensorboard import SummaryWriter

BATCH_SIZE = 16
PERCENTILE = 70
HIDDEN_SIZE = 128

In [13]:
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size), 
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )
    
    def forward(self, x):
        return self.net(x)


In [14]:
Episode = namedtuple("Episode", field_names = ["reward", "steps"])
EpisodeStep = namedtuple("EpisodeStep", field_names=["action", "observation"])

In [15]:
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    sm = nn.Softmax(dim = 1)

    while True:
        obs_v = torch.FloatTensor([obs])
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, _ = env.step(action)

        episode_reward += reward
        episode_steps.append(EpisodeStep(observation=obs, action=action))

        if is_done:
            batch.append(Episode(reward=episode_reward, steps=episode_steps))
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        
        obs = next_obs


In [16]:
def filter_batch(batch, percentile):
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))

    train_obs = []
    train_act = []
    for example in batch:
        if example.reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, example.steps))
        train_act.extend(map(lambda step: step.action, example.steps))

        train_obs_v = torch.FloatTensor(train_obs)
        train_act_v = torch.LongTensor(train_act)
    return train_obs_v, train_act_v, reward_bound, reward_mean

In [17]:
if __name__ == "__main__":
    env = gym.make("CartPole-v0")
    # env = gym.wrappers.Monitor(env, directory="mon", force=True)
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    net = Net(obs_size, HIDDEN_SIZE, n_actions)
    objective = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=net.parameters(), lr=0.01)
    writer = SummaryWriter()

    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)
        optimizer.zero_grad()
        action_scores_v = net(obs_v)
        loss_v = objective(action_scores_v, acts_v)
        loss_v.backward()
        optimizer.step()
        print("%d: loss=%.3f, reward_mean=%.1f, reward_bound=%.1f" % (
            iter_no, loss_v.item(), reward_m, reward_b))
        writer.add_scalar("loss", loss_v.item(), iter_no)
        writer.add_scalar("reward_bound", reward_b, iter_no)
        writer.add_scalar("reward_mean", reward_m, iter_no)
        if reward_m > 199: 
            print("Solved!")
            break
    writer.close()




[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
0: loss=0.695, reward_mean=16.3, reward_bound=17.0
1: loss=0.688, reward_mean=26.5, reward_bound=28.5
2: loss=0.669, reward_mean=32.4, reward_bound=35.5
3: loss=0.660, reward_mean=28.4, reward_bound=32.5
4: loss=0.662, reward_mean=32.9, reward_bound=43.5
5: loss=0.638, reward_mean=36.8, reward_bound=43.5
6: loss=0.621, reward_mean=51.7, reward_bound=58.0
7: loss=0.619, reward_mean=43.2, reward_bound=49.0
8: loss=0.620, reward_mean=53.4, reward_bound=60.0
9: loss=0.589, reward_mean=53.8, reward_bound=62.5
10: loss=0.586, reward_mean=44.1, reward_bound=52.5
11: loss=0.599, reward_mean=52.1, reward_bound=52.0
12: loss=0.594, reward_mean=49.0, reward_bound=54.5
13: loss=0.610, reward_mean=58.0, reward_bound=56.5
14: loss=0.581, reward_mean=72.6, reward_bound=72.5
15: loss=0.592, reward_mean=62.2, reward_bound=73.5
16: loss=0.575, reward_mean=63.9, reward_bound=71.0
17: loss=0.586, re