In [None]:
# Geometry-Guided Quantum SAT Solver
# Copyright (c) 2025, Sabarikirishwaran Ponnambalam
#
# This software is licensed under the Polyform Noncommercial License.
# See the LICENSE file in the project root for full license terms.

!pip install --no-cache-dir torch torchvision torchaudio -q
!pip install --no-cache-dir custatevec-cu12 cudaq -q
!pip install --no-cache-dir pandas numpy matplotlib geoopt -q
# !pip install --no-cache-dir pytorch-lightning -q
# !pip install --no-cache-dir lightning pennylane-lightning-gpu -q
# !pip install --no-cache-dir "jax[cuda12]" pennylane-catalyst -q

## Full state simulation for small scale

In [None]:
# Geometry-Guided Quantum SAT Solver

import math
import itertools

import torch
import numpy as np
from tqdm import tqdm

import geoopt
from typing import List, Tuple
import matplotlib.pyplot as plt


# Clause type
Clause = List[Tuple[int, bool]]
Formula = List[Clause]

def generate_random_3sat(n_vars: int, n_clauses: int) -> Formula:
    formula = []
    for _ in range(n_clauses):
        clause_vars = np.random.choice(n_vars, 3, replace=False)
        clause = [(int(v), bool(np.random.randint(2))) for v in clause_vars]
        formula.append(clause)
    return formula

def eval_clause(bits: Tuple[int], clause: Clause) -> bool:
    return any((bits[v] ^ is_neg) for v, is_neg in clause)

def basis_states(n: int) -> List[Tuple[int]]:
    return list(itertools.product([0, 1], repeat=n))

# Clause projector
def clause_projector(n: int, clause: Clause) -> torch.Tensor:
    dim = 2 ** n
    P = torch.zeros((dim, dim), dtype=torch.cdouble)
    for i, bits in enumerate(basis_states(n)):
        if eval_clause(bits, clause):
            P[i, i] = 1.0
    return P

def full_solution_projector(n: int, formula: Formula) -> torch.Tensor:
    dim = 2 ** n
    P = torch.zeros((dim, dim), dtype=torch.cdouble)
    for i, bits in enumerate(basis_states(n)):
        if all(eval_clause(bits, c) for c in formula):
            P[i, i] = 1.0
    return P

def fubini_study_angle(psi: torch.Tensor, phi: torch.Tensor) -> float:
    overlap = torch.abs(torch.vdot(psi, phi))
    return float(torch.acos(torch.clamp(overlap, 0.0, 1.0)))

# Geometry-guided unitary optimization on St(d,d)
def optimize_unitary_stiefel(psi: torch.Tensor, P: torch.Tensor, 
                              lr=0.1, steps=100) -> Tuple[torch.Tensor, List[float]]:
    d = psi.shape[0]
    manifold = geoopt.manifolds.Stiefel()
    U = geoopt.ManifoldParameter(torch.eye(d, dtype=torch.cdouble), manifold=manifold)
    optimizer = torch.optim.SGD([U], lr=lr)
    overlaps = []

    for _ in range(steps):
        optimizer.zero_grad()
        UHU = U.conj().T @ P @ U
        overlap = torch.vdot(psi, UHU @ psi).real
        loss = -overlap
        loss.backward()
        optimizer.step()
        overlaps.append(overlap.item())

    return U.detach(), overlaps

# Measurement & projection
def project_and_normalize(psi: torch.Tensor, P: torch.Tensor) -> Tuple[torch.Tensor, float]:
    p = torch.vdot(psi, P @ psi).real.item()
    if p < 1e-12:
        return psi, 0.0
    psi_proj = P @ psi / torch.sqrt(torch.tensor(p, dtype=torch.float64))
    return psi_proj, p

# Solver
def geometry_guided_solver(n: int, formula: Formula, steps=50, lr=0.1, verbose=False):
    dim = 2 ** n
    psi = torch.ones(dim, dtype=torch.cdouble) / math.sqrt(dim)
    projectors = [clause_projector(n, c) for c in formula]
    P_solution = full_solution_projector(n, formula)

    total_p = 1.0
    fs_angles = []

    for P in tqdm(projectors):
        U, _ = optimize_unitary_stiefel(psi, P, lr=lr, steps=steps)
        psi_rot = U @ psi
        psi, p_clause = project_and_normalize(psi_rot, P)
        total_p *= p_clause
        fs_angle = fubini_study_angle(psi, (P_solution @ psi) / torch.linalg.norm(P_solution @ psi))
        fs_angles.append(fs_angle)
        if verbose:
            print(f"Clause success: {p_clause:.6f}, Fubini–Study angle: {fs_angle:.4f} rad")
        if p_clause < 1e-6:
            break

    return {
        "final_state": psi,
        "success_prob": total_p,
        "fubini_angles": fs_angles,
        "overlap_solution": torch.vdot(psi, P_solution @ psi).real.item()
    }

# Time-to-solution
def time_to_solution(p_success: float, t_run: float, p_target=0.99) -> float:
    if p_success <= 0.0:
        return float("inf")
    if p_success >= 1.0:
        return t_run
    num_runs = math.log(1 - p_target) / math.log(1 - p_success)
    return t_run * num_runs

# Baselines (nonguided + classical)
def nonguided_measurement_solver(n: int, formula: Formula):
    dim = 2 ** n
    psi = torch.ones(dim, dtype=torch.cdouble) / math.sqrt(dim)
    total_p = 1.0
    for clause in formula:
        P = clause_projector(n, clause)
        p_clause = torch.vdot(psi, P @ psi).real.item()
        psi, _ = project_and_normalize(psi, P)
        total_p *= p_clause
    return total_p

# Benchmarking setup
def benchmark_experiment(n: int, m: int, trials=10):
    geo_scores = []
    nonguided_scores = []
    fs_logs = []
    for _ in range(trials):
        formula = generate_random_3sat(n, m)
        res = geometry_guided_solver(n, formula, steps=30, lr=0.1)
        geo_scores.append(res["success_prob"])
        fs_logs.append(res["fubini_angles"])
        nonguided = nonguided_measurement_solver(n, formula)
        nonguided_scores.append(nonguided)

    print("Geometry-guided avg success:", np.mean(geo_scores))
    print("Non-guided baseline avg success:", np.mean(nonguided_scores))
    return geo_scores, nonguided_scores, fs_logs

if __name__ == "__main__":
    n = 12
    m = 6
    geo_scores, nonguided_scores, fs_logs = benchmark_experiment(n, m, trials=5)

    plt.figure(figsize=(8, 4))
    for angles in fs_logs:
        plt.plot(angles, label='Fubini–Study angle')
    plt.xlabel("Clause step")
    plt.ylabel("Fubini–Study angle (rad)")
    plt.title("Geometric evolution of state vs solution space")
    plt.grid()
    plt.show()

## Quantum Parametric Circuit

In [None]:
"""
Geometry-guided 3-SAT solver with CUDA-Q (cuQuantum backend) and Riemannian
optimization over a parametric circuit's angles (theta).

This follows the specification described in the chat:
- 3-SAT -> diagonal SAT Hamiltonian
- parametric ansatz U(theta) defined as CUDA-Q kernel
- statevector simulation on GPU (cuQuantum)
- geometric flag regularizer + Fubini–Study smoothness regularizer
- Riemannian optimization over theta using geoopt with finite-diff gradients
"""

import math
import time
import itertools
from typing import List, Tuple, Dict, Any

import numpy as np
import torch
import geoopt

import cudaq
from tqdm.auto import tqdm 

# --------------------------------------------------------------------
# SAT encoding
# --------------------------------------------------------------------

Clause = List[Tuple[int, bool]]
Formula = List[Clause]


def generate_random_3sat(n_vars: int, n_clauses: int) -> Formula:
    """Random 3-SAT instance with distinct vars per clause."""
    formula: Formula = []
    for _ in range(n_clauses):
        clause_vars = np.random.choice(n_vars, 3, replace=False)
        clause: Clause = [(int(v), bool(np.random.randint(2))) for v in clause_vars]
        formula.append(clause)
    return formula


def eval_clause(bits: Tuple[int, ...], clause: Clause) -> bool:
    """
    Evaluate a single clause on a bitstring.
    bits[i] in {0,1}; clause is list of (var_index, is_negated).

    Interpretation:
      literal = x if is_negated == False,
                (NOT x) if is_negated == True
    """
    # clause satisfied iff any literal is True.
    for var_idx, is_neg in clause:
        val = bits[var_idx]
        if is_neg:
            lit_val = 1 - val  # NOT
        else:
            lit_val = val
        if lit_val == 1:
            return True
    return False


def bits_from_int(i: int, n: int) -> Tuple[int, ...]:
    """Convert integer to n-bit tuple (big-endian: qubit 0 is MSB)."""
    return tuple((i >> (n - 1 - j)) & 1 for j in range(n))

# --------------------------------------------------------------------
# SAT Hamiltonian
# --------------------------------------------------------------------

def precompute_sat_data(n: int, formula: Formula):
    """
    For given n and formula, precompute:
      - energies: length 2^n, energies[i] = # of unsatisfied clauses
      - sat_masks: shape (m, 2^n), sat_masks[k, i] True if assignment i
                   satisfies clauses 0..k (inclusive).
      - basis_strings: all bitstrings '000..0' .. '111..1', to feed cudaq.
    """
    num_states = 1 << n
    m = len(formula)

    energies = np.zeros(num_states, dtype=np.float64)
    sat_masks = np.zeros((m, num_states), dtype=bool)
    basis_strings = [format(i, f"0{n}b") for i in range(num_states)]

    for idx in range(num_states):
        bits = bits_from_int(idx, n)
        clause_satisfied = [eval_clause(bits, c) for c in formula]
        unsat = m - sum(clause_satisfied)
        energies[idx] = float(unsat)

        # Flag masks: satisfy first k+1 clauses
        running_ok = True
        for k in range(m):
            if running_ok and clause_satisfied[k]:
                sat_masks[k, idx] = True
            else:
                running_ok = False
                # No need to set sat_masks[k..] further; they remain False.

    return energies, sat_masks, basis_strings

# --------------------------------------------------------------------
# CUDA-Q kernel for parametric ansatz U(theta)
# --------------------------------------------------------------------

@cudaq.kernel
def sat_ansatz(params: list[float], n_qubits: int):
    """
    Parametric ansatz:
      - Start from |0..0>
      - Apply H on each qubit (-> |+>^n)
      - depth layers of (RY, RZ) per qubit + CNOT chain entanglers

    params length must be 2 * n_qubits * depth.
    depth = len(params) // (2 * n_qubits).
    """
    q = cudaq.qvector(n_qubits)

    # Prepare reference state |+>^n
    for i in range(n_qubits):
        h(q[i])

    depth = int(len(params) // (2 * n_qubits))
    idx = 0
    for _ in range(depth):
        # Single-qubit rotations
        for i in range(n_qubits):
            ry(params[idx], q[i])
            idx += 1
            rz(params[idx], q[i])
            idx += 1
        # Entangling layer: CNOT chain
        for i in range(n_qubits - 1):
            x.ctrl(q[i], q[i + 1])

def prepare_state(theta: np.ndarray, n_qubits: int) -> cudaq.State:
    """
    Run CUDA-Q statevector simulation for given theta and n_qubits.
    Returns a cudaq.State living on GPU (cuQuantum backend).
    """
    theta_list = theta.tolist()
    state = cudaq.get_state(sat_ansatz, theta_list, n_qubits)
    return state

# --------------------------------------------------------------------
# Geometry guided loss and regularizers
# --------------------------------------------------------------------

def evaluate_loss(
    theta: np.ndarray,
    n: int,
    formula: Formula,
    energies: np.ndarray,
    sat_masks: np.ndarray,
    basis_strings: List[str],
    prev_state: cudaq.State | None,
    lambda_flag: float = 0.1,
    lambda_smooth: float = 0.01,
) -> Tuple[float, Dict[str, Any], cudaq.State]:
    """
    Compute loss and metrics at given theta.

    Returns:
      loss (scalar)
      metrics dict
      current cudaq.State
    """
    m = len(formula)
    state = prepare_state(theta, n)

    # basis state amplitudes
    amps = np.array(state.amplitudes(basis_strings), dtype=np.complex128)
    probs = np.abs(amps) ** 2

    # objective
    E = float(np.dot(probs, energies))  # expectation of frac. of not-sat. clauses
    main_loss = E / m if m > 0 else 0.0
    sat_prob = 1.0 - main_loss if m > 0 else 0.0

    # Flag regularizer: FS angle to subspaces satisfying first k clauses
    m_clauses = sat_masks.shape[0]
    flag_terms: List[float] = []
    for k in range(m_clauses):
        p_k = float(probs[sat_masks[k]].sum())
        p_k = max(0.0, min(1.0, p_k))
        if p_k <= 0.0:
            theta_fs_k = math.pi / 2.0
        else:
            theta_fs_k = math.acos(math.sqrt(p_k))
        flag_terms.append(theta_fs_k ** 2)
    R_flag = float(sum(flag_terms) / m_clauses) if m_clauses > 0 else 0.0

    # Smoothness regularizer: FS angle to previous state
    R_smooth = 0.0
    fs_angle_prev = None
    if prev_state is not None:
        overlap = state.overlap(prev_state)
        val = abs(overlap)
        val = max(0.0, min(1.0, val))
        fs_angle_prev = math.acos(val)
        R_smooth = fs_angle_prev ** 2

    loss = main_loss + lambda_flag * R_flag + lambda_smooth * R_smooth

    metrics = {
        "loss": loss,
        "main_loss": main_loss,
        "E": E,
        "sat_prob": sat_prob,
        "R_flag": R_flag,
        "R_smooth": R_smooth,
        "theta_fs_prev": fs_angle_prev,
    }
    return loss, metrics, state


def finite_difference_grad(
    theta: np.ndarray,
    loss_fn,
    eps: float = 1e-2,
) -> np.ndarray:
    """
    Central finite-difference gradient for scalar loss_fn(theta).

    loss_fn: callable(theta: np.ndarray) -> float
    """
    grad = np.zeros_like(theta, dtype=np.float64)
    for i in range(theta.size):
        theta_plus = theta.copy()
        theta_minus = theta.copy()
        theta_plus[i] += eps
        theta_minus[i] -= eps

        lp = loss_fn(theta_plus)
        lm = loss_fn(theta_minus)
        grad[i] = (lp - lm) / (2.0 * eps)
    return grad


# --------------------------------------------------------------------
# Riemannian optimization
# --------------------------------------------------------------------

def optimize_geometry_guided_sat(
    n: int,
    formula: Formula,
    depth: int = 1,
    max_iters: int = 20,
    lr: float = 0.1,
    lambda_flag: float = 0.1,
    lambda_smooth: float = 0.01,
    verbose: bool = True,
) -> Tuple[np.ndarray, List[Dict[str, Any]], float]:
    """
    Geometry-guided SAT solver logic

    Returns:
      best_theta (numpy array)
      history (list of metric dicts)
      total_run_time (seconds)
    """
    # Set CUDA-Q target to GPU/cuQuantum if available
    try:
        cudaq.set_target("nvidia")
    except Exception:
        try:
            cudaq.set_target("cuquantum")
        except Exception: # Fallback: let CUDA-Q choose default simulator
            pass

    str_tm_SATprecomp = time.time()
    energies, sat_masks, basis_strings = precompute_sat_data(n, formula)
    end_tm_SATprecomp = time.time()

    print("SAT precompute data exec. time: ", (end_tm_SATprecomp - str_tm_SATprecomp))

    num_params = 2 * n * depth
    manifold = geoopt.manifolds.Euclidean()
    theta = geoopt.ManifoldParameter(
        torch.zeros(num_params, dtype=torch.float64), manifold=manifold
    )
    optimizer = geoopt.optim.RiemannianAdam([theta], lr=lr)

    history: List[Dict[str, Any]] = []
    prev_state: cudaq.State | None = None
    best_theta = None
    best_sat_prob = -1.0

    start_time = time.perf_counter()
        
    iter_range = range(max_iters)
    if verbose:
        iter_range = tqdm(iter_range, desc="Optimization", unit="iter")
    
    #Optimization loop
    for it in iter_range:
        optimizer.zero_grad(set_to_none=True)

        # Current parameter
        theta_np = theta.detach().cpu().numpy().astype(np.float64)

        # loss_fn (fix prev_state for this iteration)
        def loss_fn_local(vec: np.ndarray) -> float:
            l, _, _ = evaluate_loss(
                vec,
                n,
                formula,
                energies,
                sat_masks,
                basis_strings,
                prev_state=prev_state,
                lambda_flag=lambda_flag,
                lambda_smooth=lambda_smooth,
            )
            return l

        # Evaluate base loss
        base_loss, metrics, state = evaluate_loss(
            theta_np,
            n,
            formula,
            energies,
            sat_masks,
            basis_strings,
            prev_state=prev_state,
            lambda_flag=lambda_flag,
            lambda_smooth=lambda_smooth,
        )

        # Estimate gradient
        grad_np = finite_difference_grad(theta_np, loss_fn_local, eps=1e-2)
        theta.grad = torch.from_numpy(grad_np).to(theta)

        optimizer.step()

        prev_state = state
        history.append(metrics)

        if metrics["sat_prob"] > best_sat_prob:
            best_sat_prob = metrics["sat_prob"]
            best_theta = theta_np.copy()

        if verbose:
            if isinstance(iter_range, tqdm):
                iter_range.set_postfix(
                    loss=f"{metrics['loss']:.4f}",
                    sat=f"{metrics['sat_prob']:.4f}",
                    Rf=f"{metrics['R_flag']:.3f}",
                    Rs=f"{metrics['R_smooth']:.3f}",
                )
            else:
                print(
                    f"[Iter {it:02d}] "
                    f"loss={metrics['loss']:.4f}, "
                    f"sat_prob={metrics['sat_prob']:.4f}, "
                    f"R_flag={metrics['R_flag']:.4f}, "
                    f"R_smooth={metrics['R_smooth']:.4f}"
                )

    total_time = time.perf_counter() - start_time
    return best_theta, history, total_time


# --------------------------------------------------------------------
# TTS logic
# --------------------------------------------------------------------

def time_to_solution(p_success: float, t_run: float, p_target: float = 0.99) -> float:
    """
    Expected time to achieve overall success probability >= p_target,
    assuming independent repeats with per-run success p_success and
    runtime t_run per run.
    """
    if p_success <= 0.0:
        return float("inf")
    if p_success >= 1.0:
        return t_run
    num_runs = math.log(1.0 - p_target) / math.log(1.0 - p_success)
    return t_run * num_runs

# --------------------------------------------------------------------
# Driver
# --------------------------------------------------------------------
if __name__ == "__main__":    
    n = 18  # no.of. variables
    m = 91  # no.of. clauses to satisfy
    depth = 1

    formula = generate_random_3sat(n, m)
    print(f"Generated random 3-SAT with n={n}, m={m}")

    best_theta, hist, total_time = optimize_geometry_guided_sat(
        n,
        formula,
        depth=depth,
        max_iters=50,      
        lr=0.1,
        lambda_flag=0.1,
        lambda_smooth=0.01,
        verbose=True,
    )

    final_metrics = hist[-1]
    p_success = final_metrics["sat_prob"]
    tts_est = time_to_solution(p_success, total_time, p_target=0.99)

    print("\n=== Summary ===")
    print(f"Final sat_prob ≈ {p_success:.4f}")
    print(f"Total runtime (single run) ≈ {total_time:.3f} s")
    print(f"Estimated TTS (p_target=0.99) ≈ {tts_est:.3f} s")

## Upgraded