In [601]:
import pennylane as qml
import numpy as np
import jax
import matplotlib.pyplot as plt
from io import StringIO
import os
import cirq
from cirq import KakDecomposition
from typing import Tuple
import json
# Configure JAX
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy


In [602]:
save = False

In [603]:
def ensure_directory_exists(directory):
    """Ensure that the specified directory exists, create it if it does not."""
    if not os.path.exists(directory):
        os.makedirs(directory)

def save_params(params, filename):
    """Save parameters to a file using JSON."""
    # Convert NumPy arrays to lists
    params_serializable = np.array(params).tolist()
    with open(filename, 'w') as f:
        json.dump(params_serializable, f)

def load_params(filename):
    """Load parameters from a file using JSON."""
    with open(filename, 'r') as f:
        params_serializable = json.load(f)
    # Convert lists back to NumPy arrays
    return jnp.array(params_serializable)



In [604]:
def get_initial_params(params_file):
    """Get initial parameters from the file if it exists, otherwise generate new ones."""
    if save and os.path.exists(params_file):
        # Load parameters if they exist
        return load_params(params_file)
    else:
        # Generate random parameters and save them
        init_params = jnp.array(np.random.randn(15))
        if save:
            save_params(init_params, params_file)
        return init_params


In [605]:
def create_matrix(params, num_wires):
    """Create a SU(4) unitary matrix from the given parameters."""
    return qml.SpecialUnitary.compute_matrix(theta = params, num_wires=num_wires)



In [606]:
def perform_kak_decomposition(gate, nqubits=2): 
    
    # Prepare Cirq circuit
    qubits = cirq.LineQubit.range(nqubits) 

    # Get the indices of the target gate
    # targets = gate.get_target_index_list() 

    cirq_gates = []
    
    gate = cirq.MatrixGate(gate).on(*qubits)

    # Perform KAK decomposition using Cirq
    kak_operations = cirq.kak_decomposition(gate) 
    
    # Get the tuple of coefficients(X,Y,Z)
    coef = kak_operations.interaction_coefficients
    
    # Get the tuple of gates before and after KAK decomposition
    initial_operations = kak_operations.single_qubit_operations_before 
    final_operations = kak_operations.single_qubit_operations_after 

    # Add initial operations to the list of Cirq gates
    cirq_gates.append(cirq.MatrixGate(initial_operations[0])(qubits[0]))
    cirq_gates.append(cirq.MatrixGate(initial_operations[1])(qubits[1]))
    
    # Add middle operations (XX+YY+ZZ) to the list of Cirq gates
    cirq_gates.append(coef[0])
    cirq_gates.append(coef[1])
    cirq_gates.append(coef[2])
    
    # Add final operations to the list of Cirq gates
    cirq_gates.append(cirq.MatrixGate(final_operations[0])(qubits[0]))
    cirq_gates.append(cirq.MatrixGate(final_operations[1])(qubits[1]))

    return cirq_gates

In [607]:
# def get_kak_vector(U):
#    qubits = [cirq.LineQubit(0), cirq.LineQubit(1)]
#    cirq.MatrixGate(U).on(*qubits)
#    return cirq.kak_vector(U)

In [608]:
def euler_decomposition(cirq_gate, nqubits=2):
    """
    Decompose a given single qubit cirq gate into a list of gates 
    using euler decomposition.
    
    Args:
        cirq_gate (cirq.ops.gate_operation.GateOperation) : Single qubit cirq gate to decompose
        nqubits (int) : Number of qubits in the quantum circuit

    Returns:
        cirq_gates (list) : List of decomposed Cirq gates
    """
    cirq_gates = []
    
    # Create a cirq circuit and append the gate
    qc = cirq.Circuit()
    qc.append(cirq_gate)
    
    # Get the unitary matrix of the input gate
    unitary_matrix = qc.unitary()
    
    # Prepare Cirq circuit
    circuit = cirq.LineQubit.range(nqubits)

    # Get the target qubit
    target_qubit = int(cirq_gate.qubits[0])
    
    # Perform Euler decomposition
    decomposition_angles = cirq.deconstruct_single_qubit_matrix_into_angles(unitary_matrix)
    
    # Add decomposed gates to the list
    cirq_gates.append(decomposition_angles[0])
    cirq_gates.append(decomposition_angles[1])
    cirq_gates.append(decomposition_angles[2])
    
    return cirq_gates

In [609]:
init_params1 = []
init_params2 = []
def run():
    directory = 'data/mapper/'
    params_file = os.path.join(directory, 'init_params.json')
    # Ensure the directory exists
    ensure_directory_exists(directory)

    # Get the initial parameters
    init_params1 = get_initial_params(params_file)
    # print(init_params1)
    # Create the SU(4) unitary matrix
    unitary = create_matrix(init_params1, 2)
    # Perform the KAK decomposition
    kak_operations = perform_kak_decomposition(unitary)
    # Perform Euler decomposition for initial and final operations separately
    initial_op0 = euler_decomposition(kak_operations[0])
    initial_op1 = euler_decomposition(kak_operations[1])
    final_op0 = euler_decomposition(kak_operations[5])
    final_op1 = euler_decomposition(kak_operations[6])
    #init_params2 = initial_op0+initial_op1+kak_operations[2:5]+final_op0+final_op1
    init_params2 = initial_op1+initial_op0+kak_operations[2:5][::-1]+final_op1+final_op0
    # print("params", init_params2 )
    return init_params1, init_params2
    #print("kak vector", get_kak_vector(unitary))
    # def __str__(self) -> str:
    #     xx = self.interaction_coefficients[0] * 4 / np.pi
    #     yy = self.interaction_coefficients[1] * 4 / np.pi
    #     zz = self.interaction_coefficients[2] * 4 / np.pi
    #     before0 = axis_angle(self.single_qubit_operations_before[0])
    #     before1 = axis_angle(self.single_qubit_operations_before[1])
    #     after0 = axis_angle(self.single_qubit_operations_after[0])
    #     after1 = axis_angle(self.single_qubit_operations_after[1])
    #     return (
    #         'KAK {\n'
    #         f'    xyz*(4/π): {xx:.3g}, {yy:.3g}, {zz:.3g}\n'
    #         f'    before: ({before0}) ⊗ ({before1})\n'
    #         f'    after: ({after0}) ⊗ ({after1})\n'
    #         '}'
    #     )

    # before0 = cirq.axis_angle(kak_decomp.single_qubit_operations_before[0])
    # before1 = cirq.axis_angle(kak_decomp.single_qubit_operations_before[1])
    # after0 = cirq.axis_angle(kak_decomp.single_qubit_operations_after[0])
    # after1 = cirq.axis_angle(kak_decomp.single_qubit_operations_after[1])
    # xx = kak_decomp.interaction_coefficients[0] * 4 / np.pi
    # yy = kak_decomp.interaction_coefficients[1] * 4 / np.pi
    # zz = kak_decomp.interaction_coefficients[2] * 4 / np.pi
    # print("A0", type(kak_decomp.single_qubit_operations_before[0]))
    # print("Euler Decomposition", euler_decomposition(cirq.MatrixGate(kak_decomp.single_qubit_operations_before[0]), 2))
    # print("KAK Decomposition Parameters:")
    # print(f"{before0}, {before1}, {xx}, {yy}, {zz}, {after0}, {after1}")
    # # print("a0:", a0)
    # print("a1:", a1)
    # print("B:", B)
    # print("b0:", b0)
    # print("b1:", b1)
    # print("Gamma:", gamma)



In [610]:
from typing import List, Protocol, Callable, Tuple
from typing import List, Tuple, Callable


# Configure JAX
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy

def create_transverse_hamiltonian(num_wires: int, J: float = 1.0, h: float = 0.5):
    coeffs = []
    obs = []

    # ZZ interactions
    for i in range(num_wires):
        for j in range(i + 1, num_wires):
            coeffs.append(-J)
            obs.append(qml.PauliZ(i) @ qml.PauliZ(j))

    # X interactions
    for i in range(num_wires):
        coeffs.append(-h)
        obs.append(qml.PauliX(i))

    hamiltonian = qml.Hamiltonian(coeffs, obs)
    e_min = min(qml.eigvals(hamiltonian))
    e_max = max(qml.eigvals(hamiltonian))
    # hamiltonian = qml.SparseHamiltonian(hamiltonian.sparse_matrix(), range(num_wires))
    return hamiltonian, e_min, e_max

def create_heisenberg_hamiltonian(num_wires: int, J: float = 1.0):
    coeffs = []
    obs = []

    # Heisenberg interactions (XX + YY + ZZ)
    for i in range(num_wires):
        for j in range(i + 1, num_wires):
            coeffs.extend([-J, -J, -J])
            obs.extend([qml.PauliX(i) @ qml.PauliX(j), qml.PauliY(i) @ qml.PauliY(j), qml.PauliZ(i) @ qml.PauliZ(j)])

    hamiltonian = qml.Hamiltonian(coeffs, obs)
    e_min = min(qml.eigvals(hamiltonian))
    e_max = max(qml.eigvals(hamiltonian))
    # hamiltonian = qml.SparseHamiltonian(hamiltonian.sparse_matrix(), range(num_wires))
    return hamiltonian, e_min, e_max

def create_longitudinal_ising_hamiltonian(num_wires: int, J: float = 1.0, hx: float = 0.5, hz: float = 0.5):
    coeffs = []
    obs = []

    # ZZ interactions
    for i in range(num_wires):
        for j in range(i + 1, num_wires):
            coeffs.append(-J)
            obs.append(qml.PauliZ(i) @ qml.PauliZ(j))

    # X interactions (transverse field)
    for i in range(num_wires):
        coeffs.append(-hx)
        obs.append(qml.PauliX(i))
    
    # Z interactions (longitudinal field)
    for i in range(num_wires):
        coeffs.append(-hz)
        obs.append(qml.PauliZ(i))

    hamiltonian = qml.Hamiltonian(coeffs, obs)
    e_min = min(qml.eigvals(hamiltonian))
    e_max = max(qml.eigvals(hamiltonian))
    # hamiltonian = qml.SparseHamiltonian(hamiltonian.sparse_matrix(), range(num_wires))

    return hamiltonian, e_min, e_max

In [611]:

class Operation(Protocol):
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        ...

class TwoQubitDecomp:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        """Implement an arbitrary SU(4) gate on two qubits
        using the decomposition from Theorem 5 in
        https://arxiv.org/pdf/quant-ph/0308006.pdf"""
        i, j = wires
        qml.Rot(*params[:3], wires=i)
        qml.Rot(*params[3:6], wires=j)
        qml.CNOT(wires=[j, i])
        qml.RZ(params[6], wires=i)
        qml.RY(params[7], wires=j)
        qml.CNOT(wires=[i, j])
        qml.RY(params[8], wires=j)
        qml.CNOT(wires=[j, i])
        qml.Rot(*params[9:12], wires=i)
        qml.Rot(*params[12:15], wires=j)

class PauliRotSequence:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        qml.ArbitraryUnitary(params, wires=wires)

class SpecialUnitaryGate:
    def apply(self, params: np.ndarray, wires: List[int]) -> None:
        qml.SpecialUnitary(params, wires=wires)

# Initialize parameters
num_wires = 6
loc = 2
learning_rate = 1e-4
num_steps = 1000
rebuild = False
# init_params = jax.numpy.array(init_params)
# # Define a function that returns a default value for the innermost dictionary
# def nested_dict():
#     return defaultdict(lambda: defaultdict(lambda: [None, None]))

# Initialize the top-level defaultdict
energies = {}

dev = qml.device("default.qubit", wires=num_wires)

def circuit(params, operation: Operation, hamiltonian):
        operation.apply(params, [0,1])
        return qml.expval(hamiltonian)


#  circuitLauncher
def launch_circuit(hamiltonian_func: Callable[[], Tuple[qml.Hamiltonian, float]]):
    hamiltonian, e_min, e_max = hamiltonian_func()
    def circuit_wrapped(params, operation: Operation):
        return circuit(params, operation, hamiltonian)

    qnode = qml.QNode(circuit_wrapped, dev, interface="jax")
    return qnode, e_min, e_max

In [742]:
if __name__ == "__main__":
    diffs = []
    for i in range(1):
      init_params1, init_params2 = run()
      print(init_params1)
      qnode, E_min, E_max = launch_circuit(lambda: create_transverse_hamiltonian(num_wires))
      cost1 = qnode(init_params1, SpecialUnitaryGate())
      init_params2[6] = 2* init_params1[6] - (np.pi/2)
      init_params2[7] =  (np.pi/2) - 2* init_params1[7]
      init_params2[8] = 2* init_params1[8] - (np.pi/2)
  
      cost2 = qnode(init_params2, TwoQubitDecomp())

      print(cost1 - cost2)
      #print(init_params1, init_params2)

[ 0.16497722 -0.95100118 -1.11113465 -1.86852571  0.6632018   1.48024618
  0.84686982  0.5990433   0.76125825  0.50805501 -0.3768341   0.53360786
  0.22889197  0.56115247  1.02434896]
-1.709252335095222




In [745]:
unitary = create_matrix(init_params1, 2)

decomp = qml.ops.two_qubit_decomposition(unitary, [0,1])
print(init_params1)
init_params2 = []
for i in range(18):
   for k in decomp[i].parameters: 
    init_params2.append(float(k))
print(init_params2)

cost1 = qnode(init_params1, SpecialUnitaryGate())
init_params2[6] = 2* init_params1[6] - (np.pi/2)
init_params2[7] =  (np.pi/2) - 2* init_params1[7]
init_params2[8] = 2* init_params1[8] - (np.pi/2)

cost2 = qnode(init_params2, TwoQubitDecomp())

print(cost1 - cost2)

[ 0.16497722 -0.95100118 -1.11113465 -1.86852571  0.6632018   1.48024618
  0.84686982  0.5990433   0.76125825  0.50805501 -0.3768341   0.53360786
  0.22889197  0.56115247  1.02434896]
[1.0976058591446565, 1.0489001593349658, 9.852303347601447, 5.144647895346997, 2.3374128707264323, 0.43695845184846327, 0.18300785177647289, -0.4369957649645462, -1.4503763170195014, 11.084578852933019, 1.5451648552381687, 1.9622858288051632, 12.28677428289651, 2.769454933848118, 0.8603980061431933]
-1.2769708347337492
