In [13]:
import pennylane as qml
import numpy as np
import jax
import matplotlib.pyplot as plt
from io import StringIO
import pickle
import os
import cirq
from cirq import KakDecomposition
from typing import Tuple
# Configure JAX
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy


In [14]:
def ensure_directory_exists(directory):
    """Ensure that the specified directory exists, create it if it does not."""
    if not os.path.exists(directory):
        os.makedirs(directory)

def save_params(params, filename):
    """Save parameters to a file using pickle."""
    with open(filename, 'wb') as f:
        pickle.dump(params, f)

def load_params(filename):
    """Load parameters from a file using pickle."""
    with open(filename, 'rb') as f:
        return pickle.load(f)



In [15]:
def get_initial_params(params_file):
    """Get initial parameters from the file if it exists, otherwise generate new ones."""
    if os.path.exists(params_file):
        # Load parameters if they exist
        return load_params(params_file)
    else:
        # Generate random parameters and save them
        init_params = jnp.array(np.random.randn(15))
        save_params(init_params, params_file)
        return init_params


In [16]:
def create_matrix(params, num_wires):
    """Create a SU(4) unitary matrix from the given parameters."""
    return qml.SpecialUnitary.compute_matrix(theta = params, num_wires=num_wires)



In [17]:
def perform_kak_decomposition(U) -> Tuple[cirq.MatrixGate, cirq.MatrixGate, Tuple[float, float, float], complex, KakDecomposition]:
    """Perform the KAK decomposition on a unitary matrix and return the parameters."""
    qubits = [cirq.LineQubit(0), cirq.LineQubit(1)]
    operation = cirq.MatrixGate(U).on(*qubits)
    kak_decomp = cirq.kak_decomposition(operation)
    a = [cirq.MatrixGate(m) for m in kak_decomp.single_qubit_operations_before]
    b = [cirq.MatrixGate(m) for m in kak_decomp.single_qubit_operations_after]
    A = kak_decomp.interaction_coefficients
    gamma = kak_decomp.global_phase
    print(kak_decomp)
    return a[0], a[1], A, b[0],b[1], gamma, kak_decomp


In [18]:
def main():
    directory = 'data/mapper/'
    params_file = os.path.join(directory, 'init_params.pkl')

    # Ensure the directory exists
    ensure_directory_exists(directory)

    # Get the initial parameters
    init_params = get_initial_params(params_file)
    print(init_params)
    # Create the SU(4) unitary matrix
    unitary = create_matrix(init_params, 2)

    # Perform the KAK decomposition
    a0, a1, B, b1, b0, gamma, kak_decomp = perform_kak_decomposition(unitary)
    # def __str__(self) -> str:
    #     xx = self.interaction_coefficients[0] * 4 / np.pi
    #     yy = self.interaction_coefficients[1] * 4 / np.pi
    #     zz = self.interaction_coefficients[2] * 4 / np.pi
    #     before0 = axis_angle(self.single_qubit_operations_before[0])
    #     before1 = axis_angle(self.single_qubit_operations_before[1])
    #     after0 = axis_angle(self.single_qubit_operations_after[0])
    #     after1 = axis_angle(self.single_qubit_operations_after[1])
    #     return (
    #         'KAK {\n'
    #         f'    xyz*(4/π): {xx:.3g}, {yy:.3g}, {zz:.3g}\n'
    #         f'    before: ({before0}) ⊗ ({before1})\n'
    #         f'    after: ({after0}) ⊗ ({after1})\n'
    #         '}'
    #     )

    before0 = cirq.axis_angle(kak_decomp.single_qubit_operations_before[0])
    before1 = cirq.axis_angle(kak_decomp.single_qubit_operations_before[1])
    after0 = cirq.axis_angle(kak_decomp.single_qubit_operations_after[0])
    after1 = cirq.axis_angle(kak_decomp.single_qubit_operations_after[1])
    xx = kak_decomp.interaction_coefficients[0] * 4 / np.pi
    yy = kak_decomp.interaction_coefficients[1] * 4 / np.pi
    zz = kak_decomp.interaction_coefficients[2] * 4 / np.pi

    print("KAK Decomposition Parameters:")
    print(f"{before0}, {before1}, {xx}, {yy}, {zz}, {after0}, {after1}")
    # print("a0:", a0)
    # print("a1:", a1)
    # print("B:", B)
    # print("b0:", b0)
    # print("b1:", b1)
    # print("Gamma:", gamma)

if __name__ == "__main__":
    main()


[-1.58798692 -0.80052595 -0.18718267 -0.59787411  0.25895406  1.63709804
 -1.02576173 -0.93333536  1.77121638 -1.77456891 -0.05335546  0.40300334
  0.71158233 -0.19948615  0.28918582]
KAK {
    xyz*(4/π): 0.797, 0.504, 0.414
    before: (0.617*π around 0.536*X+0.799*Y+0.273*Z) ⊗ (-0.893*π around 0.381*X+0.112*Y+0.918*Z)
    after: (0.86*π around 0.563*X+0.824*Y-0.0658*Z) ⊗ (0.775*π around -0.221*X+0.8*Y+0.558*Z)
}
KAK Decomposition Parameters:
0.617*π around 0.536*X+0.799*Y+0.273*Z, -0.893*π around 0.381*X+0.112*Y+0.918*Z, 0.7971314755695709, 0.5043843170362101, 0.4141121423028734, 0.86*π around 0.563*X+0.824*Y-0.0658*Z, 0.775*π around -0.221*X+0.8*Y+0.558*Z
