In [None]:
import numpy as np
import os
import pennylane as qml
from typing import List, Tuple
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from pennylane.optimize import AdamOptimizer
import matplotlib.pyplot as plt

class HardwareEfficientDatasetGenerator:
    """
    Generates entangled quantum states using pre-trained hardware-efficient ansatz weights
    from the NTangled_Datasets repository structure.
    """
    
    def __init__(self, base_path: str = "Hardware_Efficient"):
        """
        Initialize generator with path to hardware-efficient ansatz weights.
        
        Args:
            base_path: Path to the 'Hardware Efficient' directory containing qubit folders
        """
        self.base_path = base_path
        self.supported_qubits = [3, 4, 8]
        self.supported_depths_34 = list(range(1, 7))
        self.supported_depths_8 = [5, 6]
        self.supported_ce_34 = [0.05, 0.15, 0.25, 0.35]
        self.supported_ce_8_6 = [0.10, 0.25]
        self.supported_ce_8_5 = [0.15, 0.40, 0.45]

    def validate_params(self, qubits: int, depth: int, goal_ce: float) -> bool:
        # Validate inputs
        if qubits not in self.supported_qubits:
            raise ValueError(f"Unsupported qubit count: {qubits}. Choose from {self.supported_qubits}")
        
        if (qubits == 8):
            if depth not in self.supported_depths_8:
                raise ValueError(f"Unsupported depth: {depth}. Choose from {self.supported_depths_8}")
            if (depth == 6):
                if goal_ce not in self.supported_ce_8_6:
                    raise ValueError(f"Unsupported CE value: {goal_ce}. Choose from {self.supported_ce_8_6}")
            else:
                if goal_ce not in self.supported_ce_8_5:
                    raise ValueError(f"Unsupported CE value: {goal_ce}. Choose from {self.supported_ce_8_5}")
        else:
            if depth not in self.supported_depths_34:
                raise ValueError(f"Unsupported depth: {depth}. Choose from {self.supported_depths_34}")
            if goal_ce not in self.supported_ce_34:
                raise ValueError(f"Unsupported CE value: {goal_ce}. Choose from {self.supported_ce_34}")
        
        return True

    def _construct_filepath(self, qubits: int, depth: int, goal_ce: float) -> str:
        """
        Construct the full file path based on parameters and directory structure.
        
        Args:
            qubits: Number of qubits (3 or 4)
            depth: Circuit depth (1-6)
            goal_ce: Target concentratable entanglement (0.05, 0.15, 0.25, 0.35, 0.5)
            
        Returns:
            Full path to .npy weights file
        """         

        self.validate_params(qubits, depth, goal_ce)

        # Validate goal_ce format
        ce_str = f"{int(goal_ce * 100):02d}"
        
        # Construct file path
        qubit_dir = f"{qubits}_Qubits"
        depth_dir = f"Depth_{depth}"
        filename = f"hwe_{qubits}q_ps_{ce_str}_{depth}_weights.npy"
        
        return os.path.join(self.base_path, qubit_dir, depth_dir, filename)

    def load_weights(self, qubits: int, depth: int, goal_ce: float) -> np.ndarray:
        """
        Load weights from the repository.
        
        Args:
            qubits: Number of qubits
            input_type: Type of input states
            goal_ce: Goal concentratable entanglement (0-1)
            depth: Circuit depth
            
        Returns:
            Weights as numpy array
        """
        file_path = self._construct_filepath(qubits, depth, goal_ce)
        
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Weights file not found at: {file_path}")
            
        return np.load(file_path)

    def hardware_efficient_ansatz(self, params: np.ndarray, wires: List[int]):
        """
        Implements the hardware efficient ansatz.
        
        Args:
            params: The ansatz parameters
            wires: The qubits to apply the ansatz to
            depth: The circuit depth
        """
        n_qubits = len(wires)
        depth = params.shape[0]
        
        # Validate parameter dimensions
        if params.shape != (depth, n_qubits, 3):
            raise ValueError(
                f"Parameter shape mismatch. Expected (depth={depth}, qubits={n_qubits}, 3), "
                f"got {params.shape}"
            )

        for d in range(depth):
            # Get parameters for this layer (shape: [n_qubits, 3])
            layer_params = params[d]
            
            # Apply single-qubit rotations
            for i, wire in enumerate(wires):
                qml.RX(layer_params[i, 0], wires=wire)
                qml.RY(layer_params[i, 1], wires=wire)
                qml.RZ(layer_params[i, 2], wires=wire)
            
            # Apply entangling gates (nearest-neighbor + circular connectivity)
            for i in range(n_qubits - 1):
                qml.CNOT(wires=[wires[i], wires[i+1]])
            if n_qubits > 1:
                qml.CNOT(wires=[wires[-1], wires[0]])


    def generate_states(self, qubits: int, depth: int, goal_ce: float, 
                       num_samples: int = 100) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate entangled quantum states dataset.
        
        Args:
            ansatz_type: Type of ansatz ('hardware_efficient', 'strongly_entangling', or 'convolutional')
            qubits: Number of qubits
            input_type: Type of input states ('computational' or 'random')
            goal_ce: Goal concentratable entanglement (0-1)
            depth: Circuit depth
            num_samples: Number of states to generate
            file_format: Format of weight files ('npy' or 'txt')
            
        Returns:
            Array of generated quantum states
        """
        # Load pretrained weights
        weights = self.load_weights(qubits, depth, goal_ce)
        print("Loaded weights shape:", weights.shape)
        
        # Initialize quantum device
        dev = qml.device("default.qubit", wires=qubits)
        
        # Define quantum circuit
        @qml.qnode(dev)
        def circuit(input_state):
            # Prepare computational basis state
            for q in range(qubits):
                if (input_state >> q) & 1:
                    qml.PauliX(wires=q)
            
            # Apply hardware-efficient ansatz
            self.hardware_efficient_ansatz(weights, list(range(qubits)))
            
            return qml.state()
        
        # Generate states
        input_states = []
        output_states = []
        
        for _ in range(num_samples):
            # Generate random computational basis state
            input_state = np.random.randint(0, 2**qubits)
            input_states.append(input_state)
            
            # Get output state
            output_state = circuit(input_state)
            output_states.append(output_state)
            
        return np.array(input_states), np.array(output_states)

    def save_dataset(self, input_states: np.ndarray, output_states: np.ndarray, 
                    save_dir: str, filename_prefix: str):
        """
        Save generated dataset to files.
        
        Args:
            input_states: Array of input computational basis states
            output_states: Array of output quantum states
            save_dir: Directory to save files
            filename_prefix: Prefix for output files
        """
        os.makedirs(save_dir, exist_ok=True)
        
        # Save input states
        np.save(os.path.join(save_dir, f"{filename_prefix}_inputs.npy"), input_states)
        
        # Save output states
        np.save(os.path.join(save_dir, f"{filename_prefix}_outputs.npy"), output_states)

    def create_classification_dataset(self, 
                                    entanglement_levels: List[float] = [0.15, 0.35, 0.45],
                                    depths: int | list[int] = 3,
                                    num_samples: int = 3000,
                                    qubits: int = 4,
                                    random_state: int = 42) -> Tuple[np.ndarray, np.ndarray]:
        """
        Create a classified dataset of quantum states with different entanglement levels.
        
        Args:
            entanglement_levels: List of target CE values for different classes
            depths: Single depth for all classes or list of depths per class
            num_samples: Total number of samples in the dataset
            qubits: Number of qubits in generated states
            random_state: Seed for reproducible shuffling
            
        Returns:
            X: Array of quantum states (num_samples, 2**qubits)
            y: Array of class labels (num_samples,)
            
        Raises:
            ValueError: For invalid input combinations
        """
        # Input validation
        if isinstance(depths, int):
            depths = [depths] * len(entanglement_levels)
        elif len(depths) != len(entanglement_levels):
            raise ValueError("Length of depths must match entanglement_levels")
            
        for ce_level, depth in zip(entanglement_levels, depths):
            self.validate_params(qubits, depth, ce_level)

        # Calculate samples per class with remainder distribution
        samples_per_class, remainder = divmod(num_samples, len(entanglement_levels))
        class_samples = [samples_per_class + (1 if i < remainder else 0) 
                        for i in range(len(entanglement_levels))]

        # Generate states for each class
        X, y = [], []
        for class_idx, (ce_level, depth, n_samples) in enumerate(zip(entanglement_levels, 
                                                                depths, 
                                                                class_samples)):
            print(f"Generating class {class_idx+1}/{len(entanglement_levels)}: "
                f"CE={ce_level}, depth={depth}, samples={n_samples}")
            
            _, states = self.generate_states(
                qubits=qubits,
                depth=depth,
                goal_ce=ce_level,
                num_samples=n_samples
            )
            
            X.append(states)
            y.append(np.full(n_samples, class_idx))

        # Combine and shuffle
        X = np.concatenate(X)
        y = np.concatenate(y)
        rng = np.random.default_rng(random_state)
        indices = rng.permutation(len(X))
        
        return X[indices], y[indices]


# # Example usage
# if __name__ == "__main__":
#     # Initialize generator with your directory structure
#     generator = HardwareEfficientDatasetGenerator(base_path="Hardware_Efficient")
    
#     # Generation parameters
#     params = {
#         'qubits': 3,
#         'depth': 2,
#         'goal_ce': 0.25,
#         'num_samples': 50
#     }
    
#     # Generate dataset
#     input_states, output_states = generator.generate_states(**params)
    
#     # Save dataset
#     save_dir = "generated_datasets"
#     filename = f"hwe_{params['qubits']}q_{params['depth']}d_{params['goal_ce']}ce"
#     generator.save_dataset(input_states, output_states, save_dir, filename)
    
#     print(f"Generated dataset with {len(output_states)} samples")
#     print(f"Input states shape: {input_states.shape}")
#     print(f"Output states shape: {output_states.shape}")
#     print(f"First output state amplitudes:\n{output_states[0]}")

# Example usage in the __main__ block
if __name__ == "__main__":
    generator = HardwareEfficientDatasetGenerator(base_path="Hardware_Efficient")

    # Classification dataset example
    print("\nGenerating classification dataset:")
    qubits = 3
    X, y = generator.create_classification_dataset(
        entanglement_levels=[0.15, 0.35],
        depths=[3, 5],
        qubits=qubits,
        num_samples=1000
    )
    
    print("\nClassification dataset stats:")
    print(f"Total samples: {len(X)}")
    print(f"Class distribution: {np.bincount(y)}")
    print(f"State shape: {X[0].shape}") # 2**qubits 
    print(f"First state:\n{X[0]}")
    print(f"Class label: {y[0]}")
    
    # Save classification dataset
    generator.save_dataset(
        input_states=X,  # Dummy inputs since we're using class labels
        output_states=y,
        save_dir="classification_datasets",
        filename_prefix=f"{qubits}q_ce_classification"
    )`