In [25]:
import time
import torch
import numpy as np
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display

class QLearningAgent(Agent):
    def __init__(self, name, action_space_size, observation_space_size, learning_rate=0.1, discount_factor=0.99, epsilon=0.1):
        super(QLearningAgent, self).__init__(name)
        self.action_space_size = action_space_size
        self.observation_space_size = observation_space_size
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.q_table = np.zeros((observation_space_size, action_space_size))

    def choose_action(self, state):
        if np.random.uniform(0, 1) < self.epsilon:
            return np.random.choice(range(self.action_space_size))
        else:
            return np.argmax(self.q_table[state])

    def learn(self, state, action, reward, next_state):
        q_predict = self.q_table[state, action]
        q_target = reward + self.discount_factor * np.max(self.q_table[next_state])
        self.q_table[state, action] += self.learning_rate * (q_target - q_predict)

class QLearningVMASEnvRunner:
    def __init__(
        self,
        env,
        render: bool,
        num_envs: int,
        n_steps: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.n_steps = n_steps
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.kwargs = kwargs
        self.frame_list = []  # Initialize frame list

        # Initialize agents
        self.agents = [QLearningAgent(name=f"Agent_{i}", action_space_size=env.get_agent_action_size(agent), observation_space_size=env.get_agent_observation_space(agent, self.obs), learning_rate=0.1, discount_factor=0.99, epsilon=0.1) for i, agent in enumerate(env.agents)]

    def run_vmas_env(self):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            seed=0,
            **self.kwargs
        )

        init_time = time.time()
        step = 0

        for s in range(self.n_steps):
            step += 1
            print(f"Step {step}")

            actions = []
            for i, agent in enumerate(self.agents):
                if not self.random_action:
                    if isinstance(agent, QLearningAgent):  # Check if agent is an instance of QLearningAgent
                        action = self._get_q_learning_action(agent)
                    else:
                        raise ValueError("Agent must be an instance of QLearningAgent for Q-learning action.")
                else:
                    action = self._get_random_action(agent)

                actions.append(action)

            obs, rews, dones, info = env.step(actions)

            if self.render:
                frame = env.render(
                    mode="rgb_array",
                    agent_index_focus=None,
                )
                self.frame_list.append(frame)  # Append frame to the list

        total_time = time.time() - init_time
        print(
            f"It took: {total_time}s for {self.n_steps} steps of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )

    def _get_q_learning_action(self, agent: QLearningAgent):
        state = agent.state  # Assuming the agent has a state attribute
        action = agent.choose_action(state)
        return torch.tensor([action], device=self.device, dtype=torch.long).unsqueeze(-1).expand(env.batch_dim, 1)

    def _get_random_action(self, agent: Agent):
        return torch.randint(0, agent.action_space_size, (env.batch_dim, 1), device=self.device, dtype=torch.long)

if __name__ == "__main__":
      
    scenario_name = "waterfall"

    env_runner = QLearningVMASEnvRunner(
        render=True,
        num_envs=32,
        n_steps=10,
        device="cuda",
        scenario=scenario_name,
        continuous_actions=False,
        random_action=False,
        # Environment specific variables
        n_agents=2,
    )
    # Run the VMAS environment
    env_runner.run_vmas_env()



TypeError: QLearningVMASEnvRunner.__init__() missing 1 required positional argument: 'env'