In [1]:
import time
import torch
import json
import ast
from vmas import make_env
from vmas.simulator.core import Agent
from vmas.simulator.scenario import BaseScenario
from typing import Union
from moviepy.editor import ImageSequenceClip
from IPython.display import HTML, display as ipython_display
import numpy as np
import matplotlib.pyplot as plt
from gym.spaces import Discrete 

class ProblemSolver:
    def __init__(self, env, agent_id, alpha=0.1, gamma=0.99, epsilon=0.2, communication_weight=0.5):
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = {}
        self.env = env
        self.agent_id = agent_id
        self.communication_weight = communication_weight  # Weight parameter for incorporating messages

    def get_action(self, agent, env, agent_id, agent_obs):
        agent_obs_cpu = agent_obs[:6].cpu().numpy()  # Transfer only the required slice to CPU
        agent_obs = tuple(np.round(agent_obs_cpu, decimals=5))  # Round the observation

        if agent_obs not in self.q_table:
            self.q_table[agent_obs] = np.zeros(self.env.action_space[self.agent_id].n)

        if np.random.rand() < self.epsilon:
            # Select a random action
            action = np.random.randint(env.action_space[self.agent_id].n)
        else:
            # Select the action with the highest Q-value
            action = np.argmax(self.q_table[agent_obs])
        
        return (action,)  # Return as a tuple

    def update_q_table(self,  agent, env, agent_id, obs, action, reward, next_obs):
        obs_key = tuple(np.round(obs.cpu().numpy(), decimals=5))  # Only transfer to CPU when necessary
        next_obs_key = tuple(np.round(next_obs.cpu().numpy(), decimals=5))
        action = int(action.item())  # Convert tensor to Python scalar

        # print (f"reward obtained = {reward}")

        if isinstance(self.env.action_space[self.agent_id], Discrete):
            action_space_size = self.env.action_space[self.agent_id].n
        else:
            raise ValueError("This Q-learning implementation requires a discrete action space.")

        if obs_key not in self.q_table:
            self.q_table[obs_key] = np.zeros(action_space_size)

        if next_obs_key not in self.q_table:
            self.q_table[next_obs_key] = np.zeros(action_space_size)

        best_next_action = np.argmax(self.q_table[next_obs_key])
        td_target = reward + self.gamma * self.q_table[next_obs_key][best_next_action]

        td_error = td_target - self.q_table[obs_key][action]
        self.q_table[obs_key][action] += self.alpha * td_error

        print(f"Agent {self.agent_id} - Updated Q-table for obs {obs_key}, action {action}, reward {reward}, next_obs {next_obs_key}")

    
    def print_q_table(self):
        print(f"Q-table for Agent {self.agent_id}:")
        for state, actions in self.q_table.items():
            print(f"  State: {state}")
            for action, q_value in enumerate(actions):
                print(f"    Action: {action}, Q-value: {q_value:.5f}")
        print(f"End of Q-table for Agent {self.agent_id}\n")

class Case:
    added_states = set()  # Class attribute to store states already added to the case base

    def __init__(self, problem, solution, trust_value=1):
        self.problem = problem if isinstance(problem, list) else ast.literal_eval(problem)  # Convert problem to numpy array
        self.solution = solution
        self.trust_value = trust_value
    
    @staticmethod
    def sim_q(state1, state2):
        state1 = np.atleast_1d(state1)
        state2 = np.atleast_1d(state2)
        CNDMaxDist = 6  # Maximum distance between two nodes in the-0.9, 0.7, 0.0, 0.0, 0.0, -0.9) CND based on EOPRA reference
        v = state1.size  # Total number of objects the agent can perceive
        DistQ = np.sum([Case.dist_q(Objic, Objip) for Objic, Objip in zip(state1, state2)])
        similarity = (CNDMaxDist * v - DistQ) / (CNDMaxDist * v)
        return similarity

    @staticmethod
    def dist_q(X1, X2):
        return np.min(np.abs(X1 - X2))

    @staticmethod
    def retrieve(agent, env, state, case_base, threshold=0.1):

        # Convert the state to numpy if it's a tensor
        if isinstance(state, torch.Tensor):
            state = state.cpu().numpy()

        # Slice the physical observations
        physical_obs = state[:6]

        if not agent.silent:
            comm_obs = state[6:]
            # Convert comm_obs to a numpy array if it is a tensor
            if isinstance(comm_obs, torch.Tensor):
                comm_obs = comm_obs.cpu().numpy()

        # print(f"physical_obs = {physical_obs}")

        # Ensure the state is in a list format to avoid issues with ast.literal_eval
        state_list = state.tolist() if isinstance(state, np.ndarray) else state
        state_str = json.dumps(state_list)  # Convert list to a JSON string for ast.literal_eval

        # Use ast.literal_eval safely to convert the string back to a list
        state = ast.literal_eval(state_str)

        similarities = {}
        for case in case_base:
            problem_numeric = np.array(case.problem, dtype=float)
            state_numeric = np.array(state, dtype=float)
            
            # print(f"state received = {state_numeric}")
            # print(f"case received = {problem_numeric}")
            # print("---------")
           
            similarities[case] = Case.sim_q(state_numeric, problem_numeric)  # Compare state with the problem part of the case

        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

        if sorted_similarities:
            most_similar_case = sorted_similarities[0][0] if sorted_similarities[0][1] >= threshold else None
        else:
            most_similar_case = None

        return most_similar_case


    @staticmethod
    def reuse(c, temporary_case_base):
        temporary_case_base.append(c)

    @staticmethod
    def revise(case_base, temporary_case_base, successfull_task):
        for case in temporary_case_base:
            if successfull_task and case in case_base:
                case.trust_value += 0.1  # Increment trust value if the episode ended successfully and the case is in the case base
            elif not successfull_task and case in case_base:
                case.trust_value -= 0.1  # Decrement trust value if the episode ended unsuccessfully and the case is in the case base
            case.trust_value = max(0, min(case.trust_value, 1))  # Ensure trust value is within [0, 1]

    @staticmethod
    def retain(case_base, temporary_case_base, successfull_task, threshold=0.7):
        if successfull_task:
            # Iterate through the temporary case base to find the last occurrence of each unique state
            for case in reversed(temporary_case_base):
                state = tuple(np.atleast_1d(case.problem))
                # Check if the state is already in the case base or has been added previously
                if state not in Case.added_states:
                    # Add the case to the case base if the state is new
                    case_base.append(case)
                    Case.added_states.add(state)
                else:
                    # Find the index of the existing case in the case base
                    existing_index = next((i for i, c in enumerate(case_base) if tuple(np.atleast_1d(c.problem)) == state), None)
                    if existing_index is not None:
                        # Get the existing case from the case base
                        existing_case = case_base[existing_index]
                        # Update the trust value of the existing case with the new value from the revise step
                        existing_case.trust_value = case.trust_value

        # Filter case_base based on trust_value
        case_base = [case for case in case_base if case.trust_value >= threshold]
        return case_base


class QCBRLVmasRunner:
    def __init__(
        self,
        render: bool,
        num_envs: int,
        num_episodes: int,
        max_steps_per_episode: int,
        device: str,
        scenario: Union[str, BaseScenario],
        continuous_actions: bool,
        random_action: bool,
        n_agents: int,
        obs_discrete: bool = False,
        **kwargs
    ):
        self.render = render
        self.num_envs = num_envs
        self.num_episodes = num_episodes
        self.max_steps_per_episode = max_steps_per_episode
        self.device = device
        self.scenario = scenario
        self.continuous_actions = continuous_actions
        self.random_action = random_action
        self.obs_discrete = obs_discrete
        self.kwargs = kwargs
        self.frame_list = []  
        self.problem_solver_agents = []
        self.rewards_history = []  
        self.action_counts = {i: {} for i in range(n_agents)}  
        self.agent_rewards_history = {i: [] for i in range(n_agents)}
        self.successful_episodes_individual = {i: 0 for i in range(n_agents)}  # Track successful episodes for each agent
        self.successful_episodes_all_agents = 0  # Track episodes where all agents succeed
        self.case_base = {i: [] for i in range(n_agents)}  # Separate case base for each agent
        self.temporary_case_base = {i: [] for i in range(n_agents)}  # Separate temporary case base for each agent

    def discretize(self, data, bins):
        bins = np.array(bins)
        if np.isscalar(data):
            data = np.array([data])
        bin_indices = np.digitize(data, bins) - 1  # np.digitize returns indices starting from 1
        bin_indices = np.clip(bin_indices, 0, len(bins) - 1)  # Ensure indices are within the valid range
        bin_values = bins[bin_indices]
        bin_values = np.round(bin_values, 4)  
        return bin_indices, bin_values

    def discretize_tensor_slice(self, tensor_slice, bins):
        tensor_np = tensor_slice.cpu().numpy()  # Convert to numpy for easier handling
        indices, values = self.discretize(tensor_np, bins)
        indices = torch.tensor(indices, device=tensor_slice.device)
        values = torch.tensor(values, device=tensor_slice.device)
        return indices, values

    def _get_deterministic_obs(self, env, observation):
        pos_bins = np.linspace(-1, 1, num=50)
        vel_bins = np.linspace(0, 0, num=50)
        lidar_bins = np.linspace(0, 1, num=50)

        pos = observation[0:2]
        vel = observation[2:4]
        goal_pose = observation[4:6]
        comms_data = observation[6:13]
        sensor_data = observation[13:]

        discrete_pos_indices, discrete_pos_values = self.discretize_tensor_slice(pos, pos_bins)
        discrete_vel_indices, discrete_vel_values = self.discretize_tensor_slice(vel, vel_bins)
        discrete_goal_pose_indices, discrete_goal_pose_values = self.discretize_tensor_slice(goal_pose, pos_bins)
        discrete_sensor_data_indices, discrete_sensor_data_values = self.discretize_tensor_slice(sensor_data, lidar_bins)

        concatenated_tensor_values = torch.cat(
            [discrete_pos_values, discrete_vel_values, discrete_goal_pose_values, comms_data, discrete_sensor_data_values],
            dim=0
        )

        return concatenated_tensor_values

    def _get_deterministic_action(self, agent: Agent, env, agent_id, agent_obs):
        if self.continuous_actions:
            if agent.silent:
                action = torch.tensor([[-1, 0.5]], device=env.device)
            else:
                if agent_id == 0:
                    action = torch.tensor([[-1, 0.5, 2]], device=env.device)
                else:
                    action = torch.tensor([[-1, 0.5, 1]], device=env.device)
        else:
            physical_obs = agent_obs[0:6]

            if not agent.silent:
                comm_obs = agent_obs[6:]
            
            physical_action = self.problem_solver_agents[agent_id].get_action(agent, env, agent_id, physical_obs)
            physical_action_tensor = torch.tensor(physical_action, device=self.device)

            if agent.silent:
                action = physical_action_tensor
            else:
                physical_obs_tensor = torch.tensor(physical_obs, device=self.device)
                comm_action_tensor = torch.cat([physical_obs_tensor, physical_action_tensor], dim=0) 

                zero_tensor = torch.zeros(6, dtype=torch.float64, device=self.device)
                first_row = torch.cat((physical_action_tensor, zero_tensor))
                action = torch.stack((first_row, comm_action_tensor)).unsqueeze(0)

        return action


    def save_case_base(self, agent_id):
        filename = f"case_base_{agent_id}.json"
        case_base_data = []
        for case in self.case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            
            if torch.is_tensor(case.solution):
                solution = case.solution.tolist() if case.solution.numel() > 1 else int(case.solution.item())
            else:
                solution = int(case.solution)
            
            if torch.is_tensor(case.trust_value):
                trust_value = case.trust_value.tolist() if case.trust_value.numel() > 1 else float(case.trust_value.item())
            else:
                trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Case base saved successfully.")

    def save_case_base_eps(self, agent_id, eps):
        filename = f"case_base_{agent_id}_{eps}.json"
        case_base_data = []
        for case in self.case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            
            if torch.is_tensor(case.solution):
                solution = case.solution.tolist() if case.solution.numel() > 1 else int(case.solution.item())
            else:
                solution = int(case.solution)
            
            if torch.is_tensor(case.trust_value):
                trust_value = case.trust_value.tolist() if case.trust_value.numel() > 1 else float(case.trust_value.item())
            else:
                trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Case base saved successfully.")


    def save_case_base_temporary(self, agent_id,):
        filename = f"case_base_temporary_{agent_id}.json"
        case_base_data = []
        for case in self.temporary_case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            
            if torch.is_tensor(case.solution):
                solution = case.solution.tolist() if case.solution.numel() > 1 else int(case.solution.item())
            else:
                solution = int(case.solution)
            
            if torch.is_tensor(case.trust_value):
                trust_value = case.trust_value.tolist() if case.trust_value.numel() > 1 else float(case.trust_value.item())
            else:
                trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Temporary case base saved successfully.")

    def save_case_base_temporary_eps(self, agent_id, eps):
        filename = f"case_base_temporary_{agent_id}_{eps}.json"
        case_base_data = []
        for case in self.temporary_case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            
            if torch.is_tensor(case.solution):
                solution = case.solution.tolist() if case.solution.numel() > 1 else int(case.solution.item())
            else:
                solution = int(case.solution)
            
            if torch.is_tensor(case.trust_value):
                trust_value = case.trust_value.tolist() if case.trust_value.numel() > 1 else float(case.trust_value.item())
            else:
                trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Temporary case base saved successfully.")

        
    def load_case_base(self, agent_id):
        filename = f"case_base_{agent_id}.json"
        try:
            with open(filename, 'r') as file:
                case_base_data = json.load(file)
            self.case_base[agent_id] = [Case(problem=np.array(case["problem"]) if isinstance(case["problem"], list) else case["problem"],
                                            solution=case["solution"],
                                            trust_value=case["trust_value"]) for case in case_base_data]
        except FileNotFoundError:
            self.case_base[agent_id] = []


    def generate_gif(self, scenario_name):
        fps = 25
        clip = ImageSequenceClip(self.frame_list, fps=fps)
        clip.write_gif(f'{scenario_name}.gif', fps=fps)
        return HTML(f'<img src="{scenario_name}.gif">')

    def plot_action_distribution(self):
        num_agents = len(self.action_counts)

        for agent_id, counts in self.action_counts.items():
            unique_actions, action_counts = np.unique(list(counts.values()), return_counts=True)
            action_dict = dict(zip(unique_actions, action_counts))
            plt.bar(action_dict.keys(), action_dict.values(), label=f'Agent {agent_id}', alpha=0.7)

        plt.xlabel('Action')
        plt.ylabel('Frequency')
        plt.title('Action Distribution for Each Agent')
        plt.legend()
        plt.show()

    def plot_rewards_history(self):
        num_agents = len(self.agent_rewards_history)

        for agent_id, rewards in self.agent_rewards_history.items():
            plt.plot(range(1, self.num_episodes + 1), rewards, label=f'Agent {agent_id}')

        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.title('Total Reward per Episode for Each Agent')
        plt.legend()
        plt.show()


    def run_vmas_env(self):
        scenario_name = self.scenario if isinstance(self.scenario, str) else self.scenario.__class__.__name__

        env = make_env(
            scenario=self.scenario,
            num_envs=self.num_envs,
            device=self.device,
            continuous_actions=self.continuous_actions,
            **self.kwargs
        )
        
        for agent_id, agent in enumerate(env.agents):
            self.problem_solver_agents.append(ProblemSolver(env, agent_id, communication_weight=0.5))

        init_time = time.time()
        total_steps = 0
        is_episode_success = False

        for episode in range(self.num_episodes):
            print(f"Episode {episode}")
            obs_cont = env.reset()

            done = torch.tensor([False] * self.num_envs, device=self.device)
            step = 0
            
            episode_rewards = {i: 0 for i in range(len(self.problem_solver_agents))}
            episode_done_counts = {i: 0 for i in range(len(env.agents))}  # Track done counts for each agent
            
            self.temporary_case_base = {i: [] for i in range(len(env.agents))}
            
            while not torch.all(done).item() and step < self.max_steps_per_episode:
                step += 1
                total_steps += 1
                print(f"Step {step} of Episode {episode}")

                actions = []
               

                for i, agent in enumerate(env.agents):
                    if self.obs_discrete:
                        discrete_obs = self._get_deterministic_obs(env, obs_cont[i])
                        
                    # print(f"observation continuous agent{i} = {obs_cont[i]}")
                    # print(f"observation discrete agent{i} = {discrete_obs}")

                    case = Case.retrieve(agent, env, discrete_obs[0:6], self.case_base[i], threshold=0.1)
                    
                    if case:
                        action = case.solution
                        Case.reuse(case, self.temporary_case_base[i])
                        print(f"action type of agent {i}: case base")
                    else:
                        if self.random_action:
                            action = env.get_random_action(agent)
                        else:
                            action = self._get_deterministic_action(agent, env, i, discrete_obs)
                        print(f"action type of agent {i}: problem solver")

                    problem = discrete_obs[0:6].cpu().numpy().tolist()
                    new_case = Case(problem, action)
                    self.temporary_case_base[i].append(new_case)

                    actions.append(action)
                
                next_obs_cont, rews, dones, info = env.step(actions)
                # print(f"next obs all agents = {next_obs_cont}")
                # print(f"reward all agents = {rews}")
                
                done = dones
                # print(f"dones status for all agents = {done}")
                # print("--------------------")

                for i, agent in enumerate(env.agents):
                    if self.obs_discrete:
                        discrete_obs = self._get_deterministic_obs(env, obs_cont[i])
                        discrete_next_obs = self._get_deterministic_obs(env, next_obs_cont[i])
                        
                    physical_obs_for_update = discrete_obs[0:6]
                    physical_nextobs_for_update = discrete_next_obs[0:6]
                    physical_actions_for_update = actions[i][0, 0, 0].unsqueeze(0)

                    self.problem_solver_agents[i].update_q_table(agent, env, i,
                        physical_obs_for_update, physical_actions_for_update, rews[i].item(), physical_nextobs_for_update
                    )
                    # Accumulate rewards for each agent within the episode
                    episode_rewards[i] += rews[i].item()

                    if (done[0][i]):  # Increment individual agent's done count
                        episode_done_counts[i] += 1

                obs_cont = next_obs_cont

                if self.render:
                    frame = env.render(
                        mode="rgb_array",
                        agent_index_focus=None,
                    )
                    self.frame_list.append(frame)

                # print("-----------------")
            # Update rewards history after each episode
            for agent_id, total_reward in episode_rewards.items():
                self.agent_rewards_history[agent_id].append(total_reward)

            # Update success based on individual agent's done status
            for agent_id in range(len(env.agents)):
                print(f"done status for agent {agent_id}: {done[0][agent_id]}")
                # print(f"done status for agent {agent_id} v2: {done[0, agent_id]}")
                if done[0,agent_id]:
                    self.successful_episodes_individual[agent_id] += episode_done_counts[agent_id] / step  # Calculate success rate for the agent

            # Update success based on all agents' done status
            if torch.all(done).item():
                self.successful_episodes_all_agents += 1
            
            # Calculate success percentages for each agent
            success_percents = [self.successful_episodes_individual[i] / (episode + 1) * 100 for i in range(len(env.agents))]
            overall_success_percent = self.successful_episodes_all_agents / (episode + 1) * 100

            print(f"Success percentage of each agent at the end of episode {episode}:")
            
            for i, percent in enumerate(success_percents):
                print(f"Agent {i}: {percent}%")

            print(f"Overall success percentage for all agents up to episode {episode}: {overall_success_percent}%")


            for i in range(len(env.agents)):
                print(f"done status for agent {i}: {done[0][i]}")
                
                # Case.revise(self.case_base[i], self.temporary_case_base[i], torch.any(done).item())
                # self.case_base[i] = Case.retain(
                #     self.case_base[i], self.temporary_case_base[i], torch.any(done).item()
                # )

                Case.revise(self.case_base[i], self.temporary_case_base[i], done[0][i])
                self.case_base[i] = Case.retain(
                    self.case_base[i], self.temporary_case_base[i], done[0][i]
                )

                # for case in self.temporary_case_base[i]:
                #     print(f"Step {step} -- Problem Stored in Temp CB: {case.problem}, Solution Stored in Temp CB: {case.solution}")

                # for case in self.case_base[i]:
                #     print(f"Step {step} -- Problem Stored in CB: {case.problem}, Solution Stored in CB: {case.solution}")

                self.save_case_base_temporary_eps(i, episode)  # Save temporary case base after each episode
                self.save_case_base_eps(i, episode)  # Save case base after each episode
            
            torch.cuda.empty_cache()  # Free up unused memory

        
        for agent_id, agent in enumerate(env.agents):
            self.save_case_base_temporary(agent_id)  # Save temporary case base after training
            self.save_case_base(agent_id)  # Save case base after training

        # Print final success percentages for each agent-0.9, 0.7, 0.0, 0.0, 0.0, -0.9)
        
        # Print overall success percentage for all agents
        overall_success_percentage = self.successful_episodes_all_agents / self.num_episodes * 100
        print(f"Overall success percentage for all agents = {overall_success_percentage}%")



        total_time = time.time() - init_time
        print(
            f"It took: {total_time}s for {total_steps} steps across {self.num_episodes} episodes of {self.num_envs} parallel environments on device {self.device} "
            f"for {scenario_name} scenario."
        )

        # success_percentage = (self.successful_episodes / self.num_episodes) * 100
        # print(f"Percentage of successful episodes: {success_percentage}%")


if __name__ == "__main__":
    scenario_name = "navigation_comm"
    use_cuda = True

    env_runner = QCBRLVmasRunner( 
        render=True,
        num_envs=1,
        num_episodes=3,
        max_steps_per_episode=5,
        device = torch.device("cuda" if use_cuda else "cpu"),
        scenario=scenario_name,
        continuous_actions=False,
        random_action=False,
        n_agents=2,
        obs_discrete=True,
        agents_with_same_goal=2,
        collisions=False,
        shared_rew=False,
    )

    env_runner.run_vmas_env()
    # for agent in env_runner.problem_solver_agents:
    #     agent.print_q_table()
    env_runner.plot_rewards_history()

    ipython_display(env_runner.generate_gif(scenario_name))


  import distutils.spawn


Episode 0
Step 1 of Episode 0
action type of agent 0: problem solver
action type of agent 1: problem solver


RuntimeError: Tensors must have same number of dimensions: got 4 and 5