In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, Counter
import random
import matplotlib.pyplot as plt
import gym
from gym import spaces
from qsimbench import get_outcomes, get_index
from typing import Dict, List, Tuple
from tqdm import tqdm
import os
import json
from ast import literal_eval

# --- Oracle and Caching System ---
# This section defines the "teacher" for our RL agent. The Oracle determines
# the correct answer (the optimal number of shots), which is used to score
# the agent's performance and provide a learning signal (the reward)

ORACLE_CACHE = {}
CACHE_FILE = "oracle_cache_4_features.json"

def save_oracle_cache():
    """
    Saves the dictionary of computed optimal shots to a JSON file
    This is a critical optimization, as running the Oracle is the most time-consuming part of the program
    Saving the results allows us to run the Oracle only once for the entire dataset
    """
    # JSON cannot serialize tuple keys, so we convert them to strings first
    string_key_cache = {str(k): v for k, v in ORACLE_CACHE.items()}
    with open(CACHE_FILE, 'w') as f:
        json.dump(string_key_cache, f)
    print(f"Oracle cache with {len(ORACLE_CACHE)} items saved to {CACHE_FILE}")

def load_oracle_cache():
    """
    Loads the pre-computed optimal shots from the cache file if it exists
    This dramatically speeds up subsequent training runs
    """
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'r') as f:
                string_key_cache = json.load(f)
                # Convert the string keys from the JSON file back into their original tuple format
                for k_str, v in string_key_cache.items():
                    ORACLE_CACHE[literal_eval(k_str)] = v
            print(f"Loaded {len(ORACLE_CACHE)} items from oracle cache.")
        except Exception as e:
            print(f"Warning: Could not read cache file '{CACHE_FILE}'. Starting empty. Error: {e}")

def find_optimal_shots(algorithm: str, size: int, backend: str, step_size=50, max_shots=20000, threshold=0.01, stability_k=3):
    """
    ================================================================================
    *** THE ORACLE ***
    ================================================================================
    PURPOSE:
        This function serves as the "teacher" by implementing the Incremental Execution (IE) algorithm described in the paper
        It iteratively runs batches of shots and uses a statistical metric to find the point of convergence, which we define as the "optimal" shot count.

    ROLE IN THE SYSTEM:
        Its output (`optimal_shots`) is used ONLY to calculate the final reward for the RL agent during training and as a benchmark during evaluation
        The agent itself never gets to see this value while it is making decisions

    METHOD:
        It determines convergence by measuring the Total Variation Distance (TVD) between the cumulative probability distribution of outcomes at consecutive steps
        If the TVD remains below a `threshold` for `stability_k` consecutive iterations, the distribution is considered stable, and the process stops
    """
    # Check if we have already computed this value to save time
    cache_key = (algorithm, size, backend)
    if cache_key in ORACLE_CACHE:
        return ORACLE_CACHE[cache_key]

    cumulative_counts = Counter()  # Stores the aggregated outcomes (e.g., {'01': 10, '10': 12})
    stable_iterations = 0          # Counter for consecutive stable steps
    prev_dist = {}                 # Stores the probability distribution from the previous step

    def normalize_dist(counts: Dict[str, int]) -> Dict[str, float]:
        """Converts raw counts into a normalized probability distribution"""
        total = sum(counts.values())
        return {k: v / total for k, v in counts.items()} if total > 0 else {}

    def total_variation_distance(p: Dict[str, float], q: Dict[str, float]) -> float:
        """
        Calculates the TVD between two probability distributions, p and q
        TVD is a metric of the distance between two distributions
        A small TVD (close to 0) means the distributions are very similar
        The formula is: TVD(p, q) = 0.5 * sum(|p(i) - q(i)| for all outcomes i)
        """
        all_keys = set(p) | set(q) # Consider all outcomes present in either distribution
        return 0.5 * sum(abs(p.get(k, 0) - q.get(k, 0)) for k in all_keys)

    # Main loop of the Incremental Execution algorithm
    for total_shots_so_far in range(step_size, max_shots + step_size, step_size):
        try:
            # Use qsimbench to simulate running a new batch of shots
            new_batch_counts = get_outcomes(algorithm, size, backend, shots=step_size, strategy='random', exact=True)
        except Exception:
            # If the simulation fails for any reason, we default to the maximum shot count as a fallback
            ORACLE_CACHE[cache_key] = max_shots
            return max_shots

        cumulative_counts.update(new_batch_counts)
        curr_dist = normalize_dist(cumulative_counts)

        if prev_dist:
            tvd = total_variation_distance(prev_dist, curr_dist)
            if tvd < threshold:
                stable_iterations += 1
            else:
                stable_iterations = 0 # Reset if the distribution becomes unstable again

            # If the distribution has been stable for `stability_k` steps, we've found the optimal point
            if stable_iterations >= stability_k:
                ORACLE_CACHE[cache_key] = total_shots_so_far
                return total_shots_so_far
                
        prev_dist = curr_dist
        
    # If the loop finishes without finding a stable point, the optimal is defined as the maximum allowed shots
    ORACLE_CACHE[cache_key] = max_shots
    return max_shots

# --- The Reinforcement Learning Environment ---

class IterativeQuantumEnv(gym.Env):
    """
    This class defines the "game" or environment for the RL agent, following the OpenAI Gym interface
    An "episode" in this game consists of iteratively running batches of shots for a single quantum problem, 
    where the agent's goal is to decide at each step whether to CONTINUE or STOP
    """
    def __init__(self, max_shots=20000, step_size=50):
        super().__init__()
        self.max_shots = max_shots
        self.step_size = step_size
        self.index = get_index()  # Get all available problems from qsimbench
        self.all_triplets = self._get_all_noisy_triplets()
        self.alg_map, self.backend_map = self._create_mappings()

        print("Pre-computing optimal shots for all triplets using the oracle...")
        # This loop populates the oracle cache to ensure training is fast.
        for triplet in tqdm(self.all_triplets, desc="Oracle Pre-computation"):
            find_optimal_shots(triplet[0], triplet[1], triplet[2])
        print("Oracle pre-computation complete.")

        # Action space: The agent can choose between two discrete actions
        # 0: CONTINUE (run another batch of shots)
        # 1: STOP (end the episode)
        self.action_space = spaces.Discrete(2)
        
        # Observation space: A 4-dimensional continuous vector representing the agent's state
        # The values are normalized between 0 and 1
        self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)

    def _get_all_noisy_triplets(self) -> List[Tuple[str, int, str]]:
        """Helper function to parse the qsimbench index and find all valid noisy problems"""
        triplets = []
        for alg, sizes in self.index.items():
            for size, backends in sizes.items():
                if "aer_simulator" in backends: # Ensure an ideal, noiseless simulator exists for comparison
                    for b in backends:
                        if b != "aer_simulator": # We only want to train on noisy, realistic backends
                            triplets.append((alg, int(size), b))
        return triplets

    def _create_mappings(self) -> Tuple[Dict[str, int], Dict[str, int]]:
        """Creates dictionaries to map string names (e.g., 'qaoa') to integer indices for normalization"""
        all_algs = sorted(list(set(t[0] for t in self.all_triplets)))
        all_backends = sorted(list(set(t[2] for t in self.all_triplets)))
        alg_map = {name: i for i, name in enumerate(all_algs)}
        backend_map = {name: i for i, name in enumerate(all_backends)}
        return alg_map, backend_map

    def reset(self) -> np.ndarray:
        """
        Starts a new episode. This is called at the beginning of each training game
        It randomly selects a new quantum problem and resets the shot count
        """
        self.current_triplet = random.choice(self.all_triplets)
        # Get the pre-computed optimal shots for this problem from the Oracle's cache
        self.optimal_shots = find_optimal_shots(*self.current_triplet)
        self.current_shots = 0
        return self._get_state() # Return the initial state of the new episode

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Processes one step of the game based on the agent's chosen action
        """
        done = False
        reward = 0
        info = {}

        if action == 1:  # Agent chose to STOP
            # Terminate the episode and calculate the final, large reward
            state, reward, done, info = self._terminate()
        else:  # Agent chose to CONTINUE
            self.current_shots += self.step_size
            if self.current_shots >= self.max_shots:
                # If we hit the shot limit, the episode ends automatically
                state, reward, done, info = self._terminate()
            else:
                # If we continue, give a small negative reward (penalty) to encourage
                # the agent to finish episodes faster and be more efficient
                reward = -0.02
                state = self._get_state()

        return state, reward, done, info

    def _terminate(self) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Ends the episode and calculates the final reward based on performance
        This is where the agent receives its main learning signal
        """
        # Calculate the error: positive for overshooting, negative for undershooting
        error = self.current_shots - self.optimal_shots

        # --- Asymmetric Reward Shaping Logic ---
        # The design of the reward is crucial for guiding the agent's behavior
        if error < 0:
            # SEVERE PENALTY FOR UNDERSHOOTING: Stopping too early is a critical failure because the result is not statistically stable
            # The penalty is proportional to how badly it undershot, teaching the agent that this is the worst possible outcome
            final_reward = -1.0 - (abs(error) / self.optimal_shots)
        else:
            # DECAYING REWARD FOR OVERSHOOTING: This is considered a success, but it becomes less successful as more shots are wasted
            # The reward is +1.0 for a perfect stop (error=0) and decays exponentially, encouraging the agent to be precise
            final_reward = np.exp(-0.0005 * error)

        state = self._get_state()
        done = True
        info = {
            'shots_used': self.current_shots,
            'optimal_shots': self.optimal_shots,
            'error': error,
            'triplet': self.current_triplet,
            'final_reward': final_reward
        }
        return state, final_reward, done, info

    def _get_state(self) -> np.ndarray:
        """
        Constructs the agent's state vector. This represents everything the agent is allowed to "see" about the world to make its decision
        """
        alg, size, backend = self.current_triplet

        # Normalize all features to the [0, 1] range
        # This helps the neural network learn more effectively by preventing features with large scales (like shot counts) from dominating features with small scales
        alg_norm = self.alg_map.get(alg, 0) / (len(self.alg_map) - 1) if len(self.alg_map) > 1 else 0.5
        size_norm = size / 15.0  # Assuming max size is around 15 from the dataset
        backend_norm = self.backend_map.get(backend, 0) / (len(self.backend_map) - 1) if len(self.backend_map) > 1 else 0.5
        shots_norm = self.current_shots / self.max_shots

        state = np.array([alg_norm, size_norm, backend_norm, shots_norm], dtype=np.float32)
        return state

# --- RL Agent and Network ---
# This section defines the "brain" of our agent: a Deep Q-Network

class DQN(nn.Module):
    """
    A simple Deep Q-Network (DQN) model. This neural network acts as a function approximator
    Instead of a giant table mapping every possible state to an action value (a Q-table), 
    this network *learns* a function that takes a state and outputs the expected value for each action
    """
    def __init__(self, input_size: int, output_size: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, output_size) # Output layer has one neuron per possible action
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class Agent:
    """
    The RL agent that encapsulates the DQN model and all the logic for acting, remembering, and learning.
    """
    def __init__(self, state_size: int, action_size: int, learning_rate: float = 1e-4, gamma: float = 0.99):
        # The main network that the agent uses to make decisions. Its weights are constantly updated.
        self.q_net = DQN(state_size, action_size)
        # The target network is a copy of the q_net. It is held constant for a period to provide a stable target during training, preventing oscillations
        self.target_net = DQN(state_size, action_size)
        self.update_target()

        # Experience Replay Memory: stores past (state, action, reward, next_state, done) tuples
        self.memory = deque(maxlen=200000)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=learning_rate)
        
        # --- Hyperparameters ---
        # Gamma (discount factor): determines how much the agent values future rewards
        # A value closer to 1 makes the agent more "farsighted"
        self.gamma = gamma
        # Epsilon: The exploration rate for the epsilon-greedy policy
        # It starts at 1.0 (100% random actions) and decays over time to a minimum value.
        self.epsilon = 1.0
        self.epsilon_decay = 0.9998
        self.epsilon_min = 0.05
        
        self.action_size = action_size

    def act(self, state: np.ndarray) -> int:
        """
        Chooses an action based on the current state using an epsilon-greedy policy.
        """
        # Exploration: with probability epsilon, choose a random action.
        if random.random() < self.epsilon:
            return random.randint(0, self.action_size - 1)
        # Exploitation: with probability 1-epsilon, use the neural network to predict the best action (the one with the highest Q-value)
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            return self.q_net(state_tensor).argmax().item()

    def remember(self, state, action, reward, next_state, done):
        """Stores an experience tuple in the replay memory."""
        self.memory.append((state, action, reward, next_state, done))

    def train(self, batch_size: int = 128):
        """
        Trains the Q-network by sampling a batch of experiences from memory.
        This is the core learning step.
        """
        if len(self.memory) < batch_size * 10: return # Wait for enough memory to be collected

        # 1. Sample a random batch of experiences from the replay memory.
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert the batch into PyTorch tensors.
        states = torch.from_numpy(np.array(states)).float()
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.from_numpy(np.array(next_states)).float()
        dones = torch.BoolTensor(dones).unsqueeze(1)

        # 2. Get the agent's predicted Q-values for the actions it actually took
        # This is Q(s, a) from the main network
        q_values = self.q_net(states).gather(1, actions)

        # 3. Calculate the target Q-value. This is based on the Bellman equation:
        # Target Q(s, a) = r + gamma * max_a'(Q_target(s', a'))
        # We use the target_net to get the value of the next state, which provides a stable target
        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(1)
            # If an episode was done, there is no future reward, so the second term is zero
            target_q_values = rewards + (self.gamma * next_q_values * ~dones)
        
        # 4. Calculate the loss (Mean Squared Error) between the predicted and target Q-values
        loss = F.mse_loss(q_values, target_q_values)

        # 5. Update the weights of the main Q-network using backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon to reduce exploration over time
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target(self):
        """Sync the target network's weights with the main Q-network's weights"""
        self.target_net.load_state_dict(self.q_net.state_dict())

# --- Main Training and Evaluation ---
def train_agent_sequential(episodes: int = 15000) -> Tuple[Agent, IterativeQuantumEnv, List[float]]:
    """Main training loop that runs the agent through multiple episodes"""
    env = IterativeQuantumEnv()
    agent = Agent(env.observation_space.shape[0], env.action_space.n)
    rewards_history = []

    for ep in tqdm(range(episodes), desc="Training Episodes"):
        state = env.reset()
        done = False
        total_episode_reward = 0
        
        # Run one full episode (one "game")
        while not done:
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            
            # The reward stored in memory must be the reward for that specific step
            # On the final step, this is the large terminal reward
            final_reward = info.get('final_reward', reward)
            agent.remember(state, action, final_reward, next_state, done)
            
            state = next_state
            agent.train() # Perform one training step on a batch from memory
        
        # For plotting, we calculate the total reward for the episode
        total_episode_reward = info.get('final_reward') + (info['shots_used']/env.step_size * -0.02)
        rewards_history.append(total_episode_reward)

        # Periodically update the target network for stable training
        if ep % 25 == 0:
            agent.update_target()
    
    print("--- Training Finished ---")
    save_oracle_cache()
    return agent, env, rewards_history

def evaluate_agent_sequential(agent: Agent, env: IterativeQuantumEnv, num_tests: int = 200) -> List[Dict]:
    """Evaluates the final trained agent's performance with exploration turned off"""
    agent.epsilon = 0.0 # Set to a greedy policy (always exploit)
    results = []
    for _ in tqdm(range(num_tests), desc="Evaluation"):
        state = env.reset()
        done = False
        info = {}
        while not done:
            action = agent.act(state)
            state, _, done, info = env.step(action)
        results.append(info)
    return results


# --- Analysis and Plotting Dashboard ---

class AnalysisDashboard:
    """A class to handle the visualization of the training and evaluation results"""
    def __init__(self, training_rewards: List[float], evaluation_results: List[Dict]):
        self.rewards = training_rewards
        self.results = evaluation_results
        self.shots_used = [r['shots_used'] for r in self.results]
        self.optimal_shots = [r['optimal_shots'] for r in self.results]
        self.errors = np.array([r['error'] for r in self.results])
        self.triplets = [r['triplet'] for r in self.results]

    def plot_dashboard(self):
        """Generates and displays a 2x2 dashboard of performance plots"""
        fig, axes = plt.subplots(2, 2, figsize=(18, 14))
        fig.suptitle("RL Agent Performance Analysis Dashboard (4-Feature State)", fontsize=20, y=0.98)

        self._plot_training_rewards(axes[0, 0])
        self._plot_performance_scatter(axes[0, 1])
        self._plot_error_distribution(axes[1, 0])
        self._plot_error_by_algorithm(axes[1, 1])

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

    def _plot_training_rewards(self, ax: plt.Axes):
        """
        PURPOSE: To visualize the agent's learning progress over time
        
        HOW IT WORKS: It plots the raw total reward for each episode, which can be noisy
        
        More importantly, it overlays a 200-episode moving average (`np.convolve`) to show the underlying trend
        A clear upward trend in the moving average indicates that the agent is successfully learning to achieve higher rewards
        """
        ax.plot(self.rewards, label='Total Reward per Episode', alpha=0.2, color='C0')
        if len(self.rewards) >= 200:
            moving_avg = np.convolve(self.rewards, np.ones(200)/200, mode='valid')
            ax.plot(np.arange(199, len(self.rewards)), moving_avg, color='C1', label='200-Episode Moving Average')
        ax.axhline(0, color='k', linestyle='--', alpha=0.5)
        ax.set_title("1. Training Rewards Over Time", fontsize=14)
        ax.set_xlabel("Episode")
        ax.set_ylabel("Total Episode Reward")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_performance_scatter(self, ax: plt.Axes):
        """
        PURPOSE: To directly compare the agent's final policy against the Oracle's perfect policy
        
        (This is the most important plot for evaluating the final performance)
        HOW IT WORKS: It creates a scatter plot where each point is one evaluation episode
        The x-coordinate is the optimal shots from the Oracle, and the y-coordinate is the number of shots the agent chose to use.
        
        INTERPRETATION:
        - The red dashed line (y=x) is the "Perfect Policy". Points on this line are perfect decisions
        - Points ABOVE the line represent OVERSHOOTING (wasted resources)
        - Points BELOW the line represent UNDERSHOOTING (statistically unstable results)
        - The Mean Absolute Error (MAE) gives a single number for the average shot deviation
        """
        mae = np.mean(np.abs(self.errors))
        ax.scatter(self.optimal_shots, self.shots_used, alpha=0.6, edgecolors='k', s=50)
        perfect_range = [0, max(max(self.optimal_shots), max(self.shots_used))]
        ax.plot(perfect_range, perfect_range, 'r--', linewidth=2, label='Perfect Policy (y=x)')
        ax.set_title(f"2. Agent Policy vs. Oracle (MAE: {mae:.1f} shots)", fontsize=14)
        ax.set_xlabel("Optimal Shots (Determined by Oracle)")
        ax.set_ylabel("Shots Used by Agent")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_error_distribution(self, ax: plt.Axes):
        """
        PURPOSE: To visualize the distribution of the agent's stopping errors

        HOW IT WORKS: It plots a histogram of the errors (`agent_shots - optimal_shots`)

        INTERPRETATION: This shows the agent's tendencies. A distribution skewed to the right of zero (positive error) indicates a cautious agent that tends to overshoot
        A distribution skewed to the left indicates a risky agent that tends to undershoot
        A tight distribution centered at zero is the ideal result
        """
        ax.hist(self.errors, bins=40, edgecolor='k', alpha=0.7)
        ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero Error (Perfect Stop)')
        ax.set_title("3. Distribution of Stopping Errors", fontsize=14)
        ax.set_xlabel("Error (Agent Shots - Optimal Shots)")
        ax.set_ylabel("Frequency of Occurrences")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_error_by_algorithm(self, ax: plt.Axes):
        """
        PURPOSE: To see if the agent's performance varies across different types of quantum algorithms
        
        HOW IT WORKS: It groups the evaluation results by algorithm type and calculates the average error for each
        It then displays this as a bar chart
        
        INTERPRETATION: This helps identify if the agent has learned a generalizable policy
        or if it struggles with specific, perhaps more complex, types of quantum circuits.
        Large bars for certain algorithms indicate areas for improvement.
        """
        algs = [t[0] for t in self.triplets]
        unique_algs = sorted(list(set(algs)))
        mean_errors_by_alg = []
        
        for alg in unique_algs:
            # Get all errors for the current algorithm
            errors_for_alg = [self.errors[i] for i, a in enumerate(algs) if a == alg]
            if errors_for_alg:
                mean_errors_by_alg.append(np.mean(errors_for_alg))
            else:
                mean_errors_by_alg.append(0)

        ax.bar(unique_algs, mean_errors_by_alg, edgecolor='k', alpha=0.7, color='C2')
        ax.axhline(0, color='r', linestyle='--', linewidth=2, label='Zero Error')
        ax.set_title("4. Average Error by Algorithm Type", fontsize=14)
        ax.set_xlabel("Quantum Algorithm")
        ax.set_ylabel("Average Error (Agent Shots - Optimal)")
        ax.tick_params(axis='x', rotation=45, labelsize=8) # Rotate labels for readability
        ax.legend()
        ax.grid(True, linestyle='--')


if __name__ == "__main__":
    # Load the oracle cache to speed up the run if it exists
    load_oracle_cache()

    # Train the agent
    trained_agent, trained_env, rewards_history = train_agent_sequential(episodes=10000)

    # Evaluate the final trained policy on 500 random test cases
    evaluation_results = evaluate_agent_sequential(trained_agent, trained_env, num_tests=500)

    # Generate and show the performance dashboard
    dashboard = AnalysisDashboard(rewards_history, evaluation_results)
    dashboard.plot_dashboard()