In [1]:
import gym
from collections import namedtuple
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

In [2]:
batch_size = 16
n_neurons = 64
percentile = 70

In [53]:
class model(nn.Module):
    def __init__(self, obs_size, n_neurons, n_actions):
        super(model, self).__init__()
        
        self.pipe = nn.Sequential(
            nn.Linear(obs_size, n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, n_actions)
        )
    
    def forward(self, x):
        return self.pipe(x)

In [54]:
Episode_Step = namedtuple('Episode_Step', ['observations', 'actions'])
Episode = namedtuple('Episode', ['rewards', 'episode_step'])

In [108]:
def iterate_batchs(env, model, batch_size):
    batch = []
    episode_reward = 0.0
    steps = []
    obs = env.reset()
    sm = nn.Softmax(dim=1)
    
    while True:
        obs_v = torch.FloatTensor(torch.tensor(obs).unsqueeze(0))
        #obs_v = torch.tensor(float(e.observation_space.sample())).unsqueeze(0)
        actions_prob_v = sm(model(obs_v))
        #actions_prob_v = model(obs_v)
        actions_prob = actions_prob_v.data.numpy()
        
        action = np.random.choice(len(actions_prob), p=actions_prob)
        next_obs, reward, is_done, _ = env.step(action)
        episode_reward += reward
        steps.append(Episode_Step(observations = obs, actions = action))
        
        if is_done:
            batch.append(Episode(rewards=episode_reward, episode_step=steps))
            episode_reward = 0.0
            next_obs = env.reset()
            steps = []
            
            if len(batch) == batch_size:
                yield batch
                batch = []
        
        obs = next_obs

In [109]:
def filter_batch(batch, percentile):
    rewards = list(map(lambda s: s.rewards, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))
    
    train_obs = []
    train_act = []
    for example in batch:
        if example.rewards < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observations, example.episode_step))
        train_act.extend(map(lambda step: step.actions, example.episode_step))
    
    train_obs_v = torch.FloatTensor(train_obs)
    train_act_v = torch.LongTensor(train_act)
    
    return train_obs_v, train_act_v, reward_bound, reward_mean

In [None]:
if __name__ == "__main__":
    env = gym.make('CartPole-v0')
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n
    
    net = model(obs_size, n_neurons, n_actions)
    objective = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    #writer = SummaryWriter('cartpole')

    for iter_no, batch in enumerate(iterate_batchs(env, net, batch_size)):
        env.render()
        train_obs, train_act, reward_b, reward_m = filter_batch(batch, percentile)
        
        act_preds = net(train_obs)
        
        loss = objective(act_preds, train_act)
        
        optimizer.zero_grad()
        
        loss.backward()
        
        optimizer.step()
        
        print("%d: loss=%.3f, reward_mean=%.1f, reward_bound=%.1f" % (
                iter_no, loss.item(), reward_m, reward_b))
        #writer.add_scalar("loss", loss_v.item(), iter_no)
        #writer.add_scalar("reward_bound", reward_b, iter_no)
        #writer.add_scalar("reward_mean", reward_m, iter_no)
        if reward_m > 300:
            print("Solved!")
            break
            
    #writer.close()
    env.close()   

In [48]:
#debugging:
id_, batch = next(enumerate(iterate_batchs(env, net, batch_size)))
train_obs, train_act, reward_b, reward_m = filter_batch(batch, percentile)

In [112]:
env = gym.make('CartPole-v0')
obs = env.reset()
sm = nn.Softmax(dim = 1)

for t in range(1000):
    with torch.no_grad():
        env.render()
        obs = torch.FloatTensor(torch.tensor(obs).unsqueeze(0))
        action_v = sm(net(obs))
        act_probs = action_v.data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        obs, reward, is_done, _ = env.step(action)
env.close()

  logger.warn(
