<a href="https://colab.research.google.com/github/apof/Multi-Agent-RL/blob/main/Commnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [13]:
AGENTS_NUMBER = 5
ACTIONS_NUMBER = 5
STATE_LENGTH = 50
ENCODING_LENGTH = 200
COMMUNICATION_STEPS = 3
BATCH_SIZE = 64

In [14]:
class ReplayBuffer(object):
  def __init__(self,max_size,input_space,n_actions):
    self.mem_size = max_size
    self.mem_cntr = 0
    self.state_memory = np.zeros((self.mem_size,*input_shape))
    self.new_state_memory = np.zeros((self.mem_size,*input_shape))
    self.action_memory = np.zeros((self.mem_size,n_actions))
    self.reward_memory = np.zeros(self.mem_size)
    self.terminal_memory = np.zeros(self.memsize,dtype=np.float32)

  def store_transition(self,state,action,reward,state_,done):
    index = self.mem_cntr%self.mem_size
    self.state_memory[index] = state
    self.new_state_memory[index] = state_
    self.reward_memory[index] = reward
    self.terminal_memory[index] = 1-done
    self.mem_cntr += 1

  def sample_buffer(self,batch_size):
    max_mem = min(self.memcntr,self.mem_size)
    batch = np.random.choice(max_mem,batch_size)
    states = self.state_memory[batch]
    next_states = self.new_state_memory[batch]
    actions = self.action_memory[batch]
    rewards = self.reward_memory[batch]
    terminal = self.terminal_memory[batch]
    return states,actions,rewards,new_states,terminal

In [15]:
class CriticNetwork(nn.Module):
  def __init__(self,beta = 0.001,input_dims = ENCODING_LENGTH,fc1_dims = 300,fc2_dims = 30,n_actions = ACTIONS_NUMBER,device = None):
    super(CriticNetwork, self).__init__()
    self.input_dims = input_dims
    self.fc1_dims = fc1_dims
    self.fc2_dims = fc2_dims
    self.n_actions = n_actions
    self.device = device

    self.commnet = CommNet(STATE_LENGTH,ENCODING_LENGTH,AGENTS_NUMBER,COMMUNICATION_STEPS)

    ## define a fully connected layer that processes the encoded states
    self.fc1 = nn.Linear(self.input_dims,self.fc1_dims)
    self.fc2 = nn.Linear(self.fc1_dims,self.fc2_dims)
    ## define a layer that processes the action values
    self.action_value = nn.Linear(self.n_actions,self.fc2_dims)
    ## define a q value layer
    self.q = nn.Linear(self.fc2_dims,1)

    ## define the optimiser
    self.optimiser = optim.Adam(self.parameters(),lr=beta)

    self.to(self.device)

  def forward(self,state,action):

    encoded_state = self.commnet(state)

    ## process the encoded state
    state_out = []
    for agent_index in range(self.n_actions):
      state_value = self.fc1(encoded_state[:,agent_index,:])
      state_value = F.relu(state_value)
      state_value = self.fc2(state_value)
      state_value = F.relu(state_value)
      state_value = torch.reshape(state_value,(state_value.shape[0],1,state_value.shape[1]))
      state_out.append(state_value)
    state_out = torch.cat(state_out,axis=1)

    ## process the actions
    actions_out = []
    for agent_index in range(self.n_actions):
      action_value = F.relu(self.action_value(action[:,agent_index,:]))
      action_value = torch.reshape(action_value,(action_value.shape[0],1,action_value.shape[1]))
      actions_out.append(action_value)
    actions_out = torch.cat(actions_out,axis=1)

    ## combine states and actions
    state_action_value = F.relu(torch.add(state_out,actions_out))

    q_out = []
    for agent_index in range(self.n_actions):
      q_value = F.relu(self.q(state_action_value[:,agent_index,:]))
      q_value = torch.reshape(q_value,(q_value.shape[0],1,q_value.shape[1]))
      q_out.append(q_value)
    q_out = torch.cat(q_out,axis=1)

    print("Critic Network output: " + str(q_out.shape))

    return q_out

In [16]:
class ActorNetwork(nn.Module):
  def __init__(self,beta = 0.001,input_dims = ENCODING_LENGTH,fc1_dims = 300,fc2_dims = 30,n_actions = AGENTS_NUMBER,device = None):
    super(ActorNetwork, self).__init__()
    self.input_dims = input_dims
    self.fc1_dims = fc1_dims
    self.fc2_dims = fc2_dims
    self.n_actions = n_actions
    self.device = device

    self.commnet = CommNet(STATE_LENGTH,ENCODING_LENGTH,AGENTS_NUMBER,COMMUNICATION_STEPS)

    ## define a fully connected layer that processes the encoded states
    self.fc1 = nn.Linear(self.input_dims,self.fc1_dims)
    self.fc2 = nn.Linear(self.fc1_dims,self.fc2_dims)
    self.fc3 = nn.Linear(self.fc2_dims,self.n_actions)

    ## define the optimiser
    self.optimiser = optim.Adam(self.parameters(),lr=beta)

    self.to(self.device)

  def forward(self,state):

    encoded_state = self.commnet(state)

    ## process the encoded state
    out = []
    for agent_index in range(self.n_actions):
      state_value = F.relu(self.fc1(encoded_state[:,agent_index,:]))
      state_value = F.relu(self.fc2(state_value))
      state_value = torch.tanh(self.fc3(state_value))
      state_value = torch.reshape(state_value,(state_value.shape[0],1,state_value.shape[1]))
      out.append(state_value)
    out = torch.cat(out,axis=1)

    print("Actor Network output: " + str(out.shape))

    return out

In [17]:
class CommNet(nn.Module):
  def __init__(self,state_dimension,encoding_dimension,agents_num,communication_steps,output_type = 'actor'):
    super().__init__()
    self.state_dimension = state_dimension
    self.encoding_dimension = encoding_dimension
    self.agents_num = agents_num
    self.communication_steps = communication_steps
    ## define if the outpout of the network is actor or critic type
    self.output_type = output_type

    self.encoding_layer = nn.Sequential(nn.Linear(self.state_dimension, self.encoding_dimension),nn.ReLU())
    self.WH_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])
    self.WC_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])


  def encoding_step(self,states):
    print("Input shape: " + str(states.shape))
    ## encode the state of every agent
    encoded_inputs = []
    for agent_index in range(self.agents_num):
      encoded_input = self.encoding_layer(states[:,agent_index,:])
      encoded_input = torch.reshape(encoded_input,(encoded_input.shape[0],1,encoded_input.shape[1]))
      encoded_inputs.append(encoded_input)
    encoded_inputs = torch.cat(encoded_inputs,axis=1)
    print("Encoded Inputs shape: " + str(encoded_inputs.shape))
    return encoded_inputs

  def f(self,h,c,step_index):
    ## decide which weight to use
    return torch.tanh(self.WH_layer[step_index](h)  + self.WC_layer[step_index](c))

  def communication_step(self,H,C,step_index):

    print("Communication step: " + str(step_index))

    ## compute the next H
    next_H = []
    for agent_index in range(self.agents_num):
      h = H[:,agent_index,:]
      c = C[:,agent_index,:]
      next_h = self.f(h,c,step_index)
      next_h = torch.reshape(next_h,(next_h.shape[0],1,next_h.shape[1]))
      next_H.append(next_h)
    next_H = torch.cat(next_H,axis = 1)

    print(next_H.shape)

    next_C = []
    for i in range(self.agents_num):
      next_c = []
      for j in range(self.agents_num):
        if (i!=j):
          next_c.append(next_H[:,j,:])
      stacked_c = torch.stack(next_c,axis = 0)
      next_c = torch.mean(stacked_c,axis = 0)
      next_c = torch.reshape(next_c,(next_c.shape[0],1,next_c.shape[1]))
      next_C.append(next_c)
    next_C = torch.cat(next_C,axis = 1)

    print(next_C.shape)

    return next_H, next_C

  def forward(self,states):
    ## encode the state for every agent
    encoded_states = self.encoding_step(states)
    ## communication steps
    ## define the first communication vector H filled with zeros
    C0 = torch.zeros_like(encoded_states)
    H = encoded_states
    C = C0
    for step in range(self.communication_steps):
      H_new, C_new = self.communication_step(H,C,step)
      C = C_new
      H = H_new

    return H

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
CriticNetwork = CriticNetwork(device = device)
ActorNetwork = ActorNetwork(device = device)

In [20]:
states = np.random.rand(BATCH_SIZE,AGENTS_NUMBER,STATE_LENGTH)
actions = np.random.rand(BATCH_SIZE,AGENTS_NUMBER,ACTIONS_NUMBER)
states = torch.from_numpy(states).float().to(CriticNetwork.device)
actions = torch.from_numpy(actions).float().to(CriticNetwork.device)

In [21]:
critic_out = CriticNetwork.forward(states,actions)

Input shape: torch.Size([64, 5, 50])
Encoded Inputs shape: torch.Size([64, 5, 200])
Communication step: 0
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 1
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 2
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Critic Network output: torch.Size([64, 5, 1])


In [22]:
actor_out = ActorNetwork.forward(states)

Input shape: torch.Size([64, 5, 50])
Encoded Inputs shape: torch.Size([64, 5, 200])
Communication step: 0
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 1
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 2
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Actor Network output: torch.Size([64, 5, 5])
