In [1]:
! pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
[K     |█                               | 10kB 7.6MB/s eta 0:00:01[K     |██▏                             | 20kB 3.8MB/s eta 0:00:01[K     |███▏                            | 30kB 5.0MB/s eta 0:00:01[K     |████▎                           | 40kB 4.4MB/s eta 0:00:01[K     |█████▎                          | 51kB 5.4MB/s eta 0:00:01[K     |██████▍                         | 61kB 6.3MB/s eta 0:00:01[K     |███████▍                        | 71kB 6.6MB/s eta 0:00:01[K     |████████▌                       | 81kB 7.1MB/s eta 0:00:01[K     |█████████▌                      | 92kB 7.4MB/s eta 0:00:01[K     |██████████▋                     | 102kB 7.8MB/s eta 0:00:01[K     |███████████▊                    | 112kB 7.8MB/s eta 0:00:01[K     |████████████▊                   | 122kB 7.8

In [2]:
import random
import gym
import gym.spaces
from collections import namedtuple
import numpy as np
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
HIDDEN_SIZE = 128
BATCH_SIZE = 100
PERCENTILE = 30
GAMMA = 0.9

In [4]:
class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n, ), dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res

In [5]:
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)

Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

In [6]:
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    sm = nn.Softmax(dim=1)
    while True:
        obs_v = torch.FloatTensor([obs])
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, _ = env.step(action)
        episode_reward += reward
        episode_steps.append(EpisodeStep(observation=obs, action=action))
        if is_done:
            batch.append(Episode(reward=episode_reward, steps=episode_steps))
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

In [7]:
def filter_batch(batch, percentile):
    filter_fun = lambda s: s.reward * (GAMMA ** len(s.steps))
    disc_rewards = list(map(filter_fun, batch))
    reward_bound = np.percentile(disc_rewards, percentile)

    train_obs = []
    train_act = []
    elite_batch = []
    for example, discounted_reward in zip(batch, disc_rewards):
        if discounted_reward > reward_bound:
            train_obs.extend(map(lambda step: step.observation,
                                 example.steps))
            train_act.extend(map(lambda step: step.action,
                                 example.steps))
            elite_batch.append(example)

    return elite_batch, train_obs, train_act, reward_bound

In [8]:
random.seed(12345)
env = DiscreteOneHotWrapper(gym.make("FrozenLake-v0"))
# env = gym.wrappers.Monitor(env, directory="mon", force=True)
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Net(obs_size, HIDDEN_SIZE, n_actions)
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.001)
writer = SummaryWriter(comment="-frozenlake-tweaked")

In [10]:
full_batch = []
for iter_no, batch in enumerate(iterate_batches(
        env, net, BATCH_SIZE)):
    reward_mean = float(np.mean(list(map(
        lambda s: s.reward, batch))))
    full_batch, obs, acts, reward_bound = \
        filter_batch(full_batch + batch, PERCENTILE)
    if not full_batch:
        continue
    obs_v = torch.FloatTensor(obs)
    acts_v = torch.LongTensor(acts)
    full_batch = full_batch[-500:]

    optimizer.zero_grad()
    action_scores_v = net(obs_v)
    loss_v = objective(action_scores_v, acts_v)
    loss_v.backward()
    optimizer.step()
    print("%d: loss=%.3f, rw_mean=%.3f, "
          "rw_bound=%.3f, batch=%d" % (
        iter_no, loss_v.item(), reward_mean,
        reward_bound, len(full_batch)))
    writer.add_scalar("loss", loss_v.item(), iter_no)
    writer.add_scalar("reward_mean", reward_mean, iter_no)
    writer.add_scalar("reward_bound", reward_bound, iter_no)
    if reward_mean > 0.8:
        print("Solved!")
        break
writer.close()

0: loss=0.946, rw_mean=0.050, rw_bound=0.000, batch=5
1: loss=0.964, rw_mean=0.080, rw_bound=0.000, batch=13
2: loss=0.995, rw_mean=0.090, rw_bound=0.000, batch=22
3: loss=0.982, rw_mean=0.040, rw_bound=0.000, batch=26
4: loss=0.969, rw_mean=0.040, rw_bound=0.000, batch=30
5: loss=0.956, rw_mean=0.070, rw_bound=0.000, batch=37
6: loss=0.953, rw_mean=0.040, rw_bound=0.000, batch=41
7: loss=0.968, rw_mean=0.050, rw_bound=0.000, batch=46
8: loss=0.980, rw_mean=0.090, rw_bound=0.000, batch=55
9: loss=0.985, rw_mean=0.070, rw_bound=0.000, batch=62
10: loss=0.993, rw_mean=0.110, rw_bound=0.000, batch=73
11: loss=0.999, rw_mean=0.060, rw_bound=0.000, batch=79
12: loss=0.993, rw_mean=0.090, rw_bound=0.000, batch=88
13: loss=0.982, rw_mean=0.090, rw_bound=0.000, batch=97
14: loss=0.982, rw_mean=0.090, rw_bound=0.000, batch=106
15: loss=0.973, rw_mean=0.120, rw_bound=0.000, batch=118
16: loss=0.970, rw_mean=0.070, rw_bound=0.000, batch=125
17: loss=0.973, rw_mean=0.090, rw_bound=0.000, batch=134

KeyboardInterrupt: ignored