In [None]:
import numpy as np
import pennylane as qml
from pennylane import numpy as pnp
import qmcpy as qmc
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.stats import qmc as scipy_qmc
import time
from typing import Tuple, List, Callable
import warnings
warnings.filterwarnings('ignore')

class VQEPointPlacement:
    """
    Variational Quantum Eigensolver for optimal point placement.
    Encodes discrepancy minimization as a quantum optimization problem.
    """

    def __init__(self, n_points: int, dimension: int, n_qubits: int = 8):
        self.n_points = n_points
        self.dimension = dimension
        self.n_qubits = n_qubits
        self.n_params = None

        # Initialize quantum device
        self.dev = qml.device('default.qubit', wires=n_qubits)

        # Create the quantum circuit
        self.circuit = self._create_circuit()

    def _create_circuit(self):
        """Create parameterized quantum circuit for point generation."""

        @qml.qnode(self.dev, diff_method="parameter-shift")
        def circuit(params):
            # Number of parameters needed
            n_layers = 3
            self.n_params = n_layers * self.n_qubits * 3  # RX, RY, RZ for each qubit

            # Ensure params has correct shape
            if len(params) != self.n_params:
                # Pad or truncate as needed
                if len(params) < self.n_params:
                    params = np.concatenate([params, np.zeros(self.n_params - len(params))])
                else:
                    params = params[:self.n_params]

            params = params.reshape(n_layers, self.n_qubits, 3)

            # Initialize with Hadamard gates
            for wire in range(self.n_qubits):
                qml.Hadamard(wires=wire)

            # Variational layers
            for layer in range(n_layers):
                # Rotation gates
                for wire in range(self.n_qubits):
                    qml.RX(params[layer, wire, 0], wires=wire)
                    qml.RY(params[layer, wire, 1], wires=wire)
                    qml.RZ(params[layer, wire, 2], wires=wire)

                # Entangling gates
                for wire in range(self.n_qubits - 1):
                    qml.CNOT(wires=[wire, wire + 1])

                # Ring connectivity
                if self.n_qubits > 2:
                    qml.CNOT(wires=[self.n_qubits - 1, 0])

            # Return probabilities
            return qml.probs(wires=range(self.n_qubits))

        return circuit

    def _generate_points_from_probs(self, probs: np.ndarray) -> np.ndarray:
        """Convert quantum probabilities to point coordinates."""
        # Use probabilities to generate points
        n_states = len(probs)

        # Sample states according to probabilities
        states = np.random.choice(n_states, size=self.n_points, p=probs)

        # Convert states to binary representations
        points = np.zeros((self.n_points, self.dimension))

        for i, state in enumerate(states):
            binary = format(state, f'0{self.n_qubits}b')
            # Use different bits for different dimensions
            for d in range(self.dimension):
                bit_start = (d * self.n_qubits // self.dimension)
                bit_end = ((d + 1) * self.n_qubits // self.dimension)
                if bit_end > bit_start:
                    # Convert subset of bits to coordinate
                    subset = binary[bit_start:bit_end]
                    coord = int(subset, 2) / (2**(bit_end - bit_start) - 1)
                    points[i, d] = coord
                else:
                    # Fallback for edge cases
                    points[i, d] = (state % (2**self.n_qubits)) / (2**self.n_qubits - 1)

        return points

    def star_discrepancy(self, points: np.ndarray) -> float:
        """
        Compute star discrepancy of point set.
        Lower discrepancy indicates better uniformity.
        """
        n, d = points.shape
        max_disc = 0.0

        # Sample test points for discrepancy calculation
        n_test = min(1000, 2**d * 10)  # Limit for computational efficiency

        for _ in range(n_test):
            # Random test box [0, u1] x [0, u2] x ... x [0, ud]
            u = np.random.uniform(0, 1, d)

            # Count points in box
            in_box = np.all(points <= u, axis=1)
            empirical_measure = np.sum(in_box) / n

            # Theoretical measure (volume of box)
            theoretical_measure = np.prod(u)

            # Update maximum discrepancy
            disc = abs(empirical_measure - theoretical_measure)
            max_disc = max(max_disc, disc)

        return max_disc

    def cost_function(self, params: np.ndarray) -> float:
        """Cost function combining discrepancy and uniformity measures."""
        try:
            # Get probabilities from quantum circuit
            probs = self.circuit(params)

            # Generate points
            points = self._generate_points_from_probs(probs)

            # Compute star discrepancy
            discrepancy = self.star_discrepancy(points)

            # Add entropy term to encourage exploration
            entropy = -np.sum(probs * np.log(probs + 1e-10))
            max_entropy = np.log(len(probs))

            # Combined cost (minimize discrepancy, maximize normalized entropy)
            cost = discrepancy - 0.1 * (entropy / max_entropy)

            return cost

        except Exception as e:
            print(f"Error in cost function: {e}")
            return 1.0  # Return high cost on error

    def optimize(self, max_iter: int = 100) -> Tuple[np.ndarray, float]:
        """Optimize quantum circuit parameters using classical optimizer."""
        # Initialize parameters
        if self.n_params is None:
            # Run circuit once to determine parameter count
            dummy_params = np.random.uniform(0, 2*np.pi, 72)  # Initial guess
            self.circuit(dummy_params)

        initial_params = np.random.uniform(0, 2*np.pi, self.n_params)

        print(f"Starting VQE optimization with {self.n_params} parameters...")

        # Track optimization progress
        costs = []

        def callback(params):
            cost = self.cost_function(params)
            costs.append(cost)
            if len(costs) % 10 == 0:
                print(f"Iteration {len(costs)}: Cost = {cost:.6f}")

        # Optimize using COBYLA (good for noisy objectives)
        result = minimize(
            self.cost_function,
            initial_params,
            method='COBYLA',
            options={'maxiter': max_iter, 'disp': True},
            callback=callback
        )

        return result.x, result.fun, costs

    def generate_optimal_points(self, optimal_params: np.ndarray) -> np.ndarray:
        """Generate points using optimal parameters."""
        probs = self.circuit(optimal_params)
        return self._generate_points_from_probs(probs)


class PointPlacementComparison:
    """Compare VQE, QMC, and MC point placement strategies."""

    def __init__(self, n_points: int, dimension: int):
        self.n_points = n_points
        self.dimension = dimension

    def generate_monte_carlo_points(self) -> np.ndarray:
        """Generate random Monte Carlo points."""
        return np.random.uniform(0, 1, (self.n_points, self.dimension))

    def generate_qmc_points(self, method: str = 'sobol') -> np.ndarray:
        """Generate Quasi-Monte Carlo points using qmcpy."""
        if method == 'sobol':
            sampler = qmc.Sobol(self.dimension, randomize=True)
        elif method == 'halton':
            sampler = qmc.Halton(self.dimension, scramble=True)
        elif method == 'lattice':
            # Use qmcpy's lattice sequence
            sampler = qmc.Lattice(self.dimension, randomize=True)
        else:
            raise ValueError(f"Unknown QMC method: {method}")

        return sampler.gen_samples(self.n_points)

    def evaluate_discrepancy(self, points: np.ndarray) -> dict:
        """Evaluate various discrepancy measures."""
        vqe_placer = VQEPointPlacement(self.n_points, self.dimension)

        # Star discrepancy
        star_disc = vqe_placer.star_discrepancy(points)

        # L2 discrepancy approximation
        n_test = 1000
        l2_disc = 0.0
        for _ in range(n_test):
            u = np.random.uniform(0, 1, self.dimension)
            in_box = np.all(points <= u, axis=1)
            empirical = np.sum(in_box) / len(points)
            theoretical = np.prod(u)
            l2_disc += (empirical - theoretical)**2
        l2_disc = np.sqrt(l2_disc / n_test)

        # Minimum distance (measure of point clustering)
        if len(points) > 1:
            from scipy.spatial.distance import pdist
            min_dist = np.min(pdist(points))
        else:
            min_dist = 0.0

        return {
            'star_discrepancy': star_disc,
            'l2_discrepancy': l2_disc,
            'min_distance': min_dist
        }

    def run_comparison(self) -> dict:
        """Run comprehensive comparison of all methods."""
        results = {}

        print("=== Point Placement Comparison ===\n")

        # 1. Monte Carlo
        print("1. Generating Monte Carlo points...")
        start_time = time.time()
        mc_points = self.generate_monte_carlo_points()
        mc_time = time.time() - start_time
        mc_metrics = self.evaluate_discrepancy(mc_points)

        results['monte_carlo'] = {
            'points': mc_points,
            'time': mc_time,
            'metrics': mc_metrics
        }

        print(f"   Time: {mc_time:.3f}s")
        print(f"   Star Discrepancy: {mc_metrics['star_discrepancy']:.6f}")

        # 2. Quasi-Monte Carlo methods
        qmc_methods = ['sobol', 'halton']
        for method in qmc_methods:
            print(f"\n2. Generating QMC points ({method})...")
            start_time = time.time()
            qmc_points = self.generate_qmc_points(method)
            qmc_time = time.time() - start_time
            qmc_metrics = self.evaluate_discrepancy(qmc_points)

            results[f'qmc_{method}'] = {
                'points': qmc_points,
                'time': qmc_time,
                'metrics': qmc_metrics
            }

            print(f"   Time: {qmc_time:.3f}s")
            print(f"   Star Discrepancy: {qmc_metrics['star_discrepancy']:.6f}")

        # 3. VQE-based optimization
        print(f"\n3. Generating VQE-optimized points...")
        start_time = time.time()

        # Use smaller problem for demonstration
        vqe_n_qubits = min(6, max(4, int(np.log2(self.n_points)) + 1))
        vqe = VQEPointPlacement(self.n_points, self.dimension, vqe_n_qubits)

        optimal_params, final_cost, costs = vqe.optimize(max_iter=50)
        vqe_points = vqe.generate_optimal_points(optimal_params)
        vqe_time = time.time() - start_time
        vqe_metrics = self.evaluate_discrepancy(vqe_points)

        results['vqe'] = {
            'points': vqe_points,
            'time': vqe_time,
            'metrics': vqe_metrics,
            'optimization_costs': costs,
            'final_cost': final_cost
        }

        print(f"   Time: {vqe_time:.3f}s")
        print(f"   Star Discrepancy: {vqe_metrics['star_discrepancy']:.6f}")
        print(f"   Final VQE Cost: {final_cost:.6f}")

        return results


def visualize_results(results: dict, n_points: int, dimension: int):
    """Visualize comparison results."""

    # Create summary table
    print("\n=== COMPARISON SUMMARY ===")
    print(f"{'Method':<15} {'Time (s)':<10} {'Star Disc':<12} {'L2 Disc':<12} {'Min Dist':<10}")
    print("-" * 65)

    for method, data in results.items():
        metrics = data['metrics']
        print(f"{method:<15} {data['time']:<10.3f} "
              f"{metrics['star_discrepancy']:<12.6f} "
              f"{metrics['l2_discrepancy']:<12.6f} "
              f"{metrics['min_distance']:<10.6f}")

    # Create visualizations
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f'Point Placement Comparison ({n_points} points, {dimension}D)',
                 fontsize=16)

    # Plot 1: Point distributions (2D projection if dimension > 2)
    for i, (method, data) in enumerate(results.items()):
        row, col = i // 3, i % 3