In [1]:
# conda env: centralised-agents

import gym 
import jax.numpy as jnp
import jax
import numpy as np
import haiku as hk
from copy import deepcopy
from jax import jit, grad, vmap, pmap, random
import optax
import chex
import rlax
from typing import Tuple
import distrax

In [2]:
# global hypterparameters
HORIZON = 100
NUM_EPOCHS = 2
NUM_MINIBATCHES = 2
SEED = 2022
LEARNING_RATE = 5e-4
DISCOUNT = 0.99
GAE_LAMBDA = 0.95 
CLIPPING_EPSILON = 0.2

In [3]:
# first state
random_state = random.PRNGKey(SEED)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
class CentralControllerWrapper: 
    
    def __init__(self, ma_env):
        
        self.env = ma_env 
        self.num_agents = ma_env.n_agents 
        self.action_mapping = self.enumerate_agent_actions()
        self.action_space = len(self.action_mapping)
        self.observation_space = np.sum([len(i) for i in ma_env.reset()])
        
    def reset(self, ):
        
        obs_n = self.env.reset()
        joint_obs = self.create_joint_obs(obs_n)
        
        return joint_obs
    
    def step(self, joint_action): 
        
        action = self.action_mapping[joint_action]
        obs_n, reward_n, done_n, info = self.env.step(action)
        
        joint_obs = self.create_joint_obs(obs_n)
        team_reward = jnp.sum(jnp.array(reward_n))
        team_done = all(done_n)
        
        return joint_obs, team_reward, team_done, info
    
    def random_action(self,): 
        
        action = np.random.randint(low = 0, high = self.action_space)
        return action 
    
    def enumerate_agent_actions(self, ):
        
        agent_actions = [np.arange(self.env.action_space[i].n) for i in range(len(self.env.action_space))]
        enumerated_actions = np.array(np.meshgrid(*agent_actions)).T.reshape(-1,self.num_agents)
        action_mapping = {int(i): list(action) for i, action in enumerate(enumerated_actions)}
        return action_mapping
    
    def create_joint_obs(self, env_obs):
        
        array_obs = np.array(env_obs)
        joint_obs = np.concatenate(array_obs, axis = -1)
        
        return joint_obs
    
    def unwrapped_env(self):
        return self

In [5]:
# For debugging

class NormalGymWrapper: 
    
    def __init__(self, env):
        
        self.env = env  
        self.action_space = env.action_space.n
        self.observation_space = env.observation_space.shape[0]
        
    def reset(self, ):
        
        obs = self.env.reset()
        joint_obs = np.array(obs)
        
        return joint_obs
    
    def step(self, action): 
        
        obs, reward, done, info = self.env.step(action)
        
        return np.array(obs), jnp.array(reward), done, info
    
    def unwrapped_env(self):
        return self

In [6]:
### Getting environment details 
# env = gym.make('ma_gym:Switch2-v0')
# env = CentralControllerWrapper(env)

env = gym.make("CartPole-v0")
env = NormalGymWrapper(env)
num_actions     = env.action_space
observation_dim = env.observation_space

In [7]:
@chex.dataclass
class BufferState: 
    states: jnp.ndarray
    actions: jnp.ndarray 
    rewards: jnp.ndarray 
    dones: jnp.ndarray
    log_probs: jnp.ndarray
    values: jnp.ndarray
    advantages: jnp.ndarray
    returns: jnp.ndarray
    counter: jnp.int32 
    key: chex.PRNGKey
    buffer_size: jnp.int32
    gae_lambda: jnp.float32 
    discount: jnp.float32
    num_minibatches: jnp.int32

In [8]:
@chex.dataclass
class MiniBatch:
    states: jnp.ndarray
    actions: jnp.ndarray 
    rewards: jnp.ndarray 
    dones: jnp.ndarray
    log_probs: jnp.ndarray
    values: jnp.ndarray
    advantages: jnp.ndarray
    returns: jnp.ndarray
    key: chex.PRNGKey
    

In [9]:
# Very basic jax replay buffer

class JaxTrajectoryBuffer: 
    
    def create_buffer(
        self, 
        buffer_size: int, 
        observation_dim: int,
        gae_lambda: float, 
        discount: float, 
        num_minibatches: int,
        buffer_key: chex.PRNGKey = random.PRNGKey(0),
    ) -> BufferState:
        
        state_buffer = jnp.empty((buffer_size + 1, observation_dim), dtype=jnp.float32)
        action_buffer = jnp.empty(buffer_size + 1, dtype=jnp.int32)
        reward_buffer = jnp.empty(buffer_size + 1, dtype=jnp.float32)
        done_buffer = jnp.empty(buffer_size + 1, dtype=bool) 
        log_probs_buffer = jnp.empty(buffer_size + 1, dtype=jnp.float32)
        values_buffer = jnp.empty(buffer_size + 1, dtype=jnp.float32)
        advantages_buffer = jnp.empty(buffer_size, dtype=jnp.float32)
        returns_buffer = jnp.empty(buffer_size, dtype=jnp.float32)
        
        buffer_state = BufferState(
            states = state_buffer, 
            actions = action_buffer, 
            rewards = reward_buffer, 
            dones = done_buffer, 
            log_probs=log_probs_buffer, 
            values=values_buffer,
            advantages=advantages_buffer,
            returns=returns_buffer, 
            counter = jnp.array(0, dtype=jnp.int32), 
            key = buffer_key, 
            buffer_size = jnp.array(buffer_size, dtype=jnp.int32), 
            gae_lambda = jnp.array(gae_lambda, dtype=jnp.float32), 
            discount = jnp.array(discount, dtype=jnp.float32),
            num_minibatches = jnp.array(num_minibatches, dtype=jnp.int32),
        )
        
        return buffer_state
    
    def add(
        self,
        buffer_state, 
        state, 
        action, 
        reward, 
        done, 
        log_prob,
        value, 
    ) -> BufferState:
        
        index = buffer_state.counter
        #x = x.at[idx].set(y)
        buffer_state.states = buffer_state.states.at[index].set(state)
        buffer_state.actions = buffer_state.actions.at[index].set(action)
        buffer_state.rewards = buffer_state.rewards.at[index].set(reward)
        buffer_state.dones = buffer_state.dones.at[index].set(done)
        buffer_state.log_probs = buffer_state.log_probs.at[index].set(log_prob)
        buffer_state.values = buffer_state.values.at[index].set(value)
        
        buffer_state.counter += 1
        
        return buffer_state
    
    def compute_advantages(
        self, 
        buffer_state, 
    ) -> BufferState:
        
        # Returns array of length [0:k-1]
        
        advantages = rlax.truncated_generalized_advantage_estimation(
            r_t = buffer_state.rewards[1: ], 
            discount_t = (1 - buffer_state.dones[1: ]) * buffer_state.discount, 
            lambda_ = buffer_state.gae_lambda, 
            values = buffer_state.values,
        )
        
        # Don't have to add a zero just make other arrays shorter during training. 
        
        buffer_state.advantages = advantages
        
        # can now get the returns by saying 
        # returns = advantages - values[0:k-1]  -> essentially Adv - V_{t}
        returns = advantages + buffer_state.values[:-1]
        buffer_state.returns = returns 
        
        return buffer_state
        
        
    def get_epoch_indices(
        self, 
        buffer_state, 
    ) -> jnp.ndarray:
        
        key, sample_key = random.split(buffer_state.key)
        
        shuffled_idx = random.permutation(sample_key, buffer_state.buffer_size)
        
        buffer_state.key = key 
        
        # Split indices into minibatches 
        # right now the type here is a list. might have to be cast into 
        # something else. 
        minibatch_idxs = jnp.split(shuffled_idx, buffer_state.num_minibatches)
        
        return buffer_state, minibatch_idxs
    
    def should_train(
        self, 
        buffer_state 
    ) -> bool:
        
        return jnp.equal(buffer_state.counter, buffer_state.buffer_size + 1)
    
    def reset(
        self, 
        buffer_state: BufferState
    ) -> BufferState:
        
        state_buffer = jnp.empty((buffer_state.buffer_size + 1, observation_dim), dtype=jnp.float32)
        action_buffer = jnp.empty(buffer_state.buffer_size + 1, dtype=jnp.int32)
        reward_buffer = jnp.empty(buffer_state.buffer_size + 1, dtype=jnp.float32)
        done_buffer = jnp.empty(buffer_state.buffer_size + 1, dtype=bool) 
        log_probs_buffer = jnp.empty(buffer_state.buffer_size + 1, dtype=jnp.float32)
        values_buffer = jnp.empty(buffer_state.buffer_size + 1, dtype=jnp.float32)
        advantages_buffer = jnp.empty(buffer_state.buffer_size, dtype=jnp.float32)
        returns_buffer = jnp.empty(buffer_state.buffer_size, dtype=jnp.float32)
        
        buffer_state.states = state_buffer
        buffer_state.actions = action_buffer
        buffer_state.rewards = reward_buffer
        buffer_state.dones = done_buffer
        buffer_state.log_probs = log_probs_buffer
        buffer_state.values = values_buffer
        buffer_state.advantages = advantages_buffer
        buffer_state.returns = returns_buffer
        buffer_state.counter = jnp.array(0, dtype=jnp.int32)
        
        return buffer_state
        
        

In [10]:
# create replay buffer 
buffer = JaxTrajectoryBuffer()
buffer_state = buffer.create_buffer(
        buffer_size = HORIZON,  
        observation_dim = observation_dim,
        gae_lambda = GAE_LAMBDA, 
        discount = DISCOUNT, 
        num_minibatches = NUM_MINIBATCHES,
)
jit_add = jax.jit(buffer.add)

In [11]:
# create networks 

network_state, policy_init_state = random.split(random_state)
network_state, value_init_state  = random.split(network_state)

# Create feedforward policy and value network  

def policy_fn(batch) -> jnp.ndarray:
    """Standard MLP network."""
    x = batch.astype(jnp.float32)
    mlp = hk.Sequential([
        hk.Linear(64), jax.nn.relu,
        hk.Linear(64), jax.nn.relu,
        hk.Linear(num_actions),
    ])
    return mlp(x)

def value_fn(batch) -> jnp.ndarray:
    """Standard MLP network."""
    x = batch.astype(jnp.float32)
    mlp = hk.Sequential([
        hk.Linear(64), jax.nn.relu,
        hk.Linear(64), jax.nn.relu,
        hk.Linear(1),
    ])
    return mlp(x)

dummy_pass_data = jnp.ones((1, observation_dim))

# initialize policy and value parameters  
policy_network = hk.without_apply_rng(hk.transform(policy_fn))
policy_params  = policy_network.init(policy_init_state, dummy_pass_data)

value_network = hk.without_apply_rng(hk.transform(value_fn))
value_params  = value_network.init(value_init_state, dummy_pass_data)


# Intialize optimisers and optimiser states 
policy_optimiser = optax.adam(LEARNING_RATE)
policy_optimiser_state = policy_optimiser.init(policy_params)

value_optimiser = optax.adam(LEARNING_RATE)
value_optimiser_state = value_optimiser.init(value_params)

In [12]:
# @jit
# @chex.assert_max_traces(n=1)
def get_action_logprob(observation, action_key, policy_params):
    
    """Given an observation, returns the action to take in the environment, \
       the related log probability, the state value, and the distribution entropy.
       """
    
    logits = policy_network.apply(policy_params, observation)
    distribution = distrax.Categorical(logits=logits)
    
    action, logprob = distribution.sample_and_log_prob(
        seed = action_key, 
    )
    
    entropy = distribution.entropy()
    
    return jnp.squeeze(action), jnp.squeeze(logprob), jnp.squeeze(entropy)

In [13]:
# @jit
# @chex.assert_max_traces(n=1)
def get_value(observation, value_params):
    
    """Given an observation, returns the action to take in the environment, \
       the related log probability, the state value, and the distribution entropy.
       """
    
    value = value_network.apply(value_params, observation)
    
    return jnp.squeeze(value)

In [14]:
def ppo_policy_loss(policy_params, minibatch: MiniBatch):
    
    # TODO: pass in somehow 
    CLIP_EPSILON = 0.2
    
    states = minibatch.states 
    actions = minibatch.actions
    old_log_probs = minibatch.log_probs
    old_values = minibatch.values 
    key = minibatch.key
    advantages = minibatch.advantages
    
    key, train_key = random.split(key)
    batch_train_keys = random.split(train_key, len(states))
    
    _, new_log_probs, entropy = vmap(get_action_logprob, in_axes = (0, 0, None))(
        states, 
        batch_train_keys, 
        policy_params)
    
    ratio = jnp.exp(new_log_probs - old_log_probs)
    
    term_1 = ratio * advantages 
    term_2 = jnp.clip(a = ratio, a_min= 1- CLIP_EPSILON, a_max = 1 + CLIP_EPSILON) * advantages
    
    jax.debug.print("term policy loss 1 {x}:", x=term_1.shape)
    jax.debug.print("term policy loss 2 {y}:", y=term_2.shape)
    jax.debug.print("policy loss before mean {z}:", z=jnp.minimum(term_1, term_2).shape)
    
    
    # negative loss for gradient ascent 
    loss = -jnp.mean(jnp.minimum(term_1, term_2))
    
    # TODO maybe stop gradient for all value stuff? 
    
    return loss 
    

In [15]:
def ppo_value_loss(value_params, minibatch: MiniBatch):
    
    states = minibatch.states 
    returns = minibatch.returns 
    key = minibatch.key
    
    key, train_key = random.split(key)
    batch_train_key = random.split(train_key, len(states))
    
    new_values = vmap(get_value, in_axes = (0, None))(
        states,
        value_params, 
    )
    
    jax.debug.print("values value loss {x}:", x=new_values.shape)
    jax.debug.print("returns value loss {y}:", y=returns.shape)
    jax.debug.print("loss pre mean {y}:", y=((returns - new_values)**2).shape)
    
    loss = jnp.mean((returns - new_values)**2)
    
    return loss 

In [16]:
# @jit
# @chex.assert_max_traces(n=1)
def update_policy(policy_params, policy_optimiser_state, minibatch):
    grads = jax.grad(ppo_policy_loss, argnums=0)(policy_params, minibatch)
    updates, new_pol_optimiser_state = policy_optimiser.update(grads, policy_optimiser_state)
    new_policy_params = optax.apply_updates(policy_params, updates)
    return new_policy_params, new_pol_optimiser_state


# @jit
# @chex.assert_max_traces(n=1)
def update_value(value_params, value_optimiser_state, minibatch):
    grads = jax.grad(ppo_value_loss, argnums=0)(value_params, minibatch)
    updates, new_val_optimiser_state = value_optimiser.update(grads, value_optimiser_state)
    new_value_params = optax.apply_updates(value_params, updates)
    return new_value_params, new_val_optimiser_state
       

In [17]:
# training loop 

episode_returns = []
policy_losses = []
value_losses = []

for episode in range(1, 10000):
    obs = env.reset()
    done = False
    episode_return = 0
    
    while not done: 
        # select action
        
        key, action_key = random.split(buffer_state.key)
        buffer_state.key = key
        action, log_prob, entropy = get_action_logprob(obs, action_key, policy_params)
        value = get_value(obs, value_params)
        
        obs_, reward, done, _  = env.step(action.tolist())
        
        
        buffer_state = jit_add(buffer_state=buffer_state, 
                                state=obs, 
                                action=action, 
                                reward=reward, 
                                done=done, 
                                log_prob=log_prob,
                                value=value, 
                                ) 
        
        episode_return += reward 
        obs = obs_
        
        # whether should train of not 
        # followed by training logic 
        if buffer.should_train(buffer_state): 
            # Compute advantages 
            buffer_state = buffer.compute_advantages(buffer_state)
            policy_losses = []
            value_losses = []
            
            for epoch in range(NUM_EPOCHS):
                buffer_state, minibatch_idxs = buffer.get_epoch_indices(buffer_state)
                
                for minibatch in minibatch_idxs: 
                    
                    # TODO only get what is really needed.
                    train_minibatch = MiniBatch(
                        states=buffer_state.states[minibatch],
                        actions=buffer_state.actions[minibatch],
                        rewards=buffer_state.rewards[minibatch],  
                        dones=buffer_state.dones[minibatch],
                        log_probs=buffer_state.log_probs[minibatch],
                        values=buffer_state.values[minibatch],
                        advantages=buffer_state.advantages[minibatch],
                        returns=buffer_state.returns[minibatch],
                        key=buffer_state.key, 
                    )
                    
                    policy_params, policy_optimiser_state = update_policy(
                        policy_params, 
                        policy_optimiser_state, 
                        train_minibatch,
                    )
                    policy_loss = ppo_policy_loss(policy_params=policy_params, minibatch=train_minibatch)
                    policy_losses.append(policy_loss)
                    
                    value_params, value_optimiser_state = update_value(
                        value_params, 
                        value_optimiser_state, 
                        train_minibatch,
                    )
                    value_loss = ppo_value_loss(value_params=value_params, minibatch=train_minibatch)
                    value_losses.append(value_loss)
                    
                    key, new_key = random.split(train_minibatch.key)
                    buffer_state.key = new_key
                    
            buffer_state = buffer.reset(buffer_state)
        
        
        
    episode_returns.append(episode_return)
    
    if episode % 10 == 0:
        print("Episode:", episode, "Average Return:", np.mean(episode_returns[-100:]))
        print("Ave policy loss", np.mean(policy_losses), "Ave value loss", np.mean(value_losses))

    

term policy loss 1 [ 4.954503  13.028123  13.000199   1.6912967  5.8092785  9.745354
 13.395067  11.636981  10.648446   1.9259645 12.766047   8.793589
  6.110169   2.0673223  8.889917   5.654637   8.766969   3.5000653
 10.169163   1.0651058 10.932582   8.249395  10.184782  12.759099
  0.7738145  5.592564  13.9128685  9.551195   9.944553  10.159391
 13.120686  10.0078535  7.48395    5.805062   7.7717986  4.107291
  4.9499145  4.239339   7.6211767  2.56833    9.645289  11.0597515
  2.3789756  1.7513583  5.812869   2.8808389  7.0829563  4.093126
 11.506878   8.314521 ]:
term policy loss 2 [ 4.954503  13.028123  13.000199   1.6912967  5.8092785  9.745354
 13.395067  11.636981  10.648446   2.0793455 12.766047   8.793589
  6.110169   2.0673223  8.889917   5.654637   8.766969   3.5000653
 10.169163   1.0651058 10.932582   8.249395  10.184782  12.759099
  0.7738145  5.0887628 13.9128685  9.551195   9.944553  10.159391
 13.120686  10.0078535  7.48395    5.805062   7.7717986  4.107291
  4.949914

KeyboardInterrupt: 