In [1]:
from sac import SAC
from replay_memory import ReplayMemory
import gym
from gym import spaces
from addict import Dict
import numpy as np
import itertools
from scipy.special import softmax
import random

In [2]:
WARM_UP = 0
ACT_THRESHOLD = 0
BATCH_SIZE = 256
UPDATES_PER_STEP = 1
PRINT_FREQ = 100

In [3]:
num_inputs = 4
num_hidden = 8
num_outputs = 1
bandwidth = 3

In [4]:
env = gym.make('CartPole-v0')
env.action_space.shape = (1,)
env.action_space.high = np.array([1])
env.action_space.low = np.array([0])

In [5]:
input_layer_action_space = gym.spaces.Box(low=0, high=1, shape=(num_hidden + bandwidth,))
hidden_layer_action_space = gym.spaces.Box(low=0, high=1, shape=(num_outputs + bandwidth,))
output_layer_action_space = env.action_space

In [6]:
args = Dict()
args.gamma = 0.99
args.tau = 0.005
args.alpha = 0.2
args.policy = 'Gaussian'
args.target_update_interval = 10
args.automatic_entropy_tuning = True
args.cuda = False
args.hidden_size = 256
args.lr = 0.003

In [7]:
agent = SAC(1, input_layer_action_space, args)

In [8]:
input_layer = [SAC(1, input_layer_action_space, args) for _ in range(num_inputs)]
input_memory = [ReplayMemory(1000000) for _ in range(num_inputs)]
hidden_layer = [SAC(bandwidth, hidden_layer_action_space, args) for _ in range(num_hidden)]
hidden_memory = [ReplayMemory(1000000) for _ in range(num_hidden)]
output_layer = [SAC(bandwidth, output_layer_action_space, args) for _ in range(num_outputs)]
output_memory = [ReplayMemory(1000000) for _ in range(num_outputs)]

In [9]:
def eval_layers(input_state):
    input_state_actions = [tuple() for _ in range(num_inputs)]
    hidden_state_actions = [tuple() for _ in range(num_hidden)]
    output_state_actions = [tuple() for _ in range(num_outputs)]
    input_actions = [agent.select_action(input_state[i:i+1]) for i, agent in enumerate(input_layer)]
    for i, input_action in enumerate(input_actions):
        input_action[:num_hidden] = softmax(input_action[:num_hidden])
        if max(input_action[:num_hidden]) > ACT_THRESHOLD:
            input_state_actions[i] = (input_state[i:i+1], input_action, True)
            input_actions[i] = (np.argmax(input_action[:num_hidden]), input_action[num_hidden:])
        else:
            input_state_actions[i] = (input_state[i:i+1], input_action, False)
            input_actions[i] = tuple()
    hidden_state = [tuple() for _ in range(num_hidden)]
    for input_action in input_actions:
        try:
            hidden_i, hidden_msg = input_action
        except:
            continue
        try:
            hidden_state[hidden_i] += hidden_msg
        except:
            hidden_state[hidden_i] = hidden_msg
    hidden_actions = []
    for i, _hidden_state in enumerate(hidden_state):
        if len(_hidden_state):
            hidden_action = hidden_layer[i].select_action(hidden_state[i]) 
            if hidden_action[0] > ACT_THRESHOLD:
                hidden_state_actions[i] = (_hidden_state, hidden_action, True)
                hidden_actions.append((0, hidden_action[1:]))
            else:
                hidden_state_actions[i] = (_hidden_state, hidden_action, False)
                hidden_actions.append(tuple())
    output_state = [tuple() for _ in range(num_outputs)]
    for hidden_action in hidden_actions:
        try:
            output_i, output_msg = hidden_action
        except:
            continue
        try:
            output_state[output_i] += output_msg
        except:
            output_state[output_i] = output_msg
    output_actions = [agent.select_action(output_state[i]) for i, agent in enumerate(output_layer) if len(output_state[i])]
    if output_actions:
        output_state_actions = [(output_state[0], output_actions[0], True)]
    else:
        output_state_actions = [tuple()]
    inner_activations = {
        'input_state_actions' : input_state_actions,
        'hidden_state_actions' : hidden_state_actions,
        'output_state_actions' : output_state_actions
    }
    try:
        if output_actions[0] > 0.5:
            return 1, inner_activations
        else:
            return 0, inner_activations
    except:
        return random.randint(0, 1), inner_activations

In [10]:
def sample_layers(input_state):
    input_state_actions = [tuple() for _ in range(num_inputs)]
    hidden_state_actions = [tuple() for _ in range(num_hidden)]
    output_state_actions = [tuple() for _ in range(num_outputs)]
    input_actions = [input_layer_action_space.sample() for _ in input_layer]
    for i, input_action in enumerate(input_actions):
        input_action[:num_hidden] = softmax(input_action[:num_hidden])
        if max(input_action[:num_hidden]) > ACT_THRESHOLD:
            input_state_actions[i] = (input_state[i:i+1], input_action, True)
            input_actions[i] = (np.argmax(input_action[:num_hidden]), input_action[num_hidden:])
        else:
            input_state_actions[i] = (input_state[i:i+1], input_action, False)
            input_actions[i] = tuple()
    hidden_state = [tuple() for _ in range(num_hidden)]
    for input_action in input_actions:
        try:
            hidden_i, hidden_msg = input_action
        except:
            continue
        try:
            hidden_state[hidden_i] += hidden_msg
        except:
            hidden_state[hidden_i] = hidden_msg
    hidden_actions = []
    for i, _hidden_state in enumerate(hidden_state):
        if len(_hidden_state):
            hidden_action = hidden_layer_action_space.sample()
            if hidden_action[0] > ACT_THRESHOLD:
                hidden_state_actions[i] = (_hidden_state, hidden_action, True)
                hidden_actions.append((0, hidden_action[1:]))
            else:
                hidden_state_actions[i] = (_hidden_state, hidden_action, False)
                hidden_actions.append(tuple())
    output_state = [tuple() for _ in range(num_outputs)]
    for hidden_action in hidden_actions:
        try:
            output_i, output_msg = hidden_action
        except:
            continue
        try:
            output_state[output_i] += output_msg
        except:
            output_state[output_i] = output_msg
    output_actions = [np.array([output_layer_action_space.sample()]) for i, _ in enumerate(output_layer) if len(output_state[i])]
    output_state_actions = [(output_state[0], output_actions[0], True)]
    inner_activations = {
        'input_state_actions' : input_state_actions,
        'hidden_state_actions' : hidden_state_actions,
        'output_state_actions' : output_state_actions
    }
    try:
        if output_actions[0] > 0.5:
            return 1, inner_activations
        else:
            return 0, inner_activations
    except:
        return random.randint(0, 1), inner_activations

In [11]:
def push_memory(inner_activations, reward, next_state, mask):
    input_state_actions = inner_activations['input_state_actions']
    hidden_state_actions = inner_activations['hidden_state_actions']
    output_state_actions = inner_activations['output_state_actions']
    for i, ((input_state, input_action, flag), mem) in enumerate(zip(input_state_actions, input_memory)):
        next_input_state = next_state[i:i+1]
        if flag:
            mem.push(input_state, input_action, reward, next_input_state, mask)
        else:
            mem.push(input_state, input_action, 0, next_input_state, mask)
    next_action, next_inner_activations = eval_layers(next_state)
    for hidden_state_action, mem, next_hidden_state_action in zip(hidden_state_actions, hidden_memory, next_inner_activations['hidden_state_actions']):
        try:
            hidden_state, hidden_action, flag = hidden_state_action
            next_hidden_state, _, _ = next_hidden_state_action
        except:
            continue
        if flag:
            mem.push(hidden_state, hidden_action, reward, next_hidden_state, mask)
        else:
            mem.push(hidden_state, hidden_action, 0, next_hidden_state, mask)
    for output_state_action, mem, next_output_state_action in zip(output_state_actions, output_memory, next_inner_activations['output_state_actions']):
        try:
            output_state, output_action, flag = output_state_action
            next_output_state, _, _ = next_output_state_action
        except:
            continue
        if flag:
            mem.push(output_state, output_action, reward, next_output_state, mask)
        else:
            mem.push(output_state, output_action, 0, next_output_state, mask)
    return next_action, next_inner_activations

In [12]:
def norm_stats(stats):
    for k in stats:
        stats[k] /= stats.cnt
    return stats
def average_stats(lst):
    avg = Dict()
    for stats in lst:
        for k in stats:
            avg[k] += stats[k]
        avg.cnt += 1
    return norm_stats(avg)
def get_avg_loss(train_stats):
    try:
        input_loss = (train_stats.input.critic_1_loss + train_stats.input.critic_2_loss + train_stats.input.policy_loss)/3
    except:
        input_loss = float('nan')
    try:
        hidden_loss = (train_stats.hidden.critic_1_loss + train_stats.hidden.critic_2_loss + train_stats.hidden.policy_loss)/3
    except:
        hidden_loss = float('nan')
    try:
        output_loss = (train_stats.output.critic_1_loss + train_stats.output.critic_2_loss + train_stats.output.policy_loss)/3
    except:
        output_loss = float('nan')
    return input_loss, hidden_loss, output_loss

In [13]:
total_steps = 0
updates = 0
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_steps = 0
    reasoned_steps = 0
    done = False
    state = env.reset()
    inner_activations = {}
    action = None
    train_stats = Dict()
    train_stats.input = []
    train_stats.hidden = []
    train_stats.output = []
    while not done:
        if WARM_UP > total_steps:
            action, inner_activations = sample_layers(state)  # Sample random action
        else:
            if not inner_activations:
                action, inner_activations = eval_layers(state)  # Sample action from policy
            if inner_activations['output_state_actions'][0]:
                reasoned_steps += 1

        # Number of updates per step in environment
        for i in range(UPDATES_PER_STEP):
            # Update parameters of all the networks
            flag = False
            for agent, memory in zip(input_layer, input_memory):
                input_stats = Dict()
                if len(memory) > BATCH_SIZE:
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, BATCH_SIZE, updates)
                    input_stats.critic_1_loss += critic_1_loss
                    input_stats.critic_2_loss += critic_2_loss
                    input_stats.policy_loss += policy_loss
                    input_stats.ent_loss += ent_loss
                    input_stats.alpha += alpha
                    input_stats.cnt += 1
                    flag = True
                norm_stats(input_stats)
                train_stats.input.append(input_stats)
            for agent, memory in zip(hidden_layer, hidden_memory):
                hidden_stats = Dict()
                if len(memory) > BATCH_SIZE:
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, BATCH_SIZE, updates)
                    hidden_stats.critic_1_loss += critic_1_loss
                    hidden_stats.critic_2_loss += critic_2_loss
                    hidden_stats.policy_loss += policy_loss
                    hidden_stats.ent_loss += ent_loss
                    hidden_stats.alpha += alpha
                    hidden_stats.cnt += 1
                    flag = True
                norm_stats(hidden_stats)
                train_stats.hidden.append(hidden_stats)
            for agent, memory in zip(output_layer, output_memory):
                output_stats = Dict()
                if len(memory) > BATCH_SIZE:
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, BATCH_SIZE, updates)
                    output_stats.critic_1_loss += critic_1_loss
                    output_stats.critic_2_loss += critic_2_loss
                    output_stats.policy_loss += policy_loss
                    output_stats.ent_loss += ent_loss
                    output_stats.alpha += alpha
                    output_stats.cnt += 1
                    flag = True
                norm_stats(output_stats)
                train_stats.output.append(output_stats)
            updates += flag

        next_state, reward, done, _ = env.step(action) # Step
        episode_steps += 1
        total_steps += 1
        episode_reward += reward

        # Ignore the "done" signal if it comes from hitting the time horizon.
        # (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        action, inner_activations = push_memory(inner_activations, reward, next_state, mask)

        state = next_state
    train_stats.input = average_stats(train_stats.input)
    train_stats.hidden = average_stats(train_stats.hidden)
    train_stats.output = average_stats(train_stats.output)
    loss = get_avg_loss(train_stats)
    print("Episode: {}, total steps: {}, episode steps: {}, reward: {}, loss_i: {}, loss_h: {}, loss_o: {}".format(i_episode, total_steps, episode_steps, round(episode_reward, 2), *list(map(lambda x : round(x, 3), loss))))
    if updates > 0 and i_episode % PRINT_FREQ == 0:
        print('INPUT: %s' % str(train_stats.input))
        print('HIDDEN: %s' % str(train_stats.hidden))
        print('OUTPUT: %s' % str(train_stats.output))
    
    if i_episode % 10 == 0:
        avg_reward = 0.
        episodes = 10
        for _  in range(episodes):
            state = env.reset()
            episode_reward = 0
            done = False
            while not done:
                action, _ = eval_layers(state)

                next_state, reward, done, _ = env.step(action)
                episode_reward += reward


                state = next_state
            avg_reward += episode_reward
        avg_reward /= episodes

        print("----------------------------------------")
        print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
        print("----------------------------------------")

Episode: 1, total steps: 10, episode steps: 10, reward: 10.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 2, total steps: 28, episode steps: 18, reward: 18.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 3, total steps: 41, episode steps: 13, reward: 13.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 4, total steps: 58, episode steps: 17, reward: 17.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 5, total steps: 68, episode steps: 10, reward: 10.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 6, total steps: 78, episode steps: 10, reward: 10.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 7, total steps: 120, episode steps: 42, reward: 42.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 8, total steps: 179, episode steps: 59, reward: 59.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 9, total steps: 189, episode steps: 10, reward: 10.0, loss_i: nan, loss_h: nan, loss_o: nan
Episode: 10, total steps: 203, episode steps: 14, reward: 14.0, loss_i: nan, loss_h: nan, loss_o:

----------------------------------------
Test Episodes: 10, Avg. Reward: 18.0
----------------------------------------
Episode: 71, total steps: 1411, episode steps: 10, reward: 10.0, loss_i: -0.573, loss_h: -3.334, loss_o: -0.235
Episode: 72, total steps: 1448, episode steps: 37, reward: 37.0, loss_i: -0.573, loss_h: -23.075, loss_o: -0.236
Episode: 73, total steps: 1473, episode steps: 25, reward: 25.0, loss_i: -0.575, loss_h: -19.477, loss_o: -0.236
Episode: 74, total steps: 1489, episode steps: 16, reward: 16.0, loss_i: -0.577, loss_h: -15.341, loss_o: -0.238
Episode: 75, total steps: 1505, episode steps: 16, reward: 16.0, loss_i: -0.579, loss_h: -20.579, loss_o: -0.238
Episode: 76, total steps: 1520, episode steps: 15, reward: 15.0, loss_i: -0.571, loss_h: -20.235, loss_o: -0.24
Episode: 77, total steps: 1536, episode steps: 16, reward: 16.0, loss_i: -0.577, loss_h: -22.843, loss_o: -0.24
Episode: 78, total steps: 1547, episode steps: 11, reward: 11.0, loss_i: -0.571, loss_h: -15.

Episode: 132, total steps: 2525, episode steps: 9, reward: 9.0, loss_i: -0.549, loss_h: -0.221, loss_o: -0.278
Episode: 133, total steps: 2539, episode steps: 14, reward: 14.0, loss_i: -0.531, loss_h: -0.221, loss_o: -0.282
Episode: 134, total steps: 2551, episode steps: 12, reward: 12.0, loss_i: -0.541, loss_h: -0.222, loss_o: -0.284
Episode: 135, total steps: 2562, episode steps: 11, reward: 11.0, loss_i: -0.538, loss_h: -0.223, loss_o: -0.285
Episode: 136, total steps: 2574, episode steps: 12, reward: 12.0, loss_i: -0.555, loss_h: -0.223, loss_o: -0.284
Episode: 137, total steps: 2591, episode steps: 17, reward: 17.0, loss_i: -0.516, loss_h: -0.224, loss_o: -0.286
Episode: 138, total steps: 2607, episode steps: 16, reward: 16.0, loss_i: -0.528, loss_h: -0.225, loss_o: -0.288
Episode: 139, total steps: 2619, episode steps: 12, reward: 12.0, loss_i: -0.514, loss_h: -0.225, loss_o: -0.287
Episode: 140, total steps: 2629, episode steps: 10, reward: 10.0, loss_i: -0.528, loss_h: -0.226, 

KeyboardInterrupt: 