<a href="https://colab.research.google.com/github/apof/Multi-Agent-RL/blob/main/Commnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [13]:
AGENTS_NUMBER = 2
ACTIONS_NUMBER = 5
STATE_LENGTH = 50
ENCODING_LENGTH = 200
COMMUNICATION_STEPS = 3
BATCH_SIZE = 64

In [14]:
class ReplayBuffer(object):
  def __init__(self,max_size,input_space,n_actions):
    self.mem_size = max_size
    self.mem_cntr = 0
    self.state_memory = np.zeros((self.mem_size,*input_shape))
    self.new_state_memory = np.zeros((self.mem_size,*input_shape))
    self.action_memory = np.zeros((self.mem_size,n_actions))
    self.reward_memory = np.zeros(self.mem_size)
    self.terminal_memory = np.zeros(self.memsize,dtype=np.float32)

  def store_transition(self,state,action,reward,state_,done):
    index = self.mem_cntr%self.mem_size
    self.state_memory[index] = state
    self.new_state_memory[index] = state_
    self.reward_memory[index] = reward
    self.terminal_memory[index] = 1-done
    self.mem_cntr += 1

  def sample_buffer(self,batch_size):
    max_mem = min(self.memcntr,self.mem_size)
    batch = np.random.choice(max_mem,batch_size)
    states = self.state_memory[batch]
    next_states = self.new_state_memory[batch]
    actions = self.action_memory[batch]
    rewards = self.reward_memory[batch]
    terminal = self.terminal_memory[batch]
    return states,actions,rewards,new_states,terminal

In [15]:
## this network gets the output of the commnet and applies a value decomposition nework on the top of it
class VDN(nn.Module):

  def __init__(self,linear_dim = 300,device = None):
    super().__init__()
    self.device = device
    self.commnet = CommNet(STATE_LENGTH,ENCODING_LENGTH,AGENTS_NUMBER,COMMUNICATION_STEPS,device)
    self.linear_dim = linear_dim
    self.value_stream = nn.ModuleList([nn.Sequential(nn.Linear(ENCODING_LENGTH, self.linear_dim),nn.ReLU(),nn.Linear(self.linear_dim,1)) for i in range(AGENTS_NUMBER)])
    self.advantage_stream = nn.ModuleList([nn.Sequential(nn.Linear(ENCODING_LENGTH, self.linear_dim),nn.ReLU(),nn.Linear(self.linear_dim,5)) for i in range(AGENTS_NUMBER)])
    self.to(device)

  def forward(self,states):
    encoded_state = self.commnet(states)

    out = []
    for agent_index in range(AGENTS_NUMBER):
      value = self.value_stream[agent_index](encoded_state[:,agent_index,:])
      advantage = self.advantage_stream[agent_index](encoded_state[:,agent_index,:])
      qvals = value + (advantage - advantage.mean())
      qvals = torch.reshape(qvals,(qvals.shape[0],1,qvals.shape[1]))
      out.append(qvals)
    out = torch.cat(out,axis=1)

    print("Forward pass: " + str(out.shape))

    return out

  def get_current(self,states,actions):
    ## get the current action based on the policy network
    q_vals = self.forward(states)
    a = []
    for i in range(AGENTS_NUMBER):
      a.append(q_vals[:,i,:].gather(dim=1, index=actions[:,i].unsqueeze(-1)))
    ##sum the action values across all the agents
    joint_actions = torch.zeros(a[0].size()).to(self.device)
    for i in range(AGENTS_NUMBER):
      joint_actions = torch.add(joint_actions,a[i])
    print("Joint Values: " + str(joint_actions.shape))

    return joint_actions

  def get_next(self,next_states, flags):                
    ## given the next state return the max Q value output using the target network
    qvals = self.forward(next_states)
    joint_qval = torch.zeros((qvals.size()[0],qvals.size()[2])).to(self.device)
    for i in range(AGENTS_NUMBER):
      joint_qval = torch.add(joint_qval,qvals[:,i,:]*flags[:,i,:])
    ## return the maximum joint action value
    next_value = joint_qval.max(dim=1)[0].detach()

    print("Next Values: " + str(next_value.shape))

    return next_value




In [16]:
class CommNet(nn.Module):
  def __init__(self,state_dimension,encoding_dimension,agents_num,communication_steps,device):
    super().__init__()
    self.state_dimension = state_dimension
    self.encoding_dimension = encoding_dimension
    self.agents_num = agents_num
    self.communication_steps = communication_steps
    self.device = device
    self.to(device)

    self.encoding_layer = nn.Sequential(nn.Linear(self.state_dimension, self.encoding_dimension),nn.ReLU())
    self.WH_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])
    self.WC_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])


  def encoding_step(self,states):
    print("Input shape: " + str(states.shape))
    ## encode the state of every agent
    encoded_inputs = []
    for agent_index in range(self.agents_num):
      encoded_input = self.encoding_layer(states[:,agent_index,:])
      encoded_input = torch.reshape(encoded_input,(encoded_input.shape[0],1,encoded_input.shape[1]))
      encoded_inputs.append(encoded_input)
    encoded_inputs = torch.cat(encoded_inputs,axis=1)
    print("Encoded Inputs shape: " + str(encoded_inputs.shape))
    return encoded_inputs

  def f(self,h,c,step_index):
    ## decide which weight to use
    return torch.tanh(self.WH_layer[step_index](h)  + self.WC_layer[step_index](c))

  def communication_step(self,H,C,step_index):

    print("Communication step: " + str(step_index))

    ## compute the next H
    next_H = []
    for agent_index in range(self.agents_num):
      h = H[:,agent_index,:]
      c = C[:,agent_index,:]
      next_h = self.f(h,c,step_index)
      next_h = torch.reshape(next_h,(next_h.shape[0],1,next_h.shape[1]))
      next_H.append(next_h)
    next_H = torch.cat(next_H,axis = 1)

    print(next_H.shape)

    next_C = []
    for i in range(self.agents_num):
      next_c = []
      for j in range(self.agents_num):
        if (i!=j):
          next_c.append(next_H[:,j,:])
      stacked_c = torch.stack(next_c,axis = 0)
      next_c = torch.mean(stacked_c,axis = 0)
      next_c = torch.reshape(next_c,(next_c.shape[0],1,next_c.shape[1]))
      next_C.append(next_c)
    next_C = torch.cat(next_C,axis = 1)

    print(next_C.shape)

    return next_H, next_C

  def forward(self,states):
    ## encode the state for every agent
    encoded_states = self.encoding_step(states)
    ## communication steps
    ## define the first communication vector H filled with zeros
    C0 = torch.zeros_like(encoded_states)
    H = encoded_states
    C = C0
    for step in range(self.communication_steps):
      H_new, C_new = self.communication_step(H,C,step)
      C = C_new
      H = H_new

    return H

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
model = VDN(device = device)

In [19]:
states = np.random.rand(BATCH_SIZE,AGENTS_NUMBER,STATE_LENGTH)
actions = np.zeros((BATCH_SIZE,AGENTS_NUMBER))
flags = np.zeros((BATCH_SIZE,AGENTS_NUMBER,ACTIONS_NUMBER))

states = torch.from_numpy(states).float().to(model.device)
#actions = torch.from_numpy(actions).int().to(model.device)
flags = torch.from_numpy(flags).to(model.device)

actions = torch.as_tensor(actions, dtype=torch.int64).to(model.device)

#get_cuda_device = flags.get_device()
#print(get_cuda_device)

In [20]:
_ = model.forward(states)

Input shape: torch.Size([64, 2, 50])
Encoded Inputs shape: torch.Size([64, 2, 200])
Communication step: 0
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 1
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 2
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Forward pass: torch.Size([64, 2, 5])


In [21]:
current_values = model.get_current(states,actions)

Input shape: torch.Size([64, 2, 50])
Encoded Inputs shape: torch.Size([64, 2, 200])
Communication step: 0
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 1
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 2
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Forward pass: torch.Size([64, 2, 5])
Joint Values: torch.Size([64, 1])


In [22]:
next_values = model.get_next(states, flags)

Input shape: torch.Size([64, 2, 50])
Encoded Inputs shape: torch.Size([64, 2, 200])
Communication step: 0
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 1
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Communication step: 2
torch.Size([64, 2, 200])
torch.Size([64, 2, 200])
Forward pass: torch.Size([64, 2, 5])
Next Values: torch.Size([64])
