In [5]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display

class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        value = self.critic(x)
        policy_dist = torch.tanh(self.actor(x))
        return policy_dist, value

class PPOAgent:
    def __init__(self, state_dim, action_dim, device, lr=3e-4, gamma=0.99, clip_epsilon=0.2, K_epochs=4, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.lr = lr
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.K_epochs = K_epochs
        self.batch_size = batch_size

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.memory = deque(maxlen=10000)

    def get_action(self, state):
        state = torch.FloatTensor(state).to(self.device) if not isinstance(state, torch.Tensor) else state.to(self.device)
        with torch.no_grad():
            policy_dist, _ = self.policy_old(state)
        action = policy_dist.cpu().numpy()
        return action

    def add_to_memory(self, transition):
        self.memory.append(transition)

    def compute_gae(self, rewards, masks, values, next_values):
        values = values + [next_values]
        gae = 0
        returns = []
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * masks[step] - values[step]
            gae = delta + self.gamma * self.clip_epsilon * masks[step] * gae
            returns.insert(0, gae + values[step])
        return returns

    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        _, values = self.policy(states)
        _, next_values = self.policy(next_states)
        returns = self.compute_gae(rewards, dones, values, next_values)
        returns = torch.cat(returns).detach()

        for _ in range(self.K_epochs):
            log_probs, state_values = self.policy(states)
            dist_entropy = -log_probs * actions
            ratios = torch.exp(log_probs - self.policy_old(states)[0].detach())
            advantages = returns - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            loss = -torch.min(surr1, surr2) + 0.5 * nn.MSELoss()(state_values, returns) - 0.01 * dist_entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())

class VMASEnvRunner:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        num_episodes: int,
        max_steps_per_episode: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.num_episodes = num_episodes
        self.max_steps_per_episode = max_steps_per_episode
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.kwargs = kwargs
        self.frame_list = []  # Initialize frame list

        # Initialize PPO agent
        self.ppo_agent = PPOAgent(state_dim=10, action_dim=4, device=device)  # Adjust state_dim and action_dim as needed

    def get_continuous_action(self, state):
        return self.ppo_agent.get_action(state)

    def get_discrete_action(self):
        pass

    def _get_deterministic_action(self, agent: Agent, env, obs, agent_id):
        state = obs[agent_id]  # Obtain the state from the environment
        state = torch.FloatTensor(state).to(self.device) if not isinstance(state, torch.Tensor) else state.to(self.device)

        if self.continuous_actions:
            action = self.get_continuous_action(state)
            if agent.silent:
                action = torch.tensor([[-1, 0.5]], device=self.device)  # Default silent action
            else:
                action = torch.tensor([action], device=self.device)  # Convert action index to tensor
        else:
            action = self.get_discrete_action()
            if(agent.silent):
                action = torch.tensor([[8]], device=self.device)  # range action is [0,8] of a single value
            else:
                action = torch.tensor([[8]], device=self.device)  # range action is [0,8] of a single value
        return action.clone()

    def generate_gif(self, scenario_name):
        fps = 30
        clip = ImageSequenceClip(self.frame_list, fps=fps)
        clip.write_gif(f'{scenario_name}.gif', fps=fps)

        # Return the HTML tag to display the GIF
        return HTML(f'<img src="{scenario_name}.gif">')

    def run_vmas_env(self):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            seed=0,
            **self.kwargs
        )

        init_time = time.time()
        total_steps = 0
        
        for e in range(self.num_episodes):  # Loop over episodes
            print(f"Episode {e}")
            obs = env.reset()  # Reset environment at the start of each episode
            done = [False] * self.num_envs
            step = 0
            while not all(done) and step < self.max_steps_per_episode:  # Loop over steps within an episode
                step += 1
                total_steps += 1
                print(f"Step {step} of Episode {e}")

                actions = []
                for i, agent in enumerate(env.agents):
                    if not self.random_action:
                        action = self._get_deterministic_action(agent, env, obs, i)
                    else:
                        action = env.get_random_action(agent)

                    print(f"action agent {i}: {action}")

                    actions.append(action)

                next_obs, rews, dones, info = env.step(actions)
                done = [done or d for done, d in zip(done, dones)]

                # Add transition to PPO memory and train
                self.ppo_agent.add_to_memory((obs, actions, rews, next_obs, dones))
                self.ppo_agent.train()

                obs = next_obs

                print(f"obs agent after step {step}: {obs}")
                print(f"rews agent: {rews}")

                if self.render:
                    frame = env.render(
                        mode="rgb_array",
                        agent_index_focus=None,
                    )
                    self.frame_list.append(frame)  # Append frame to the list

        total_time = time.time() - init_time

        print(
            f"It took: {total_time}s for {total_steps} steps across {self.num_episodes} episodes of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )

if __name__ == "__main__":
    scenario_name = "navigation_comm"

    env_runner = VMASEnvRunner(
        render=True,
        num_envs=1,
        num_episodes=1,
        max_steps_per_episode=10,
        device="cuda",
        scenario=scenario_name,
        continuous_actions=True,
        random_action=False,
        # Environment specific variables
        n_agents=3,
    )
    # Run the VMAS environment
    env_runner.run_vmas_env()

    # Generate and display the GIF
    ipython_display(env_runner.generate_gif(scenario_name))


Episode 0
Step 1 of Episode 0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x22 and 10x128)