In [73]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LayerNormalization
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam
from collections import deque
import numpy as np
import matplotlib.pyplot as plt
import math
from DickeStateEnv import DickeStateEnv


"""
Build a for the actor and critic agents
"""
class MADDPG:
    """
    Define a class that returns the actor and critic (aswell as duplicate targets) for a single 
    agent within the MADDPG framework. Here we assume each agent will control a single parameter
    """
    def __init__(self, observation_sizes, gamma=0.95, actor_structs=None, critic_structs=None, n_actions=None,
                agent_names=None):
        self.gamma = gamma
        self.observation_sizes = observation_sizes
        self.n_agents = len(observation_sizes)
        self.optimizer = optimizer
        self.critic_input_size = sum(observation_sizes) + sum(n_actions)
        
        if agent_names==None:
            self._agent_names = ['agent_{}'.format(i) for i in range(self.n_agents)]
        else:
            self._agent_names = agent_names
            
        if n_actions==None:
            self.n_actions = [1]*self.n_agents
        else:
            self.n_actions = n_actions
            
        if actor_structs==None:
            self.actor_structs=[[32,32]]*self.n_agents
        else:
            self.actor_structs=actor_structs
        
        if critic_structs==None:
            self.critic_structs=[[128,64]]*self.n_agents
        else:
            self.critic_structs=critic_structs
        
    def build_actor(self, obs_size, struct, num_actions):
        """
        Function for initializing and building the main and target actor networks for a single agent.
        
        Parameters
        ----------
        observation_shape: (Tuple) Of type (None, observation_legnth). None here denoted a yet undefined
                                   quantity, i.e 1 for single actions, and batch_size when training
        num_actions: (int) Number of parameters the agent is responsible for. Default is 1. Each 
                           agent is responsible for one single control parameter.
        hidden_arch: (list) Defines both the number of nodes, and number of hidden layers. Layers are Dense 
                            by default.
                        
        Returns
        -------
        actor: (keras.model object) The network to be used as the actor withing the agent.
        actor_: (keras.model object) Identical copy of actor to be used as the agents actor target network.                      
        """
        obs_shape = (1, obs_size)
        # observation shape should be (1,obs_size) or (batch_size, obs_size)
        actor = Sequential()
        for h in struct: 
            actor.add(Dense(h, activation="relu"))
        actor.add(Dense(num_actions, activation="sigmoid"))
        actor.build(input_shape=obs_shape)
        actor_ = tf.keras.models.clone_model(actor)
        actor_.set_weights(actor.get_weights())
        return actor, actor_

    def build_critic(self, struct):
        """
        Function for initializing and building the main and target critic networks for a single agent.
        
        Parameters
        ----------
        observation_shape: (Tuple) Of type (None, observation_legnth). None here denoted a yet undefined
                                   quantity, i.e 1 for single actions, and batch_size when training
        num_actions: (int) Number of parameters the agent is responsible for. Default is 1. Each 
                           agent is responsible for one single control parameter.
        hidden_arch: (list) Defines both the number of nodes, and number of hidden layers. Layers are Dense 
                            by default.
                        
        Returns
        -------
        critic: (keras.model object) The network to be used as the actor withing the agent.
        critic_: (keras.model object) Identical copy of actor to be used as the agents actor target network.                      
        """
    
        obs_shape = (1,self.critic_input_size)
        critic = Sequential()
        for h in struct:
            critic.add(Dense(h, activation="relu"))
            critic.add(LayerNormalization())
        critic.add(Dense(1, activation=None))
        critic.build(input_shape=obs_shape)
        critic_ = tf.keras.models.clone_model(critic)
        critic_.set_weights(critic.get_weights())
        return critic, critic_
    
    def init_agent_networks(self):
        self.actors = {}
        self.actor_targets = {}
        self.critics = {}
        self.critic_targets = {}
        self.agent_dict = {}
        for i, name in enumerate(self.agent_names):
            actor, actor_ = self.build_actor(obs_size=self.observation_sizes[i], struct=self.actor_structs[i],
                                            num_actions=self.n_actions[i])
            critic, critic_ = self.build_critic(struct=self.critic_structs[i])
            self.actors[name] = actor
            self.actor_targets[name] = actor_
            self.critics[name] = critic
            self.critic_targets = critic_
        return  
    
    def action_noise(self, action, scale=0.05):
        return np.clip((action+np.random.normal(loc=0.0, scale=scale)), a_min=0, a_max=1)
    
    def get_action_async(self, observation, agent_name, noise=False):
        if not noise:
            actions = self.actors[agent_name](observations.reshape((1,self.observation_sizes[i]))).numpy()
        else:
            actor_output = self.actors[agent_name](observations.reshape((1,self.observation_sizes[i])))
            action = self.action_noise(actor_output)
        return actions
    
    def get_action_sync(self, observations, noise=False):
        actions = []
        if not noise:
            for i, name in enumerate(self.agent_names):
                actions.append(self.actors[name](observations[i].reshape((1,self.observation_sizes[i]))).numpy())
        else:
            for i, name in enumerate(self.agent_names):
                actor_output = self.actors[name](observations[i].reshape((1,self.observation_sizes[i])))
                noisy_action = self.action_noise(actor_output)
                actions.append(noisy_action)
        return actions
    
    def train_critic(self, index, obs, actions, next_obs, rewards, dones,  optimizer=Adam(lr=0.01)):
        # get the critic target and main critic networks corresponding to index
        critic_target_net = self.critic_targets[index]
        critic_net = self.critics[index]
        
        next_state = next_obs.copy()
        state = obs.copy()
        # Initialize a tensor to get the target predicted next actions from next observations
        # as well as the main actor predicted current actions from current observations 
        # to be used in the target and main critic networks
        targ_actions = []
        for i, actor_targ in enumerate(list(self.actor_targets.values)):
            targ_act = actor_targ(next_obs[i])
            targ_actions.append(targ_act)
            next_state.append(targ_act)
            state.append(actions[i])
        # First concatenate the list of actor observation arrays into one single array of shape
        # (batch_size, obs_len_1 + ..... + obs_len_N), then concatenate all actions to the end of this to 
        # make a tensor of shape (batch_size, (obs_len_1 + .... + obs_len_N + num_acts_1 + .... + num_acts_N))
        concated_next_obs_and_targ_acts = tf.concat(next_state, axis=1)
        
        concat_obs_and_acts = tf.concat(state, axis=1)
        
        # Calculate the estimated Q values as per the target critic network on the next observations and actions
        estimated_target_Qs = critic_target_net(concated_next_obs_and_targ_acts)
        # obtain a batch of bellman targets with shape (batch_size,1)
        y = rewards[:,index] + self.gamma * (1-dones[:,index]) * estimated_target_Qs
        with tf.GradientTape() as tape: 
            current_predicted_Qs = critic_net(concat_obs_and_acts)
            loss = mse(y,current_predicted_Qs)
        grads = tape.gradient(loss, self.critics[index].trainable_variables)
        optimizer.apply_gradients(zip(grads, self.critics[index].trainable_variables))
        
    def train_actor(self, index, obs, rewards, dones, optimizer=Adam(lr=0.001)):
        # Get a the batch of observations corresponding to agent index
        # Note obs will be a list of batches 
        current_agents_obs = obs[index]
        actor = self.actors[index]
        critic = self.critics[index]
        state = obs.copy()
        actions = []
        with tf.GradientTape() as tape:
            for i, actor in enumerate(list(self.actors.values)):
                actor_output = actor(obs[i])
                actions.append(actor_output)
                state.append(actor_output)
            concat_obs_and_acts = tf.concat(state, axis=1)
            Q_values = critic(concat_obs_and_acts)
            loss = -1 * tf.reduce_mean(Q_values)
        grads = tape.gradient(loss, actor.trainable_variables)
        optimizer.apply_gradients(zip(grads, actor.trainable_variables))
        
    def soft_update(self, tau = 0.01):
        for agent_idx in self.agent_names:
            current_crit_targ_weights = self.critic_targets[agent_idx].get_weights()
            current_crit_weights = self.critics[agent_idx].get_weights()
            current_act_targ_weights = self.actor_targets[agent_idx].get_weights()   
            current_act_weights = self.actors[agent_idx].get_weights()
            
            new_act_targ_weights = []
            for i in range(len(current_act_targ_weights)):
                new_act_targ_weights.append(tau * current_act_weights[i] + (1 - tau) * current_act_targ_weights[i])
            new_crit_targ_weights = []
            for i in range(len(current_crit_targ_weights)):
                new_crit_targ_weights.append(tau * current_crit_weights[i] + (1 - tau) * current_crit_targ_weights[i])
            self.critic_targets[agent_idx].set_weights(new_crit_targ_weights)
            self.actor_targets[agent_idx].set_weights(new_act_targ_weights)
    
#     def train_critic(self, critic_main, critic_targ, actor_main, actor_targ):
        

class MultiAgentReplayBuffer:
    def __init__(self, max_size, observation_sizes, n_actions, batch_size):
        # observations sizes is a list of observation_lengths for each agent ****Careful of ordering
        self.n_agents = len(observation_sizes)
        # Max memory size 
        self.mem_size = max_size
        # keeps track of how full the memory is
        self.mem_cntr = 0
        # Batch_size used in the learning process
        self.batch_size = batch_size
        for i, size in enumerate(observation_sizes):
            idx=i+1
            self.agent_idx_obs_memory = np.zeros((self.mem_size, size))
            self.agent_idx_next_obs_memory = np.zeros((self.mem_size, size))
            self.agent_idx_action_memory = np.zeros((self.mem_size, n_actions[i]))
        self.reward_memory = np.zeros((self.mem_size, self.n_agents))
        self.terminal_memory = np.zeros((self.mem_size, self.n_agents), dtype=bool)
        

    
    def store_transition(self, raw_obs, actions, rewards, next_raw_obs, dones):
        # states sould be a list of arrays for each agent, so that the state_memory structure with be
        # arrays embedded in a lists embedded in an array mem[elem_idx][agent_idx][obs_elem_idx]
        # Should turn actions and rewards and dones to lists  also so that each has the same structure
        # as the states
        
        # Each input to this function must be a list
        
        index = self.mem_cntr % self.mem_size
        for i in range(self.n_agents):
            idx=int(i+1)
            self.agent_idx_obs_memory[index, :] = raw_obs[i][:]
            self.agent_idx_next_obs_memory[index, :] =  next_raw_obs[i][:]
            self.agent_idx_action_memory[index, :] = np.array(actions[i])[0,:]
        self.reward_memory[index, :] = np.array(reward)[:]
        self.terminal_memory[index, :] = np.array(dones, dtype=bool)[:]
        
        self.mem_cntr += 1
        
    def sample_buffer(self):
        """
        returns list of batches of observations one for each agent, list of batches of next observations
        one for each agent, batch of rewards, batch of actions, batch of done bools
        """
        max_mem = min(self.mem_cntr, self.mem_size)
        
        batch = np.random.choice(max_mem, self.batch_size, replace=False)
        agent_obs = []
        agent_next_obs = []
        agent_actions = []
        for i in range(self.n_agents):
            idx = int(i+1)
            agent_idx_obs = self.agent_idx_obs_memory[batch]
            agent_obs.append(agent_idx_obs)
            agent_idx_next_obs = self.agent_idx_next_obs_memory[batch]
            agent_next_obs.append(agent_idx_next_obs)
            agent_idx_actions = self.agent_idx_action_memory[batch]
            agent_actions.append(agent_idx_actions)
        rewards = self.reward_memory[batch]
        terminal = self.terminal_memory[batch]
        
        return agent_obs, agent_next_obs, rewards, agent_actions, terminal
    
    def ready(self):
        if self.mem_cntr>= self.batch_size:
            return True
        return False