In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, Counter
import random
import matplotlib.pyplot as plt
import gym
from gym import spaces
from qsimbench import get_outcomes, get_index
from typing import Dict, List, Tuple
from tqdm import tqdm
import os
import json
from ast import literal_eval

# --- Oracle and caching system ---
# This section defines the "teacher" for our RL agent. 
# The Oracle determines the correct answer (the optimal number of shots), which is used to score the agent's performance and provide a learning signal (the reward)

ORACLE_CACHE = {}
CACHE_FILE = "oracle_cache_enhanced.json"

def save_oracle_cache():
    """
    Saves the dictionary of computed optimal shots to a JSON file
    This is a critical optimization, as running the Oracle is the most time-consuming part of the program
    Saving the results allows us to run the Oracle only once for the entire dataset
    """
    # JSON cannot serialize tuple keys, so we convert them to strings first
    string_key_cache = {str(k): v for k, v in ORACLE_CACHE.items()}
    with open(CACHE_FILE, 'w') as f:
        json.dump(string_key_cache, f)
    print(f"Oracle cache with {len(ORACLE_CACHE)} items saved to {CACHE_FILE}")

def load_oracle_cache():
    """
    Loads the pre-computed optimal shots from the cache file if it exists
    This dramatically speeds up subsequent training runs
    """
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'r') as f:
                string_key_cache = json.load(f)
                # Convert the string keys from the JSON file back into their original tuple format
                for k_str, v in string_key_cache.items():
                    ORACLE_CACHE[literal_eval(k_str)] = v
            print(f"Loaded {len(ORACLE_CACHE)} items from oracle cache.")
        except Exception as e:
            print(f"Warning: Could not read cache file '{CACHE_FILE}'. Starting empty. Error: {e}")

def find_optimal_shots(algorithm: str, size: int, backend: str, step_size=50, max_shots=20000, threshold=0.01, stability_k=3):
    """
    ================================================================================
    *** THE ORACLE ***
    ================================================================================
    PURPOSE:
        This function serves as the "teacher" by implementing the Incremental Execution (IE) algorithm described in the paper
        It iteratively runs batches of shots and uses a statistical metric to find the point of convergence, which we define as the "optimal" shot count.

    ROLE IN THE SYSTEM:
        Its output (`optimal_shots`) is used ONLY to calculate the final reward for the RL agent during training and as a benchmark during evaluation
        The agent itself never gets to see this value while it is making decisions

    METHOD:
        It determines convergence by measuring the Total Variation Distance (TVD) between the cumulative probability distribution of outcomes at consecutive steps
        If the TVD remains below a `threshold` for `stability_k` consecutive iterations, the distribution is considered stable, and the process stops
    """
    # Check if we have already computed this value to save time
    cache_key = (algorithm, size, backend)
    if cache_key in ORACLE_CACHE:
        return ORACLE_CACHE[cache_key]

    cumulative_counts = Counter()  # Stores the aggregated outcomes (e.g., {'01': 10, '10': 12})
    stable_iterations = 0          # Counter for consecutive stable steps
    prev_dist = {}                 # Stores the probability distribution from the previous step

    def normalize_dist(counts: Dict[str, int]) -> Dict[str, float]:
        """Converts raw counts into a normalized probability distribution"""
        total = sum(counts.values())
        return {k: v / total for k, v in counts.items()} if total > 0 else {}

    def total_variation_distance(p: Dict[str, float], q: Dict[str, float]) -> float:
        """
        Calculates the TVD between two probability distributions, p and q
        TVD is a metric of the distance between two distributions
        A small TVD (close to 0) means the distributions are very similar
        The formula is: TVD(p, q) = 0.5 * sum(|p(i) - q(i)| for all outcomes i)
        """
        all_keys = set(p) | set(q) # Consider all outcomes present in either distribution
        return 0.5 * sum(abs(p.get(k, 0) - q.get(k, 0)) for k in all_keys)

    # Main loop of the Incremental Execution algorithm
    for total_shots_so_far in range(step_size, max_shots + step_size, step_size):
        try:
            # Use qsimbench to simulate running a new batch of shots
            new_batch_counts = get_outcomes(algorithm, size, backend, shots=step_size, strategy='random', exact=True)
        except Exception:
            # In case of an error, fall back to a maximum value
            ORACLE_CACHE[cache_key] = max_shots
            return max_shots

        # Accumulate the new outcomes
        cumulative_counts.update(new_batch_counts)
        current_dist = normalize_dist(cumulative_counts)

        # Compute convergence metrics
        if prev_dist:
            tvd = total_variation_distance(current_dist, prev_dist)
            if tvd < threshold:
                stable_iterations += 1
                if stable_iterations >= stability_k:
                    # Convergence achieved!
                    ORACLE_CACHE[cache_key] = total_shots_so_far
                    return total_shots_so_far
            else:
                # Distribution changed significantly, reset stability counter
                stable_iterations = 0

        prev_dist = current_dist

    # If we reach this point, convergence was not achieved, so return the maximum
    ORACLE_CACHE[cache_key] = max_shots
    return max_shots

# --- Quantum environment ---

class IterativeQuantumEnv(gym.Env):
    """
    ENHANCEMENTS IN THIS VERSION:
    1. Difficulty-biased sampling: The environment oversamples harder problems (those requiring more shots)
    2. Enhanced state features: Added entropy and variance metrics to help the agent distinguish between similar-looking problems
    """
    def __init__(self, max_shots=20000, step_size=50, hard_problem_bias=0.7):
        super().__init__()
        self.max_shots = max_shots
        self.step_size = step_size
        self.hard_problem_bias = hard_problem_bias  # Probability of sampling a "hard" problem (70%)
        self.index = get_index()  # Get all available problems from qsimbench
        self.all_triplets = self._get_all_noisy_triplets()
        self.alg_map, self.backend_map = self._create_mappings()

        print("Pre-computing optimal shots for all triplets using the oracle...")
        # This loop populates the oracle cache to ensure training is fast.
        for triplet in tqdm(self.all_triplets, desc="Oracle Pre-computation"):
            find_optimal_shots(triplet[0], triplet[1], triplet[2])
        print("Oracle pre-computation complete.")
        
        # Categorize problems by difficulty
        self._categorize_by_difficulty()

        # Action space: The agent can choose between two discrete actions
        # 0: CONTINUE (run another batch of shots)
        # 1: STOP (end the episode)
        self.action_space = spaces.Discrete(2)
        
        # Observation space is now 8-dimensional instead of 4
        # We've added features to help the agent better distinguish between problems
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)
        
        # Metrics tracking for the current episode
        self.outcome_history = []  # Stores distributions from each step for additional metrics

    def _get_all_noisy_triplets(self) -> List[Tuple[str, int, str]]:
        """Helper function to parse the qsimbench index and find all valid noisy problems"""
        triplets = []
        for alg, sizes in self.index.items():
            for size, backends in sizes.items():
                if "aer_simulator" in backends: # Ensure an ideal, noiseless simulator exists for comparison
                    for b in backends:
                        if b != "aer_simulator": # We only want to train on noisy, realistic backends
                            triplets.append((alg, int(size), b))
        return triplets

    def _create_mappings(self) -> Tuple[Dict[str, int], Dict[str, int]]:
        """Creates dictionaries to map string names (e.g., 'qaoa') to integer indices for normalization"""
        all_algs = sorted(list(set(t[0] for t in self.all_triplets)))
        all_backends = sorted(list(set(t[2] for t in self.all_triplets)))
        alg_map = {name: i for i, name in enumerate(all_algs)}
        backend_map = {name: i for i, name in enumerate(all_backends)}
        return alg_map, backend_map

    def _categorize_by_difficulty(self, difficulty_threshold=5000):
        """
        ENHANCEMENT #1 - Difficulty-Based Sampling
        
        Categorizes all problems into "easy" and "hard" based on their optimal shot count
        This allows us to bias our sampling toward harder problems during training
        
        WHY THIS HELPS:
        - The agent is currently struggling with problems that require many shots (>5000)
        - By increasing exposure to these difficult cases, the agent gets more practice on its weaknesses
        - This is a form of "curriculum learning" or "hard example mining"
        """
        self.easy_triplets = []
        self.hard_triplets = []
        
        for triplet in self.all_triplets:
            optimal = find_optimal_shots(*triplet)
            if optimal > difficulty_threshold:
                self.hard_triplets.append(triplet)
            else:
                self.easy_triplets.append(triplet)
        
        print(f"Categorized {len(self.hard_triplets)} hard problems (>{difficulty_threshold} shots)")
        print(f"Categorized {len(self.easy_triplets)} easy problems (<={difficulty_threshold} shots)")
        print(f"Hard problem bias set to {self.hard_problem_bias:.1%}")

    def reset(self) -> np.ndarray:
        """
        Now uses biased sampling to select harder problems more frequently
        """
        # Biased sampling logic
        if random.random() < self.hard_problem_bias and len(self.hard_triplets) > 0:
            # Sample from hard problems
            self.current_triplet = random.choice(self.hard_triplets)
        else:
            # Sample from all problems (maintaining some easy ones to prevent overfitting to hard cases)
            self.current_triplet = random.choice(self.all_triplets)
        
        # Get the pre-computed optimal shots for this problem from the Oracle's cache
        self.optimal_shots = find_optimal_shots(*self.current_triplet)
        self.current_shots = 0
        
        # Reset outcome history for computing advanced metrics
        self.outcome_history = []
        
        return self._get_state() # Return the initial state of the new episode

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Processes one step of the game based on the agent's chosen action
        """
        done = False
        reward = 0
        info = {}

        if action == 1:  # Agent chose to STOP
            # Terminate the episode and calculate the final, large reward
            state, reward, done, info = self._terminate()
        else:  # Agent chose to CONTINUE
            self.current_shots += self.step_size
            
            # Collect outcome data for enhanced state features
            try:
                alg, size, backend = self.current_triplet
                batch_outcomes = get_outcomes(alg, size, backend, shots=self.step_size, strategy='random', exact=True)
                self.outcome_history.append(batch_outcomes)
            except:
                pass  # If sampling fails, continue without updating history
            
            if self.current_shots >= self.max_shots:
                # If we hit the shot limit, the episode ends automatically
                state, reward, done, info = self._terminate()
            else:
                # If we continue, give a small negative reward (penalty) to encourage
                # the agent to finish episodes faster and be more efficient
                reward = -0.02
                state = self._get_state()

        return state, reward, done, info

    def _terminate(self) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Ends the episode and calculates the final reward based on performance
        This is where the agent receives its main learning signal
        """
        # Calculate the error: positive for overshooting, negative for undershooting
        error = self.current_shots - self.optimal_shots

        # --- Asymmetric Reward Shaping Logic ---
        # The design of the reward is crucial for guiding the agent's behavior
        if error < 0:
            # SEVERE PENALTY FOR UNDERSHOOTING: Stopping too early is a critical failure because the result is not statistically stable
            # The penalty is proportional to how badly it undershot, teaching the agent that this is the worst possible outcome
            final_reward = -1.0 - (abs(error) / self.optimal_shots)
        else:
            # DECAYING REWARD FOR OVERSHOOTING: This is considered a success, but it becomes less successful as more shots are wasted
            # The reward is +1.0 for a perfect stop (error=0) and decays exponentially, encouraging the agent to be precise
            final_reward = np.exp(-0.0005 * error)

        state = self._get_state()
        done = True
        info = {
            'shots_used': self.current_shots,
            'optimal_shots': self.optimal_shots,
            'error': error,
            'triplet': self.current_triplet,
            'final_reward': final_reward
        }
        return state, final_reward, done, info

    def _compute_distribution_entropy(self, outcomes: Dict[str, int]) -> float:
        """
        ================================================================================
        SHANNON ENTROPY COMPUTATION (Feature 5)
        ================================================================================
        
        Computes the Shannon entropy of a probability distribution derived from
        measurement outcomes.
        
        MATHEMATICAL FORMULA:
            H(X) = -Σ p(x) * log₂(p(x))  for all outcomes x
        
        WHERE:
            - p(x) = count(x) / total_counts  (probability of outcome x)
            - The sum is taken over all unique measurement outcomes
        
        COMPUTATION STEPS:
            1. Normalize raw counts into probabilities
               Input:  {'00': 55, '01': 22, '10': 13, '11': 10}  (total=100)
               Output: {'00': 0.55, '01': 0.22, '10': 0.13, '11': 0.10}
            
            2. Apply Shannon entropy formula for each outcome:
               For '00': -0.55 * log₂(0.55) ≈ 0.528
               For '01': -0.22 * log₂(0.22) ≈ 0.474
               For '10': -0.13 * log₂(0.13) ≈ 0.382
               For '11': -0.10 * log₂(0.10) ≈ 0.332
               Total: H ≈ 1.716 bits
            
            3. Normalize by maximum possible entropy:
               - For n qubits, max entropy = log₂(2^n) = n bits
               - We use a cap of 4 bits (assumes ≤4 qubit circuits)
               - Normalized value: 1.716 / 4.0 ≈ 0.429
        
        INTERPRETATION:
            - RETURN VALUE in [0, 1]
            - 0.0 = Deterministic (one outcome has probability 1)
            - 1.0 = Maximum entropy (perfectly uniform distribution)
            - 0.5 = Moderate spread
        
        WHY THIS HELPS THE AGENT:
            - Entropy captures the "shape" of the distribution
            - High entropy -> uniform superposition -> might need more shots for precision
            - Low entropy -> peaked distribution -> might converge faster
            - Different algorithms produce different entropy signatures:
              * Grover's algorithm: low entropy (searching for specific state)
              * Random circuit sampling: high entropy (uniform superposition)
        """
        total = sum(outcomes.values())
        if total == 0:
            return 0.0
        
        entropy = 0.0
        for count in outcomes.values():
            if count > 0:
                p = count / total  # Compute probability p(x)
                entropy -= p * np.log2(p)  # Accumulate: -p*log₂(p)
        
        # Normalize by theoretical maximum entropy for a 2-qubit system (adjust if needed)
        # For n qubits, max entropy is log2(2^n) = n
        # We'll assume a reasonable upper bound of 4 qubits = 4 bits of entropy
        max_entropy = 4.0
        return min(entropy / max_entropy, 1.0)  # Clip to [0, 1]

    def _compute_distribution_variance(self, outcomes: Dict[str, int]) -> float:
        """
        ================================================================================
        DISTRIBUTION VARIANCE COMPUTATION (Feature 6)
        ================================================================================
        
        Computes the variance of the probability distribution as a measure of spread.
        
        MATHEMATICAL FORMULA:
            Var(P) = (1/n) * Σ(pᵢ - μ)²
        
        WHERE:
            - pᵢ = probability of outcome i
            - μ = mean probability = (1/n) for uniform distribution
            - n = number of unique outcomes
        
        COMPUTATION STEPS:
            1. Convert counts to probabilities:
               Input:  {'00': 90, '01': 5, '10': 3, '11': 2}  (total=100)
               Output: [0.90, 0.05, 0.03, 0.02]
            
            2. Calculate mean probability:
               mean_p = (0.90 + 0.05 + 0.03 + 0.02) / 4 = 0.25
            
            3. Compute variance:
               Var = [(0.90-0.25)² + (0.05-0.25)² + (0.03-0.25)² + (0.02-0.25)²] / 4
                   = [0.4225 + 0.04 + 0.0484 + 0.0529] / 4
                   ≈ 0.141
            
            4. Scale and normalize to [0, 1]:
               scaled_variance = min(0.141 * 10.0, 1.0) ≈ 1.0
        
        INTERPRETATION:
            - RETURN VALUE in [0, 1]
            - 0.0 = Uniform distribution (all probabilities equal)
            - 1.0 = Highly peaked (one dominant outcome)
            - Intermediate values = Partially peaked distribution
        
        WHY THIS COMPLEMENTS ENTROPY:
            - Entropy measures "uniformity" globally
            - Variance measures "peakedness" (distance from mean)
            - Example where they differ:
              * {'00': 0.7, '01': 0.1, '10': 0.1, '11': 0.1}
              * Entropy: moderate (some spread)
              * Variance: high (one outcome dominates)
            - Together, they give a richer picture of distribution shape
        
        NORMALIZATION NOTE:
            - The scaling factor of 10.0 is empirical
            - Maximum theoretical variance for k outcomes ≈ 1/k
            - This scaling brings typical variances into the [0, 1] range
        """
        if not outcomes:
            return 0.0
        
        total = sum(outcomes.values())
        if total == 0:
            return 0.0
        
        # Step 1: Convert to probabilities
        probs = [count / total for count in outcomes.values()]
        
        # Step 2 & 3: Calculate mean and variance
        mean_prob = np.mean(probs)
        variance = np.var(probs)  # NumPy computes: (1/n)Σ(pᵢ - μ)²
        
        # Step 4: Normalize
        # Max variance occurs when one probability is 1, others are 0
        # For a distribution with k outcomes, max variance ≈ 1/k (when perfectly unbalanced)
        # We use a conservative scaling factor to map typical values to [0, 1]
        return min(variance * 10.0, 1.0)

    def _compute_rate_of_change(self) -> float:
        """
        ================================================================================
        RATE OF CHANGE COMPUTATION (Feature 7)
        ================================================================================
        
        Estimates how rapidly the probability distribution is evolving by measuring
        the Total Variation Distance (TVD) between recent and previous distributions.
        
        MATHEMATICAL FORMULA:
            TVD(P, Q) = 0.5 * Σ|P(x) - Q(x)|  for all outcomes x
        
        WHERE:
            - P = probability distribution from all previous batches (cumulative)
            - Q = probability distribution from the most recent batch
            - The sum is over all possible measurement outcomes
        
        COMPUTATION STEPS:
            1. Separate outcome history into two parts:
               - Previous: All batches except the last
                 Example: [{'00': 25, '01': 15, '10': 7, '11': 3},
                          {'00': 28, '01': 12, '10': 6, '11': 4}]
                 Aggregate: {'00': 53, '01': 27, '10': 13, '11': 7} (total=100)
               
               - Recent: Only the last batch
                 Example: {'00': 30, '01': 10, '10': 5, '11': 5} (total=50)
            
            2. Normalize both to probability distributions:
               - prev_dist = {'00': 0.53, '01': 0.27, '10': 0.13, '11': 0.07}
               - recent_dist = {'00': 0.60, '01': 0.20, '10': 0.10, '11': 0.10}
            
            3. Calculate TVD:
               TVD = 0.5 * (|0.53-0.60| + |0.27-0.20| + |0.13-0.10| + |0.07-0.10|)
                   = 0.5 * (0.07 + 0.07 + 0.03 + 0.03)
                   = 0.5 * 0.20
                   = 0.10
        
        INTERPRETATION:
            - RETURN VALUE in [0, 1]
            - 1.0 = Maximum change (distributions are completely different)
            - 0.0 = No change (distributions are identical -> convergence!)
            - 0.10 = Small change (10% difference, approaching stability)
        
        WHY THIS IS CRITICAL FOR THE AGENT:
            - This is a "convergence velocity" metric
            - EARLY IN EXECUTION:
              * Distribution changes significantly with each new batch
              * TVD might be 0.15-0.30 (15-30% change)
              * Signal to agent: "Keep going, not stable yet"
            
            - NEAR CONVERGENCE:
              * Distribution barely changes with new batches
              * TVD drops below 0.01 (1% change)
              * Signal to agent: "Safe to stop, distribution is stable"
            
            - This is exactly what the Incremental Execution algorithm monitors
            - By learning this feature, the agent can anticipate convergence
        
        SPECIAL CASES:
            - If len(outcome_history) < 2: Return 1.0 (maximum uncertainty)
            - If total counts are 0: Return 1.0 (no data to compare)
        
        RELATIONSHIP TO ORACLE:
            - The Oracle stops when TVD < threshold for k consecutive iterations
            - This feature gives the agent direct access to the TVD signal
            - But the agent must learn the appropriate threshold through RL
        """
        if len(self.outcome_history) < 2:
            # Not enough data to compute rate of change
            return 1.0  # Maximum rate of change (unknown state)
        
        # Step 1: Aggregate all previous batches (excluding the most recent)
        prev_outcomes = Counter()
        for outcomes in self.outcome_history[:-1]:
            prev_outcomes.update(outcomes)
        
        # Get the most recent batch
        recent_outcomes = self.outcome_history[-1]
        
        # Step 2: Normalize both to probability distributions
        prev_total = sum(prev_outcomes.values())
        recent_total = sum(recent_outcomes.values())
        
        if prev_total == 0 or recent_total == 0:
            return 1.0  # Cannot compute meaningful distance
        
        prev_dist = {k: v/prev_total for k, v in prev_outcomes.items()}
        recent_dist = {k: v/recent_total for k, v in recent_outcomes.items()}
        
        # Step 3: Compute Total Variation Distance (TVD)
        # Must consider all outcomes that appear in either distribution
        all_keys = set(prev_dist.keys()) | set(recent_dist.keys())
        tvd = 0.5 * sum(abs(prev_dist.get(k, 0) - recent_dist.get(k, 0)) for k in all_keys)
        
        return tvd  # Already in [0, 1] range

    def _get_state(self) -> np.ndarray:
        """
        ================================================================================
        FEATURE VECTOR CONSTRUCTION - 8-Dimensional State Representation
        ================================================================================
        
        This function constructs the state observation that the RL agent uses to make
        decisions. The state is an 8-dimensional vector where each feature is normalized
        to the range [0, 1] to facilitate neural network training.
        
        ORIGINAL 4 FEATURES:
        1. Algorithm (normalized index)
        2. Problem size (normalized)
        3. Backend (normalized index)
        4. Current shots (normalized)
        
        NEW 4 ADDITIONAL FEATURES:
        5. Distribution entropy (measures spread of outcomes)
        6. Distribution variance (alternative spread metric)
        7. Rate of change (how fast distribution is evolving)
        8. Relative progress (current_shots / typical_convergence_point)
        """
        alg, size, backend = self.current_triplet

        # ============================================================================
        # FEATURE 1: ALGORITHM TYPE (Normalized Categorical Encoding)
        # ============================================================================
        # HOW IT'S COMPUTED:
        #   - Each algorithm (e.g., 'grover', 'qaoa', 'vqe') is assigned a unique integer
        #     index during initialization via self.alg_map (e.g., {'grover': 0, 'qaoa': 1})
        #   - This index is retrieved: self.alg_map.get(alg, 0)
        #   - Normalized to [0, 1] by dividing by (total_algorithms - 1)
        #     Example: If there are 3 algorithms, indices 0,1,2 become 0.0, 0.5, 1.0
        # PURPOSE:
        #   - Different quantum algorithms have fundamentally different convergence behaviors
        #   - This feature allows the agent to learn algorithm-specific stopping strategies
        # ============================================================================
        alg_norm = self.alg_map.get(alg, 0) / (len(self.alg_map) - 1) if len(self.alg_map) > 1 else 0.5
        
        # ============================================================================
        # FEATURE 2: PROBLEM SIZE (Normalized Continuous Value)
        # ============================================================================
        # HOW IT'S COMPUTED:
        #   - Directly from the problem triplet: (algorithm, SIZE, backend)
        #   - The size represents the number of qubits in the quantum circuit
        #   - Normalized by dividing by an assumed maximum size of 15 qubits
        #     (based on typical values in the qsimbench dataset)
        # PURPOSE:
        #   - Larger circuits generally require more shots to converge due to:
        #     * Exponentially larger state space (2^n possible outcomes)
        #     * More complex probability distributions
        #   - This feature helps the agent adjust its patience based on problem scale
        # NORMALIZATION EXAMPLE:
        #   - 3-qubit circuit: 3/15 = 0.2
        #   - 10-qubit circuit: 10/15 ≈ 0.67
        # ============================================================================
        size_norm = size / 15.0
        
        # ============================================================================
        # FEATURE 3: BACKEND TYPE (Normalized Categorical Encoding)
        # ============================================================================
        # HOW IT'S COMPUTED:
        #   - Each quantum backend (e.g., 'ibmq_lima', 'ionq_harmony') is assigned
        #     a unique integer index via self.backend_map during initialization
        #   - This index is retrieved: self.backend_map.get(backend, 0)
        #   - Normalized to [0, 1] by dividing by (total_backends - 1)
        # PURPOSE:
        #   - Different hardware backends have distinct noise profiles:
        #     * Superconducting qubits (IBM) vs. trapped ions (IonQ)
        #     * Different gate fidelities, coherence times, and error rates
        #   - These noise characteristics directly affect convergence speed
        #   - This feature enables backend-specific learned policies
        # ============================================================================
        backend_norm = self.backend_map.get(backend, 0) / (len(self.backend_map) - 1) if len(self.backend_map) > 1 else 0.5
        
        # ============================================================================
        # FEATURE 4: CURRENT SHOT COUNT (Normalized Progress Indicator)
        # ============================================================================
        # HOW IT'S COMPUTED:
        #   - Directly from self.current_shots (accumulated across all batches)
        #   - Normalized by dividing by self.max_shots (typically 20,000)
        # PURPOSE:
        #   - Critical feature for stopping decision: "How many resources have I used?"
        #   - Provides a hard constraint awareness: must stop before reaching max_shots
        #   - Helps agent learn when it's approaching the budget limit
        # NORMALIZATION EXAMPLE:
        #   - 1000 shots used: 1000/20000 = 0.05 (5% of budget)
        #   - 15000 shots used: 15000/20000 = 0.75 (75% of budget, time to consider stopping)
        # ============================================================================
        shots_norm = self.current_shots / self.max_shots

        # ============================================================================
        # ENHANCED FEATURES (5-8): Statistical Properties of the Measured Distribution
        # ============================================================================
        # These features are computed from self.outcome_history, which stores the
        # measurement results from each batch of shots executed so far
        
        if self.outcome_history:
            # Aggregate all outcomes across all batches to get cumulative distribution
            cumulative_outcomes = Counter()
            for batch in self.outcome_history:
                cumulative_outcomes.update(batch)
            # Example result: {'00': 55, '01': 22, '10': 13, '11': 10}
            
            # ========================================================================
            # FEATURE 5: SHANNON ENTROPY (Distribution Shape Metric)
            # ========================================================================
            # HOW IT'S COMPUTED:
            #   1. Convert counts to probabilities: p_i = count_i / total_counts
            #      Example: {'00': 0.55, '01': 0.22, '10': 0.13, '11': 0.10}
            #   2. Apply Shannon entropy formula: H = -Σ(p_i * log₂(p_i))
            #   3. Normalize by max entropy (log₂(2^n) = n qubits), capped at 4
            # PURPOSE:
            #   - Quantifies the "uniformity" of the probability distribution
            #   - HIGH ENTROPY (->1.0): Outcomes are evenly distributed (uniform)
            #     Example: {'00': 0.25, '01': 0.25, '10': 0.25, '11': 0.25} -> H≈2.0 (for 2 qubits)
            #   - LOW ENTROPY (->0.0): Outcomes are concentrated in few states (peaked)
            #     Example: {'00': 0.95, '01': 0.02, '10': 0.02, '11': 0.01} -> H≈0.3
            #   - Helps distinguish between algorithms that produce uniform superpositions
            #     (e.g., Hadamard-heavy circuits) vs. peaked distributions (e.g., optimized QAOA)
            # ========================================================================
            entropy = self._compute_distribution_entropy(cumulative_outcomes)
            
            # ========================================================================
            # FEATURE 6: DISTRIBUTION VARIANCE (Alternative Spread Metric)
            # ========================================================================
            # HOW IT'S COMPUTED:
            #   1. Convert counts to probabilities: [p₁, p₂, ..., pₙ]
            #   2. Calculate variance: Var = (1/n)Σ(pᵢ - mean_p)²
            #   3. Apply scaling factor and clip to [0, 1]
            # PURPOSE:
            #   - Provides a complementary measure to entropy
            #   - Captures how "spread out" probabilities are from their mean
            #   - HIGH VARIANCE: Very uneven distribution (some outcomes dominant)
            #   - LOW VARIANCE: More uniform distribution
            #   - Together with entropy, gives agent richer information about distribution shape
            # EXAMPLE:
            #   - Peaked distribution {'0000': 0.9, others: 0.002 each} -> high variance
            #   - Flat distribution {'00': 0.25, '01': 0.25, ...} -> low variance
            # ========================================================================
            variance = self._compute_distribution_variance(cumulative_outcomes)
        else:
            # No measurements yet (first step), use neutral default values
            entropy = 0.5
            variance = 0.5
        
        # ========================================================================
        # FEATURE 7: RATE OF CHANGE (Convergence Velocity Indicator)
        # ========================================================================
        # HOW IT'S COMPUTED:
        #   1. Take the last batch of outcomes from self.outcome_history[-1]
        #   2. Take all previous batches and aggregate them
        #   3. Normalize both into probability distributions
        #   4. Calculate Total Variation Distance (TVD) between them:
        #      TVD = 0.5 * Σ|p_prev(x) - p_recent(x)| for all outcomes x
        # PURPOSE:
        #   - Measures how rapidly the distribution is changing
        #   - HIGH RATE (1.0): Distribution still evolving significantly -> Agent should probably continue collecting more shots
        #   - LOW RATE (0.0): Distribution has stabilized -> Agent can consider stopping (convergence achieved)
        #   - This is a "forward-looking" metric that anticipates convergence
        # INTUITION:
        #   - If adding 50 more shots changes the distribution dramatically -> keep going
        #   - If adding 50 more shots barely changes anything -> time to stop
        # EXAMPLE:
        #   Early in execution: TVD = 0.15 (15% change per batch)
        #   Near convergence: TVD = 0.005 (0.5% change per batch)
        # ========================================================================
        rate_of_change = self._compute_rate_of_change()
        
        # ========================================================================
        # FEATURE 8: RELATIVE PROGRESS (Problem-Size-Aware Progress Metric)
        # ========================================================================
        # HOW IT'S COMPUTED:
        #   1. Estimate typical shots needed based on problem size:
        #      typical_shots = size * 500
        #      (Heuristic: larger circuits generally need more shots)
        #   2. Compute ratio: current_shots / typical_shots
        #   3. Clip to maximum of 1.0
        # PURPOSE:
        #   - Provides context-aware progress information
        #   - Helps agent answer: "Am I close to done, given this problem's size?"
        #   - Different from Feature 4 (shots_norm) which is absolute budget usage
        # WHY THIS HELPS:
        #   - A 3-qubit circuit at 1000 shots: 1000/(3*500) = 0.67 -> likely near done
        #   - A 12-qubit circuit at 1000 shots: 1000/(12*500) = 0.17 -> just getting started
        #   - Same absolute shot count, different interpretation based on problem scale
        # HEURISTIC JUSTIFICATION:
        #   - Larger state space (2^n) requires more sampling for statistical stability
        #   - This feature encodes this intuition as a learned signal
        # ========================================================================
        typical_shots = size * 500
        relative_progress = min(self.current_shots / typical_shots, 1.0)

        # ============================================================================
        # FINAL STATE VECTOR ASSEMBLY
        # ============================================================================
        # All 8 features are combined into a single numpy array with dtype float32
        # This vector is fed directly into the neural network as input
        state = np.array([
            alg_norm,           # [0] Algorithm type (categorical, normalized)
            size_norm,          # [1] Problem size in qubits (continuous, normalized)
            backend_norm,       # [2] Hardware backend (categorical, normalized)
            shots_norm,         # [3] Absolute shot budget usage (continuous, normalized)
            entropy,            # [4] Distribution entropy (statistical, normalized)
            variance,           # [5] Distribution variance (statistical, normalized)
            rate_of_change,     # [6] Convergence velocity (statistical, normalized)
            relative_progress   # [7] Size-aware progress (heuristic, normalized)
        ], dtype=np.float32)
        
        return state

# --- RL Agent and Network ---

class DQN(nn.Module):
    """
    UPDATED: Now accepts 8 input features instead of 4
    Network architecture is also slightly deeper to handle the increased complexity
    """
    def __init__(self, input_size: int, output_size: int):
        super().__init__()
        # Slightly larger network to accommodate richer state representation
        self.net = nn.Sequential(
            nn.Linear(input_size, 512), nn.ReLU(),
            nn.Dropout(0.1),  # Light regularization
            nn.Linear(512, 256), nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, output_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class Agent:
    """
    The RL agent that encapsulates the DQN model and all the logic for acting, remembering, and learning.
    """
    def __init__(self, state_size: int, action_size: int, learning_rate: float = 1e-4, gamma: float = 0.99):
        # The main network that the agent uses to make decisions. Its weights are constantly updated.
        self.q_net = DQN(state_size, action_size)
        # The target network is a copy of the q_net. It is held constant for a period to provide a stable target during training, preventing oscillations
        self.target_net = DQN(state_size, action_size)
        self.update_target()

        # Experience Replay Memory: stores past (state, action, reward, next_state, done) tuples
        self.memory = deque(maxlen=200000)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=learning_rate)
        
        # --- Hyperparameters ---
        # Gamma (discount factor): determines how much the agent values future rewards
        # A value closer to 1 makes the agent more "farsighted"
        self.gamma = gamma
        # Epsilon: The exploration rate for the epsilon-greedy policy
        # It starts at 1.0 (100% random actions) and decays over time to a minimum value.
        self.epsilon = 1.0
        self.epsilon_decay = 0.9998
        self.epsilon_min = 0.05
        
        self.action_size = action_size

    def act(self, state: np.ndarray) -> int:
        """
        Chooses an action based on the current state using an epsilon-greedy policy.
        """
        # Exploration: with probability epsilon, choose a random action.
        if random.random() < self.epsilon:
            return random.randint(0, self.action_size - 1)
        # Exploitation: with probability 1-epsilon, use the neural network to predict the best action (the one with the highest Q-value)
        state_tensor = torch.FloatTensor(state).unsqueeze(0)  # Convert to tensor and add batch dimension
        with torch.no_grad(): # Disable gradient computation for inference (speeds up computation and saves memory)
            q_values = self.q_net(state_tensor)
        return torch.argmax(q_values).item()  # Return the action with the highest Q-value

    def remember(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool):
        """Stores an experience tuple in the replay memory"""
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size: int = 64):
        """
        *** THE CORE LEARNING STEP ***
        Samples a batch of past experiences from memory and trains the network using the Bellman equation
        """
        if len(self.memory) < batch_size:
            return  # Not enough experiences yet to train

        # Randomly sample a minibatch of experiences. This breaks correlations in the data.
        minibatch = random.sample(self.memory, batch_size)
        
        # Separate the batch into its components
        states, actions, rewards, next_states, dones = zip(*minibatch)
        
        # Convert to tensors for PyTorch
        states = torch.FloatTensor(np.array(states))
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.FloatTensor(dones)

        # --- Compute Q-values for the current state using the main network ---
        current_q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # --- Compute target Q-values using the Bellman equation and the target network ---
        with torch.no_grad():
            # Target network is used to provide a stable Q-value estimate for the next state
            next_q_values = self.target_net(next_states).max(1)[0]
            # Bellman equation: Q_target = reward + gamma * max_a' Q(s', a')
            # If the episode is done (terminal state), the future value is 0
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        # --- Calculate loss and update weights ---
        # Mean Squared Error between predicted Q and target Q
        loss = F.mse_loss(current_q_values, target_q_values)
        
        # Standard PyTorch training loop: zero gradients, backpropagate, update weights
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon (reduce exploration over time)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target(self):
        """Copies the weights from the main network to the target network"""
        self.target_net.load_state_dict(self.q_net.state_dict())

# --- Training Loop ---

def train_agent(num_episodes: int = 15000, update_target_every: int = 500) -> Tuple:
    """
    The main training loop for the RL agent
    """
    load_oracle_cache() # Load any previously computed oracle results to save time
    
    env = IterativeQuantumEnv(
        max_shots=20000,
        step_size=50,
        hard_problem_bias=0.7  # Now 70% of training samples will be hard problems
    )
    
    # Initialize the agent with the correct input size (now 8 instead of 4)
    agent = Agent(state_size=8, action_size=2, learning_rate=1e-4, gamma=0.99)
    
    rewards_history = []
    
    for episode in tqdm(range(num_episodes), desc="Training Episodes"):
        state = env.reset()
        total_reward = 0
        done = False
        
        # Run one episode (one complete problem)
        while not done:
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            agent.replay(batch_size=64)
            state = next_state
            total_reward += reward
        
        rewards_history.append(total_reward)
        
        # Periodically update the target network
        if episode % update_target_every == 0 and episode > 0:
            agent.update_target()
            print(f"\nEpisode {episode}: Avg Reward (last 100) = {np.mean(rewards_history[-100:]):.2f}, Epsilon = {agent.epsilon:.3f}")
    
    save_oracle_cache()  # Save the oracle cache for future runs
    return agent, env, rewards_history

def evaluate_agent_sequential(agent: Agent, env: IterativeQuantumEnv, num_tests: int = 200) -> List[Dict]:
    """Evaluates the final trained agent's performance with exploration turned off"""
    agent.epsilon = 0.0 # Set to a greedy policy (always exploit)
    results = []
    for _ in tqdm(range(num_tests), desc="Evaluation"):
        state = env.reset()
        done = False
        info = {}
        while not done:
            action = agent.act(state)
            state, _, done, info = env.step(action)
        results.append(info)
    return results


# --- Analysis and Plotting Dashboard ---

class AnalysisDashboard:
    """A class to handle the visualization of the training and evaluation results"""
    def __init__(self, training_rewards: List[float], evaluation_results: List[Dict]):
        self.rewards = training_rewards
        self.results = evaluation_results
        self.shots_used = [r['shots_used'] for r in self.results]
        self.optimal_shots = [r['optimal_shots'] for r in self.results]
        self.errors = np.array([r['error'] for r in self.results])
        self.triplets = [r['triplet'] for r in self.results]

    def plot_dashboard(self):
        """Generates and displays a 2x2 dashboard of performance plots"""
        fig, axes = plt.subplots(2, 2, figsize=(18, 14))
        fig.suptitle("Enhanced RL Agent Performance Analysis (8-Feature State + Hard Problem Bias)", fontsize=20, y=0.98)

        self._plot_training_rewards(axes[0, 0])
        self._plot_performance_scatter(axes[0, 1])
        self._plot_error_distribution(axes[1, 0])
        self._plot_error_by_difficulty(axes[1, 1])

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.show()

    def _plot_training_rewards(self, ax: plt.Axes):
        """
        PURPOSE: To visualize the agent's learning progress over time
        """
        ax.plot(self.rewards, label='Total Reward per Episode', alpha=0.2, color='C0')
        if len(self.rewards) >= 200:
            moving_avg = np.convolve(self.rewards, np.ones(200)/200, mode='valid')
            ax.plot(np.arange(199, len(self.rewards)), moving_avg, color='C1', label='200-Episode Moving Average')
        ax.axhline(0, color='k', linestyle='--', alpha=0.5)
        ax.set_title("1. Training Rewards Over Time", fontsize=14)
        ax.set_xlabel("Episode")
        ax.set_ylabel("Total Episode Reward")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_performance_scatter(self, ax: plt.Axes):
        """
        PURPOSE: To directly compare the agent's final policy against the Oracle's perfect policy
        """
        mae = np.mean(np.abs(self.errors))
        ax.scatter(self.optimal_shots, self.shots_used, alpha=0.6, edgecolors='k', s=50)
        perfect_range = [0, max(max(self.optimal_shots), max(self.shots_used))]
        ax.plot(perfect_range, perfect_range, 'r--', linewidth=2, label='Perfect Policy (y=x)')
        ax.set_title(f"2. Agent Policy vs. Oracle (MAE: {mae:.1f} shots)", fontsize=14)
        ax.set_xlabel("Optimal Shots (Determined by Oracle)")
        ax.set_ylabel("Shots Used by Agent")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_error_distribution(self, ax: plt.Axes):
        """
        PURPOSE: To visualize the distribution of the agent's stopping errors
        """
        ax.hist(self.errors, bins=40, edgecolor='k', alpha=0.7)
        ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero Error (Perfect Stop)')
        ax.set_title("3. Distribution of Stopping Errors", fontsize=14)
        ax.set_xlabel("Error (Agent Shots - Optimal Shots)")
        ax.set_ylabel("Frequency of Occurrences")
        ax.legend()
        ax.grid(True, linestyle='--')

    def _plot_error_by_difficulty(self, ax: plt.Axes):
        """
        Shows how the agent performs on easy vs. hard problems
        """
        easy_errors = [e for e, opt in zip(self.errors, self.optimal_shots) if opt <= 5000]
        hard_errors = [e for e, opt in zip(self.errors, self.optimal_shots) if opt > 5000]
        
        ax.hist([easy_errors, hard_errors], bins=30, label=['Easy (≤5000)', 'Hard (>5000)'], alpha=0.7, edgecolor='k')
        ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Perfect')
        ax.set_title("4. Error Distribution by Problem Difficulty", fontsize=14)
        ax.set_xlabel("Error (Agent Shots - Optimal Shots)")
        ax.set_ylabel("Frequency")
        ax.legend()
        ax.grid(True, linestyle='--')

    def print_summary_statistics(self):
        """Generates a detailed text summary of the agent's performance"""
        print("\n" + "="*80)
        print("ENHANCED AGENT EVALUATION SUMMARY (8 Features + Hard Problem Bias)")
        print("="*80)
        
        # Overall metrics
        mae = np.mean(np.abs(self.errors))
        rmse = np.sqrt(np.mean(self.errors**2))
        median_error = np.median(self.errors)
        
        print(f"\nOverall Performance Metrics:")
        print(f"   • Mean Absolute Error (MAE):    {mae:.1f} shots")
        print(f"   • Root Mean Squared Error:      {rmse:.1f} shots")
        print(f"   • Median Error:                 {median_error:.1f} shots")
        
        # Breakdown by difficulty
        easy_mask = np.array(self.optimal_shots) <= 5000
        hard_mask = np.array(self.optimal_shots) > 5000
        
        if np.any(easy_mask):
            easy_mae = np.mean(np.abs(self.errors[easy_mask]))
            print(f"\nEasy Problems (≤5000 optimal shots):")
            print(f"   • Count: {np.sum(easy_mask)}")
            print(f"   • MAE:   {easy_mae:.1f} shots")
        
        if np.any(hard_mask):
            hard_mae = np.mean(np.abs(self.errors[hard_mask]))
            print(f"\nHard Problems (>5000 optimal shots):")
            print(f"   • Count: {np.sum(hard_mask)}")
            print(f"   • MAE:   {hard_mae:.1f} shots")
        
        # Efficiency analysis
        overshots = np.sum(self.errors > 0)
        undershots = np.sum(self.errors < 0)
        perfect = np.sum(self.errors == 0)
        
        print(f"\nDecision Breakdown:")
        print(f"   • Overshooting (wasteful):      {overshots} ({100*overshots/len(self.errors):.1f}%)")
        print(f"   • Undershooting (unstable):     {undershots} ({100*undershots/len(self.errors):.1f}%)")
        print(f"   • Perfect stops:                {perfect} ({100*perfect/len(self.errors):.1f}%)")
        
        print("\n" + "="*80 + "\n")

# --- Main Execution ---

if __name__ == "__main__":
    print("\n" + "="*80)
    print("ENHANCED QUANTUM SHOT ALLOCATION RL AGENT")
    print("="*80)
    print("\nEnhancements in this version:")
    print("  1. Training Set Tuning: 70% of samples are from hard problems (>5000 shots)")
    print("  2. Enhanced State Features: 8 features instead of 4")
    print("     - Added: entropy, variance, rate of change, relative progress")
    print("  3. Deeper Network: More capacity to learn from richer state representation")
    print("\n" + "="*80 + "\n")
    
    # Train the agent
    agent, env, rewards = train_agent(num_episodes=15000, update_target_every=500)
    
    # Evaluate the agent
    print("\nEvaluating trained agent...")
    results = evaluate_agent_sequential(agent, env, num_tests=200)
    
    # Visualize and analyze results
    dashboard = AnalysisDashboard(rewards, results)
    dashboard.print_summary_statistics()
    dashboard.plot_dashboard()

ModuleNotFoundError: No module named 'github_auth_patch'