In [1]:
import gym
from scipy.special import expit as sigmoid
import numpy as np
import numpy.random as npr
import time
from itertools import count
from collections import deque
import matplotlib.pyplot as plt
from a2c_ppo_acktr import utils
import random

from heap import Heap

In [5]:
class args(object):
    bandwidth = 3
    postwidth = 3
    threshold = 0.1
    slip_reward = -0.001
    signal_split = 1
    fps = 3
    log_interval = 1
    use_gae = False
    num_updates = 1e5
    num_steps = 10
    sleep_freq = 0
    sleep_skip = False
    memory_capacity = num_steps*fps*13
    ppo_epoch = 1
    num_mini_batch = 10
    value_loss_coef = 0.5
    entropy_coef = 0.01
    lr = 1e-2
    lbfgs_history = 1000
    max_grad_norm = 0.005
    clip_param = 0.005
    gamma = 0.99
    gae_lambda = 0.95
    use_proper_time_limits = False

In [6]:
env = gym.make('CartPole-v0')
heap = Heap(13, 4, 1, args=args)

In [7]:
episode_rewards = deque(maxlen=args.log_interval)
episode_relaxes = deque(maxlen=args.log_interval)
done = True

for j in count(len(heap.units[0].action_losses) + 1):
    heap.clear_memory()
    if done:
        state = env.reset()
        episode_rewards.append(0)
        episode_relaxes.append(0)
        heap.reset()
    for unit in heap.units:
        utils.update_linear_schedule(unit.agent.optimizer, j, args.num_updates, args.lr)
    for step in range(args.num_steps):
        if args.sleep_freq and episode_rewards[-1] % args.sleep_freq == 0 and episode_rewards[-1]:
            heap.sleep(skip=args.sleep_skip)
        action, zeros_mask = heap(state, n=args.fps)
        action = int(action > 0.5)
        if len(zeros_mask[0]):
            if state[2] > 0:
                action = 0
            else:
                action = 1
        state, reward, done, info = env.step(action)
        heap.reward(reward)
        episode_rewards[-1] += reward
        if done:
            heap.done()
        if len(zeros_mask[0]) and reward:
            episode_relaxes[-1] += 1
    heap.update()
    if j % args.log_interval == 0:
        print('Iter: %d, Avg/Max/Min. reward: %0.1f/%0.1f/%0.1f, Avg relaxation: %0.2f' % (j, sum(episode_rewards)/len(episode_rewards), max(episode_rewards), min(episode_rewards), sum(episode_relaxes)/sum(episode_rewards)))

Iter: 1, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.10




Iter: 2, Avg/Max/Min. reward: 16.0/16.0/16.0, Avg relaxation: 0.31
Iter: 3, Avg/Max/Min. reward: 9.0/9.0/9.0, Avg relaxation: 1.00
Iter: 4, Avg/Max/Min. reward: 9.0/9.0/9.0, Avg relaxation: 1.00
Iter: 5, Avg/Max/Min. reward: 9.0/9.0/9.0, Avg relaxation: 1.00
Iter: 6, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.90
Iter: 7, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.90
Iter: 8, Avg/Max/Min. reward: 9.0/9.0/9.0, Avg relaxation: 0.89
Iter: 9, Avg/Max/Min. reward: 8.0/8.0/8.0, Avg relaxation: 0.62
Iter: 10, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.50
Iter: 11, Avg/Max/Min. reward: 14.0/14.0/14.0, Avg relaxation: 0.57
Iter: 12, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.60
Iter: 13, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.70
Iter: 14, Avg/Max/Min. reward: 16.0/16.0/16.0, Avg relaxation: 0.75
Iter: 15, Avg/Max/Min. reward: 10.0/10.0/10.0, Avg relaxation: 0.60
Iter: 16, Avg/Max/Min. reward: 9.0/9.0/9.0, Avg relaxation: 0.78
Iter: 

KeyboardInterrupt: 

In [None]:
heap.plot_stats(n=500)