In [2]:
import time
import torch
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display
import numpy as np

class VMASEnvRunner:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        num_episodes: int,
        max_steps_per_episode: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        obs_discrete: bool = False,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.num_episodes = num_episodes
        self.max_steps_per_episode = max_steps_per_episode
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.obs_discrete = obs_discrete
        self.kwargs = kwargs
        self.frame_list = []  # Initialize frame list

    def qlearning_action(self, env, agent_id, agent_obs):
        pass



    def discretize(self, data, bins):
   
        # Ensure the bins are a numpy array
        bins = np.array(bins)
        
        # Check if the input data is a single value
        if np.isscalar(data):
            bin_indices = np.digitize([data], bins)[0]
            bin_values = bins[bin_indices - 1] if bin_indices > 0 else bins[0]
            return bin_indices, bin_values
        else:
            # Use numpy.digitize to get the bin indices for array-like data
            bin_indices = np.digitize(data, bins)
            # Map indices to bin values
            bin_values = [bins[index - 1] if index > 0 else bins[0] for index in bin_indices]
            return bin_indices, bin_values


    # Function to discretize a tensor slice
    def discretize_tensor_slice(self, tensor_slice, bins):
        indices = []
        values = []
        for row in tensor_slice:
            row_indices = []
            row_values = []
            for item in row:
                bin_idx, bin_val = self.discretize(item.item(), bins)
                row_indices.append(bin_idx)
                row_values.append(bin_val)
            indices.append(row_indices)
            values.append(row_values)
        return torch.tensor(indices, device=tensor_slice.device), torch.tensor(values, device=tensor_slice.device)



    def _get_deterministic_obs(self, env, observation):
        pos_bins = np.linspace(-1, 1, num=21)
        vel_bins = np.linspace(0, 0, num=21)
        lidar_bins = np.linspace(0, 1, num=11)

        pos = observation[0:2]
        vel = observation[2:4]
        goal_pose = observation[4:6]
        comms_data = observation[6:13]
        sensor_data = observation[13:]

        discrete_pos_indices, discrete_pos_values = self.discretize_tensor_slice(pos, pos_bins)
        discrete_vel_indices, discrete_vel_values = self.discretize_tensor_slice(vel, vel_bins)
        discrete_goal_pose_indices, discrete_goal_pose_values = self.discretize_tensor_slice(goal_pose, pos_bins)
        discrete_sensor_data_indices, discrete_sensor_data_values = self.discretize_tensor_slice(sensor_data, lidar_bins)

        concatenated_tensor_values = torch.cat(
            [discrete_pos_values, discrete_vel_values, discrete_goal_pose_values, comms_data, discrete_sensor_data_values],
            dim=0
        )

        return concatenated_tensor_values

    def _get_deterministic_action(self, agent: Agent, env, agent_id, agent_obs):
        if self.continuous_actions:
            if agent.silent:
                action = torch.tensor([[-1, 0.5]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = true
            else:
                if agent_id == 0:
                    action = torch.tensor([[-1, 0.5, [2,1]]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=1
                else:
                    action = torch.tensor([[-1, 0.5, 3]], device=env.device) # range action is [-1, 1] of a paired value and agent silent = false and c=1
        else:
            if agent.silent:
                action = torch.tensor([[1]], device=env.device) 
            else:
                action = torch.tensor([[6,[9,4]]], device=env.device) 
            # range discrete action for c=1 = [0:stay, 1:down, 2:up, 3:left, 4:bottom-left, 5:top-left, 6:right, 7:bottom-right, 8:top-right] of a single value
            # range discrete action for c=2 = [0/1:stay, 2/3:down, 4/5:up, 6/7:left, 8/9:bottom-left, 10/11:top-left, 12/13:right, 14/15:bottom-right, 16/17:top-right] of a single value
            # action = self.qlearning_action(self, env, agent_id, agent_obs)
            # print (f"action space: {env.action_space}")
        
        return action.clone()


    def generate_gif(self, scenario_name):
        fps = 30
        clip = ImageSequenceClip(self.frame_list, fps=fps)
        clip.write_gif(f'{scenario_name}.gif', fps=fps)

        # Return the HTML tag to display the GIF
        return HTML(f'<img src="{scenario_name}.gif">')

    def run_vmas_env(self):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            seed=0,
            **self.kwargs
        )

        init_time = time.time()
        total_steps = 0
        
        for e in range(self.num_episodes):  # Loop over episodes
            print(f"Episode {e}")
            obs = env.reset()  # Reset environment at the start of each episode
            done = [False] * self.num_envs
            step = 0
            while not all(done) and step < self.max_steps_per_episode:  # Loop over steps within an episode
                step += 1
                total_steps += 1
                print(f"Step {step} of Episode {e}")

                actions = []
                for i, agent in enumerate(env.agents):

                    if self.obs_discrete:
                        print(f"obs agent {i}: {obs[i]}")
                        discrete_obs = self._get_deterministic_obs(env, obs[i])
                        obs[i] = discrete_obs
                        print(f"discrete obs of agent {i}: {obs[i]}")

                    if self.random_action:
                        action = env.get_random_action(agent)
                    else:
                        action = self._get_deterministic_action(agent, env, i, obs[i])
                    
                    print(f"action agent {i}: {action}")
                    actions.append(action)

                obs, rews, dones, info = env.step(actions)
                done = [done or d for done, d in zip(done, dones)]
                print(f"obs agent after step {step}: {obs}")
                print(f"rews agent: {rews}")

                if self.render:
                    frame = env.render(
                        mode="rgb_array",
                        agent_index_focus=None,
                    )
                    self.frame_list.append(frame)  # Append frame to the list

        total_time = time.time() - init_time
        print(
            f"It took: {total_time}s for {total_steps} steps across {self.num_episodes} episodes of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )

if __name__ == "__main__":
    scenario_name = "navigation_comm"

    env_runner = VMASEnvRunner(
        render=True,
        num_envs=1,
        num_episodes=1,
        max_steps_per_episode=10,
        device="cuda",
        scenario=scenario_name,
        continuous_actions=False,
        random_action=False,
        obs_discrete=True,  
        n_agents=2,
    )
    # Run the VMAS environment
    env_runner.run_vmas_env()

    # Generate and display the GIF
    ipython_display(env_runner.generate_gif(scenario_name))


Episode 0
Step 1 of Episode 0
obs agent 0: tensor([-0.9000,  0.9000,  0.0000,  0.0000,  0.0000, -0.8000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000], device='cuda:0')


TypeError: iteration over a 0-d tensor