In [2]:
import random
import gym
import gym.spaces
from collections import namedtuple
import numpy as np
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim


HIDDEN_SIZE = 128
BATCH_SIZE = 100
PERCENTILE = 30
GAMMA = 0.9


class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n, ), dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res


class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)


Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])


def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    sm = nn.Softmax(dim=1)
    while True:
        obs_v = torch.FloatTensor(torch.tensor(obs).unsqueeze(0))
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, _ = env.step(action)
        episode_reward += reward
        episode_steps.append(EpisodeStep(observation=obs, action=action))
        if is_done:
            batch.append(Episode(reward=episode_reward, steps=episode_steps))
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs


def filter_batch(batch, percentile):
    disc_rewards = list(map(lambda s: s.reward * (GAMMA ** len(s.steps)), batch))
    reward_bound = np.percentile(disc_rewards, percentile)

    train_obs = []
    train_act = []
    elite_batch = []
    for example, discounted_reward in zip(batch, disc_rewards):
        if discounted_reward > reward_bound:
            train_obs.extend(map(lambda step: step.observation, example.steps))
            train_act.extend(map(lambda step: step.action, example.steps))
            elite_batch.append(example)

    return elite_batch, train_obs, train_act, reward_bound


if __name__ == "__main__":
    random.seed(12345)
    env = DiscreteOneHotWrapper(gym.make("FrozenLake-v1"))
    # env = gym.wrappers.Monitor(env, directory="mon", force=True)
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    net = Net(obs_size, HIDDEN_SIZE, n_actions)
    objective = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=net.parameters(), lr=0.001)
    writer = SummaryWriter(comment="-frozenlake-tweaked")

    full_batch = []
    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        reward_mean = float(np.mean(list(map(lambda s: s.reward, batch))))
        full_batch, obs, acts, reward_bound = filter_batch(full_batch + batch, PERCENTILE)
        if not full_batch:
            continue
        obs_v = torch.FloatTensor(obs)
        acts_v = torch.LongTensor(acts)
        full_batch = full_batch[-500:]

        optimizer.zero_grad()
        action_scores_v = net(obs_v)
        loss_v = objective(action_scores_v, acts_v)
        loss_v.backward()
        optimizer.step()
        print("%d: loss=%.3f, reward_mean=%.3f, reward_bound=%.3f, batch=%d" % (
            iter_no, loss_v.item(), reward_mean, reward_bound, len(full_batch)))
        writer.add_scalar("loss", loss_v.item(), iter_no)
        writer.add_scalar("reward_mean", reward_mean, iter_no)
        writer.add_scalar("reward_bound", reward_bound, iter_no)
        if reward_mean > 0.8:
            print("Solved!")
            break
    writer.close()

  obs_v = torch.FloatTensor(obs)


0: loss=1.379, reward_mean=0.030, reward_bound=0.000, batch=3
1: loss=1.379, reward_mean=0.010, reward_bound=0.000, batch=4
2: loss=1.378, reward_mean=0.010, reward_bound=0.000, batch=5
3: loss=1.385, reward_mean=0.050, reward_bound=0.000, batch=10
4: loss=1.381, reward_mean=0.010, reward_bound=0.000, batch=11
5: loss=1.381, reward_mean=0.030, reward_bound=0.000, batch=14
6: loss=1.378, reward_mean=0.040, reward_bound=0.000, batch=18
7: loss=1.378, reward_mean=0.020, reward_bound=0.000, batch=20
8: loss=1.376, reward_mean=0.000, reward_bound=0.000, batch=20
9: loss=1.375, reward_mean=0.000, reward_bound=0.000, batch=20
10: loss=1.374, reward_mean=0.010, reward_bound=0.000, batch=21
11: loss=1.372, reward_mean=0.020, reward_bound=0.000, batch=23
12: loss=1.370, reward_mean=0.020, reward_bound=0.000, batch=25
13: loss=1.369, reward_mean=0.010, reward_bound=0.000, batch=26
14: loss=1.368, reward_mean=0.000, reward_bound=0.000, batch=26
15: loss=1.367, reward_mean=0.010, reward_bound=0.000

127: loss=1.252, reward_mean=0.020, reward_bound=0.206, batch=230
128: loss=1.242, reward_mean=0.060, reward_bound=0.229, batch=217
129: loss=1.243, reward_mean=0.010, reward_bound=0.000, batch=218
130: loss=1.242, reward_mean=0.050, reward_bound=0.049, batch=222
131: loss=1.241, reward_mean=0.030, reward_bound=0.056, batch=225
132: loss=1.241, reward_mean=0.020, reward_bound=0.037, batch=227
133: loss=1.241, reward_mean=0.050, reward_bound=0.240, batch=229
134: loss=1.237, reward_mean=0.030, reward_bound=0.254, batch=206
135: loss=1.236, reward_mean=0.020, reward_bound=0.000, batch=208
136: loss=1.238, reward_mean=0.030, reward_bound=0.000, batch=211
137: loss=1.236, reward_mean=0.030, reward_bound=0.000, batch=214
138: loss=1.234, reward_mean=0.050, reward_bound=0.000, batch=219
139: loss=1.230, reward_mean=0.060, reward_bound=0.103, batch=223
140: loss=1.229, reward_mean=0.050, reward_bound=0.144, batch=226
141: loss=1.227, reward_mean=0.030, reward_bound=0.158, batch=228
142: loss=

252: loss=1.129, reward_mean=0.070, reward_bound=0.349, batch=213
253: loss=1.128, reward_mean=0.030, reward_bound=0.000, batch=216
254: loss=1.126, reward_mean=0.020, reward_bound=0.000, batch=218
255: loss=1.126, reward_mean=0.060, reward_bound=0.161, batch=222
256: loss=1.127, reward_mean=0.060, reward_bound=0.254, batch=224
257: loss=1.125, reward_mean=0.080, reward_bound=0.282, batch=226
258: loss=1.123, reward_mean=0.050, reward_bound=0.298, batch=228
259: loss=1.122, reward_mean=0.050, reward_bound=0.353, batch=229
260: loss=1.122, reward_mean=0.070, reward_bound=0.364, batch=230
261: loss=1.129, reward_mean=0.070, reward_bound=0.387, batch=209
262: loss=1.126, reward_mean=0.070, reward_bound=0.003, batch=216
263: loss=1.129, reward_mean=0.040, reward_bound=0.000, batch=220
264: loss=1.129, reward_mean=0.040, reward_bound=0.006, batch=224
265: loss=1.126, reward_mean=0.040, reward_bound=0.122, batch=227
266: loss=1.128, reward_mean=0.060, reward_bound=0.182, batch=229
267: loss=

377: loss=1.049, reward_mean=0.100, reward_bound=0.314, batch=224
378: loss=1.051, reward_mean=0.070, reward_bound=0.349, batch=219
379: loss=1.049, reward_mean=0.120, reward_bound=0.364, batch=223
380: loss=1.049, reward_mean=0.070, reward_bound=0.229, batch=225
381: loss=1.053, reward_mean=0.050, reward_bound=0.303, batch=227
382: loss=1.052, reward_mean=0.070, reward_bound=0.302, batch=229
383: loss=1.053, reward_mean=0.050, reward_bound=0.349, batch=229
384: loss=1.055, reward_mean=0.140, reward_bound=0.387, batch=222
385: loss=1.053, reward_mean=0.080, reward_bound=0.283, batch=225
386: loss=1.053, reward_mean=0.080, reward_bound=0.321, batch=227
387: loss=1.054, reward_mean=0.120, reward_bound=0.430, batch=219
388: loss=1.052, reward_mean=0.060, reward_bound=0.239, batch=223
389: loss=1.051, reward_mean=0.070, reward_bound=0.254, batch=223
390: loss=1.052, reward_mean=0.040, reward_bound=0.211, batch=226
391: loss=1.053, reward_mean=0.120, reward_bound=0.282, batch=227
392: loss=

502: loss=0.979, reward_mean=0.080, reward_bound=0.418, batch=231
503: loss=0.979, reward_mean=0.080, reward_bound=0.387, batch=231
504: loss=0.979, reward_mean=0.060, reward_bound=0.387, batch=231
505: loss=0.981, reward_mean=0.100, reward_bound=0.430, batch=172
506: loss=0.974, reward_mean=0.060, reward_bound=0.000, batch=178
507: loss=0.973, reward_mean=0.060, reward_bound=0.000, batch=184
508: loss=0.976, reward_mean=0.060, reward_bound=0.000, batch=190
509: loss=0.981, reward_mean=0.020, reward_bound=0.000, batch=192
510: loss=0.983, reward_mean=0.080, reward_bound=0.000, batch=200
511: loss=0.984, reward_mean=0.090, reward_bound=0.000, batch=209
512: loss=0.984, reward_mean=0.060, reward_bound=0.000, batch=215
513: loss=0.978, reward_mean=0.060, reward_bound=0.057, batch=220
514: loss=0.977, reward_mean=0.100, reward_bound=0.162, batch=224
515: loss=0.974, reward_mean=0.080, reward_bound=0.165, batch=227
516: loss=0.979, reward_mean=0.050, reward_bound=0.167, batch=228
517: loss=

627: loss=0.983, reward_mean=0.070, reward_bound=0.349, batch=226
628: loss=0.982, reward_mean=0.100, reward_bound=0.351, batch=228
629: loss=0.977, reward_mean=0.060, reward_bound=0.387, batch=225
630: loss=0.979, reward_mean=0.120, reward_bound=0.365, batch=227
631: loss=0.980, reward_mean=0.050, reward_bound=0.263, batch=229
632: loss=0.979, reward_mean=0.070, reward_bound=0.405, batch=230
633: loss=0.981, reward_mean=0.090, reward_bound=0.430, batch=200
634: loss=0.977, reward_mean=0.080, reward_bound=0.000, batch=208
635: loss=0.978, reward_mean=0.060, reward_bound=0.000, batch=214
636: loss=0.977, reward_mean=0.090, reward_bound=0.108, batch=220
637: loss=0.978, reward_mean=0.070, reward_bound=0.157, batch=224
638: loss=0.975, reward_mean=0.090, reward_bound=0.226, batch=227
639: loss=0.974, reward_mean=0.070, reward_bound=0.229, batch=226
640: loss=0.974, reward_mean=0.030, reward_bound=0.254, batch=226
641: loss=0.975, reward_mean=0.070, reward_bound=0.268, batch=228
642: loss=

KeyboardInterrupt: 