In [1]:
import gym 
import jax.numpy as jnp
import jax
import numpy as np
import haiku as hk
from copy import deepcopy
from jax import jit, grad, vmap, pmap, random
import optax

In [2]:
# global hypterparameters
BATCH_SIZE = 32
BUFFER_SIZE = 100000
SEED = 2022
LEARNING_RATE = 5e-4
DISCOUNT = 0.99
TARGET_UPDATE_PERIOD = 100
EPSILON = 1.0
TRAIN_EVERY = 1
MIN_REPLAY_SIZE = 1000

In [3]:
def huber_loss(targets: jnp.array,
               predictions: jnp.array,
               delta: float = 1.0) -> jnp.ndarray:
  """Implementation of the Huber loss with threshold delta.
  Let `x = |targets - predictions|`, the Huber loss is defined as:
  `0.5 * x^2` if `x <= delta`
  `0.5 * delta^2 + delta * (x - delta)` otherwise.
  Args:
    targets: Target values.
    predictions: Prediction values.
    delta: Threshold.
  Returns:
    Huber loss.
  """
  x = jnp.abs(targets - predictions)
  return jnp.where(x <= delta,
                   0.5 * x**2,
                   0.5 * delta**2 + delta * (x - delta))

def smoothed_l1_loss(targets: jnp.array,
               predictions: jnp.array,
               beta: float = 1.0) -> jnp.ndarray:
  """Implementation of the smooth l1 loss with threshold delta.
  
  Returns:
    smoothed l1.
  """
  x = jnp.abs(targets - predictions)
  return jnp.where(x <= beta,
                   0.5 * x**2 / beta,
                   x - 0.5 * beta) 

def mse_loss(targets: jnp.array, predictions: jnp.array) -> jnp.ndarray:
  """Implementation of the mean squared error loss."""
  return jnp.power((targets - predictions), 2)

In [4]:
# first state
random_state = random.PRNGKey(SEED)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
class CentralControllerWrapper: 
    
    def __init__(self, ma_env):
        
        self.env = ma_env 
        self.num_agents = ma_env.n_agents 
        self.action_mapping = self.enumerate_agent_actions()
        self.action_space = len(self.action_mapping)
        self.observation_space = np.sum([len(i) for i in ma_env.reset()])
        
    def reset(self, ):
        
        obs_n = self.env.reset()
        joint_obs = self.create_joint_obs(obs_n)
        
        return joint_obs
    
    def step(self, joint_action): 
        
        action = self.action_mapping[joint_action]
        obs_n, reward_n, done_n, info = self.env.step(action)
        
        joint_obs = self.create_joint_obs(obs_n)
        team_reward = jnp.sum(jnp.array(reward_n))
        team_done = all(done_n)
        
        return joint_obs, team_reward, team_done, info
    
    def random_action(self,): 
        
        action = np.random.randint(low = 0, high = self.action_space)
        return action 
    
    def enumerate_agent_actions(self, ):
        
        agent_actions = [np.arange(self.env.action_space[i].n) for i in range(len(self.env.action_space))]
        enumerated_actions = np.array(np.meshgrid(*agent_actions)).T.reshape(-1,self.num_agents)
        action_mapping = {int(i): list(action) for i, action in enumerate(enumerated_actions)}
        return action_mapping
    
    def create_joint_obs(self, env_obs):
        
        array_obs = np.array(env_obs)
        joint_obs = np.concatenate(array_obs, axis = -1)
        
        return joint_obs
    
    def unwrapped_env(self):
        return self

In [6]:
### Getting environment details 
env = gym.make('ma_gym:Switch2-v0')
env = CentralControllerWrapper(env)
# env = ConcatenateAgentIDs(env)
num_actions     = env.action_space
observation_dim = env.observation_space

In [7]:
# Very basic jax replay buffer

class JaxTransitionBuffer: 
    
    def create_buffer(
        self, 
        buffer_size, 
        observation_dim,
    ):
        state_buffer = jnp.zeros((buffer_size, observation_dim))
        action_buffer = jnp.zeros(buffer_size)
        reward_buffer = jnp.zeros(buffer_size)
        state_buffer_ = jnp.zeros((buffer_size, observation_dim))
        done_buffer = jnp.zeros(buffer_size) 
        
        buffer = dict(
            state = state_buffer, 
            action = action_buffer,
            reward = reward_buffer,
            state_ = state_buffer_,
            done = done_buffer,
        )
        
        counter = 0
        
        return buffer, counter
    
    def add(
        self,
        buffer,
        counter, 
        buffer_size, 
        state, 
        action, 
        reward, 
        done, 
        state_,
    ):
        index = counter % buffer_size
        #x = x.at[idx].set(y)
        buffer["state"] = buffer["state"].at[index].set(state)
        buffer["action"] = buffer["action"].at[index].set(action)
        buffer["reward"] =buffer["reward"].at[index].set(reward)
        buffer["state_"] = buffer["state_"].at[index].set(state_)
        buffer["done"] = buffer["done"].at[index].set(done)
        
        counter += 1
        
        return buffer, counter
    
    def sample(
        self, 
        buffer, 
        batch_size, 
        state, 
        buffer_size, 
        counter
    ):
        
        key, state = random.split(state)
        indices = random.choice(state, min(counter, buffer_size), shape=(batch_size,), replace=True)
        
        states = jnp.stack(buffer['state'][indices])
        actions = buffer['action'][indices]
        rewards = buffer['reward'][indices]
        states_ = jnp.stack(buffer['state_'][indices])
        dones = buffer['done'][indices]
        
        sampled = dict(states=states, 
                      actions=actions, 
                      rewards=rewards, 
                      states_=states_,
                      dones = dones)
        
        return state, sampled

In [8]:
def net_fn(batch) -> jnp.ndarray:
    """Standard MLP network."""
    x = batch.astype(jnp.float32)
    mlp = hk.Sequential([
        hk.Linear(64), jax.nn.relu,
        hk.Linear(64), jax.nn.relu,
        hk.Linear(num_actions),
    ])
    return mlp(x)

In [9]:
dummy_pass_data = jnp.ones((BATCH_SIZE, observation_dim))

In [10]:
# initialize online and target q networks 
q_network = hk.without_apply_rng(hk.transform(net_fn))
online_params = q_network.init(random_state, dummy_pass_data)
target_params = deepcopy(online_params)

In [11]:
opt = optax.adam(LEARNING_RATE)
opt_state = opt.init(online_params)

In [12]:

def dqn_loss(online_params, target_params, batch) -> jnp.ndarray:
    """Compute the loss of the network, including L2."""
    states = batch['states']
    actions = batch['actions'].astype(jnp.int32)
    rewards = batch['rewards']
    dones = batch['dones']
    states_ = batch['states_']
    
    q_values = q_network.apply(online_params, states)
    selected_q_values = jnp.array([q_values[i][action] for i, action in enumerate(actions)]) 
    
    # stopping gradients 
    rewards = jax.lax.stop_gradient(rewards)
    dones   = jax.lax.stop_gradient(dones)
    states_ = jax.lax.stop_gradient(states_)
    
    next_q_values = jax.lax.stop_gradient(q_network.apply(target_params, states_))
    max_next_q_values = jax.lax.stop_gradient(jnp.max(next_q_values, axis = 1))
    
    target = jax.lax.stop_gradient(rewards + DISCOUNT * (1 - dones) * max_next_q_values)
    
    td_error = selected_q_values - target

    loss = jnp.mean(td_error **2)
    
    # loss = huber_loss(target, selected_q_values)
    # loss = mse_loss(target, selected_q_values)
    # loss = smoothed_l1_loss(target, selected_q_values)

    return loss

In [13]:
@jit
def update(online_params, target_params, opt_state, batch):
    grads = jax.grad(dqn_loss, argnums=0)(online_params, target_params, batch)
    updates, new_opt_state = opt.update(grads, opt_state)
    new_online_params = optax.apply_updates(online_params, updates)
    return new_online_params, new_opt_state

In [14]:
# initialise the buffer

replay_buffer = JaxTransitionBuffer()
buffer, counter = replay_buffer.create_buffer(BUFFER_SIZE, observation_dim)

# jit the add method
jit_add = jit(replay_buffer.add)

In [15]:
global_count = 0
while global_count <= MIN_REPLAY_SIZE:
# for episode in range(1, 10000):
    obs = env.reset()
    done = False
    #episode_return = 0
    while not done: 
        #if np.random.random() < EPSILON:
        action = env.random_action()
        action = jnp.array(action)

        #EPSILON = max(0.05, EPSILON*0.99999)

        obs_, reward, done, _  = env.step(action.tolist())
        buffer, counter = jit_add(buffer=buffer, 
                                                counter=counter,
                                                buffer_size=BUFFER_SIZE, 
                                                state=obs, 
                                                action=action, 
                                                reward=reward, 
                                                done=done, 
                                                state_=obs_,
                                               ) 
        #episode_return += reward 
        obs = obs_

        global_count += 1

In [None]:
global_count = 0
episode_returns = []
losses = []
for episode in range(1, 10000):
    obs = env.reset()
    done = False
    episode_return = 0
    while not done: 
        if np.random.random() < EPSILON:
            action = env.random_action()
            action = jnp.array(action)
        else:
            action = jnp.argmax(q_network.apply(online_params, jnp.array(obs)))
            
        EPSILON = max(0.05, EPSILON*0.99999)
        
        obs_, reward, done, _  = env.step(action.tolist())
        buffer, counter = jit_add(buffer=buffer, 
                                                counter=counter,
                                                buffer_size=BUFFER_SIZE, 
                                                state=obs, 
                                                action=action, 
                                                reward=reward, 
                                                done=done, 
                                                state_=obs_,
                                               ) 
        episode_return += reward 
        obs = obs_
        
        global_count += 1
        
    if episode % TRAIN_EVERY == 0:
        
        random_state, sampled_data = replay_buffer.sample(buffer, 
                                             BATCH_SIZE, 
                                             random_state, BUFFER_SIZE, counter)
        loss=dqn_loss(online_params, target_params, sampled_data)
        losses.append(loss)
        online_params, opt_state = update(online_params, target_params, opt_state, sampled_data)

    if episode % TARGET_UPDATE_PERIOD == 0:
        
        #q_network = deepcopy(q_network)
        target_params = deepcopy(online_params)
        
    episode_returns.append(episode_return)
    
    if episode% 50 == 0:
        print("Episode:", episode, "Average Return:", np.mean(episode_returns[-100:]), "Loss:", loss, "Epsilon:", EPSILON)
    
    if np.mean(episode_returns[-100:]) >= 195 and EPSILON <= 0.05:
        print("Training Done")
        break

Episode: 50 Average Return: -1.541618 Loss: 9.535177 Epsilon: 0.9573940491288276
Episode: 100 Average Return: -3.0942175 Loss: 3.2856796 Epsilon: 0.9146530681914871


In [None]:
import matplotlib.pyplot as plt 
plt.plot(episode_returns)