### Imports

In [1]:
import gymnasium as gym
import rware
import torch
import random
import numpy as np
import sys
sys.path.append('..')
from torch import nn, optim, Tensor
from torch.nn import functional as F
from tqdm.auto import trange
from collections import deque
from pressureplate.pressureplate.environment import PressurePlate
from functools import reduce

  from .autonotebook import tqdm as notebook_tqdm


### Device

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
config = {
    'height': 15,
    'width': 9,
    'n_agents': 4,
    'sensor_range': 0,
    'layout': 'linear'
}

### Envirement

In [4]:
env = PressurePlate(**config)
env.reset()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


(array([ 1.,  0.,  0.,  0.,  5., 13.], dtype=float32),
 array([ 1.,  0.,  0.,  0.,  5., 12.], dtype=float32),
 array([ 1.,  0.,  0.,  0.,  4., 13.], dtype=float32),
 array([ 1.,  0.,  0.,  0.,  4., 12.], dtype=float32))

In [5]:
def test():

    env.reset()

    for _ in range(100):

        env.render()
        action = env.action_space.sample()
        obs,reward,done,info = env.step(action)

        if all(done):
            break

    env.close()

In [6]:
test()

### Model

In [7]:
def size(shape : int | tuple[int]) -> int:
    
    if isinstance(shape, tuple):
        return reduce(lambda x, y: x * y, shape)
    
    return shape

In [8]:
class MLP(nn.Module):

    def __init__(self, 
        in_dim : int,
        out_dim : int | tuple[int],
        encoder_hidden_dims : list[int] = [16, 32],
        head_hidden_dims : list[int] = [32, 32],
        use_hiddens : bool = True
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.encoder_hidden_dims = encoder_hidden_dims
        self.head_hidden_dims = head_hidden_dims
        self.use_hiddens = use_hiddens

        self.encoder_layers = nn.ModuleList()
        self.head_layers = nn.ModuleList()

        self.input_layer = nn.Sequential(
            nn.Linear(in_dim, encoder_hidden_dims[0]),
            nn.ReLU(inplace=True),
        )

        for d1, d2 in zip(encoder_hidden_dims[:-1], encoder_hidden_dims[1:]):

            hidden_layer = nn.Sequential(
                nn.Linear(d1, d2),
                nn.ReLU(inplace=True),
            )

            self.encoder_layers.append(hidden_layer)

        n = 2 if self.use_hiddens else 1

        self.inter_layer = nn.Sequential(
            nn.Linear(n * encoder_hidden_dims[-1], head_hidden_dims[0]),
            nn.ReLU(inplace=True),
        )

        for d1, d2 in zip(head_hidden_dims[:-1], head_hidden_dims[1:]):

            hidden_layer = nn.Sequential(
                nn.Linear(d1, d2),
                nn.ReLU(inplace=True),
            )

            self.head_layers.append(hidden_layer)

        self.output_layer = nn.Linear(head_hidden_dims[-1], size(out_dim)) 

    def forward_encoder(self, x : Tensor) -> Tensor:

        x = self.input_layer(x)

        for layer in self.encoder_layers:
            x = layer(x)

        return x
    
    def forward_head(self, x : Tensor) -> Tensor:

        x = self.inter_layer(x)

        for layer in self.head_layers:
            x = layer(x)

        x = self.output_layer(x)

        if isinstance(self.out_dim, tuple):
            B = x.size(0)
            x = x.view(B, *self.out_dim)

        return x

In [9]:
class QNet(nn.Module):

    def __init__(self, 
        in_dim : int,
        out_dim : int | tuple[int],
        encoder_hidden_dims : list[int] = [16, 32],
        head_hidden_dims : list[int] = [32, 32],
        use_hiddens : bool = True
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.encoder_hidden_dims = encoder_hidden_dims
        self.head_hidden_dims = head_hidden_dims
        self.use_hiddens = use_hiddens

        self.local_net = MLP(in_dim, out_dim, encoder_hidden_dims, head_hidden_dims, use_hiddens)
        self.target_net = MLP(in_dim, out_dim, encoder_hidden_dims, head_hidden_dims, use_hiddens)

        self.copy_params()

    def copy_params(self) -> None:
        state_dict = self.local_net.state_dict().copy()
        self.target_net.load_state_dict(state_dict)

    def forward_encoder(self, 
        x : Tensor,
        local : bool = True,
    ) -> Tensor:

        if local:
            return self.local_net.forward_encoder(x)
        
        return self.target_net.forward_encoder(x)
    
    def forward_head(self, 
        x : Tensor,
        local : bool
    ) -> Tensor:

        if local:
            return self.local_net.forward_head(x)
        
        return self.target_net.forward_head(x)
    
    def forward(self, x : Tensor, local : bool = True) -> Tensor:
        
        if self.use_hiddens:
            raise Exception("Use forward_encoder and forward_head instead")
        
        x = self.forward_encoder(x, local)
        x = self.forward_head(x, local)

        return x

In [10]:
class GlobalPolicy(QNet):

    def __init__(self,
        n_agents : int,
        in_dim : int,
        out_dim : int,
        encoder_hidden_dims : list[int] = [16, 32],
        head_hidden_dims : list[int] = [32, 32],
    ):
        
        super().__init__(
            in_dim=in_dim,
            out_dim=(n_agents, out_dim),
            encoder_hidden_dims=encoder_hidden_dims,
            head_hidden_dims=head_hidden_dims,
            use_hiddens=False
        )

        self.n_agents = n_agents

    def get_goals(self, x : Tensor, eps : float = 0.0) -> list[int]:
        
        if random.random() < eps:
            return [random.randint(0, self.out_dim - 1) for _ in range(x.size(0))]
        
        with torch.no_grad():

            # x : (B,D)

            x = self.forward(x, local=True) # (B,G)
            x = x.argmax(dim=-1) # (B,1)
            x = x.squeeze(dim=1) # (B,)
            x = x.tolist()

        return x


In [11]:
class CommNet(nn.Module):

    def __init__(self, 
        in_dim : int,
        include_learnable_params : bool = False,
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.include_learnable_params = include_learnable_params

        self.W = nn.Linear(in_dim, in_dim) if include_learnable_params else nn.Identity()

    def forward(self, x : Tensor) -> Tensor:
        B, N, D = x.size()
        out = self.W(x)
        out = (x.sum(dim=1,keepdim=True) - x) / (N - 1)
        x = torch.cat([x, out], dim=-1)

        return x

In [12]:
class Networks(nn.Module):

    def __init__(self, 
        n_agents : int,
        in_dim : int,
        out_dim : int,
        encoder_hidden_dims : list[int] = [16, 32],
        head_hidden_dims : list[int] = [32, 16],
    ):
        super().__init__()

        self.networks = nn.ModuleList()

        for _ in range(n_agents):
            network = QNet(in_dim, out_dim, encoder_hidden_dims, head_hidden_dims)
            self.networks.append(network)

        self.comm = CommNet(
            in_dim=encoder_hidden_dims[-1],
            include_learnable_params=False,
        )

    def forward(self, 
        inputs : Tensor,
        local : bool = True,
    ) -> Tensor:

        hiddens = []

        for x,net in zip(inputs, self.networks):
            output = net.forward_encoder(x, local)
            hiddens.append(output)

        hiddens = torch.stack(hiddens, dim=1)
        hiddens = self.comm(hiddens)
        hiddens = torch.permute(hiddens, dims=[1, 0, 2])

        outputs = []

        for x, network in zip(hiddens, self.networks):
            x = network.forward_head(x, local)
            outputs.append(x)

        outputs = torch.stack(outputs, dim=0)        

        return outputs
    
    def get_actions(self, x : Tensor, eps : float = 0.0) -> list[int]:
        
        if random.random() < eps:
            return [random.randint(0, self.networks[0].out_dim - 1) for _ in range(x.size(0))]
        
        with torch.no_grad():

            x = torch.unsqueeze(x, dim=1) # (N,1,D)
            x = self.forward(x, local=True) # (N,1,A)
            x = x.argmax(dim=-1) # (N,1)
            x = x.squeeze(dim=1) # (N,)
            x = x.tolist()

        return x
    
    def copy_params(self) -> None:

        for net in self.networks:
            net.copy_params()

- Test the model

In [13]:
nets = Networks(
    n_agents=2,
    in_dim=4,
    out_dim=2,
    encoder_hidden_dims=[16, 32],
    head_hidden_dims=[32, 16],
)

x = torch.randn(2, 8, 4)
y = nets(x, local=True)

print(y.shape)

torch.Size([2, 8, 2])


In [15]:
global_policy = GlobalPolicy(
    n_agents=2,
    in_dim=4,
    out_dim=2,
    encoder_hidden_dims=[16, 32],
    head_hidden_dims=[32, 16],
)

x = torch.randn(16, 4)
y = global_policy(x, local=True)

print(y.shape)

torch.Size([16, 2, 2])


### Reply Memory

In [16]:
class ReplyMemory:

    def __init__(self,
        capacity: int,
        batch_size: int,
    ) -> None:
        self.capacity = capacity
        self.batch_size = batch_size

        self.memory = deque(maxlen=capacity)

    def remember(self, 
        state : np.ndarray | list[np.ndarray], 
        action : int | list[int] , 
        reward : float | list[float], 
        next_state : np.ndarray | list[np.ndarray], 
        done : bool | list[bool]   
    ) -> None:
        
        if isinstance(state, (list, tuple)):
            state = np.stack(state)

        if isinstance(next_state, (list, tuple)):
            next_state = np.stack(next_state)
        
        state = torch.tensor(state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        done = torch.tensor(done, dtype=torch.bool)

        self.memory.append((state, action, reward, next_state, done))

    def sample(self) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:

        batch = random.sample(self.memory, self.batch_size)

        state, action, reward, next_state, done = zip(*batch)

        return (
            torch.stack(state).to(DEVICE),
            torch.stack(action).to(DEVICE),
            torch.stack(reward).to(DEVICE),
            torch.stack(next_state).to(DEVICE),
            torch.stack(done).to(DEVICE),
        )
    
    def __len__(self) -> int:
        return len(self.memory)

### Environment

In [17]:
class HMARLEnv:

    def __init__(self, env : gym.Env) -> None:
        self.n_agents = 2
        self.n_goals = 3
        self.goals = [None for _ in range(self.n_agents)]
        self.env = env

    def is_goal_done(self, 
        state : np.ndarray,
        goal : int
    ) -> bool:
        pass

    def goal_reward(self, 
        state : np.ndarray,
        goal : int
    ) -> float:
        
        if self.is_goal_done(state, goal):
            return 1.0
        
        return 0.0
    
    def get_agent_observation(self, 
        state : np.ndarray,
        agent : int
    ) -> np.ndarray:
        return state
    
    def get_goal_actions_mask(self, goal) -> np.ndarray:
        return np.ones(self.env.action_space[0].n)
    
    def reset(self) -> dict:

        state, _ = self.env.reset()

        result = dict()

        result['global'] = state

        result['local'] = [
            self.get_agent_observation(state, agent_id)
            for agent_id in range(self.n_agents)
        ]

        return result
    
    def step(self, 
        actions : list[int],
        goals : list[int]
    ) -> dict:

        result = dict()

        next_state, reward, done, _, _ = self.env.step(actions)

        result["global"] = (next_state, reward, done)

        result["local"] = [
            [
                self.get_agent_observation(next_state, agent_id),
                self.goal_reward(next_state, goal_id),
                self.is_goal_done(next_state, goal_id),
            ]
            for goal_id,agent_id in zip(goals, range(self.n_agents))
        ]

        result["local"] = list(zip(*result['local']))

        return result

### Trainer

In [12]:
class Trainer:

    ENV_METADATA = {
        "rware-tiny-2ag-v2" : {
            "n_agents" : 2,
            "n_actions" : 5,
            "n_goals" : 2
        }
    }
    
    def __init__(self,
        qnet_args : dict,
        commnet_args : dict,
        env : HMARLEnv,
        gamma : float = 0.99,      
        lr : float = 1e-3,    
        num_episodes : int = 10000,   
        max_steps : int = 1000,
        learning_freq : int = 4,
        target_update_freq : int = 50,
        memory_capacity : int = 10000,
        batch_size : int = 32,
    ) -> None:
        
        self.qnet_args = qnet_args
        self.commnet_args = commnet_args
        self.env = env
        self.gamma = gamma
        self.lr = lr
        self.num_episodes = num_episodes
        self.max_steps = max_steps
        self.learning_freq = learning_freq
        self.target_update_freq = target_update_freq
        self.memory_capacity = memory_capacity
        self.batch_size = batch_size


        self.global_memory = ReplyMemory(capacity=memory_capacity, batch_size=batch_size)
        self.agents_memory = ReplyMemory(capacity=memory_capacity, batch_size=batch_size)

        self.global_policy = GlobalPolicy(**qnet_args, out_dim=env.n_goals, n_agents=env.n_agents).to(DEVICE)
        self.networks = Networks(**qnet_args, n_agents=env.n_agents).to(DEVICE)

        self.global_policy_optimizer = optim.Adam(self.global_policy.parameters(), lr=lr)
        self.networks_optimizer = optim.Adam(self.networks.parameters(), lr=lr)

    def update_networks(self, x : tuple[Tensor]) -> None:
                
        states, actions, rewards, next_states, dones = x

        current_q_values = self.networks.forward(states, local=True) # (N,B,A)
        next_q_values = self.networks.forward(next_states, local=False) # (N,B,A)
        mask = torch.eye(n=self.env.env.action_space[0].n, dtype=torch.bool)[actions] # (N,B,A)
        current_q_values[mask] = rewards + self.gamma * (1 - dones) * torch.max(next_q_values, dim=-1).values

        loss = F.mse_loss(current_q_values, next_q_values) # Change this later

        self.networks_optimizer.zero_grad()
        loss.backward()
        self.networks_optimizer.step()

    def update_global_policy(self, x : tuple[Tensor]) -> None:

        states, actions, rewards, next_states, dones = x

        current_q_values = self.global_policy.forward(states, local=True) # (B,N,G)
        next_q_values = self.global_policy.forward(next_states, local=False) # (B,N,G)
        mask = torch.eye(n=self.env.n_goals, dtype=torch.bool)[actions] # (B,N,G)
        current_q_values[mask] = rewards + self.gamma * (1 - dones) * torch.max(next_q_values, dim=-1).values

        loss = F.mse_loss(current_q_values, next_q_values) # Change this later

        self.global_policy_optimizer.zero_grad()
        loss.backward()
        self.global_policy_optimizer.step()

    @staticmethod
    def state_to_tensor(state : dict) -> dict[str, Tensor]:

        result = dict()

        result['global'] = torch.from_numpy(result['global']).to(DEVICE).unsqueeze(0)
        result['local'] = []

        for s in state['local']:
            s = torch.from_numpy(s).to(DEVICE).unsqueeze(0)
            result['local'].append(s)

        result['local'] = torch.stack(result['local'], dim=0)

        return result

    def train(self) -> None:    

        step = 0
        history = []
        
        for episode in trange(self.num_episodes):

            env = gym.make(self.env_id)
            state = self.env.reset()
            agent_goals = [None for _ in range(self.env.n_agents)]

            for _ in range(self.max_steps):

                state_tensor = self.state_to_tensor(state)

                ### Assigne goals to agents
                goals = self.global_policy.get_goals(state_tensor['global'])
                
                for i,goal in enumerate(agent_goals):
                    
                    if goal is None:
                        agent_goals[i] = goals[i]

                ### Get Actions
                actions = self.networks.get_actions(state_tensor['local'])
                result = self.env.step(actions, goals)

                g_next_state, g_reward, g_done = result['global']
                l_next_state, l_reward, l_done = result['local']

                self.global_memory.remember(state['global'], goals, g_reward, g_next_state, g_done)
                self.agents_memory.remember(state['local'], actions, l_reward, l_next_state, l_done)

                if step % self.learning_freq == 0:

                    if len(self.agents_memory) > self.batch_size:
                        x = self.agents_memory.sample()
                        self.update_networks(x)

                    if len(self.agents_memory) > self.batch_size:
                        x = self.global_memory.sample()
                        self.update_global_policy(x)

                if g_done:
                    break

                state = g_next_state

                for i,done in enumerate(l_done):
                    
                    if done:
                        agent_goals[i] = None

            

- Training psuedo code

global_step = 0
history = []

For i in range(num_episodes):

    state = env.reset()

    For j in range(max_steps):

        for agent in agents:

            if goal is not assigned to agent:
                agent.goal = global_policy(agent)

        actions = [agent.get_action() for agent in agents]
        next_state, reward, done, _, _ = env.step(action) # This returns agents observation,goal reward and goal done not the global ones !!!
        RESET THE GOALS FOR AGENTS THAT ARE DONE WITH THEIR GOALS (You have the choice to synchronize or not)

        agents_memory.remember(state, action, reward, next_state, done)
        global_policy_memory.append((global_state, goals, global_reward, next_global_state, global_done))

        if global_step % learning_freq = 0:
            batch = sample
            update(agents, batch)

        if globa_step % target_update_freq = 0:
            for agent in agents:
                agent.copy_params()

        if global_done:
            break

        state = next_state
        history.append(reward)

def is_goal_done(agent_obs, goal) -> bool
def goal_reward(agent_obs, goal) -> float
def get_agent_observation(global_state, agent) -> observation
def get_actions_mask(goal) -> ActionsMask