In [1]:
import numpy as np
import json
import random
import matplotlib.pyplot as plt
import psutil
import pynvml
from collections import Counter
from collections import defaultdict
import ast
from environment_ma import Env

class ProblemSolver:
    def __init__(self, env, agent_id, alpha=0.1, gamma=0.99, epsilon=0.2, communication_weight=0.5):
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = {}
        self.env = env
        self.agent_id = agent_id
        self.communication_weight = communication_weight  # Weight parameter for incorporating messages

    def get_action(self, agent, env, agent_id, agent_obs):
        agent_obs_cpu = agent_obs[:6].cpu().numpy()  # Transfer only the required slice to CPU
        agent_obs = tuple(np.round(agent_obs_cpu, decimals=5))  # Round the observation

        if agent_obs not in self.q_table:
            self.q_table[agent_obs] = np.zeros(self.env.action_space[self.agent_id].n)

        if np.random.rand() < self.epsilon:
            # Select a random action
            action = np.random.randint(env.action_space[self.agent_id].n)
        else:
            # Select the action with the highest Q-value
            action = np.argmax(self.q_table[agent_obs])
        
        return (action,)  # Return as a tuple

    def update_q_table(self,  agent, env, agent_id, obs, action, reward, next_obs):
        obs_key = tuple(np.round(obs.cpu().numpy(), decimals=5))  # Only transfer to CPU when necessary
        next_obs_key = tuple(np.round(next_obs.cpu().numpy(), decimals=5))
        action = int(action.item())  # Convert tensor to Python scalar

        # print (f"reward obtained = {reward}")

       
        action_space_size = self.env.action_space[self.agent_id].n
        
        if obs_key not in self.q_table:
            self.q_table[obs_key] = np.zeros(action_space_size)

        if next_obs_key not in self.q_table:
            self.q_table[next_obs_key] = np.zeros(action_space_size)

        best_next_action = np.argmax(self.q_table[next_obs_key])
        td_target = reward + self.gamma * self.q_table[next_obs_key][best_next_action]

        td_error = td_target - self.q_table[obs_key][action]
        self.q_table[obs_key][action] += self.alpha * td_error

        print(f"Agent {self.agent_id} - Updated Q-table for obs {obs_key}, action {action}, reward {reward}, next_obs {next_obs_key}")

    
    def print_q_table(self):
        print(f"Q-table for Agent {self.agent_id}:")
        for state, actions in self.q_table.items():
            print(f"  State: {state}")
            for action, q_value in enumerate(actions):
                print(f"    Action: {action}, Q-value: {q_value:.5f}")
        print(f"End of Q-table for Agent {self.agent_id}\n")




class Case:
    added_states = set()  # Class attribute to store states already added to the case base

    def __init__(self, problem, solution, trust_value=1):
        self.problem = problem if isinstance(problem, list) else ast.literal_eval(problem)  # Convert problem to numpy array
        self.solution = solution
        self.trust_value = trust_value
    
    @staticmethod
    def sim_q(state1, state2):
        state1 = np.atleast_1d(state1)
        state2 = np.atleast_1d(state2)
        CNDMaxDist = 6  # Maximum distance between two nodes in the-0.9, 0.7, 0.0, 0.0, 0.0, -0.9) CND based on EOPRA reference
        v = state1.size  # Total number of objects the agent can perceive
        DistQ = np.sum([Case.dist_q(Objic, Objip) for Objic, Objip in zip(state1, state2)])
        similarity = (CNDMaxDist * v - DistQ) / (CNDMaxDist * v)
        return similarity

    @staticmethod
    def dist_q(X1, X2):
        return np.min(np.abs(X1 - X2))

    @staticmethod
    def retrieve(agent, env, state, case_base, threshold=0.1):

        # Slice the physical observations
        physical_obs = state[:6]

        if not agent.silent:
            comm_obs = state[6:]
            
        # print(f"physical_obs = {physical_obs}")

        # Ensure the state is in a list format to avoid issues with ast.literal_eval
        state_list = state.tolist() if isinstance(state, np.ndarray) else state
        state_str = json.dumps(state_list)  # Convert list to a JSON string for ast.literal_eval

        # Use ast.literal_eval safely to convert the string back to a list
        state = ast.literal_eval(state_str)

        similarities = {}
        for case in case_base:
            problem_numeric = np.array(case.problem, dtype=float)
            state_numeric = np.array(state, dtype=float)
            
            # print(f"state received = {state_numeric}")
            # print(f"case received = {problem_numeric}")
            # print("---------")
           
            similarities[case] = Case.sim_q(state_numeric, problem_numeric)  # Compare state with the problem part of the case

        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

        if sorted_similarities:
            most_similar_case = sorted_similarities[0][0] if sorted_similarities[0][1] >= threshold else None
        else:
            most_similar_case = None

        return most_similar_case


    @staticmethod
    def reuse(c, temporary_case_base):
        temporary_case_base.append(c)

    @staticmethod
    def revise(case_base, temporary_case_base, successfull_task):
        for case in temporary_case_base:
            if successfull_task and case in case_base:
                case.trust_value += 0.1  # Increment trust value if the episode ended successfully and the case is in the case base
            elif not successfull_task and case in case_base:
                case.trust_value -= 0.1  # Decrement trust value if the episode ended unsuccessfully and the case is in the case base
            case.trust_value = max(0, min(case.trust_value, 1))  # Ensure trust value is within [0, 1]

    @staticmethod
    def retain(case_base, temporary_case_base, successfull_task, threshold=0.7):
        if successfull_task:
            # Iterate through the temporary case base to find the last occurrence of each unique state
            for case in reversed(temporary_case_base):
                state = tuple(np.atleast_1d(case.problem))
                # Check if the state is already in the case base or has been added previously
                if state not in Case.added_states:
                    # Add the case to the case base if the state is new
                    case_base.append(case)
                    Case.added_states.add(state)
                else:
                    # Find the index of the existing case in the case base
                    existing_index = next((i for i, c in enumerate(case_base) if tuple(np.atleast_1d(c.problem)) == state), None)
                    if existing_index is not None:
                        # Get the existing case from the case base
                        existing_case = case_base[existing_index]
                        # Update the trust value of the existing case with the new value from the revise step
                        existing_case.trust_value = case.trust_value

        # Filter case_base based on trust_value
        case_base = [case for case in case_base if case.trust_value >= threshold]
        return case_base



class QCBRL:
    def __init__(self, num_agents, is_agent_silent, num_actions, max_steps_per_episode):
        self.actions = num_actions
        self.max_steps_per_episode = max_steps_per_episode
        self.num_agents = num_agents
        self.is_agent_silent = is_agent_silent
        self.case_base = {i: [] for i in range(self.num_agents)}  # Separate case base for each agent
        self.temporary_case_base = {i: [] for i in range(self.num_agents)}  # Separate temporary case base for each agent

        
        self.learning_rate = 0.01
        self.discount_factor = 0.9
        self.epsilon = 0.1
        self.q_tables = [defaultdict(lambda: [0.0] * len(self.actions)) for _ in range(self.num_agents)]
        self.rewards_per_episode = [[] for _ in range(self.num_agents)]  # Initialize rewards list for each agent
        self.successful_episodes = [0] * self.num_agents  # Initialize successful episode count for each agent
        self.total_successful_episodes = 0  # Initialize total successful episode count for all agents
        self.problem_solver_agents = []

    def run(self, episodes, max_steps, alpha=0.1, gamma=0.9, epsilon=0.1, render=False):

        env = Env(num_agents=num_agents, is_agent_silent=is_agent_silent)
        
        rewards = []
        # episode_rewards = []
        memory_usage = []
        gpu_memory_usage = []
        successful_episodes = False
        num_successful_episodes = 0
        total_steps_list = []

        success_steps = []

        for agent_id, agent in enumerate(env.agents):
            self.problem_solver_agents.append(ProblemSolver(env, agent_id, communication_weight=0.5))


        for episode in range(episodes):
            states = env.reset()
            self.temporary_case_base = []

            total_reward = [0] * self.num_agents  # Initialize total reward for each agent
            step_count = 0
            dones = [False] * self.num_agents
            win_states = []
            success_episode = False
            success_count = [0] * self.num_agents  # Track successful episodes for each agent
            
            self.temporary_case_base = {i: [] for i in range(len(env.agents))}
            
            while not (all(dones) and all(win_states)) and self.max_steps_per_episode:
                combination_actions = []

                for agent_idx in range(self.num_agents):
                    state = states[agent_idx]
                    physical_action, comm_action = self.take_action(agent_idx, state, self.case_base[agent_idx])
                    combination_actions.append((physical_action, comm_action))
                    # print(f"state agents {agent_idx}: {state}")

                    new_case = Case(state, physical_action)
                    Case.reuse(new_case, self.temporary_case_base[agent_idx])

                next_states, rewards, dones = self.env.step(combination_actions)
                

                win_states = []
                for agent_idx in range(self.num_agents):
                    state = states[agent_idx]
                    physical_action = combination_actions[agent_idx][0]
                    comm_action = combination_actions[agent_idx][1]
                    reward = rewards[agent_idx]
                    next_state = next_states[agent_idx]
                    win_state = next_state[1]  

                    self.learn(agent_idx, state, physical_action, reward, next_state, comm_action)
                    self.problem_solver_agents[agent_idx].learn(agent_idx, state, physical_action, reward, next_state)

                    total_reward[agent_idx] += reward
                    # states[agent_idx] = next_state

                    # Check if agent reached target and mark episode as successful
                    if (win_state):  # Use circle's coordinates for target
                        success_count[agent_idx] += 1
                        print(f"agent{agent_idx} hit !!!!!")
                    else:
                        print(f"agent{agent_idx} not hit !!!!!")

                    # print(f"Reward for agent {agent_idx}: {reward}")
                    # print(f"next state for agent {agent_idx}: {next_state}")
                    # print(f"target: {env.get_circle_grid_position()}")
                    
                    
                    win_states.append(win_state)
                
                states = next_states  

                
                # print(f"State: {state} - Action: {action} - Reward: {reward} - Done: {done}")


            if all(dones) and all(win_states):
                self.total_successful_episodes += 1
                success_steps.append(step_count)
                success_episode = True

            for agent_idx in range(self.env.num_agents):
                self.rewards_per_episode[agent_idx].append(total_reward[agent_idx])
                print(f"Agent {agent_idx} Hit Count: {success_count[agent_idx]}")
                print(f"success hit rate for agent {agent_idx} at episode {episode}: {success_count[agent_idx]/step_count*100}%")

                Case.revise(self.case_base[agent_idx], self.temporary_case_base[agent_idx], success_episode)
                self.case_base[agent_idx] = Case.retain(
                    self.case_base[agent_idx], self.temporary_case_base[agent_idx], success_episode
                )

                self.save_case_base_temporary_eps(agent_idx, episode)  # Save temporary case base after each episode
                self.save_case_base_eps(agent_idx, episode)  # Save case base after each episode

            print(f"Episode: {episode + 1}, Total Steps: {step_count}, Total Rewards: {total_reward}, Status Episode: {success_episode}")
            print("--------------------")


        
        for agent_id, agent in enumerate(env.agents):
            self.save_case_base_temporary(agent_id)  # Save temporary case base after training
            self.save_case_base(agent_id)  # Save case base after training


        success_rate = (num_successful_episodes / episodes) * 100
        # print(f"Successful episodes: {num_successful_episodes}%")

        # env.close()
        return rewards, success_rate, memory_usage, gpu_memory_usage, total_steps_list

    def take_action(self, agent_idx, state, case_base):

        # print(f"agent: {agent}")

        case = Case.retrieve(agent_idx, state, case_base)
                    
        if case:
            physical_action = case.solution
            print(f"action type of agent {agent_idx}: case base")
        else:
            physical_action = self.get_action(agent_idx, state)
            print(f"action type of agent {agent_idx}: problem solver")
        
        if self.env.is_agent_silent:
            communication_action = None
        else:
            communication_action = f"Message send from agent {agent_idx}"
        
        
        return (physical_action, communication_action)

    
    def save_case_base_temporary(self, agent_id,):
        filename = f"case_base_temporary_{agent_id}.json"
        case_base_data = []
        for case in self.temporary_case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            solution = int(case.solution)
            trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Temporary case base saved successfully.")

    
    
    def save_case_base(self, agent_id):
        filename = f"case_base_{agent_id}.json"
        case_base_data = []
        for case in self.case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            solution = int(case.solution)
            trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Case base saved successfully.")


    def save_case_base_temporary_eps(self, agent_id, eps):
        filename = f"case_base_temporary_{agent_id}_{eps}.json"
        case_base_data = []
        for case in self.temporary_case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            solution = int(case.solution)
            trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Temporary case base saved successfully.")



    def save_case_base_eps(self, agent_id, eps):
        filename = f"case_base_{agent_id}_{eps}.json"
        case_base_data = []
        for case in self.case_base[agent_id]:
            problem = case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem
            solution = int(case.solution)
            trust_value = float(case.trust_value)
            
            case_base_data.append({
                "problem": problem,
                "solution": solution,
                "trust_value": trust_value
            })
        
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

        print("Case base saved successfully.")
        
    
    def load_case_base(self, agent_id):
        filename = f"case_base_{agent_id}.json"
        try:
            with open(filename, 'r') as file:
                case_base_data = json.load(file)
            self.case_base[agent_id] = [Case(problem=np.array(case["problem"]) if isinstance(case["problem"], list) else case["problem"],
                                            solution=case["solution"],
                                            trust_value=case["trust_value"]) for case in case_base_data]
        except FileNotFoundError:
            self.case_base[agent_id] = []

    
    def display_success_rate(self, success_rate):
        print(f"Success rate: {success_rate}%")


    def plot_rewards(self, rewards):
        plt.plot(rewards)
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.title('Rewards over Episodes')
        plt.grid(True)
        plt.show() 

    def plot_resources(self, memory_usage, gpu_memory_usage):
        plt.plot(memory_usage, label='Memory (%)')
        plt.plot(gpu_memory_usage, label='GPU Memory (MB)')
        plt.xlabel('Episode')
        plt.ylabel('Resource Usage')
        plt.title('Resource Usage over Episodes')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_total_steps(self, total_steps_list):
        plt.plot(total_steps_list)
        plt.xlabel('Episode')
        plt.ylabel('Total Steps')
        plt.title('Total Steps for Successful Episodes over Episodes')
        plt.grid(True)
        plt.show() 


if __name__ == "__main__":
    num_agents = 1
    is_agent_silent = False  # Set to True or False to enable/disable communication
    num_actions = list(range(5)) + [(i, 'send') for i in range(5)]
    max_steps_per_episode = 100
    agent = QCBRL(num_agents, is_agent_silent, num_actions, max_steps_per_episode)

    rewards, success_rate, memory_usage, gpu_memory_usage, total_step_list = agent.run(episodes=50, max_steps=300, alpha=0.1, gamma=0.9, epsilon=0.1)


    agent.display_success_rate(success_rate)
    agent.plot_rewards(rewards)
    agent.plot_total_steps(total_step_list)
    agent.plot_resources(memory_usage, gpu_memory_usage)

TypeError: Case.retrieve() missing 1 required positional argument: 'case_base'