In [20]:
import torch, gym
import torch.nn as nn
from collections import namedtuple
import numpy as np
from torch.utils.tensorboard import SummaryWriter
#tensorboard --logdir 'runs\RL' --host localhost --port 8888
import datetime

In [22]:
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70

class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions) -> None:
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x)
    

Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names= ['observation', 'action'])

def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()[0]
    sm = nn.Softmax(dim=1)
    while True:
        obs_v = torch.FloatTensor([obs])
        act_probs_v = sm(net(obs_v))
        act_probs = act_probs_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done = env.step(action)[:3]
        episode_reward += reward
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)

        if is_done:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()[0]
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs


def filter_batch(batch, percentile):
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))
    train_obs = []
    train_act = []
    for reward, steps in batch:
        if reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, steps))
        train_act.extend(map(lambda step: step.action, steps))
    train_obs_v = torch.FloatTensor(train_obs)
    train_act_v = torch.LongTensor(train_act)
    return train_obs_v, train_act_v, reward_bound, reward_mean

if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    net = Net(obs_size, HIDDEN_SIZE, n_actions)
    objective = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=net.parameters(), lr = 0.01)

    now = datetime.datetime.now()
    s2 = now.strftime("%H_%M_%S")
    writer = SummaryWriter(fr'runs/RL/{s2}',comment = '-cartpolse')

    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)
        optimizer.zero_grad()
        action_scores_v = net(obs_v)
        loss_v = objective(action_scores_v, acts_v)
        loss_v.backward()
        optimizer.step()
        print("%d: loss=%.3f, reward_mean=%.1f, rw_bound=%.1f" % (
        iter_no, loss_v.item(), reward_m, reward_b))
        writer.add_scalar("loss", loss_v.item(), iter_no)
        writer.add_scalar("reward_bound", reward_b, iter_no)
        writer.add_scalar("reward_mean", reward_m, iter_no)

        if reward_m > 199:
            print("SOLVED!")
            break
    writer.close()

0: loss=0.702, reward_mean=16.2, rw_bound=17.0
1: loss=0.664, reward_mean=14.9, rw_bound=16.5
2: loss=0.632, reward_mean=16.2, rw_bound=20.0
3: loss=0.687, reward_mean=20.3, rw_bound=22.5
4: loss=0.676, reward_mean=19.9, rw_bound=20.0
5: loss=0.666, reward_mean=21.0, rw_bound=26.0
6: loss=0.652, reward_mean=21.6, rw_bound=25.5
7: loss=0.661, reward_mean=34.5, rw_bound=34.0
8: loss=0.630, reward_mean=37.0, rw_bound=43.0
9: loss=0.629, reward_mean=38.9, rw_bound=52.5
10: loss=0.610, reward_mean=50.6, rw_bound=56.5
11: loss=0.609, reward_mean=52.2, rw_bound=57.5
12: loss=0.612, reward_mean=35.1, rw_bound=37.0
13: loss=0.609, reward_mean=38.4, rw_bound=39.5
14: loss=0.594, reward_mean=48.1, rw_bound=50.5
15: loss=0.569, reward_mean=50.8, rw_bound=60.5
16: loss=0.584, reward_mean=56.1, rw_bound=64.0
17: loss=0.558, reward_mean=52.4, rw_bound=60.5
18: loss=0.570, reward_mean=46.7, rw_bound=52.5
19: loss=0.580, reward_mean=48.2, rw_bound=54.0
20: loss=0.572, reward_mean=48.6, rw_bound=50.0
21