In [8]:
import time
import torch
import json
import ast
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display
import numpy as np
import matplotlib.pyplot as plt
from gym.spaces import Discrete 

class ProblemSolver:
    def __init__(self, env, agent_id, alpha=0.1, gamma=0.99, epsilon=0.2, communication_weight=0.5):
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = {}
        self.env = env
        self.agent_id = agent_id
        self.communication_weight = communication_weight  # Weight parameter for incorporating messages

    def get_action(self, agent, env, agent_id, agent_obs):
        agent_obs_cpu = agent_obs[:6].cpu().numpy()  # Transfer only the required slice to CPU
        agent_obs = tuple(np.round(agent_obs_cpu, decimals=5))  # Round the observation

        if agent_obs not in self.q_table:
            self.q_table[agent_obs] = np.zeros(self.env.action_space[self.agent_id].n)

        if np.random.rand() < self.epsilon:
            # Select a random action
            action = np.random.randint(env.action_space[self.agent_id].n)
        else:
            # Select the action with the highest Q-value
            action = np.argmax(self.q_table[agent_obs])
        
        return (action,)  # Return as a tuple

    def update_q_table(self, obs, action, reward, next_obs):
        obs_key = tuple(np.round(obs.cpu().numpy(), decimals=5))  # Only transfer to CPU when necessary
        next_obs_key = tuple(np.round(next_obs.cpu().numpy(), decimals=5))
        action = int(action.item())  # Convert tensor to Python scalar

        if isinstance(self.env.action_space[self.agent_id], Discrete):
            action_space_size = self.env.action_space[self.agent_id].n
        else:
            raise ValueError("This Q-learning implementation requires a discrete action space.")

        if obs_key not in self.q_table:
            self.q_table[obs_key] = np.zeros(action_space_size)

        if next_obs_key not in self.q_table:
            self.q_table[next_obs_key] = np.zeros(action_space_size)

        best_next_action = np.argmax(self.q_table[next_obs_key])
        td_target = reward + self.gamma * self.q_table[next_obs_key][best_next_action]

        td_error = td_target - self.q_table[obs_key][action]
        self.q_table[obs_key][action] += self.alpha * td_error

class Case:
    added_states = set()  # Class attribute to store states already added to the case base

    def __init__(self, problem, solution, trust_value=1):
        self.problem = problem if isinstance(problem, list) else ast.literal_eval(problem)  # Convert problem to numpy array
        self.solution = solution
        self.trust_value = trust_value
    
    @staticmethod
    def sim_q(state1, state2):
        state1 = np.atleast_1d(state1)
        state2 = np.atleast_1d(state2)
        CNDMaxDist = 6  # Maximum distance between two nodes in the CND based on EOPRA reference
        v = state1.size  # Total number of objects the agent can perceive
        DistQ = np.sum([Case.dist_q(Objic, Objip) for Objic, Objip in zip(state1, state2)])
        similarity = (CNDMaxDist * v - DistQ) / (CNDMaxDist * v)
        return similarity

    @staticmethod
    def dist_q(X1, X2):
        return np.min(np.abs(X1 - X2))

    @staticmethod
    def retrieve(state, case_base, threshold=0.1):
        state = ast.literal_eval(state)

        similarities = {}
        for case in case_base:
            problem_numeric = np.array(case.problem, dtype=float)
            state_numeric = np.array(state, dtype=float)
            similarities[case] = Case.sim_q(state_numeric, problem_numeric)  # Compare state with the problem part of the case
        
        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        
        if sorted_similarities:
            most_similar_case = sorted_similarities[0][0] if sorted_similarities[0][1] >= threshold else None
        else:
            most_similar_case = None
        
        return most_similar_case

    @staticmethod
    def reuse(c, temporary_case_base):
        temporary_case_base.append(c)

    @staticmethod
    def revise(case_base, temporary_case_base, successful_episodes):
        for case in temporary_case_base:
            if successful_episodes and case in case_base:
                case.trust_value += 0.1  # Increment trust value if the episode ended successfully and the case is in the case base
            elif not successful_episodes and case in case_base:
                case.trust_value -= 0.1  # Decrement trust value if the episode ended unsuccessfully and the case is in the case base
            case.trust_value = max(0, min(case.trust_value, 1))  # Ensure trust value is within [0, 1]

    @staticmethod
    def retain(case_base, temporary_case_base, successful_episodes, threshold=0.7):
        if successful_episodes:
            # Iterate through the temporary case base to find the last occurrence of each unique state
            for case in reversed(temporary_case_base):
                state = tuple(np.atleast_1d(case.problem))
                # Check if the state is already in the case base or has been added previously
                if state not in Case.added_states:
                    # Add the case to the case base if the state is new
                    case_base.append(case)
                    Case.added_states.add(state)
                else:
                    # Find the index of the existing case in the case base
                    existing_index = next((i for i, c in enumerate(case_base) if tuple(np.atleast_1d(c.problem)) == state), None)
                    if existing_index is not None:
                        # Get the existing case from the case base
                        existing_case = case_base[existing_index]
                        # Update the trust value of the existing case with the new value from the revise step
                        existing_case.trust_value = case.trust_value

        # Filter case_base based on trust_value
        filtered_case_base = [case for case in case_base if case.trust_value >= threshold]
        return filtered_case_base

class QCBRL:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        num_episodes: int,
        max_steps_per_episode: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        n_agents: int,
        obs_discrete: bool = False,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.num_episodes = num_episodes
        self.max_steps_per_episode = max_steps_per_episode
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.obs_discrete = obs_discrete
        self.kwargs = kwargs
        self.frame_list = []  
        self.problem_solver_agents = []
        self.rewards_history = []  
        self.action_counts = {i: {} for i in range(n_agents)}  
        self.agent_rewards_history = {i: [] for i in range(n_agents)}
        self.successful_episodes = 0
        self.case_base = {i: [] for i in range(n_agents)}  # Separate case base for each agent
        self.temporary_case_base = {i: [] for i in range(n_agents)}  # Separate temporary case base for each agent

    def discretize(self, data, bins):
        bins = np.array(bins)
        if np.isscalar(data):
            data = np.array([data])
        bin_indices = np.digitize(data, bins) - 1  # np.digitize returns indices starting from 1
        bin_indices = np.clip(bin_indices, 0, len(bins) - 1)  # Ensure indices are within the valid range
        bin_values = bins[bin_indices]
        bin_values = np.round(bin_values, 2)  # Round the bin values to two decimal places
        return bin_indices, bin_values

    def discretize_tensor_slice(self, tensor_slice, bins):
        tensor_np = tensor_slice.cpu().numpy()  # Convert to numpy for easier handling
        indices, values = self.discretize(tensor_np, bins)
        indices = torch.tensor(indices, device=tensor_slice.device)
        values = torch.tensor(values, device=tensor_slice.device)
        return indices, values

    def _get_deterministic_obs(self, env, observation):
        pos_bins = np.linspace(-1, 1, 10)
        vel_bins = np.linspace(-1, 1, 10)
        dpos_bins = np.linspace(0, 1.5, 10)
        dvel_bins = np.linspace(0, 2, 10)

        num_envs, num_agents, obs_size = observation.size()
        discretized_obs = torch.zeros((num_envs, num_agents, obs_size), device=observation.device)

        for i in range(num_envs):
            for j in range(num_agents):
                agent_obs = observation[i, j, :]
                pos_indices, pos_values = self.discretize_tensor_slice(agent_obs[:2], pos_bins)
                vel_indices, vel_values = self.discretize_tensor_slice(agent_obs[2:4], vel_bins)
                dpos_indices, dpos_values = self.discretize_tensor_slice(agent_obs[4:6], dpos_bins)
                dvel_indices, dvel_values = self.discretize_tensor_slice(agent_obs[6:8], dvel_bins)
                discretized_obs[i, j, :2] = pos_values
                discretized_obs[i, j, 2:4] = vel_values
                discretized_obs[i, j, 4:6] = dpos_values
                discretized_obs[i, j, 6:8] = dvel_values

        return discretized_obs

    def run_vmas_env(self):
        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            **self.kwargs
        )
        env.reset()
        n_agents = len(env.agents)
        self.problem_solver_agents = [
            ProblemSolver(env, agent_id) for agent_id in range(n_agents)
        ]

        for episode in range(self.num_episodes):
            observation = env.reset()
            observation = self._get_deterministic_obs(env, observation)
            done = False
            episode_reward = 0

            for step in range(self.max_steps_per_episode):
                actions = {}
                for i, agent in enumerate(env.possible_agents):
                    agent_obs = observation[0, i, :].to(self.device)
                    case = Case.retrieve(
                        str(agent_obs.cpu().numpy()), self.case_base[i], threshold=0.1
                    )
                    if case:
                        actions[agent] = case.solution
                        Case.reuse(case, self.temporary_case_base[i])
                    else:
                        action = self.problem_solver_agents[i].get_action(
                            agent, env, i, agent_obs
                        )
                        actions[agent] = action
                        problem = agent_obs.cpu().numpy().tolist()
                        new_case = Case(problem, action)
                        self.temporary_case_base[i].append(new_case)
                next_observation, reward, done, _ = env.step(actions)
                next_observation = self._get_deterministic_obs(env, next_observation)
                episode_reward += sum(reward.values())

                for i, agent in enumerate(env.possible_agents):
                    agent_obs = observation[0, i, :].to(self.device)
                    next_agent_obs = next_observation[0, i, :].to(self.device)
                    self.problem_solver_agents[i].update_q_table(
                        agent_obs, torch.tensor(actions[agent]), reward[agent], next_agent_obs
                    )

                    action = actions[agent]
                    if action in self.action_counts[i]:
                        self.action_counts[i][action] += 1
                    else:
                        self.action_counts[i][action] = 1

                observation = next_observation
                if self.render:
                    self.frame_list.append(env.render(mode="rgb_array")[0])

                if all(done.values()):
                    break

            self.rewards_history.append(episode_reward)
            for i, agent in enumerate(env.possible_agents):
                self.agent_rewards_history[i].append(reward[agent])

            Case.revise(self.case_base, self.temporary_case_base, successful_episodes=True)
            for i in range(n_agents):
                self.case_base[i] = Case.retain(
                    self.case_base[i], self.temporary_case_base[i], successful_episodes=True
                )
            self.temporary_case_base = {i: [] for i in range(n_agents)}

        env.close()
        return self.rewards_history, self.agent_rewards_history, self.action_counts

    def generate_gif(self, scenario_name):
        clip = ImageSequenceClip(self.frame_list, fps=20)
        gif_path = f"{scenario_name}.gif"
        clip.write_gif(gif_path, fps=20)
        return gif_path

    def plot_rewards_history(self):
        plt.plot(self.rewards_history)
        plt.xlabel("Episodes")
        plt.ylabel("Rewards")
        plt.title("Training Rewards Over Episodes")
        plt.show()

if __name__ == "__main__":
    scenario_name = "navigation_comm"
    use_cuda = True

    env_runner = QCBRL(
        render=True,
        num_envs=1,
        num_episodes=50,
        max_steps_per_episode=200,
        device=torch.device("cuda" if use_cuda else "cpu"),
        scenario=scenario_name,
        continuous_actions=False,
        random_action=False,
        n_agents=2,
        obs_discrete=True,
        agents_with_same_goal=2,
        collisions=False,
        shared_rew=False,
    )

    env_runner.run_vmas_env()
    # for agent in env_runner.problem_solver_agents:
    #     agent.print_q_table()
    env_runner.plot_rewards_history()

    ipython_display(HTML(f'<img src="{env_runner.generate_gif(scenario_name)}">'))


AttributeError: 'list' object has no attribute 'size'