In [6]:
import time
import torch
import numpy as np
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display

class VMASEnvRunner:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        n_steps: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.n_steps = n_steps
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.kwargs = kwargs
        self.frame_list = []  # Initialize frame list
        self.q_tables = {}  # Dictionary to store Q-tables for each agent
        # Define discrete actions for navigation scenario
        self.discrete_actions = [
            torch.tensor([1]),  # Move right
            torch.tensor([-1]),  # Move left
            torch.tensor([0])  # Stay
        ]

    def _update_q_table(self, agent_index, state, action, reward, next_state, learning_rate, discount_factor, num_actions):
        # Implement Q-learning update rule for the specific agent
        if agent_index not in self.q_tables:
            self.q_tables[agent_index] = {}  # Initialize Q-table for the agent
        state_key = tuple(state)
        next_state_key = tuple(next_state)
        if state_key not in self.q_tables[agent_index]:
            self.q_tables[agent_index][state_key] = np.zeros((num_actions,))  # Initialize Q-values for new state
        if next_state_key not in self.q_tables[agent_index]:
            self.q_tables[agent_index][next_state_key] = np.zeros((num_actions,))  # Initialize Q-values for new state

        # Ensure action is an integer
        action = int(action)

        self.q_tables[agent_index][state_key][action] += learning_rate * (
            reward + discount_factor * np.max(self.q_tables[agent_index][next_state_key]) - self.q_tables[agent_index][state_key][action]
        )

    def _get_deterministic_action(self, agent, env, agent_index, state, num_actions):
        if not self.random_action:
            if agent_index not in self.q_tables:
                self.q_tables[agent_index] = {}  # Initialize Q-table for the agent
            state_key = tuple(state)
            if state_key not in self.q_tables[agent_index]:
                self.q_tables[agent_index][state_key] = np.zeros((num_actions,))  # Initialize Q-values for new state
            action_idx = np.argmax(self.q_tables[agent_index][state_key])  # Choose action with highest Q-value

            # Convert the action index to the corresponding tensor
            action_tensor = torch.tensor(self.discrete_actions[action_idx], device=env.device, dtype=torch.float32)

            return action_tensor

    def _get_random_action(self, agent, num_actions, env):
        # Choose a random action index for each environment
        action_indices = np.random.randint(num_actions, size=env.num_envs)
        # Create a tensor with random actions for each environment
        action_tensors = torch.stack([self.discrete_actions[idx] for idx in action_indices]).to(env.device)
        return action_tensors

    def generate_gif(self, scenario_name):
        fps = 30
        clip = ImageSequenceClip(self.frame_list, fps=fps)
        clip.write_gif(f'{scenario_name}.gif', fps=fps)

        # Return the HTML tag to display the GIF
        return HTML(f'<img src="{scenario_name}.gif">')

    def run_vmas_env(self, learning_rate, discount_factor):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            seed=0,
            **self.kwargs
        )

        init_time = time.time()
        step = 0

        states = env.reset()  # Reset the environment and get initial states for all agents

        for s in range(self.n_steps):
            step += 1
            print(f"Step {step}")

            actions = []
            for i, agent in enumerate(env.agents):
                state = states[i]  # Get current state for the agent
                num_actions = len(self.discrete_actions)
                print(f"Agent {i}: {agent} of {num_actions}")

                if not self.random_action:
                    action = self._get_deterministic_action(agent, env, i, state, num_actions)
                else:
                    action = self._get_random_action(agent, num_actions, env)
                
                print(f"Action agent {i}: {action}")

                actions.append(action)

            print(f"Actions before stacking: {actions}")

            # Ensure each action is reshaped correctly
            actions_tensor = [a.view(self.num_envs, -1) for a in actions]

            print(f"Actions tensor shape: {[a.shape for a in actions_tensor]}")

            # Step the environment with the list of action tensors
            next_states, rewards, dones, infos = env.step(actions_tensor)
            
            # Update Q-tables
            for i, agent in enumerate(env.agents):
                state = states[i]
                action = actions[i].item()  # Convert tensor to scalar
                reward = rewards[i].item() if isinstance(rewards[i], torch.Tensor) else rewards[i]
                next_state = next_states[i]
                num_actions = len(self.discrete_actions)

                self._update_q_table(i, state, action, reward, next_state, learning_rate, discount_factor, num_actions)

            states = next_states  # Update states

            if self.render:
                frame = env.render(
                    mode="rgb_array",
                    agent_index_focus=None,
                )
                self.frame_list.append(frame)  # Append frame to the list

        total_time = time.time() - init_time
        print(
            f"It took: {total_time}s for {self.n_steps} steps of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )

if __name__ == "__main__":
    scenario_name = "navigation_discrete"

    learning_rate = 0.1  # Learning rate
    discount_factor = 0.9  # Discount factor

    env_runner = VMASEnvRunner(
        render=True,
        num_envs=1,  # Single environment
        n_steps=10,
        device="cuda",
        scenario=scenario_name,
        continuous_actions=False,
        random_action=True,  # Set to True to enable random actions
        # Environment specific variables
        n_agents=2,
    )
    # Run the VMAS environment
    env_runner.run_vmas_env(learning_rate, discount_factor)

    # Generate and display the GIF
    ipython_display(env_runner.generate_gif(scenario_name))


Step 1
Agent 0: <vmas.simulator.core.Agent object at 0x7bee06474e80> of 3
Action agent 0: tensor([[1]], device='cuda:0')
Agent 1: <vmas.simulator.core.Agent object at 0x7bee06474fd0> of 3
Action agent 1: tensor([[-1]], device='cuda:0')
Actions before stacking: [tensor([[1]], device='cuda:0'), tensor([[-1]], device='cuda:0')]
Actions tensor shape: [torch.Size([1, 1]), torch.Size([1, 1])]


RuntimeError: output with shape [1] doesn't match the broadcast shape [0]