In [None]:
# End-to-end trainable quantum-kernel SVM with Qiskit + PyTorch
# - Learns SVM params (beta, b) and quantum feature map params theta
# - Backprop: parameter-shift through the kernel into the circuit
import math, numpy as np, torch
from dataclasses import dataclass
from typing import Tuple

from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector

from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


torch.manual_seed(0)
np.random.seed(0)

# -----------------------------
# 1) Toy data (binary labels ±1)
# -----------------------------
X, y = make_moons(n_samples=1000, noise=0.9, random_state=0)
X = StandardScaler().fit_transform(X).astype(np.float32)
y = 2.0*y.astype(np.float32) - 1.0  # {0,1} -> {-1,+1}

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=0, stratify=y)
Xtr = torch.tensor(Xtr)            # [n_train, d]
Xte = torch.tensor(Xte)            # [n_test, d]
ytr = torch.tensor(ytr)            # [n_train]
yte = torch.tensor(yte)            # [n_test]

n_train, d = Xtr.shape
n_qubits = d
n_layers = 2   # feel free to increase (cost grows with params × kernel evals)

# ----------------------------------------------------
# 2) A parameterized Qiskit feature map  |phi_theta(x)>
#    - Fixed data encoders (RX(x_i), RZ(x_i^2))
#    - Trainable single-qubit layers + RZZ ring per layer
#    All trainables are pure rotation angles -> clean shift rule.
# ----------------------------------------------------
@dataclass
class AnsatzShape:
    n_layers: int
    n_qubits: int

    @property
    def sizes(self):
        # For each layer and qubit: RY(theta_y), RZ(theta_z)
        # And an RZZ entangler per (layer, qubit) along a ring
        return {
            "theta_y": (self.n_layers, self.n_qubits),
            "theta_z": (self.n_layers, self.n_qubits),
            "theta_zz": (self.n_layers, self.n_qubits),
        }

    @property
    def total_params(self):
        L, Q = self.n_layers, self.n_qubits
        return L * Q * 3

def unflatten_theta(theta_flat: np.ndarray, shape: AnsatzShape):
    L, Q = shape.n_layers, shape.n_qubits
    assert theta_flat.size == shape.total_params
    k = 0
    th_y = theta_flat[k:k+L*Q].reshape(L, Q); k += L*Q
    th_z = theta_flat[k:k+L*Q].reshape(L, Q); k += L*Q
    th_zz = theta_flat[k:k+L*Q].reshape(L, Q)
    return th_y, th_z, th_zz

def build_feature_map(x_vec: np.ndarray, theta_flat: np.ndarray, shape: AnsatzShape) -> QuantumCircuit:
    """Return a circuit preparing |phi_theta(x)> from |0...0>."""
    L, Q = shape.n_layers, shape.n_qubits
    th_y, th_z, th_zz = unflatten_theta(theta_flat, shape)

    qc = QuantumCircuit(Q)
    # Fixed data encoding
    for q in range(Q):
        qc.rx(float(x_vec[q]), q)     # RX(x_i)
        qc.rz(float(x_vec[q]**2), q)  # RZ(x_i^2)

    # Trainable layers
    for l in range(L):
        for q in range(Q):
            qc.ry(float(th_y[l, q]), q)
            qc.rz(float(th_z[l, q]), q)
        # Entangling ring
        for q in range(Q):
            r = (q + 1) % Q
            qc.rzz(float(th_zz[l, q]), q, r)
    return qc

def statevector_from_circuit(qc: QuantumCircuit) -> np.ndarray:
    return Statevector.from_instruction(qc).data  # complex vector (2^n,)

def batch_states(X_np: np.ndarray, theta_np: np.ndarray, shape: AnsatzShape) -> np.ndarray:
    """Return array [N, 2^n] of complex statevectors for all samples."""
    return np.stack([statevector_from_circuit(build_feature_map(x, theta_np, shape))
                     for x in X_np], axis=0)

def kernel_from_states(S: np.ndarray) -> np.ndarray:
    """K_ij = |<phi_i|phi_j>|^2 from S [N, D]."""
    G = S @ S.conj().T             # [N,N] complex Gram
    K = np.abs(G)**2               # fidelity kernel
    # Stabilizer to keep PSD and help optimization:
    K += 1e-6 * np.eye(K.shape[0], dtype=K.dtype)
    return K.real.astype(np.float32)

# ---------------------------------------------------------
# 3) Autograd: Quantum kernel forward + parameter-shift back
# ---------------------------------------------------------
class QKernelShift(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X: torch.Tensor, theta: torch.Tensor, n_layers_q: int):
        # Save static config
        shape = AnsatzShape(n_layers=int(n_layers_q), n_qubits=X.shape[1])
        ctx.shape = shape

        # Compute kernel K(theta) on CPU with Qiskit
        X_np = X.detach().cpu().numpy()
        th_np = theta.detach().cpu().numpy()
        S = batch_states(X_np, th_np, shape)
        K = kernel_from_states(S)

        # For backward
        ctx.save_for_backward(X.detach(), theta.detach())
        return torch.from_numpy(K)  # [N,N] float32

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # dL/dtheta = sum_ij (dL/dK_ij) * (dK_ij/dtheta)
        X, theta = ctx.saved_tensors
        shape: AnsatzShape = ctx.shape
        X_np = X.cpu().numpy()
        th_np = theta.cpu().numpy().copy()
        gK = grad_output.detach().cpu().numpy()

        shift = math.pi/2
        P = th_np.size
        grad_theta = np.zeros_like(th_np, dtype=np.float64)

        # Helper to compute K for a given theta
        def K_of(theta_vec: np.ndarray) -> np.ndarray:
            S = batch_states(X_np, theta_vec, shape)
            return kernel_from_states(S)

        # Precompute K at current theta? Not necessary for central diff, but optional.

        # Parameter-shift loop (exact for single-parameter rotation gates)
        for p in range(P):
            th_plus = th_np.copy();   th_plus[p]  += shift
            th_minus = th_np.copy();  th_minus[p] -= shift

            Kp = K_of(th_plus)
            Km = K_of(th_minus)
            dK = 0.5*(Kp - Km)       # (f(theta+π/2) - f(theta-π/2))/2

            grad_theta[p] = np.sum(gK * dK, dtype=np.float64)

        # No gradients for X (not learning the raw inputs)
        grad_X = None
        return grad_X, torch.from_numpy(grad_theta.astype(np.float32)), None

def quantum_kernel_matrix(X: torch.Tensor, theta: torch.Tensor, n_layers_q: int) -> torch.Tensor:
    return QKernelShift.apply(X, theta, n_layers_q)  # [N,N]

# --------------------------------------------
# 4) SVM-in-primal parameters (β in R^n, b∈R)
# --------------------------------------------
shape = AnsatzShape(n_layers=n_layers, n_qubits=n_qubits)
theta = torch.nn.Parameter(0.3*torch.randn(shape.total_params))  # trainable feature map
beta = torch.nn.Parameter(torch.zeros(n_train))                  # SVM coeffs
b = torch.nn.Parameter(torch.zeros(()))                          # bias

C = 5.0     # soft-margin weight
lr = 0.05
optim = torch.optim.Adam([theta, beta, b], lr=lr)

def svm_primal_loss(K: torch.Tensor, y: torch.Tensor, beta: torch.Tensor, b: torch.Tensor, C: float):
    # scores s = K β + b
    s = K @ beta + b                  # [n]
    hinge = torch.clamp(1.0 - y * s, min=0.0)
    reg = 0.5 * (beta @ (K @ beta))   # 0.5 * β^T K β
    return reg + C * hinge.mean(), s

# -----------------------------------------
# 5) Training loop (backprop end-to-end)
# -----------------------------------------
for step in range(200):
    optim.zero_grad()

    K_tr = quantum_kernel_matrix(Xtr, theta, n_layers)  # backpropagates via parameter-shift
    loss, scores = svm_primal_loss(K_tr, ytr, beta, b, C)

    loss.backward()
    optim.step()

    if (step+1) % 40 == 0:
        with torch.no_grad():
            # Evaluate cross-kernel K(Xte, Xtr) to classify test points
            # (no gradient needed)
            # Build states once per set for efficiency
            Xtr_np = Xtr.cpu().numpy()
            Xte_np = Xte.cpu().numpy()
            th_np = theta.detach().cpu().numpy()
            S_tr = batch_states(Xtr_np, th_np, shape)
            S_te = batch_states(Xte_np, th_np, shape)
            G_te_tr = S_te @ S_tr.conj().T
            K_te_tr = (np.abs(G_te_tr)**2).astype(np.float32)

            preds = np.sign(K_te_tr @ beta.detach().cpu().numpy() + b.detach().cpu().numpy())
            acc = (preds.squeeze() == yte.cpu().numpy()).mean()
        print(f"step {step+1:3d} | loss {loss.item():.4f} | test acc {acc*100:.1f}%")

print("Done.")
