In [73]:
import qiskit as qsk, numpy as np, matplotlib.pyplot as plt
from qiskit.quantum_info import SparsePauliOp, Statevector, Clifford
from typing import Callable

def random_thetas(shape: "shape"):
    possible_theta_values = np.pi / 2 * np.arange(4)
    return np.random.choice(possible_theta_values, shape)

def evolveSPO(op: SparsePauliOp, other: Clifford) -> SparsePauliOp:
    return SparsePauliOp(
        op.paulis.evolve(other)
    )

class ParametricCircuit:
    def __init__(self):
        self.n_parameters: int = None
        self.params_shape: "shape" = None
    
    def __call__(self, thetas: np.ndarray, *args, **kwds) -> qsk.QuantumCircuit:
        pass

class Model:
    def __init__(self,
                 parametric_circuit: ParametricCircuit,
                 encoding: Callable[..., Statevector],
                 observable: SparsePauliOp
                 ):
        self.param_circ = parametric_circuit
        self.enc = encoding
        self.obs = observable

    def __call__(self, thetas: np.ndarray, input) -> float:
        psi = self.enc(input)
        qc = self.param_circ(thetas)
        evolved_ops = evolveSPO(self.obs, Clifford(qc))
        return psi.expectation_value(evolved_ops)
    
    def param_shift_rule_grad(self, thetas: np.ndarray, input, index: tuple[int]) -> float:
        # print(f"Relevant theta = {thetas[index]}")
        thetas[index] += np.pi / 2
        # print(f"Relevant theta = {thetas[index]}")
        value_plus = self.__call__(thetas, input)
        # print(f"Value plus = {value_plus}")
        thetas[index] -= np.pi
        # print(f"Relevant theta = {thetas[index]}")
        value_minus = self.__call__(thetas, input)
        thetas[index] += np.pi/2
        # print(f"Value minus = {value_minus}")
        return (value_plus - value_minus) / 2
    
    def empirical_NTK(self, thetas: np.ndarray, input1, input2) -> float:
        grad = self.param_shift_rule_grad
        # print(
        #     np.array([
        #         self.param_shift_rule_grad(thetas, input1, index) #* grad(thetas, input2, index)
        #         for index in np.ndindex(thetas.shape)
        #     ])
        # )
        return np.sum([
            grad(thetas, input1, index) * grad(thetas, input2, index)
            for index in np.ndindex(thetas.shape)
        ])
    
    def analytic_NTK(self, input1, input2, n_shots):
        return np.mean([
            self.empirical_NTK(random_thetas(self.param_circ.params_shape),
                               input1, input2)
            for _ in range(n_shots)
        ])


class SimpleTestCircuit(ParametricCircuit):
    def __init__(self, n_qubits: int):
        self.n_parameters: int = n_qubits
        self.params_shape: tuple[int] = (n_qubits)
    
    def __call__(self, thetas: np.ndarray) -> qsk.QuantumCircuit:
        qc = qsk.QuantumCircuit(self.n_parameters)
        for i in range(self.n_parameters): qc.rx(thetas[i], i)
        return qc

In [74]:
def basis_encoding(int: int, n_qubits: int) -> Statevector:
    return Statevector.from_int(int, 2**n_qubits)

def sumZ(n_qubits: int) -> SparsePauliOp:
    return SparsePauliOp.from_sparse_list(
        [("Z", [i], 1) for i in range(n_qubits)],
        num_qubits = n_qubits
    )

n_qubits = 5

simple_model = Model(
    SimpleTestCircuit(n_qubits),
    lambda i: basis_encoding(i, n_qubits),
    sumZ(n_qubits)
)

In [77]:
thetas = np.zeros(n_qubits)
thetas[0] = np.pi/2
print(thetas)
display(simple_model.param_circ(thetas).draw("text"))
print(simple_model(thetas, 0))
print(simple_model.param_shift_rule_grad(thetas, 0, (0)))
print(f"theatas after \n{thetas}")
print(simple_model.empirical_NTK(thetas, 0, 0))

[1.57079633 0.         0.         0.         0.        ]


(4+0j)
(-1+0j)
theatas after 
[1.57079633 0.         0.         0.         0.        ]
(1+0j)


In [82]:
simple_model.analytic_NTK(0,1, 500)

np.complex128(1.53+0j)