In [11]:
import gym
from collections import namedtuple
import numpy as np
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim

In [12]:
HIDDEN_SIZE = 128
BATCH_SIZE = 100
PERCENTILE = 30
GAMMA = 0.9

Episode = namedtuple('Episode', field_names=['reward', 'steps'])
Step = namedtuple('Step', field_names=['observation', 'action'])

In [13]:
class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n,), dtype=np.float32)
        
    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res

In [18]:
class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )
        
    def forward(self, x):
        return self.net(x)

In [31]:
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    sm = nn.Softmax(dim=1)
    while True:
        obs_v = torch.FloatTensor([obs]).to(torch.device('cuda'))
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.cpu().data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, _ = env.step(action)
        episode_reward += reward
        episode_steps.append(Step(observation=obs, action=action))
        if is_done:
            batch.append(Episode(reward=episode_reward, steps=episode_steps))
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

In [32]:
def filter_batch(batch, percentile):
    disc_rewards = list(map(lambda s: s.reward * (GAMMA ** len(s.steps)), batch))
    reward_bound = np.percentile(disc_rewards, percentile)

    train_obs = []
    train_act = []
    elite_batch = []
    for example, discounted_reward in zip(batch, disc_rewards):
        if discounted_reward > reward_bound:
            train_obs.extend(map(lambda step: step.observation, example.steps))
            train_act.extend(map(lambda step: step.action, example.steps))
            elite_batch.append(example)

    return elite_batch, train_obs, train_act, reward_bound

In [None]:
env = DiscreteOneHotWrapper(gym.make("FrozenLake-v0"))
env = gym.wrappers.Monitor(env, directory="mon", force=True)

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
net = Net(obs_size, HIDDEN_SIZE, n_actions)
net.to(torch.device('cuda'))
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.001)
writer = SummaryWriter(comment="-frozenlake")

full_batch = []
for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
    reward_mean = float(np.mean(list(map(lambda s: s.reward, batch))))
    full_batch, obs, acts, reward_bound = filter_batch(full_batch + batch, PERCENTILE)
    if not full_batch:
        continue
    obs_v = torch.FloatTensor(obs).to(torch.device('cuda'))
    acts_v = torch.LongTensor(acts).to(torch.device('cuda'))
    full_batch = full_batch[-500:]

    optimizer.zero_grad()
    action_scores_v = net(obs_v)
    loss_v = objective(action_scores_v, acts_v)
    loss_v.backward()
    optimizer.step()
    print("%d: loss=%.3f, reward_mean=%.3f, reward_bound=%.3f, batch=%d" % (
        iter_no, loss_v.item(), reward_mean, reward_bound, len(full_batch)))
    writer.add_scalar("loss", loss_v.item(), iter_no)
    writer.add_scalar("reward_mean", reward_mean, iter_no)
    writer.add_scalar("reward_bound", reward_bound, iter_no)
    if reward_mean > 0.8:
        print("Solved!")
        break
writer.close()


1: loss=1.388, reward_mean=0.010, reward_bound=0.000, batch=1
2: loss=1.380, reward_mean=0.000, reward_bound=0.000, batch=1
3: loss=1.373, reward_mean=0.000, reward_bound=0.000, batch=1
4: loss=1.370, reward_mean=0.010, reward_bound=0.000, batch=2
5: loss=1.372, reward_mean=0.010, reward_bound=0.000, batch=3
6: loss=1.366, reward_mean=0.010, reward_bound=0.000, batch=4
7: loss=1.378, reward_mean=0.020, reward_bound=0.000, batch=6
8: loss=1.380, reward_mean=0.010, reward_bound=0.000, batch=7
9: loss=1.379, reward_mean=0.000, reward_bound=0.000, batch=7
10: loss=1.378, reward_mean=0.000, reward_bound=0.000, batch=7
11: loss=1.379, reward_mean=0.010, reward_bound=0.000, batch=8
12: loss=1.384, reward_mean=0.020, reward_bound=0.000, batch=10
13: loss=1.382, reward_mean=0.010, reward_bound=0.000, batch=11
14: loss=1.380, reward_mean=0.010, reward_bound=0.000, batch=12
15: loss=1.381, reward_mean=0.010, reward_bound=0.000, batch=13
16: loss=1.379, reward_mean=0.020, reward_bound=0.000, batch

129: loss=1.338, reward_mean=0.030, reward_bound=0.068, batch=231
130: loss=1.336, reward_mean=0.040, reward_bound=0.072, batch=231
131: loss=1.336, reward_mean=0.000, reward_bound=0.000, batch=231
132: loss=1.336, reward_mean=0.010, reward_bound=0.080, batch=230
133: loss=1.336, reward_mean=0.050, reward_bound=0.096, batch=231
134: loss=1.334, reward_mean=0.020, reward_bound=0.098, batch=226
135: loss=1.333, reward_mean=0.010, reward_bound=0.000, batch=227
136: loss=1.332, reward_mean=0.030, reward_bound=0.105, batch=229
137: loss=1.330, reward_mean=0.050, reward_bound=0.109, batch=225
138: loss=1.330, reward_mean=0.020, reward_bound=0.024, batch=227
139: loss=1.330, reward_mean=0.020, reward_bound=0.064, batch=229
140: loss=1.329, reward_mean=0.020, reward_bound=0.096, batch=230
141: loss=1.329, reward_mean=0.020, reward_bound=0.122, batch=222
142: loss=1.328, reward_mean=0.030, reward_bound=0.041, batch=225
143: loss=1.328, reward_mean=0.020, reward_bound=0.027, batch=227
144: loss=

254: loss=1.195, reward_mean=0.020, reward_bound=0.167, batch=231
255: loss=1.194, reward_mean=0.040, reward_bound=0.314, batch=231
256: loss=1.194, reward_mean=0.040, reward_bound=0.349, batch=230
257: loss=1.175, reward_mean=0.070, reward_bound=0.387, batch=164
258: loss=1.172, reward_mean=0.110, reward_bound=0.000, batch=175
259: loss=1.169, reward_mean=0.040, reward_bound=0.000, batch=179
260: loss=1.167, reward_mean=0.040, reward_bound=0.000, batch=183
261: loss=1.170, reward_mean=0.080, reward_bound=0.000, batch=191
262: loss=1.169, reward_mean=0.050, reward_bound=0.000, batch=196
263: loss=1.168, reward_mean=0.040, reward_bound=0.000, batch=200
264: loss=1.169, reward_mean=0.080, reward_bound=0.000, batch=208
265: loss=1.166, reward_mean=0.070, reward_bound=0.007, batch=215
266: loss=1.165, reward_mean=0.080, reward_bound=0.091, batch=220
267: loss=1.165, reward_mean=0.040, reward_bound=0.069, batch=224
268: loss=1.166, reward_mean=0.040, reward_bound=0.096, batch=227
269: loss=

379: loss=1.092, reward_mean=0.070, reward_bound=0.145, batch=225
380: loss=1.090, reward_mean=0.030, reward_bound=0.167, batch=226
381: loss=1.082, reward_mean=0.070, reward_bound=0.229, batch=225
382: loss=1.086, reward_mean=0.090, reward_bound=0.254, batch=220
383: loss=1.088, reward_mean=0.040, reward_bound=0.130, batch=224
384: loss=1.084, reward_mean=0.040, reward_bound=0.185, batch=226
385: loss=1.084, reward_mean=0.040, reward_bound=0.207, batch=228
386: loss=1.080, reward_mean=0.090, reward_bound=0.282, batch=226
387: loss=1.080, reward_mean=0.030, reward_bound=0.164, batch=228
388: loss=1.085, reward_mean=0.070, reward_bound=0.314, batch=221
389: loss=1.082, reward_mean=0.070, reward_bound=0.314, batch=224
390: loss=1.081, reward_mean=0.060, reward_bound=0.308, batch=227
391: loss=1.081, reward_mean=0.040, reward_bound=0.342, batch=229
392: loss=1.082, reward_mean=0.070, reward_bound=0.349, batch=226
393: loss=1.081, reward_mean=0.060, reward_bound=0.241, batch=228
394: loss=

504: loss=1.010, reward_mean=0.080, reward_bound=0.282, batch=224
505: loss=1.020, reward_mean=0.080, reward_bound=0.314, batch=198
506: loss=1.013, reward_mean=0.110, reward_bound=0.098, batch=207
507: loss=1.005, reward_mean=0.070, reward_bound=0.000, batch=214
508: loss=1.005, reward_mean=0.060, reward_bound=0.020, batch=220
509: loss=1.000, reward_mean=0.120, reward_bound=0.162, batch=224
510: loss=1.004, reward_mean=0.050, reward_bound=0.167, batch=224
511: loss=1.004, reward_mean=0.140, reward_bound=0.206, batch=226
512: loss=1.005, reward_mean=0.090, reward_bound=0.284, batch=228
513: loss=1.013, reward_mean=0.070, reward_bound=0.349, batch=194
514: loss=1.009, reward_mean=0.060, reward_bound=0.000, batch=200
515: loss=1.002, reward_mean=0.070, reward_bound=0.000, batch=207
516: loss=0.995, reward_mean=0.050, reward_bound=0.000, batch=212
517: loss=0.995, reward_mean=0.050, reward_bound=0.000, batch=217
518: loss=0.994, reward_mean=0.040, reward_bound=0.000, batch=221
519: loss=

629: loss=1.007, reward_mean=0.070, reward_bound=0.298, batch=228
630: loss=1.007, reward_mean=0.030, reward_bound=0.317, batch=229
631: loss=1.004, reward_mean=0.100, reward_bound=0.387, batch=226
632: loss=1.003, reward_mean=0.050, reward_bound=0.368, batch=228
633: loss=1.004, reward_mean=0.090, reward_bound=0.387, batch=228
634: loss=1.003, reward_mean=0.120, reward_bound=0.392, batch=229
635: loss=1.010, reward_mean=0.170, reward_bound=0.430, batch=205
636: loss=1.010, reward_mean=0.040, reward_bound=0.000, batch=209
637: loss=1.005, reward_mean=0.100, reward_bound=0.135, batch=215
638: loss=1.004, reward_mean=0.030, reward_bound=0.000, batch=218
639: loss=1.005, reward_mean=0.060, reward_bound=0.137, batch=222
640: loss=1.006, reward_mean=0.090, reward_bound=0.191, batch=225
641: loss=1.006, reward_mean=0.040, reward_bound=0.206, batch=225
642: loss=1.007, reward_mean=0.080, reward_bound=0.282, batch=225
643: loss=1.007, reward_mean=0.100, reward_bound=0.314, batch=225
644: loss=

754: loss=0.992, reward_mean=0.110, reward_bound=0.331, batch=228
755: loss=0.995, reward_mean=0.100, reward_bound=0.349, batch=218
756: loss=0.997, reward_mean=0.050, reward_bound=0.190, batch=222
757: loss=0.997, reward_mean=0.010, reward_bound=0.000, batch=223
758: loss=0.995, reward_mean=0.070, reward_bound=0.220, batch=226
759: loss=0.994, reward_mean=0.050, reward_bound=0.229, batch=227
760: loss=0.991, reward_mean=0.030, reward_bound=0.163, batch=229
761: loss=0.993, reward_mean=0.030, reward_bound=0.213, batch=230
762: loss=0.998, reward_mean=0.090, reward_bound=0.314, batch=229
763: loss=0.995, reward_mean=0.060, reward_bound=0.349, batch=228
764: loss=0.994, reward_mean=0.040, reward_bound=0.293, batch=229
765: loss=0.992, reward_mean=0.060, reward_bound=0.387, batch=223
766: loss=0.990, reward_mean=0.040, reward_bound=0.183, batch=226
767: loss=0.991, reward_mean=0.090, reward_bound=0.387, batch=227
768: loss=0.989, reward_mean=0.050, reward_bound=0.373, batch=229
769: loss=

879: loss=0.982, reward_mean=0.020, reward_bound=0.154, batch=229
880: loss=0.984, reward_mean=0.070, reward_bound=0.324, batch=230
881: loss=0.983, reward_mean=0.040, reward_bound=0.338, batch=231
882: loss=0.984, reward_mean=0.090, reward_bound=0.387, batch=230
883: loss=0.983, reward_mean=0.090, reward_bound=0.430, batch=226
884: loss=0.983, reward_mean=0.100, reward_bound=0.454, batch=228
885: loss=0.983, reward_mean=0.110, reward_bound=0.397, batch=229
886: loss=0.984, reward_mean=0.080, reward_bound=0.478, batch=232
887: loss=0.984, reward_mean=0.030, reward_bound=0.445, batch=232
888: loss=0.984, reward_mean=0.020, reward_bound=0.363, batch=232
889: loss=0.990, reward_mean=0.070, reward_bound=0.478, batch=214
890: loss=0.989, reward_mean=0.050, reward_bound=0.000, batch=219
891: loss=0.991, reward_mean=0.070, reward_bound=0.140, batch=223
892: loss=0.992, reward_mean=0.080, reward_bound=0.206, batch=225
893: loss=0.991, reward_mean=0.100, reward_bound=0.321, batch=227
894: loss=

In [None]:
from torchsummary import summary
summary(net.to(torch.device('cuda')), (1, 16))