In [1]:
from sac import SAC
from replay_memory import ReplayMemory
import gym
from addict import Dict
import numpy as np
import itertools

def norm_stats(stats):
    for k in stats:
        stats[k] /= stats.cnt
    return stats

def average_stats(lst):
    avg = Dict()
    for stats in lst:
        for k in stats:
            avg[k] += stats[k]
        avg.cnt += 1
    return norm_stats(avg)

def get_avg_loss(train_stats):
    try:
        return (train_stats.critic_1_loss + train_stats.critic_2_loss + train_stats.policy_loss)/3
    except:
        return float('nan')

In [2]:
WARM_UP = 10000
BUFFER_SIZE = 10**6
BATCH_SIZE = 256
UPDATES_PER_STEP = 1
PRINT_FREQ = 100

In [3]:
env = gym.make('CartPole-v0')
env.action_space.shape = (1,)
env.action_space.high = np.array([1])
env.action_space.low = np.array([0])

In [4]:
args = Dict()
args.gamma = 0.99
args.tau = 1
args.alpha = 0.2
args.policy = 'Gaussian'
args.target_update_interval = 1000
args.automatic_entropy_tuning = False
args.cuda = False
args.hidden_size = 256
args.lr = 0.003

In [5]:
agent = SAC(4, env.action_space, args)

In [6]:
memory = ReplayMemory(BUFFER_SIZE)

In [7]:
total_steps = 0
last_total_steps = 0
updates = 0
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_steps = 0
    done = False
    state = env.reset()
    train_stats = []
    while not done:
        if WARM_UP > total_steps:
            action = env.action_space.sample()
        else:
            action = agent.select_action(state)
            action = int(action > 0.5)
        
        if len(memory) > BATCH_SIZE:
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, BATCH_SIZE, updates)
            _train_stats = Dict()
            _train_stats.critic_1_loss = critic_1_loss
            _train_stats.critic_2_loss = critic_2_loss
            _train_stats.policy_loss = policy_loss
            _train_stats.ent_loss = ent_loss
            _train_stats.alpha = alpha
            train_stats.append(_train_stats)
            updates += 1

        next_state, reward, done, _ = env.step(action) # Step
        episode_steps += 1
        total_steps += 1
        episode_reward += reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)
        memory.push(state, np.array([action]), reward, next_state, mask)

        state = next_state
    train_stats = average_stats(train_stats)
    loss = get_avg_loss(train_stats)
    print("Episode: {}, total steps: {}, episode steps: {}, reward: {}, loss: {}".format(i_episode, total_steps, episode_steps, round(episode_reward, 2), round(loss, 3)))
    if updates > 0 and i_episode % PRINT_FREQ == 0:
        print('TRAIN STATS: %s' % str(train_stats))
    
    if i_episode % 10 == 0:
        avg_reward = 0.
        max_reward = 0.
        episodes = 10
        for _  in range(episodes):
            state = env.reset()
            episode_reward = 0
            done = False
            while not done:
                action = agent.select_action(state)
                action = int(action > 0.5)

                next_state, reward, done, _ = env.step(action)
                episode_reward += reward

                state = next_state
            avg_reward += episode_reward
            max_reward = max(episode_reward, max_reward)
        avg_reward /= episodes

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. Reward: {}, Max. Reward: {}".format(episodes, round(avg_reward, 2), round(max_reward, 2)))
        print("----------------------------------------")

Episode: 1, total steps: 31, episode steps: 31, reward: 31.0, loss: nan
Episode: 2, total steps: 48, episode steps: 17, reward: 17.0, loss: nan
Episode: 3, total steps: 68, episode steps: 20, reward: 20.0, loss: nan
Episode: 4, total steps: 146, episode steps: 78, reward: 78.0, loss: nan
Episode: 5, total steps: 163, episode steps: 17, reward: 17.0, loss: nan
Episode: 6, total steps: 216, episode steps: 53, reward: 53.0, loss: nan
Episode: 7, total steps: 254, episode steps: 38, reward: 38.0, loss: nan
Episode: 8, total steps: 280, episode steps: 26, reward: 26.0, loss: 0.415
Episode: 9, total steps: 312, episode steps: 32, reward: 32.0, loss: 0.035
Episode: 10, total steps: 334, episode steps: 22, reward: 22.0, loss: -0.134
----------------------------------------
Test Episodes: 10, Avg. Reward: 21.0, Max. Reward: 45.0
----------------------------------------
Episode: 11, total steps: 380, episode steps: 46, reward: 46.0, loss: -0.218
Episode: 12, total steps: 406, episode steps: 26, 

Episode: 92, total steps: 2045, episode steps: 44, reward: 44.0, loss: -0.596
Episode: 93, total steps: 2058, episode steps: 13, reward: 13.0, loss: -0.594
Episode: 94, total steps: 2066, episode steps: 8, reward: 8.0, loss: -0.595
Episode: 95, total steps: 2079, episode steps: 13, reward: 13.0, loss: -0.595
Episode: 96, total steps: 2096, episode steps: 17, reward: 17.0, loss: -0.594
Episode: 97, total steps: 2106, episode steps: 10, reward: 10.0, loss: -0.594
Episode: 98, total steps: 2135, episode steps: 29, reward: 29.0, loss: -0.592
Episode: 99, total steps: 2151, episode steps: 16, reward: 16.0, loss: -0.591
Episode: 100, total steps: 2175, episode steps: 24, reward: 24.0, loss: -0.593
TRAIN STATS: {'critic_1_loss': 0.03792622988112271, 'critic_2_loss': 0.03847720636986196, 'policy_loss': -1.8542451808849971, 'ent_loss': 0.0, 'alpha': 0.20000000298023224, 'cnt': 1.0}
----------------------------------------
Test Episodes: 10, Avg. Reward: 17.5, Max. Reward: 37.0
-----------------

Episode: 180, total steps: 4116, episode steps: 12, reward: 12.0, loss: -1.052
----------------------------------------
Test Episodes: 10, Avg. Reward: 24.2, Max. Reward: 43.0
----------------------------------------
Episode: 181, total steps: 4127, episode steps: 11, reward: 11.0, loss: -1.047
Episode: 182, total steps: 4141, episode steps: 14, reward: 14.0, loss: -1.039
Episode: 183, total steps: 4154, episode steps: 13, reward: 13.0, loss: -1.023
Episode: 184, total steps: 4170, episode steps: 16, reward: 16.0, loss: -1.034
Episode: 185, total steps: 4193, episode steps: 23, reward: 23.0, loss: -1.059
Episode: 186, total steps: 4204, episode steps: 11, reward: 11.0, loss: -1.018
Episode: 187, total steps: 4221, episode steps: 17, reward: 17.0, loss: -1.039
Episode: 188, total steps: 4239, episode steps: 18, reward: 18.0, loss: -1.035
Episode: 189, total steps: 4281, episode steps: 42, reward: 42.0, loss: -1.002
Episode: 190, total steps: 4362, episode steps: 81, reward: 81.0, loss: 

Episode: 266, total steps: 6077, episode steps: 25, reward: 25.0, loss: -1.408
Episode: 267, total steps: 6092, episode steps: 15, reward: 15.0, loss: -1.415
Episode: 268, total steps: 6109, episode steps: 17, reward: 17.0, loss: -1.393
Episode: 269, total steps: 6142, episode steps: 33, reward: 33.0, loss: -1.396
Episode: 270, total steps: 6162, episode steps: 20, reward: 20.0, loss: -1.429
----------------------------------------
Test Episodes: 10, Avg. Reward: 36.4, Max. Reward: 94.0
----------------------------------------
Episode: 271, total steps: 6176, episode steps: 14, reward: 14.0, loss: -1.377
Episode: 272, total steps: 6200, episode steps: 24, reward: 24.0, loss: -1.427
Episode: 273, total steps: 6215, episode steps: 15, reward: 15.0, loss: -1.43
Episode: 274, total steps: 6249, episode steps: 34, reward: 34.0, loss: -1.388
Episode: 275, total steps: 6287, episode steps: 38, reward: 38.0, loss: -1.442
Episode: 276, total steps: 6299, episode steps: 12, reward: 12.0, loss: -

Episode: 352, total steps: 7907, episode steps: 23, reward: 23.0, loss: -1.788
Episode: 353, total steps: 7945, episode steps: 38, reward: 38.0, loss: -1.777
Episode: 354, total steps: 7979, episode steps: 34, reward: 34.0, loss: -1.812
Episode: 355, total steps: 7991, episode steps: 12, reward: 12.0, loss: -1.874
Episode: 356, total steps: 8032, episode steps: 41, reward: 41.0, loss: -1.831
Episode: 357, total steps: 8058, episode steps: 26, reward: 26.0, loss: -1.848
Episode: 358, total steps: 8075, episode steps: 17, reward: 17.0, loss: -1.838
Episode: 359, total steps: 8104, episode steps: 29, reward: 29.0, loss: -1.805
Episode: 360, total steps: 8153, episode steps: 49, reward: 49.0, loss: -1.825
----------------------------------------
Test Episodes: 10, Avg. Reward: 47.2, Max. Reward: 108.0
----------------------------------------
Episode: 361, total steps: 8170, episode steps: 17, reward: 17.0, loss: -1.829
Episode: 362, total steps: 8183, episode steps: 13, reward: 13.0, loss:

Episode: 440, total steps: 9941, episode steps: 22, reward: 22.0, loss: -2.258
----------------------------------------
Test Episodes: 10, Avg. Reward: 70.1, Max. Reward: 176.0
----------------------------------------
Episode: 441, total steps: 9962, episode steps: 21, reward: 21.0, loss: -2.279
Episode: 442, total steps: 9981, episode steps: 19, reward: 19.0, loss: -2.28
Episode: 443, total steps: 10063, episode steps: 82, reward: 82.0, loss: -2.321
Episode: 444, total steps: 10176, episode steps: 113, reward: 113.0, loss: -2.307
Episode: 445, total steps: 10234, episode steps: 58, reward: 58.0, loss: -2.311
Episode: 446, total steps: 10293, episode steps: 59, reward: 59.0, loss: -2.367
Episode: 447, total steps: 10400, episode steps: 107, reward: 107.0, loss: -2.514
Episode: 448, total steps: 10486, episode steps: 86, reward: 86.0, loss: -2.512
Episode: 449, total steps: 10605, episode steps: 119, reward: 119.0, loss: -2.538
Episode: 450, total steps: 10646, episode steps: 41, reward

KeyboardInterrupt: 