In [None]:
"""
Qiskit implementation: Fisher ratio loss for fidelity kernels (QC-computable)
-----------------------------------------------------------------------------

This script trains the parameters θ of a parametric quantum feature map U_θ(x)
by **maximizing the Fisher ratio** of the induced *fidelity kernel*
K_ij = |⟨ψ_θ(x_i)|ψ_θ(x_j)⟩|^2, computed via an **echo overlap circuit**
(U_θ(x_j)^† U_θ(x_i)) and measured with a Sampler (or Aer fallback).

Highlights
- Purely fidelity-based losses: only needs probabilities of measuring |0…0⟩.
- Parameter-shift gradients **on hardware**: each ∂K_ij/∂θ_k comes from 4
  shifted fidelity evaluations on the echo circuit (±π/2 applied to either
  the "i" or the "j" side of the overlap).
- Works mini-batch: compute loss from batch sub-kernels
- Two-class & multi-class Fisher criteria

Dependencies
- qiskit >= 0.46 (Terra 1.x) recommended
- qiskit-aer (optional but recommended for local simulation)

Quick start
-----------
1) Prepare your data X (shape [N, d]) and labels y (ints 0..C-1).
2) Run `python qiskit_fisher_fidelity_training.py` to see a demo with
   synthetic data on an Aer Sampler.
3) To plug your own feature map, implement `build_feature_map()` with your
   U_θ(x) and update `FeatureMapConfig` as needed.
"""
from __future__ import annotations

import math
import random
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np

from qiskit import QuantumCircuit
from qiskit.circuit import Parameter, ParameterVector

# Try modern Sampler first; fall back to Aer primitives or AerSimulator counts
try:
    from qiskit_aer.primitives import Sampler as AerSampler
    _HAVE_AER_SAMPLER = True
except Exception:
    _HAVE_AER_SAMPLER = False

try:
    from qiskit.primitives import Sampler
    _HAVE_TERRA_SAMPLER = True
except Exception:
    _HAVE_TERRA_SAMPLLER = False


# ------------------------------ Feature Map ------------------------------ #
@dataclass
class FeatureMapConfig:
    n_qubits: int
    data_dim: int
    n_layers: int = 2
    name: str = "EchoFidelityDefault"


def build_feature_map(cfg: FeatureMapConfig,
                      theta_i: ParameterVector,
                      x_i: ParameterVector) -> QuantumCircuit:
    """Construct a simple, expressive, trainable feature map U_θ(x).

    Layout:
      For L layers:
        - Data encoding (RZ(α x), RX(β x) with learned α,β via θ)
        - Trainable single-qubit Ry/Rz
        - Linear entanglement with CZ + trainable Rz on controls

    Parameters are consumed left-to-right, top-to-bottom for determinism.
    This is just a **reasonable default**—replace with your physics-informed
    ansatz if you have one.
    """
    n = cfg.n_qubits
    L = cfg.n_layers
    qc = QuantumCircuit(n, name=f"U_{cfg.name}")

    # We'll allocate parameters as follows per layer:
    # per-qubit: [alpha_k, beta_k, ry_k, rz_k] → 4*n
    # entanglers between q and q+1: [phi_q] → (n-1)
    # total per-layer: 4*n + (n-1)
    per_layer = 4 * n + (n - 1)
    assert len(theta_i) >= L * per_layer, (
        f"theta length {len(theta_i)} < required {L*per_layer}")

    th_idx = 0
    for layer in range(L):
        # Data encoding with learned scalings (alpha, beta)
        for q in range(n):
            alpha = theta_i[th_idx]; th_idx += 1
            beta  = theta_i[th_idx]; th_idx += 1
            # Map data dim -> qubits by modulo (simple, works well in practice)
            xparam = x_i[q % cfg.data_dim]
            qc.rz(alpha * xparam, q)
            qc.rx(beta  * xparam, q)

        # Trainable single-qubit block
        for q in range(n):
            ry = theta_i[th_idx]; th_idx += 1
            rz = theta_i[th_idx]; th_idx += 1
            qc.ry(ry, q)
            qc.rz(rz, q)

        # Linear entanglement
        for q in range(n - 1):
            phi = theta_i[th_idx]; th_idx += 1
            qc.cz(q, q + 1)
            qc.rz(phi, q + 1)

    return qc


# --------------------------- Echo overlap circuit ------------------------ #
@dataclass
class EchoCircuit:
    cfg: FeatureMapConfig
    theta_i: ParameterVector
    x_i: ParameterVector
    theta_j: ParameterVector
    x_j: ParameterVector
    circuit: QuantumCircuit


def build_echo_circuit(cfg: FeatureMapConfig) -> EchoCircuit:
    """Build parameterized echo circuit for fidelity F_ij = |⟨ψ_i|ψ_j⟩|^2.

    We construct U_θi(x_i) followed by [U_θj(x_j)]^†, then measure |0..0⟩.
    Separate ParameterVectors (θi, θj, x_i, x_j) let us shift parameters
    on one side at a time for parameter-shift gradients.
    """
    n = cfg.n_qubits
    theta_i = ParameterVector("theta_i", length=(4 * n + (n - 1)) * cfg.n_layers)
    theta_j = ParameterVector("theta_j", length=len(theta_i))
    x_i = ParameterVector("x_i", length=cfg.data_dim)
    x_j = ParameterVector("x_j", length=cfg.data_dim)

    U_i = build_feature_map(cfg, theta_i, x_i)
    U_j = build_feature_map(cfg, theta_j, x_j)

    qc = QuantumCircuit(n, name="EchoOverlap")
    qc.compose(U_i, inplace=True)
    qc.compose(U_j.inverse(), inplace=True)
    qc.measure_all()

    return EchoCircuit(cfg, theta_i, x_i, theta_j, x_j, qc)


# --------------------------- Sampler helper ------------------------------ #
class ProbZeroEstimator:
    """Helper to get Pr(0…0) from a circuit with bound parameters.

    Uses Sampler if available; otherwise tries Aer Sampler; finally falls back
    to qasm_simulator via transpile+run. The interface exposes a single method
    `batch_prob_zero(circuits, param_values)` returning a list of probabilities.
    """
    def __init__(self, shots: Optional[int] = 4096):
        self.shots = shots
        self._mode = None
        self._sampler = None

        if _HAVE_AER_SAMPLER:
            try:
                self._sampler = AerSampler(shots=shots)
                self._mode = "aer_sampler"
                return
            except Exception:
                pass

        if _HAVE_TERRA_SAMPLER:
            try:
                self._sampler = Sampler()
                self._mode = "terra_sampler"
                return
            except Exception:
                pass

        # Last resort: AerSimulator QASM counts
        try:
            from qiskit_aer import Aer
            from qiskit import transpile
            self._backend = Aer.get_backend("qasm_simulator")
            self._transpile = transpile
            self._mode = "qasm"
        except Exception as e:
            raise RuntimeError(
                "No available Sampler or Aer qasm_simulator found.\n"
                "Install qiskit-aer or use a provider Sampler.") from e

    def batch_prob_zero(self,
                        circuits: List[QuantumCircuit],
                        param_values: List[Dict[Parameter, float]]) -> List[float]:
        assert len(circuits) == len(param_values)
        probs: List[float] = []

        if self._mode in {"aer_sampler", "terra_sampler"}:
            # Sampler path
            # Group into a single run for efficiency
            jobs = self._sampler.run(circuits=circuits,
                                     parameter_values=[list(pv.values()) for pv in param_values],
                                     parameter_binds=param_values if hasattr(self._sampler, 'run') else None,
                                     shots=self.shots)
            res = jobs.result()
            # API differences: res.quasi_dists or .quasi_dists
            dists = getattr(res, "quasi_dists", None)
            if dists is None:
                dists = res.quasi_dists
            for i, dist in enumerate(dists):
                n = circuits[i].num_qubits
                key = "0" * n
                p0 = float(dist.get(key, 0.0))
                probs.append(p0)
            return probs

        # qasm path
        from qiskit import execute
        for c, bind in zip(circuits, param_values):
            cb = c.bind_parameters(bind)
            tc = self._transpile(cb, self._backend)
            job = execute(tc, backend=self._backend, shots=self.shots)
            counts = job.result().get_counts()
            n = c.num_qubits
            p0 = counts.get("0" * n, 0) / self.shots
            probs.append(float(p0))
        return probs


# ----------------------- Fidelity and gradients -------------------------- #
class FidelityKernel:
    """Compute fidelities K_ij via echo circuits and parameter-shift grads."""

    def __init__(self, cfg: FeatureMapConfig, shots: Optional[int] = 4096):
        self.cfg = cfg
        self.echo = build_echo_circuit(cfg)
        self.estimator = ProbZeroEstimator(shots=shots)

    def _bind(self,
              x_i: Sequence[float], theta_i: Sequence[float],
              x_j: Sequence[float], theta_j: Sequence[float]) -> Dict[Parameter, float]:
        bind = {}
        # Data
        for p, v in zip(self.echo.x_i, x_i):
            bind[p] = float(v)
        for p, v in zip(self.echo.x_j, x_j):
            bind[p] = float(v)
        # Params
        for p, v in zip(self.echo.theta_i, theta_i):
            bind[p] = float(v)
        for p, v in zip(self.echo.theta_j, theta_j):
            bind[p] = float(v)
        return bind

    def fidelity(self,
                 x_i: Sequence[float], x_j: Sequence[float],
                 theta: Sequence[float]) -> float:
        """Compute F_ij with θ_i = θ_j = θ."""
        bind = self._bind(x_i, theta, x_j, theta)
        p0 = self.estimator.batch_prob_zero([self.echo.circuit], [bind])[0]
        return float(p0)

    def fidelities(self,
                   pairs: List[Tuple[int, int]],
                   X: np.ndarray,
                   theta: Sequence[float]) -> Dict[Tuple[int, int], float]:
        """Batch compute fidelities for selected index pairs (i,j)."""
        circs, binds = [], []
        for i, j in pairs:
            bind = self._bind(X[i], theta, X[j], theta)
            binds.append(bind)
            circs.append(self.echo.circuit)
        probs = self.estimator.batch_prob_zero(circs, binds)
        return {(i, j): float(p) for (i, j), p in zip(pairs, probs)}

    # ----- Parameter-shift gradients ----- #
    def fidelity_param_shift_grad(self,
                                  x_i: Sequence[float], x_j: Sequence[float],
                                  theta: Sequence[float], k: int,
                                  shift: float = math.pi / 2.0) -> float:
        """∂_θk F_ij using 4 shifted echo evaluations (shift on i- and j-sides)."""
        theta = np.asarray(theta, dtype=float)
        th_plus = theta.copy(); th_minus = theta.copy()

        # Shift on the "i" side (θ_i)
        th_plus[k] += shift
        th_minus[k] -= shift
        bind_p_i = self._bind(x_i, th_plus, x_j, theta)
        bind_m_i = self._bind(x_i, th_minus, x_j, theta)

        # Shift on the "j" side (θ_j)
        bind_p_j = self._bind(x_i, theta, x_j, th_plus)
        bind_m_j = self._bind(x_i, theta, x_j, th_minus)

        circs = [self.echo.circuit] * 4
        binds = [bind_p_i, bind_m_i, bind_p_j, bind_m_j]
        p0_p_i, p0_m_i, p0_p_j, p0_m_j = self.estimator.batch_prob_zero(circs, binds)
        return 0.5 * ((p0_p_i - p0_m_i) + (p0_p_j - p0_m_j))

    def dK_dtheta(
        self,
        pairs: List[Tuple[int, int]],
        X: np.ndarray,
        theta: Sequence[float],
        param_indices: Optional[Sequence[int]] = None,
    ) -> Dict[int, float]:
        """Compute ∑_{(i,j)∈pairs} w_{ij} ∂K_ij/∂θ for each θ, given weights later.

        Returns a dict mapping k → a callable contribution; in practice we
        compute all ∂K_ij/∂θ_k and cache them so the trainer can combine with
        dL/dK weights. To save shots, pass a subset of `param_indices`.
        """
        if param_indices is None:
            param_indices = list(range(len(theta)))
        grads: Dict[int, Dict[Tuple[int, int], float]] = {k: {} for k in param_indices}
        for (i, j) in pairs:
            for k in param_indices:
                g = self.fidelity_param_shift_grad(X[i], X[j], theta, k)
                grads[k][(i, j)] = g
        # Pack as k → (pair → grad)
        return grads


# ------------------------ Fisher loss and utilities ---------------------- #
@dataclass
class FisherLossOut:
    loss: float
    SB: float
    SW: float
    B_mat: np.ndarray
    W_mat: np.ndarray


def fisher_mats(y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Build B and W matrices such that:
       tr(S_B)=⟨K,B⟩_F and tr(S_W)=⟨K,W⟩_F.
    """
    n = len(y)
    classes = np.unique(y)
    B = np.zeros((n, n), dtype=float)
    W = np.zeros((n, n), dtype=float)

    one = np.ones((n, n), dtype=float) / n
    for c in classes:
        idx = np.where(y == c)[0]
        b = len(idx)
        e = np.zeros((n, 1)); e[idx, 0] = 1.0
        A_c = (e @ e.T) / b
        P_c = np.diagflat(e)
        B += A_c
        W += (P_c - A_c)
    B -= one
    return B, W


def fisher_loss_from_pairs(
    n: int,
    pairs: List[Tuple[int, int]],
    K_pairs: Dict[Tuple[int, int], float],
    y: np.ndarray,
    eps: float = 1e-6,
) -> FisherLossOut:
    """Fisher loss using only the pairs we evaluated.

    We accumulate ⟨K,B⟩ and ⟨K,W⟩ using pair entries (symmetrised), assuming
    K_ii = 1 (pure-state fidelity). This supports mini-batch pair sampling.
    """
    B, W = fisher_mats(y)
    SB = 0.0
    SW = 0.0

    # Diagonals contribute to W: tr(K P_c) = number of points in each class
    # and to nothing in B beyond the -1/n term already absorbed in B.
    # We'll add diagonal contributions analytically where applicable.
    # Sum off-diagonals from provided pairs.
    used = set()
    for (i, j), v in K_pairs.items():
        if i == j:
            continue
        SB += B[i, j] * v + B[j, i] * v
        SW += W[i, j] * v + W[j, i] * v
        used.add((i, j)); used.add((j, i))

    # Add diagonal contributions
    for i in range(n):
        SB += B[i, i] * 1.0
        SW += W[i, i] * 1.0

    loss = -math.log(max(SB, 1e-12)) + math.log(SW + eps)
    return FisherLossOut(loss=loss, SB=SB, SW=SW, B_mat=B, W_mat=W)


# ------------------------------ Trainer --------------------------------- #
@dataclass
class TrainConfig:
    shots: int = 4096
    batch_size: int = 16
    max_iters: int = 200
    lr: float = 0.1
    grad_params_per_step: Optional[int] = None  # None → all params
    seed: int = 7


class FisherFidelityTrainer:
    def __init__(self, cfg: FeatureMapConfig, train_cfg: TrainConfig):
        self.cfg = cfg
        self.train_cfg = train_cfg
        self.kernel = FidelityKernel(cfg, shots=train_cfg.shots)
        self.rng = random.Random(train_cfg.seed)

    # ---- Pair selection ---- #
    def sample_pairs(self, idxs: List[int], y: np.ndarray,
                      intra_per_class: int = 4,
                      inter_pairs: int = 16) -> List[Tuple[int, int]]:
        pairs: List[Tuple[int, int]] = []
        # Intra-class pairs
        for c in np.unique(y[idxs]):
            class_idx = [i for i in idxs if y[i] == c]
            self.rng.shuffle(class_idx)
            for k in range(min(intra_per_class, max(0, len(class_idx) - 1))):
                pairs.append((class_idx[k], class_idx[-k - 1]))
        # Inter-class pairs
        others = idxs.copy()
        self.rng.shuffle(others)
        for _ in range(inter_pairs):
            i, j = self.rng.sample(others, 2)
            if y[i] != y[j]:
                pairs.append((i, j))
        # Deduplicate & sort tuples to keep (i,j) with i<=j for consistency
        norm = []
        seen = set()
        for (i, j) in pairs:
            a, b = (i, j) if i <= j else (j, i)
            if (a, b) not in seen and a != b:
                norm.append((a, b))
                seen.add((a, b))
        return norm

    def step(self, X: np.ndarray, y: np.ndarray, theta: np.ndarray) -> Tuple[float, Dict[str, float]]:
        n = len(X)
        # Select a mini-batch of indices
        idxs = list(range(n))
        self.rng.shuffle(idxs)
        idxs = idxs[: self.train_cfg.batch_size]

        # Sample informative pairs from the batch
        pairs = self.sample_pairs(idxs, y)

        # 1) Forward: fidelities on sampled pairs
        K_pairs = self.kernel.fidelities(pairs, X, theta)

        # 2) Loss + dL/dK coefficients
        out = fisher_loss_from_pairs(n, pairs, K_pairs, y)
        # Gradient wrt K is: -B/SB + W/(SW+eps)
        dL_dK = -out.B_mat / max(out.SB, 1e-12) + out.W_mat / (out.SW + 1e-6)

        # Restrict to pairs actually used
        used_pairs = pairs

        # 3) Parameter-shift gradients of K_ij
        if self.train_cfg.grad_params_per_step is None:
            param_indices = list(range(len(theta)))
        else:
            param_indices = self.rng.sample(list(range(len(theta))),
                                            k=min(self.train_cfg.grad_params_per_step, len(theta)))

        dK = {k: {} for k in param_indices}
        for (i, j) in used_pairs:
            for k in param_indices:
                dK[k][(i, j)] = self.kernel.fidelity_param_shift_grad(X[i], X[j], theta, k)

        # 4) Aggregate ∂L/∂θ_k = sum_{(i,j)} (dL/dK_ij) * (∂K_ij/∂θ_k) * sym_factor
        # Symmetry factor: since we only keep i<j pairs, dL/dK contributes as B_ij and W_ij twice in ⟨K,·⟩.
        grads = np.zeros_like(theta)
        for k in param_indices:
            s = 0.0
            for (i, j), g_ij in dK[k].items():
                coeff = dL_dK[i, j] + dL_dK[j, i]  # symmetric contribution
                s += coeff * g_ij
            grads[k] = s

        # 5) Parameter update (SGD)
        theta_new = theta - self.train_cfg.lr * grads

        metrics = {
            "SB": out.SB,
            "SW": out.SW,
            "pairs": float(len(used_pairs)),
            "grad_norm": float(np.linalg.norm(grads)),
        }
        return float(out.loss), metrics, theta_new

    def fit(self, X: np.ndarray, y: np.ndarray, theta0: Optional[np.ndarray] = None) -> np.ndarray:
        if theta0 is None:
            # Small random init to avoid symmetries
            theta0 = 0.01 * np.random.default_rng(self.train_cfg.seed).standard_normal(
                (4 * self.cfg.n_qubits + (self.cfg.n_qubits - 1)) * self.cfg.n_layers
            )
        theta = theta0.copy()

        for it in range(1, self.train_cfg.max_iters + 1):
            loss, metrics, theta = self.step(X, y, theta)
            if it % 10 == 0 or it == 1:
                print(f"iter {it:4d} | loss={loss:.5f} | SB={metrics['SB']:.5f} | SW={metrics['SW']:.5f} | "
                      f"pairs={int(metrics['pairs'])} | grad_norm={metrics['grad_norm']:.3f}")
        return theta


# ------------------------------- Demo ----------------------------------- #
def _toy_data(N: int = 40, d: int = 3, seed: int = 7) -> Tuple[np.ndarray, np.ndarray]:
    """Two interleaving blobs in R^d, labels 0/1."""
    rng = np.random.default_rng(seed)
    x0 = rng.normal(loc=+1.0, scale=0.6, size=(N // 2, d))
    x1 = rng.normal(loc=-1.0, scale=0.6, size=(N - N // 2, d))
    X = np.vstack([x0, x1]).astype(float)
    y = np.array([0] * (N // 2) + [1] * (N - N // 2), dtype=int)
    # Shuffle
    p = rng.permutation(N)
    return X[p], y[p]


def main_demo():
    N, d = 24, 3
    X, y = _toy_data(N=N, d=d)

    cfg = FeatureMapConfig(n_qubits=3, data_dim=d, n_layers=2)
    tcfg = TrainConfig(shots=2048, batch_size=16, max_iters=30, lr=0.3, grad_params_per_step=12)
    trainer = FisherFidelityTrainer(cfg, tcfg)

    theta = trainer.fit(X, y)

    # Evaluate full-kernel Fisher after training (dense, for demo)
    # Warning: O(N^2) echo calls; okay for small N.
    pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]
    K_pairs = trainer.kernel.fidelities(pairs, X, theta)
    out = fisher_loss_from_pairs(N, pairs, K_pairs, y)
    print("\nFinal Fisher stats: SB=", out.SB, "SW=", out.SW, "Loss=", out.loss)


if __name__ == "__main__":
    main_demo()
