In [1]:
import gym
from scipy.special import expit as sigmoid
import numpy as np
import numpy.random as npr
import time
from itertools import count
from collections import deque
import matplotlib.pyplot as plt
from a2c_ppo_acktr import utils

from layer import Layer

In [2]:
class args(object):
    num_hidden = 8
    bandwidth = 3
    eval_interval = None
    log_interval = 250
    use_gae = False
    num_updates = 1e5
    num_steps = 32
    memory_capacity = num_steps
    ppo_epoch = 1
    num_mini_batch = 32
    value_loss_coef = 0.5
    entropy_coef = 0.01
    lr = 1e-3
    eps = 1e-5
    max_grad_norm = 0.2
    clip_param = 0.2
    gamma = 0.99
    gae_lambda = 0.95
    use_proper_time_limits = False

In [3]:
env = gym.make('CartPole-v0')
input = Layer(4, 1, args.bandwidth, post_width=args.num_hidden, args=args, post_process=None)
hidden = Layer(args.num_hidden, args.bandwidth, args.bandwidth, post_width=1, args=args, post_process=None)
output = Layer(1, args.bandwidth, 1, post_process=sigmoid, args=args)



In [4]:
episode_rewards = deque(maxlen=args.log_interval)
done = True

for j in count(len(input.units[0].action_losses) + 1):
    input.clear_memory()
    hidden.clear_memory()
    output.clear_memory()
    if done:
        state = env.reset()
        episode_rewards.append(0)
        input.reset()
        hidden.reset()
        output.reset()
    for unit in input.units + hidden.units + output.units:
        utils.update_linear_schedule(unit.agent.optimizer, j, args.num_updates, args.lr)
    for step in range(args.num_steps):
        action = output(hidden(input([state[i:i+1] for i in range(4)])))[0][0]
        action = int(action > 0.5)
        state, reward, done, info = env.step(action)
        input.reward(reward)
        hidden.reward(reward)
        output.reward(reward)
        episode_rewards[-1] += reward
        if done:
            input.done()
            hidden.done()
            output.done()
    input.update()
    hidden.update()
    output.update()
    if j % args.log_interval == 0:
        print('Iter: %d, Avg/Max/Min. reward: %0.1f/%0.1f/%0.1f' % (j, sum(episode_rewards)/len(episode_rewards), max(episode_rewards), min(episode_rewards)))



Iter: 250, Avg/Max/Min. reward: 20.7/63.0/8.0
Iter: 500, Avg/Max/Min. reward: 24.2/123.0/9.0
Iter: 750, Avg/Max/Min. reward: 24.9/84.0/8.0
Iter: 1000, Avg/Max/Min. reward: 23.6/86.0/9.0
Iter: 1250, Avg/Max/Min. reward: 26.1/93.0/9.0
Iter: 1500, Avg/Max/Min. reward: 24.4/114.0/9.0
Iter: 1750, Avg/Max/Min. reward: 22.2/81.0/9.0
Iter: 2000, Avg/Max/Min. reward: 22.4/85.0/9.0
Iter: 2250, Avg/Max/Min. reward: 22.0/71.0/9.0
Iter: 2500, Avg/Max/Min. reward: 22.5/71.0/9.0
Iter: 2750, Avg/Max/Min. reward: 22.2/78.0/9.0
Iter: 3000, Avg/Max/Min. reward: 22.5/84.0/8.0
Iter: 3250, Avg/Max/Min. reward: 22.6/83.0/8.0
Iter: 3500, Avg/Max/Min. reward: 22.6/68.0/8.0
Iter: 3750, Avg/Max/Min. reward: 22.1/68.0/8.0
Iter: 4000, Avg/Max/Min. reward: 21.6/62.0/9.0
Iter: 4250, Avg/Max/Min. reward: 22.3/67.0/9.0
Iter: 4500, Avg/Max/Min. reward: 22.4/71.0/9.0
Iter: 4750, Avg/Max/Min. reward: 23.5/77.0/9.0
Iter: 5000, Avg/Max/Min. reward: 22.4/69.0/9.0
Iter: 5250, Avg/Max/Min. reward: 21.9/69.0/9.0
Iter: 5500, Av

KeyboardInterrupt: 

In [None]:
print('INPUTS')
input.plot_stats(n=10000)
print('HIDDEN')
hidden.plot_stats(n=10000)
print('OUTPUTS')
output.plot_stats(n=10000)