<a href="https://colab.research.google.com/github/apof/Multi-Agent-RL/blob/main/Commnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np

In [2]:
AGENTS_NUMBER = 5
STATE_LENGTH = 50
ENCODING_LENGTH = 200
COMMUNICATION_STEPS = 3

In [3]:
class CommNet(nn.Module):
  def __init__(self,state_dimension,encoding_dimension,agents_num,communication_steps,output_type = 'actor'):
    super().__init__()
    self.state_dimension = state_dimension
    self.encoding_dimension = encoding_dimension
    self.agents_num = agents_num
    self.communication_steps = communication_steps
    ## define if the outpout of the network is actor or critic type
    self.output_type = output_type

    self.encoding_layer = nn.Sequential(nn.Linear(self.state_dimension, self.encoding_dimension),nn.ReLU())
    self.WH_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])
    self.WC_layer = nn.ModuleList([nn.Linear(self.encoding_dimension, self.encoding_dimension) for i in range(self.communication_steps)])


  def encoding_step(self,states):
    print("Input shape: " + str(states.shape))
    ## encode the state of every agent
    encoded_inputs = []
    for agent_index in range(self.agents_num):
      encoded_input = self.encoding_layer(states[:,agent_index,:])
      encoded_input = torch.reshape(encoded_input,(encoded_input.shape[0],1,encoded_input.shape[1]))
      encoded_inputs.append(encoded_input)
    encoded_inputs = torch.cat(encoded_inputs,axis=1)
    print("Encoded Inputs shape: " + str(encoded_inputs.shape))
    return encoded_inputs

  def f(self,h,c,step_index):
    ## decide which weight to use
    return torch.tanh(self.WH_layer[step_index](h)  + self.WC_layer[step_index](c))

  def communication_step(self,H,C,step_index):

    print("Communication step: " + str(step_index))

    ## compute the next H
    next_H = []
    for agent_index in range(self.agents_num):
      h = H[:,agent_index,:]
      c = C[:,agent_index,:]
      next_h = self.f(h,c,step_index)
      next_h = torch.reshape(next_h,(next_h.shape[0],1,next_h.shape[1]))
      next_H.append(next_h)
    next_H = torch.cat(next_H,axis = 1)

    print(next_H.shape)

    next_C = []
    for i in range(self.agents_num):
      next_c = []
      for j in range(self.agents_num):
        if (i!=j):
          next_c.append(next_H[:,j,:])
      stacked_c = torch.stack(next_c,axis = 0)
      next_c = torch.mean(stacked_c,axis = 0)
      next_c = torch.reshape(next_c,(next_c.shape[0],1,next_c.shape[1]))
      next_C.append(next_c)
    next_C = torch.cat(next_C,axis = 1)

    print(next_C.shape)

    return next_H, next_C

  def forward(self,states):
    ## encode the state for every agent
    encoded_states = self.encoding_step(states)
    ## communication steps
    ## define the first communication vector H filled with zeros
    C0 = torch.zeros_like(encoded_states)
    H = encoded_states
    C = C0
    for step in range(self.communication_steps):
      H_new, C_new = self.communication_step(H,C,step)
      C = C_new
      H = H_new

    


In [4]:
samples = np.random.uniform(low=0.5, high=13.3, size=(64,5,50))
samples = torch.from_numpy(samples).float()
print(samples.shape)

torch.Size([64, 5, 50])


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
commnet = CommNet(STATE_LENGTH,ENCODING_LENGTH,AGENTS_NUMBER,COMMUNICATION_STEPS).to(device)

In [7]:
commnet.forward(samples)

Input shape: torch.Size([64, 5, 50])
Encoded Inputs shape: torch.Size([64, 5, 200])
Communication step: 0
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 1
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
Communication step: 2
torch.Size([64, 5, 200])
torch.Size([64, 5, 200])
