In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from gym.envs.registration import register
import random
from collections import deque
import matplotlib.pyplot as plt
import psutil
import pynvml
import json

# Register the custom FrozenLake environment
register(
    id='CustomRewardFrozenLake-v1',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={'map_name': '4x4', 'is_slippery': False},
    max_episode_steps=100,
    reward_threshold=0.78,  # Adjust the reward threshold if needed
)

# Define the custom FrozenLake environment with modified rewards
class CustomRewardFrozenLake(gym.Env):
    def __init__(self):
        self.env = gym.make("CustomRewardFrozenLake-v1")
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        if reward == 0 and not done:
            reward = 0
        elif reward == 0 and done:
            reward = -5
        elif reward == 1:
            reward = 1
        return state, reward, done, info

    def reset(self):
        return self.env.reset()

    def render(self):
        self.env.render()

    def close(self):
        self.env.close()

# Define Actor Network
class Actor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

# Define Critic Network
class Critic(nn.Module):
    def __init__(self, input_dim):
        print(f"Critic Input Dim: {input_dim}")
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define Experience Buffer
class ExperienceBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=buffer_size)

    def add(self, state, action, log_prob, value, reward, done):
        self.buffer.append((state, action, log_prob, value, reward, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, log_probs, values, rewards, dones = zip(*batch)
        return states, actions, log_probs, values, rewards, dones

# PPO loss function
def calculate_ppo_loss(actor, critic, states, actions, log_probs_old, values_old, advantages, epsilon=0.2, c1=0.5, c2=0.01):
    policy = actor(states)
    dist = torch.distributions.Categorical(policy)

    log_probs = dist.log_prob(actions)
    ratios = torch.exp(log_probs - log_probs_old)
    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1.0 - epsilon, 1.0 + epsilon) * advantages

    actor_loss = -torch.min(surr1, surr2).mean()

    values = critic(states)
    critic_loss = (values - values_old).pow(2).mean()

    entropy = dist.entropy().mean()

    total_loss = actor_loss + c1 * critic_loss - c2 * entropy

    return actor_loss, critic_loss

# Discounted rewards function
def discount_rewards(rewards, gamma=0.99):
    discounted_rewards = []
    running_add = 0
    for r in reversed(rewards):
        running_add = running_add * gamma + r
        discounted_rewards.insert(0, running_add)
    return discounted_rewards

# Function to compute advantages
def compute_advantages(critic, states, rewards):
    values = critic(states).squeeze()
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)  # Convert rewards to a PyTorch tensor
    advantages = rewards_tensor - values
    return advantages

class Case:
    added_states = set()  # Class attribute to store states already added to the case base

    def __init__(self, problem, solution, trust_value=1):
        self.problem = np.array(problem)  # Convert problem to numpy array
        self.solution = solution
        self.trust_value = trust_value
    
    @staticmethod
    def sim_q(state1, state2):
        state1 = np.atleast_1d(state1)  # Ensure state1 is at least 1-dimensional
        state2 = np.atleast_1d(state2)  # Ensure state2 is at least 1-dimensional
        CNDMaxDist = 6  # Maximum distance between two nodes in the CND
        v = state1.size  # Total number of objects the agent can perceive
        DistQ = np.sum([Case.Dmin_phi(Objic, Objip) for Objic, Objip in zip(state1, state2)])
        similarity = (CNDMaxDist * v - DistQ) / (CNDMaxDist * v)
        return similarity

    @staticmethod
    def Dmin_phi(X1, X2):
        return np.max(np.abs(X1 - X2))
    

    @staticmethod
    def retrieve(state, case_base, threshold=0.2):
        similarities = {}
        for case in case_base:
            similarities[case] = Case.sim_q(state, case.problem)  # Compare state with the problem part of the case
        
        sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        
        if sorted_similarities:
            most_similar_case = sorted_similarities[0][0] if sorted_similarities[0][1] >= threshold else None
        else:
            most_similar_case = None
        
        return most_similar_case

    @staticmethod
    def reuse(c, temporary_case_base):
        temporary_case_base.append(c)
        # Store the new case from the problem solver
        # if c not in temporary_case_base:
            # temporary_case_base.append(c)
        
        # Check if there are similar cases in case_base
        # similar_cases = [case for case in case_base if np.array_equal(case.problem, c.problem)]
        # for similar_case in similar_cases:
            # temporary_case_base.append(similar_case)
            # if similar_case not in temporary_case_base:
            #     temporary_case_base.append(similar_case)

    @staticmethod
    def revise(case_base, temporary_case_base, successful_episodes):
        for case in temporary_case_base:
            if successful_episodes and case in case_base:
                case.trust_value += 0.1  # Increment trust value if the episode ended successfully and the case is in the case base
            elif not successful_episodes and case in case_base:
                case.trust_value -= 0.1  # Decrement trust value if the episode ended unsuccessfully and the case is in the case base
            case.trust_value = max(0, min(case.trust_value,1))  # Ensure trust value is within[0,1]

    @staticmethod
    def retain(case_base, temporary_case_base, successful_episodes, threshold=0):
        if successful_episodes:
            # Iterate through the temporary case base to find the last occurrence of each unique state
            for case in reversed(temporary_case_base):
                state = tuple(np.atleast_1d(case.problem))
                # Check if the state is already in the case base or has been added previously
                if state not in Case.added_states:
                    # Add the case to the case base if the state is new
                    case_base.append(case)
                    Case.added_states.add(state)
            
            # Filter case_base based on trust_value
            filtered_case_base = []
            for case in case_base:
                # print(f"trust value >= Threshold?: {case.trust_value} >= {threshold}?")
                if case.trust_value >= threshold:
                    # print(f"problem | trust value: {case.problem} | {case.trust_value}")
                    # print("case saved dong")
                    filtered_case_base.append(case)
                else:
                    # print(f"problem | trust value: {case.problem} | {case.trust_value}")
                    # print("case unsaved dong")
                    pass

            return filtered_case_base
        else:
            return case_base  # Return original case_base if episode is not successful

class QCBRL:
    def __init__(self, num_actions, env):
        self.num_actions = num_actions
        self.env = env
        self.problem_solver = ProblemSolver(env)
        self.case_base = []
        self.temporary_case_base = []

    def run(self, episodes=100, max_steps=100, alpha=0.1, gamma=0.9, epsilon=0.1, render=False):
        rewards = []
        # episode_rewards = []
        memory_usage = []
        gpu_memory_usage = []
        successful_episodes = False
        num_successful_episodes = 0

        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)

        for episode in range(episodes):
            state = self.env.reset()
            # print(f"State: {state}")
            episode_reward = 0
            self.temporary_case_base = []
            
            states = []
            actions = []
            log_probs = []
            values = []
            rewards = []
            dones = []

            for _ in range(max_steps):
                if render:
                    env.render()
                action = self.take_action(state, epsilon)
                next_state, reward, done, _ = self.env.step(action)

                next_action = self.take_action(next_state, epsilon)

                state_one_hot = torch.zeros(1, env.observation_space.n)
                state_one_hot[0, state] = 1
                states.append(state_one_hot)
                actions.append(action)
                rewards.append(reward)
                dones.append(done)

                c = Case(state, action)
                Case.reuse(c, self.temporary_case_base)

                log_prob = None  # Replace None with the actual log probability
                value = None  # Replace None with the actual value
                log_probs.append(log_prob)  # Append the log probability to the list
                values.append(value)  # Append the value to the list
                rewards.append(reward)
                dones.append(done)

                # print("Shape of states:", torch.stack(states).shape)
                # states_reshaped = torch.stack(states).float().view(1, -1)
                

                state = next_state
                episode_reward += reward

                if done:
                    successful_episodes = reward > 0
                    break
            
            print(f"States: {states}")
            self.problem_solver.update_policy(states, actions, log_probs, values, rewards, dones)
            
            if episode_reward > 0:  # If the agent reached the goal state
                num_successful_episodes += 1

            rewards.append(episode_reward)
            print(f"Episode {episode + 1}, Total Reward: {episode_reward}")

            Case.revise(self.case_base, self.temporary_case_base, successful_episodes)
            self.case_base = Case.retain(self.case_base, self.temporary_case_base, successful_episodes)
            
            memory_usage.append(psutil.virtual_memory().percent)
            gpu_memory_usage.append(pynvml.nvmlDeviceGetMemoryInfo(handle).used / 1024**2)

        
        self.save_case_base_temporary()  # Save temporary case base after training
        self.save_case_base()  # Save case base after training

        success_rate = (num_successful_episodes / episodes) * 100
        # print(f"Successful episodes: {num_successful_episodes}%")

        # env.close()
        return rewards, success_rate, memory_usage, gpu_memory_usage

    def take_action(self, state, epsilon):

        # similar_solution = Case.retrieve(state, self.case_base)
        # if similar_solution is not None:
        #     action = similar_solution.solution
        #     # print("action from case base")
        # else:
        #     action = self.problem_solver.choose_action(state, epsilon)
        #     # print("action from problem solver")
        
        action = self.problem_solver.choose_action(state, epsilon)
        
        return action
    
    def save_case_base_temporary(self):
        filename = "case_base_temporary.json"
        case_base_data = [{"problem": case.problem.tolist() if isinstance(case.problem, np.ndarray) else int(case.problem), 
                        "solution": int(case.solution), 
                        "trust_value": int(case.trust_value)} for case in self.temporary_case_base]
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)
        print("Temporary case base saved successfully.")

    def save_case_base(self):
        filename = "case_base.json"
        case_base_data = [{"problem": case.problem.tolist() if isinstance(case.problem, np.ndarray) else int(case.problem), 
                        "solution": int(case.solution), 
                        "trust_value": int(case.trust_value)} for case in self.case_base]
        with open(filename, 'w') as file:
            json.dump(case_base_data, file)

            print("Case base saved successfully.")  # Add this line to check if the case base is being saved
        
    def load_case_base(self):
        filename = "case_base.json"
        try:
            with open(filename, 'r') as file:
                case_base_data = json.load(file)
                self.case_base = [Case(np.array(case["problem"]), case["solution"], case["trust_value"]) for case in case_base_data]
                print("Case base loaded successfully.")  # Add this line to check if the case base is being loaded
        except FileNotFoundError:
            print("Case base file not found. Starting with an empty case base.")

    
    def display_success_rate(self, success_rate):
        print(f"Success rate: {success_rate}%")


    def plot_rewards(self, rewards):
        plt.plot(rewards)
        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.title('Rewards over Episodes')
        plt.grid(True)
        plt.show() 

    def plot_resources(self, memory_usage, gpu_memory_usage):
        plt.plot(memory_usage, label='Memory (%)')
        plt.plot(gpu_memory_usage, label='GPU Memory (MB)')
        plt.xlabel('Episode')
        plt.ylabel('Resource Usage')
        plt.title('Resource Usage over Episodes')
        plt.legend()
        plt.grid(True)
        plt.show()

class ProblemSolver:
    def __init__(self, env):
        self.env = env
        self.actor = Actor(env.observation_space.n, env.action_space.n)
        self.critic = Critic(env.observation_space.n)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=0.001)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=0.001)

    def choose_action(self, state, epsilon):
        state_one_hot = torch.zeros(1, self.env.observation_space.n)
        state_one_hot[0, state] = 1
        policy = self.actor(state_one_hot)
        action_probs = torch.softmax(policy, dim=-1)
        action = np.random.choice(env.action_space.n, p=action_probs.detach().numpy().flatten())
        return action

    def update_policy(self, states, actions, log_probs, values, rewards, dones):
        discounted_rewards = discount_rewards(rewards)

        print(f"Shape before Advantages: {states}")
        advantages = compute_advantages(self.critic, torch.stack(states), discounted_rewards)
        
        actions_tensor = torch.tensor(actions, dtype=torch.int64)
        log_probs_tensor = torch.cat(log_probs)
        values_tensor = torch.cat(values)
        advantages_tensor = torch.tensor(advantages, dtype=torch.float32)

        actor_loss, critic_loss = calculate_ppo_loss(self.actor, self.critic, states, actions_tensor, log_probs_tensor, values_tensor, advantages_tensor)

        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.actor_optimizer.step()
        self.critic_optimizer.step()


if __name__ == "__main__":
    env = CustomRewardFrozenLake()
    num_states = env.observation_space.n
    num_actions = env.action_space.n

    agent = QCBRL(num_actions, env)
    rewards, success_rate, memory_usage, gpu_memory_usage = agent.run(episodes=1000, max_steps=1000, alpha=0.1, gamma=0.9, epsilon=0.1)

    agent.display_success_rate(success_rate)
    agent.plot_rewards(rewards)
    agent.plot_resources(memory_usage, gpu_memory_usage)


Critic Input Dim: 16
States: [tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])]
Shape before Advantages: [tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 0., 0.,

  logger.warn(f"Overriding environment {id}")


RuntimeError: The size of tensor a (14) must match the size of tensor b (7) at non-singleton dimension 0