In [2]:
# conda env: centralised-agents

import gym 
import jax.numpy as jnp
import jax
import numpy as np
import haiku as hk
from copy import deepcopy
from jax import jit, grad, vmap, pmap, random
import optax
import chex
import rlax
from typing import Tuple

In [2]:
# global hypterparameters
BATCH_SIZE = 64
BUFFER_SIZE = 100_000
SEED = 2022
LEARNING_RATE = 5e-4
DISCOUNT = 0.99
# TARGET_UPDATE_PERIOD = 100
TARGET_UPDATE_PERIOD = 10
EPSILON = 1.0
# EPSILON_EXP_DECAY = 0.99999 
EPSILON_EXP_DECAY = 0.99995
TRAIN_EVERY = 1
MIN_REPLAY_SIZE = 1000

In [3]:
# first state
random_state = random.PRNGKey(SEED)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
class CentralControllerWrapper: 
    
    def __init__(self, ma_env):
        
        self.env = ma_env 
        self.num_agents = ma_env.n_agents 
        self.action_mapping = self.enumerate_agent_actions()
        self.action_space = len(self.action_mapping)
        self.observation_space = np.sum([len(i) for i in ma_env.reset()])
        
    def reset(self, ):
        
        obs_n = self.env.reset()
        joint_obs = self.create_joint_obs(obs_n)
        
        return joint_obs
    
    def step(self, joint_action): 
        
        action = self.action_mapping[joint_action]
        obs_n, reward_n, done_n, info = self.env.step(action)
        
        joint_obs = self.create_joint_obs(obs_n)
        team_reward = jnp.sum(jnp.array(reward_n))
        team_done = all(done_n)
        
        return joint_obs, team_reward, team_done, info
    
    def random_action(self,): 
        
        action = np.random.randint(low = 0, high = self.action_space)
        return action 
    
    def enumerate_agent_actions(self, ):
        
        agent_actions = [np.arange(self.env.action_space[i].n) for i in range(len(self.env.action_space))]
        enumerated_actions = np.array(np.meshgrid(*agent_actions)).T.reshape(-1,self.num_agents)
        action_mapping = {int(i): list(action) for i, action in enumerate(enumerated_actions)}
        return action_mapping
    
    def create_joint_obs(self, env_obs):
        
        array_obs = np.array(env_obs)
        joint_obs = np.concatenate(array_obs, axis = -1)
        
        return joint_obs
    
    def unwrapped_env(self):
        return self

In [5]:
# For debugging

class NormalGymWrapper: 
    
    def __init__(self, env):
        
        self.env = env  
        self.action_space = env.action_space.n
        self.observation_space = env.observation_space.shape[0]
        
    def reset(self, ):
        
        obs = self.env.reset()
        joint_obs = np.array(obs)
        
        return joint_obs
    
    def step(self, action): 
        
        obs, reward, done, info = self.env.step(action)
        
        return np.array(obs), jnp.array(reward), done, info
    
    def random_action(self,):
        
        return self.env.action_space.sample()
   
    def unwrapped_env(self):
        return self

In [6]:
# TODO create one hot agetn ID wrapper. 

class ConcatenateAgentIDs:
    pass
#     super().__init()
#     def __init__(self, wrapped_env):
        
#         self.num_agents = wrapped_env.num_agents
    
#     def create_joint_obs(self): 
        
#         env_obs = super().create_joint_obs()
#         array_obs = np.array(env_obs)
        
#         for agent in self.num_agents:
#             one_hot = np.zeros(self.num_agents)
#             one_hot[agent] = 1.0 
#             array_obs[agent] = np.concatenate((one_hot, array_obs[agent]), axis=-1)
        
#         joint_obs = np.concatenate(array_obs, axis = -1)

In [7]:
### Getting environment details 
env = gym.make('ma_gym:Checkers-v0')
# env = CentralControllerWrapper(env)
# env = ConcatenateAgentIDs(env)

# env = gym.make("CartPole-v0")
# env = NormalGymWrapper(env)
# num_actions     = env.action_space
# observation_dim = env.observation_space

In [8]:
env.observation_space

[Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (47,), float32),
 Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (47,), float32)]

In [8]:
@chex.dataclass
class BufferState: 
    states: jnp.ndarray
    actions: jnp.ndarray 
    rewards: jnp.ndarray 
    next_states: jnp.ndarray 
    dones: jnp.ndarray 
    counter: jnp.int32 
    key: chex.PRNGKey
    buffer_size: jnp.int32
    batch_size: jnp.int32 
    min_buffer_size: jnp.int32

In [9]:
@chex.dataclass
class ExperienceSample: 
    states: jnp.ndarray
    actions: jnp.ndarray 
    rewards: jnp.ndarray 
    next_states: jnp.ndarray 
    dones: jnp.ndarray 

In [10]:
# Very basic jax replay buffer

class JaxTransitionBuffer: 
    
    def create_buffer(
        self, 
        buffer_size: int, 
        observation_dim: int,
        min_buffer_size: int = 50, 
        batch_size: int = 32, 
        buffer_key: chex.PRNGKey = random.PRNGKey(0),
    ) -> BufferState:
        
        state_buffer = jnp.empty((buffer_size, observation_dim))
        action_buffer = jnp.empty(buffer_size)
        reward_buffer = jnp.empty(buffer_size)
        state_buffer_ = jnp.empty((buffer_size, observation_dim))
        done_buffer = jnp.empty(buffer_size) 
        
        counter = 0
        
        buffer_state = BufferState(
            states = state_buffer, 
            actions = action_buffer, 
            rewards = reward_buffer, 
            next_states = state_buffer_, 
            dones = done_buffer, 
            counter = jnp.array(counter, dtype=jnp.int32), 
            key = buffer_key, 
            buffer_size = jnp.array(buffer_size, dtype=jnp.int32), 
            batch_size = jnp.array(batch_size, dtype=jnp.int32), 
            min_buffer_size = jnp.array(min_buffer_size, dtype=jnp.int32), 
        )
        
        return buffer_state
    
    def add(
        self,
        buffer_state, 
        state, 
        action, 
        reward, 
        done, 
        state_,
    ) -> BufferState:
        
        index = buffer_state.counter % buffer_state.buffer_size
        #x = x.at[idx].set(y)
        buffer_state.states = buffer_state.states.at[index].set(state)
        buffer_state.actions = buffer_state.actions.at[index].set(action)
        buffer_state.rewards = buffer_state.rewards.at[index].set(reward)
        buffer_state.next_states = buffer_state.next_states.at[index].set(state_)
        buffer_state.dones = buffer_state.dones.at[index].set(done)
        
        buffer_state.counter += 1
        
        return buffer_state
    
    def sample(
        self, 
        buffer_state, 
    ) -> Tuple[BufferState, ExperienceSample]:
        
        key, sample_key = random.split(buffer_state.key)
        indices = random.choice(
            sample_key, 
            min(buffer_state.counter, buffer_state.buffer_size), 
            shape=(buffer_state.batch_size,), 
            replace=False)
        
        buffer_state.key = key 
        
        states = jnp.stack(buffer_state.states[indices])
        actions = buffer_state.actions[indices]
        rewards = buffer_state.rewards[indices]
        states_ = jnp.stack(buffer_state.next_states[indices])
        dones = buffer_state.dones[indices]
        
        sampled = ExperienceSample( 
                        states=states,
                        actions=actions, 
                        rewards=rewards, 
                        next_states=states_, 
                        dones=dones,
        )
        
        return buffer_state, sampled
    
    def can_sample_batch(
        self, 
        buffer_state 
    ):
        return np.greater_equal(buffer_state.counter, buffer_state.batch_size)
        
        

In [11]:
# Create buffer and jit adding 
buffer = JaxTransitionBuffer()
buffer_state = buffer.create_buffer(
    buffer_size=BUFFER_SIZE, 
    observation_dim=observation_dim,
)

jit_add = jit(buffer.add)

In [12]:
# Create feedforward Q_network 

def net_fn(batch) -> jnp.ndarray:
    """Standard MLP network."""
    x = batch.astype(jnp.float32)
    mlp = hk.Sequential([
        hk.Linear(64), jax.nn.relu,
        hk.Linear(64), jax.nn.relu,
        hk.Linear(num_actions),
    ])
    return mlp(x)

dummy_pass_data = jnp.ones((BATCH_SIZE, observation_dim))

# initialize online and target parameters  
q_network = hk.without_apply_rng(hk.transform(net_fn))
online_params = q_network.init(random_state, dummy_pass_data)
target_params = deepcopy(online_params)

# Intialize optimiser and optimiser state 
optimiser = optax.adam(LEARNING_RATE)
optimiser_state = optimiser.init(online_params)

In [13]:
def dqn_loss(online_params, target_params, batch: ExperienceSample) -> jnp.ndarray:
    """Compute the loss of the network"""
    states = batch.states
    actions = batch.actions.astype(jnp.int32)
    rewards = batch.rewards
    dones = batch.dones
    states_ = batch.next_states
    
    # TODO not sure how rlax does discounts 
    discounts = DISCOUNT * (1 - dones)
    
    q_values = q_network.apply(online_params, states)
    # Do this better. 
    selected_q_values = jnp.array([q_values[i][action] for i, action in enumerate(actions)]) 
    
    # stopping gradients 
    rewards = jax.lax.stop_gradient(rewards)
    dones   = jax.lax.stop_gradient(dones)
    states_ = jax.lax.stop_gradient(states_)
    
    next_q_values = jax.lax.stop_gradient(q_network.apply(target_params, states_))
    max_next_q_values = jax.lax.stop_gradient(jnp.max(next_q_values, axis = 1))
    
    target = jax.lax.stop_gradient(rewards + discounts * max_next_q_values)
    td_error = selected_q_values - target
    
    # TODO: Figure out rlax 
    # td_error = vmap(rlax.q_learning)(
    #     q_tm1=q_values, 
    #     a_tm1=actions, 
    #     r_t=rewards, 
    #     discount_t=dones, 
    #     q_t=next_q_values, 
    # )

    loss = jnp.mean(td_error **2)

    return loss

In [14]:
@jit
@chex.assert_max_traces(n=1)
def update(online_params, target_params, opt_state, batch):
    grads = jax.grad(dqn_loss, argnums=0)(online_params, target_params, batch)
    updates, new_optimiser_state = optimiser.update(grads, optimiser_state)
    new_online_params = optax.apply_updates(online_params, updates)
    return new_online_params, new_optimiser_state

In [15]:
episode_returns = []
losses = []
for episode in range(1, 10000):
    obs = env.reset()
    done = False
    episode_return = 0
    
    while not done: 
        if np.random.random() < EPSILON:
            action = env.random_action()
            action = jnp.array(action)
        else:
            action = jnp.argmax(q_network.apply(online_params, jnp.array(obs)))
            
        if buffer.can_sample_batch(buffer_state): 
            EPSILON = max(0.05, EPSILON*EPSILON_EXP_DECAY)
        
        obs_, reward, done, _  = env.step(action.tolist())
        
        buffer_state = jit_add(buffer_state=buffer_state, 
                                state=obs, 
                                action=action, 
                                reward=reward, 
                                done=done, 
                                state_=obs_,
                               ) 
        episode_return += reward 
        obs = obs_
        
        
    if (episode % TRAIN_EVERY == 0) and buffer.can_sample_batch(buffer_state):
        
        buffer_state, sampled_data = buffer.sample(buffer_state)
        
        # Computing loss can be done by returning aux from the update. 
        loss=dqn_loss(online_params, target_params, sampled_data)
        
        losses.append(loss)
        online_params, optimiser_state = update(online_params, target_params, optimiser_state, sampled_data)

    if episode % TARGET_UPDATE_PERIOD == 0:
        
        target_params = deepcopy(online_params)
        
    episode_returns.append(episode_return)
    
    if episode % 50 == 0:
        print("Episode:", episode, "Average Return:", np.mean(episode_returns[-100:]), "Loss:", loss, "Epsilon:", EPSILON)
    

Episode: 50 Average Return: 23.14 Loss: 0.8846806 Epsilon: 0.9453014512766509
Episode: 100 Average Return: 23.32 Loss: 1.2362819 Epsilon: 0.8913635811474366
Episode: 150 Average Return: 21.02 Loss: 2.5116284 Epsilon: 0.8509907382125926
Episode: 200 Average Return: 18.56 Loss: 4.5197144 Epsilon: 0.8123652739752802
Episode: 250 Average Return: 16.79 Loss: 2.599301 Epsilon: 0.7824649634575382
Episode: 300 Average Return: 14.73 Loss: 13.8580265 Epsilon: 0.7546833388280776
Episode: 350 Average Return: 14.98 Loss: 22.11323 Epsilon: 0.7259980099572511
Episode: 400 Average Return: 15.13 Loss: 21.493906 Epsilon: 0.6996962772089826
Episode: 450 Average Return: 14.48 Loss: 46.54582 Epsilon: 0.6752921848484298
Episode: 500 Average Return: 14.54 Loss: 40.16642 Epsilon: 0.6506322182395456
Episode: 550 Average Return: 14.49 Loss: 36.62677 Epsilon: 0.6280963965279573
Episode: 600 Average Return: 13.64 Loss: 161.01985 Epsilon: 0.6077373682729577
Episode: 650 Average Return: 13.39 Loss: 122.27694 Epsilo