In [52]:
import pennylane as qml
import numpy as np
import jax

from typing import List, Protocol, Callable, Tuple
from typing import List, Tuple, Callable

# Configure JAX
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy


### Transverse Hamiltonian ###

In [53]:
def create_transverse_hamiltonian(num_wires: int, J: float = 1.0, h: float = 0.5):
    coeffs = []
    obs = []

    # ZZ interactions
    for i in range(num_wires):
        for j in range(i + 1, num_wires):
            coeffs.append(-J)
            obs.append(qml.PauliZ(i) @ qml.PauliZ(j))

    # X interactions
    for i in range(num_wires):
        coeffs.append(-h)
        obs.append(qml.PauliX(i))

    hamiltonian = qml.Hamiltonian(coeffs, obs)
    e_min = min(qml.eigvals(hamiltonian))
    e_max = max(qml.eigvals(hamiltonian))
    # hamiltonian = qml.SparseHamiltonian(hamiltonian.sparse_matrix(), range(num_wires))
    return hamiltonian, e_min, e_max

### Helper methods for pauli coeff extract ###

In [54]:
from scipy.linalg import logm

def extract_pauli_coeff(U,threshold=1e-6):
    # Define the Pauli matrices
    I = np.eye(2)
    sigma_x = np.array([[0, 1], [1, 0]])
    sigma_y = np.array([[0, -1j], [1j, 0]])
    sigma_z = np.array([[1, 0], [0, -1]])

    # Compute the Hermitian matrix H from the matrix logarithm of U
    H = logm(U)

    # Define the extended Pauli basis for two qubits
    pauli_basis = {
        'IX': np.kron(I, sigma_x), 'IY': np.kron(I, sigma_y), 'IZ': np.kron(I, sigma_z),
        'XI': np.kron(sigma_x, I), 'XX': np.kron(sigma_x, sigma_x), 'XY': np.kron(sigma_x, sigma_y), 'XZ': np.kron(sigma_x, sigma_z),
        'YI': np.kron(sigma_y, I), 'YX': np.kron(sigma_y, sigma_x), 'YY': np.kron(sigma_y, sigma_y), 'YZ': np.kron(sigma_y, sigma_z),
        'ZI': np.kron(sigma_z, I), 'ZX': np.kron(sigma_z, sigma_x), 'ZY': np.kron(sigma_z, sigma_y), 'ZZ': np.kron(sigma_z, sigma_z)
    }

    # Calculate the coefficients
    # coefficients = {name: np.trace(H @ Pk)/ (4j) for name, Pk in pauli_basis.items()}
    coefficients = {}
    for key, Pk in pauli_basis.items():
        coeff = np.trace(H @ Pk) / 4
        # Set coefficients below the threshold to zero
        if abs(coeff.real) < threshold:
            coeff = complex(0, coeff.imag)
        if abs(coeff.imag) < threshold:
            coeff = complex(coeff.real, 0)
        coefficients[key] = coeff
    return coefficients

from scipy.linalg import logm

def extract_pauli_coeff(U,threshold=1e-6):
    # Define the Pauli matrices
    I = np.eye(2)
    sigma_x = np.array([[0, 1], [1, 0]])
    sigma_y = np.array([[0, -1j], [1j, 0]])
    sigma_z = np.array([[1, 0], [0, -1]])

    # Compute the Hermitian matrix H from the matrix logarithm of U
    H = logm(U)

    # Define the extended Pauli basis for two qubits
    pauli_basis = {
        'IX': np.kron(I, sigma_x), 'IY': np.kron(I, sigma_y), 'IZ': np.kron(I, sigma_z),
        'XI': np.kron(sigma_x, I), 'XX': np.kron(sigma_x, sigma_x), 'XY': np.kron(sigma_x, sigma_y), 'XZ': np.kron(sigma_x, sigma_z),
        'YI': np.kron(sigma_y, I), 'YX': np.kron(sigma_y, sigma_x), 'YY': np.kron(sigma_y, sigma_y), 'YZ': np.kron(sigma_y, sigma_z),
        'ZI': np.kron(sigma_z, I), 'ZX': np.kron(sigma_z, sigma_x), 'ZY': np.kron(sigma_z, sigma_y), 'ZZ': np.kron(sigma_z, sigma_z)
    }

    # Calculate the coefficients
    # coefficients = {name: np.trace(H @ Pk)/ (4j) for name, Pk in pauli_basis.items()}
    coefficients = {}
    for key, Pk in pauli_basis.items():
        coeff = np.trace(H @ Pk) / 4
        # Set coefficients below the threshold to zero
        if abs(coeff.real) < threshold:
            coeff = complex(0, coeff.imag)
        if abs(coeff.imag) < threshold:
            coeff = complex(coeff.real, 0)
        coefficients[key] = coeff
    return coefficients

def two_qubit_decomp(params, wires):
    """Implement an arbitrary SU(4) gate on two qubits
    using the decomposition from Theorem 5 in
    https://arxiv.org/pdf/quant-ph/0308006.pdf"""
    
    i, j = wires    
    qml.RZ(params[0], wires=i)
    qml.RY(params[1], wires=i)
    qml.RZ(params[2], wires=i)
    qml.RZ(params[3], wires=j)
    qml.RY(params[4], wires=j)
    qml.RZ(params[5], wires=j)
    qml.CNOT(wires=[j, i])
    qml.RZ(params[6], wires=i)
    qml.RY(params[7], wires=j)
    qml.CNOT(wires=[i, j])
    qml.RY(params[8], wires=j)
    qml.CNOT(wires=[j, i])
    qml.RZ(params[9], wires=j)
    qml.RY(params[10], wires=j)
    qml.RZ(params[11], wires=j)
    qml.RZ(params[12], wires=i)
    qml.RY(params[13], wires=i)
    qml.RZ(params[14], wires=i)

dev = qml.device('default.qubit', wires=2)

# Wrap the circuit in a QNode
@qml.qnode(dev)
def two_qubit_decomp_circuit(params):
    two_qubit_decomp(params, wires=[0, 1])
    return qml.state()


### Define Two Qubit Decomp and SU4 Implementation ###

In [55]:
class Operation(Protocol):
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        ...

class TwoQubitDecomp:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        """Implement an arbitrary SU(4) gate on two qubits
        using the decomposition from Theorem 5 in
        https://arxiv.org/pdf/quant-ph/0308006.pdf"""
        i, j = wires
        qml.RZ(params[0], wires=i)
        qml.RY(params[1], wires=i)
        qml.RZ(params[2], wires=i)

        qml.RZ(params[3], wires=j)
        qml.RY(params[4], wires=j)
        qml.RZ(params[5], wires=j)

        qml.CNOT(wires=[j, i])
        qml.RZ(params[6], wires=i)
        qml.RY(params[7], wires=j)
        qml.CNOT(wires=[i, j])
        qml.RY(params[8], wires=j)
        qml.CNOT(wires=[j, i])

        qml.RZ(params[9], wires=j)
        qml.RY(params[10], wires=j)
        qml.RZ(params[11], wires=j)

        qml.RZ(params[12], wires=i)
        qml.RY(params[13], wires=i)
        qml.RZ(params[14], wires=i)

class PauliRotSequence:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        qml.ArbitraryUnitary(params, wires=wires)

class SpecialUnitaryGate:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        qml.SpecialUnitary(params, wires=wires)


### Ansatz ###

In [56]:

# Initialize parameters
ansatz_dev = qml.device("default.qubit", wires=2)

def parameterized_circuit(params, operation: Operation, hamiltonian):
      operation.apply(params, [0,1])
      return qml.expval(hamiltonian)


#  circuitLauncher
def launch_circuit(hamiltonian_func: Callable[[], Tuple[qml.Hamiltonian, float]]):
   hamiltonian, e_min, e_max = hamiltonian_func()
   def circuit_wrapped(params, operation: Operation):
      return parameterized_circuit(params, operation, hamiltonian)

   qnode = qml.QNode(circuit_wrapped, ansatz_dev, interface="jax")
   return qnode, e_min, e_max

### SU4 TO DECOMP ###

In [57]:
def su4_to_decomp(su4_params, wires=[0,1]):
    # Create a SU(4) unitary matrix from the given parameters
   unitary = qml.SpecialUnitary.compute_matrix(theta = su4_params, num_wires=2)
   # kak_Decomposition
   decomp = qml.ops.two_qubit_decomposition(unitary, wires=wires)
   decomp_params = []
   # adjust the precision
   for gate in decomp:
      if hasattr(gate, 'parameters'):
         decomp_params.extend([float(p) for p in gate.parameters])
   return decomp_params

### DECOMP TO SU4 ###

In [58]:
def decomp_to_su4(decomp_params, wires=[0,1]):
    # Extract the unitary matrix
    U_rot = qml.matrix(two_qubit_decomp_circuit)(decomp_params)

    coeff_Urot = extract_pauli_coeff(U_rot)

    coeff_SU4  = [elem/1j for elem in list(coeff_Urot.values())]

    return np.array(coeff_SU4)#np.array(coeff_SU4)#coeff_SU4


In [59]:
def verify(decomp_params, su4_params):
   qnode, _, _ = launch_circuit(lambda: create_transverse_hamiltonian(2))
   cost1 = qnode(decomp_params, TwoQubitDecomp())
   cost2 = qnode(su4_params, SpecialUnitaryGate())

   print(f"Cost1: {cost1}")
   print(f"Cost2: {cost2}")
   print(f"Difference: {cost1 - cost2}")

In [60]:
# TwoQubitToSU4
decomp_params = jnp.array(np.random.randn(15))
su4_params = decomp_to_su4(decomp_params=decomp_params)

verify(decomp_params, su4_params)


Cost1: -0.19603727302074647
Cost2: -0.19603727302074692
Difference: 4.440892098500626e-16


In [61]:
# SU4ToTwoQubit

# TwoQubitToSU4
su4_params = jnp.array(np.random.randn(15))

decomp_params = su4_to_decomp(su4_params=su4_params)

verify(decomp_params, su4_params)


Cost1: -0.12104634849046886
Cost2: -0.1210463484904688
Difference: -5.551115123125783e-17
