In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import json
import random
import matplotlib.pyplot as plt
import psutil
import pynvml
from collections import Counter
from collections import defaultdict
import ast
# from gym.envs.registration import register
from environment_static import Env

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Actor-critic network architecture
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.actor = nn.Linear(64, action_dim)
        self.critic = nn.Linear(64, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value
    
class ProblemSolver:
    def __init__(self, num_states, num_actions):
        self.policy = ActorCritic(num_states, num_actions)
        self.learning_rate = 0.0005
        self.optimizer = optim.Adam(self.policy.parameters(), lr=self.learning_rate)
        self.gamma = 0.99
        self.eps_clip = 0.2
        self.update_timestep = 20

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, _ = self.policy(state)
        action_probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(action_probs, 1)
        return action.item()

    def calculate_advantages(self, rewards, dones, values, next_value):
        advantages = []
        discounted_sum = 0
        for i in range(len(rewards) - 1, -1, -1):
            discounted_sum = rewards[i] + self.gamma * discounted_sum * (1 - dones[i])
            advantage = discounted_sum - values[i].item()
            advantages.insert(0, advantage)
        return advantages

    def update_policy(self, states, actions, advantages, returns):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        advantages = torch.FloatTensor(advantages).unsqueeze(1)
        returns = torch.FloatTensor(returns)

        unique_actions, unique_indices = torch.unique(actions, return_inverse=True)

        logits, values = self.policy(states)
        values = values.squeeze()

        action_probs = torch.softmax(logits, dim=-1)
        action_masks = torch.zeros_like(action_probs).scatter_(1, unique_actions.unsqueeze(1), 1)
        old_action_probs = torch.sum(action_probs * action_masks[unique_indices.unsqueeze(1)], dim=1)
        ratios = torch.exp(torch.log(old_action_probs + 1e-10) - torch.log(action_probs + 1e-10))

        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        critic_loss = nn.MSELoss()(returns, values)

        loss = actor_loss + 0.5 * critic_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

class Case:
    added_states = set()  # Class attribute to store states already added to the case base

    def __init__(self, problem, solution, trust_value=1):
        self.problem = problem if isinstance(problem, list) else ast.literal_eval(problem)  # Convert problem to numpy array
        self.solution = solution
        self.trust_value = trust_value
    
    @staticmethod
    def sim_q(state1, state2):
        state1 = np.atleast_1d(state1)  # Ensure state1 is at least 1-dimensional
        state2 = np.atleast_1d(state2)  # Ensure state2 is at least 1-dimensional
        CNDMaxDist = 6  # Maximum distance between two nodes in the CND
        v = state1.size  # Total number of objects the agent can perceive
        DistQ = np.sum([Case.Dmin_phi(Objic, Objip) for Objic, Objip in zip(state1, state2)])
        similarity = (CNDMaxDist * v - DistQ) / (CNDMaxDist * v)
        return similarity

    @staticmethod
    def Dmin_phi(X1, X2):
        return np.max(np.abs(X1 - X2))
    

    @staticmethod
    def retrieve(state, case_base, threshold=0.2):
        state = ast.literal_eval(state)
        print(f"state to be measured: {state}")
        print(f"Case base to be measured: {case_base}")
        similarities = {}
        for case in case_base:
            # Convert strings to numerical values if necessary
            problem_numeric = np.array(case.problem, dtype=float)
            state_numeric = np.array(state, dtype=float)
            similarities[case] = Case.sim_q(state_numeric, problem_numeric)  # Compare state with the problem part of the case
        
        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        
        if sorted_similarities:
            most_similar_case = sorted_similarities[0][0] if sorted_similarities[0][1] >= threshold else None
        else:
            most_similar_case = None
        
        return most_similar_case


    @staticmethod
    def reuse(c, temporary_case_base):
        temporary_case_base.append(c)
    

    @staticmethod
    def revise(case_base, temporary_case_base, successful_episodes):
        for case in temporary_case_base:
            if successful_episodes and case in case_base:
                case.trust_value += 0.1  # Increment trust value if the episode ended successfully and the case is in the case base
            elif not successful_episodes and case in case_base:
                case.trust_value -= 0.1  # Decrement trust value if the episode ended unsuccessfully and the case is in the case base
            case.trust_value = max(0, min(case.trust_value,1))  # Ensure trust value is within[0,1]

    @staticmethod
    def retain(case_base, temporary_case_base, successful_episodes, threshold=0):
        if successful_episodes:
            # Iterate through the temporary case base to find the last occurrence of each unique state
            for case in reversed(temporary_case_base):
                state = tuple(np.atleast_1d(case.problem))
                # Check if the state is already in the case base or has been added previously
                if state not in Case.added_states:
                    # Add the case to the case base if the state is new
                    case_base.append(case)
                    Case.added_states.add(state)
            
            # Filter case_base based on trust_value
            filtered_case_base = []
            for case in case_base:
                # print(f"trust value >= Threshold?: {case.trust_value} >= {threshold}?")
                if case.trust_value >= threshold:
                    # print(f"problem | trust value: {case.problem} | {case.trust_value}")
                    # print("case saved dong")
                    filtered_case_base.append(case)
                else:
                    # print(f"problem | trust value: {case.problem} | {case.trust_value}")
                    # print("case unsaved dong")
                    pass

            return filtered_case_base
        else:
            return case_base  # Return original case_base if episode is not successful

            
class QCBRL:
    def __init__(self, num_states, num_actions, env):
        self.num_states = num_states
        self.num_actions = num_actions
        self.env = env
        self.problem_solver = ProblemSolver(num_states, num_actions)
        self.case_base = []
        self.temporary_case_base = []

    def run(self, episodes, max_steps, alpha=0.1, gamma=0.9, epsilon=0.1, render=False):
        # rewards = []
        # episode_rewards = []
        memory_usage = []
        gpu_memory_usage = []
        successful_episodes = False
        num_successful_episodes = 0
        total_steps_list = []

        total_timesteps = 0
        episode_rewards = []
        success_count = 0
        success_steps = []

        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)

        for episode in range(episodes):
            state = self.env.reset()
            states, actions, rewards, dones, next_states = [], [], [], [], []
            total_rewards = 0
            episode_reward = 0
            total_steps = 0 
            self.temporary_case_base = []
            
            for _ in range(max_steps):
                
                episode_steps = 0
                # if render:
                #     env.render()
                action = self.take_action(state, epsilon)
                next_state, reward, done = self.env.step(action)

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)
                next_states.append(next_state)

                c = Case(str(state), action)
                Case.reuse(c, self.temporary_case_base)
                
                state = next_state
                total_steps += 1
                episode_reward += reward

                if done:
                    successful_episodes = reward > 0
                    total_rewards += episode_reward
                    if episode_reward >= 0:  # Considered successful if the total reward is 200 or more
                        success_count += 1
                        success_steps.append(episode_steps)
                    else:
                        success_steps.append(0)
                    break
                
                
            if episode_reward > 0:  # If the agent reached the goal state
                num_successful_episodes += 1
                total_steps_list.append(total_steps)  # Append total steps for this episode   
            else:
                total_steps_list.append(0)
            
            if total_timesteps % self.problem_solver.update_timestep == 0:
                _, next_value = self.problem_solver.policy(torch.FloatTensor(next_states))
                returns = []
                discounted_sum = 0
                for i in range(len(rewards) - 1, -1, -1):
                    discounted_sum = rewards[i] + self.problem_solver.gamma * discounted_sum * (1 - dones[i])
                    returns.insert(0, discounted_sum)
                
                # Convert states to tensor before passing to policy network
                states_tensor = torch.FloatTensor(states)
                
                advantages = self.problem_solver.calculate_advantages(rewards, dones, torch.cat((self.problem_solver.policy(states_tensor)[1], next_value.detach()), 0), next_value)
                self.problem_solver.update_policy(states, actions, advantages, returns)


            episode_rewards.append(episode_reward)
            print(f"Episode {episode + 1}, Total Reward: {episode_reward}")

            Case.revise(self.case_base, self.temporary_case_base, successful_episodes)
            self.case_base = Case.retain(self.case_base, self.temporary_case_base, successful_episodes)
            
            memory_usage.append(psutil.virtual_memory().percent)
            gpu_memory_usage.append(pynvml.nvmlDeviceGetMemoryInfo(handle).used / 1024**2)

        
        self.save_case_base_temporary()  # Save temporary case base after training
        self.save_case_base()  # Save case base after training

        success_rate = (num_successful_episodes / episodes) * 100
        # print(f"Successful episodes: {num_successful_episodes}%")

        # env.close()
        return episode_rewards, success_rate, memory_usage, gpu_memory_usage, total_steps_list

    def take_action(self, state, epsilon):
        # print(f"state before action: {state}")
        state_str = str(state)
        similar_solution = Case.retrieve(state_str, self.case_base)
        if similar_solution is not None:
            action = similar_solution.solution
            print("action from case base")
        else:
            action = self.problem_solver.choose_action(state)
            print("action from problem solver")
        
        # action = self.problem_solver.choose_action(state)
        
        return action
    
    def save_case_base_temporary(self):
        filename = "case_base_temporary.json"
        case_base_data = [{"problem": case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem, 
                        "solution": int(case.solution), 
                        "trust_value": int(case.trust_value)} for case in self.temporary_case_base]
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)
        print("Temporary case base saved successfully.")

    def save_case_base(self):
        filename = "case_base.json"
        case_base_data = [{"problem": case.problem.tolist() if isinstance(case.problem, np.ndarray) else case.problem, 
                        "solution": int(case.solution), 
                        "trust_value": int(case.trust_value)} for case in self.case_base]
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

            print("Case base saved successfully.")  # Add this line to check if the case base is being saved
        
    def load_case_base(self):
        filename = "case_base.json"
        try:
            with open(filename, 'r') as file:
                case_base_data = json.load(file)
                self.case_base = [Case(np.array(case["problem"]), case["solution"], case["trust_value"]) for case in case_base_data]
                print("Case base loaded successfully.")  # Add this line to check if the case base is being loaded
        except FileNotFoundError:
            print("Case base file not found. Starting with an empty case base.")

    
    def display_success_rate(self, success_rate):
        print(f"Success rate: {success_rate}%")


    def plot_rewards(self, rewards):
        plt.plot(rewards)
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.title('Rewards over Episodes')
        plt.grid(True)
        plt.show() 

    def plot_resources(self, memory_usage, gpu_memory_usage):
        plt.plot(memory_usage, label='Memory (%)')
        plt.plot(gpu_memory_usage, label='GPU Memory (MB)')
        plt.xlabel('Episode')
        plt.ylabel('Resource Usage')
        plt.title('Resource Usage over Episodes')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_total_steps(self, total_steps_list):
        plt.plot(total_steps_list)
        plt.xlabel('Episode')
        plt.ylabel('Total Steps')
        plt.title('Total Steps for Successful Episodes over Episodes')
        plt.grid(True)
        plt.show() 
if __name__ == "__main__":
    env = Env()
    state_dim = len(env.reset())
    action_dim = len(env.action_space)
    # print(num_actions)
    # print(num_states)
    agent = QCBRL(state_dim, action_dim, env)
    rewards, success_rate, memory_usage, gpu_memory_usage, total_step_list = agent.run(episodes=100, max_steps=1000, alpha=0.1, gamma=0.9, epsilon=0.1)

    agent.display_success_rate(success_rate)
    agent.plot_rewards(rewards)
    agent.plot_total_steps(total_step_list)
    agent.plot_resources(memory_usage, gpu_memory_usage)

state to be measured: [0, 0]
Case base to be measured: []
action from problem solver
state to be measured: [1, 0]
Case base to be measured: []
action from problem solver
state to be measured: [1, 0]
Case base to be measured: []
action from problem solver
state to be measured: [1, 0]
Case base to be measured: []
action from problem solver
state to be measured: [0, 0]
Case base to be measured: []
action from problem solver
state to be measured: [0, 1]
Case base to be measured: []
action from problem solver
state to be measured: [0, 1]
Case base to be measured: []
action from problem solver
state to be measured: [1, 1]
Case base to be measured: []
action from problem solver
state to be measured: [0, 1]
Case base to be measured: []
action from problem solver
state to be measured: [0, 1]
Case base to be measured: []
action from problem solver
state to be measured: [1, 1]
Case base to be measured: []
action from problem solver
state to be measured: [1, 0]
Case base to be measured: []
action 

KeyboardInterrupt: 