In [4]:
import time
import torch
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display

class VMASEnvRunner:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        num_episodes: int,
        max_steps_per_episode: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.num_episodes = num_episodes
        self.max_steps_per_episode = max_steps_per_episode
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.kwargs = kwargs
        self.frame_list = []  # Initialize frame list

    def get_continuous_action(self):
        pass

    def get_discrete_action(self):
        pass

    def _get_deterministic_action(self, agent: Agent, env, agent_id):
        if self.continuous_actions:
            
            # action = agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)
            if(agent.silent):
                action = torch.tensor([[-1, 0.5]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = true
                
            else:
                if(agent_id==0):
                    action = torch.tensor([[-1, 0.5, 0, 0]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=1
                else:
                    action = torch.tensor([[-1, 0.5, 1, 1]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=1                    
                # action = torch.tensor([[-1, 0.5, 1, 0.5]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=1
                # action = torch.tensor([[-1, 0.5, 0.33, 2]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=2
        else:
            
            # action = torch.tensor([1], device=env.device, dtype=torch.long).unsqueeze(-1).expand(env.batch_dim, 1)
            if(agent.silent):
                action = torch.tensor([[8]], device=env.device) # range action is [0,8] of a single value
            else:
                action = torch.tensor([[8]], device=env.device) # range action is [0,8] of a single value
        return action.clone()

    def generate_gif(self, scenario_name):
        fps = 30
        clip = ImageSequenceClip(self.frame_list, fps=fps)
        clip.write_gif(f'{scenario_name}.gif', fps=fps)

        # Return the HTML tag to display the GIF
        return HTML(f'<img src="{scenario_name}.gif">')

    def run_vmas_env(self):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        # print ("clip 1")
        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            seed=0,
            **self.kwargs
        )
        # print ("clip 2")

        init_time = time.time()
        total_steps = 0
        
        
        for e in range(self.num_episodes):  # Loop over episodes
            # episode += 1
            print(f"Episode {e}")
            obs = env.reset()  # Reset environment at the start of each episode
            done = [False] * self.num_envs
            step = 0
            # print(f"obs agent after reset: {obs}")
            while not all(done) and step < self.max_steps_per_episode:  # Loop over steps within an episode
                step += 1
                total_steps += 1
                print(f"Step {step} of Episode {e}")

                actions = []
                for i, agent in enumerate(env.agents):
                    if not self.random_action:
                        action = self._get_deterministic_action(agent, env, i)
                    else:
                        action = env.get_random_action(agent)

                    # if(i==1){
                    #     action = agent[i-1].action
                    # }
                    print (f"action agent {i}: {action}")

                    actions.append(action)
                    # print(f"action agent {i}: {action} of {env.action_space}")
                    # print(f"Messages from agent {i}: {action} of {env.action_space}")
                
                obs, rews, dones, info = env.step(actions)
                done = [done or d for done, d in zip(done, dones)]


                print(f"obs agent after step {step}: {obs}")
                print(f"rews agent: {rews}")

                if self.render:
                    frame = env.render(
                        mode="rgb_array",
                        agent_index_focus=None,
                    )
                    self.frame_list.append(frame)  # Append frame to the list

                # Display communication information
                # for i, agent in enumerate(env.agents):
                #     print(f"Agent {i} received messages: {agent.messages}")

        total_time = time.time() - init_time
        print(
            f"It took: {total_time}s for {total_steps} steps across {self.num_episodes} episodes of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )
        


if __name__ == "__main__":
    scenario_name = "navigation_comm"

    env_runner = VMASEnvRunner(
        render=True,
        num_envs=1,
        num_episodes=1,
        max_steps_per_episode=1,
        device="cuda",
        scenario=scenario_name,
        continuous_actions=False,
        random_action=False,
        # Environment specific variables
        n_agents=3,
    )
    # Run the VMAS environment
    env_runner.run_vmas_env()

    # Generate and display the GIF
    ipython_display(env_runner.generate_gif(scenario_name))


Episode 0
Step 1 of Episode 0
action agent 0: tensor([[8]], device='cuda:0')
action agent 1: tensor([[8]], device='cuda:0')
action agent 2: tensor([[8]], device='cuda:0')
obs agent after step 1: [tensor([[ 0.2838, -0.2976, -0.1000, -0.1000,  0.8630,  0.4957,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  1.0000,  0.0000,  1.0000,  0.0000]],
       device='cuda:0'), tensor([[-0.9425,  0.2652, -0.1000, -0.1000, -1.8352,  0.4201,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  1.0000,  0.0000,  1.0000,  0.0000]],
       device='cuda:0'), tensor([[-0.3584, -0.2285, -0.1000, -0.1000, -0.7769,  0.7597,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  1.0000,  0.0000,  1.0000,  0.0000]],
       device='cuda:0')]
rews agent: [tensor([0.0046], device='cuda:0'), 

                                                  