In [44]:
import pennylane as qml
import numpy as np
import jax
import matplotlib.pyplot as plt
from io import StringIO
import pickle
import os
import cirq

# Configure JAX
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jnp = jax.numpy


In [45]:
def ensure_directory_exists(directory):
    """Ensure that the specified directory exists, create it if it does not."""
    if not os.path.exists(directory):
        os.makedirs(directory)

def save_params(params, filename):
    """Save parameters to a file using pickle."""
    with open(filename, 'wb') as f:
        pickle.dump(params, f)

def load_params(filename):
    """Load parameters from a file using pickle."""
    with open(filename, 'rb') as f:
        return pickle.load(f)



In [46]:
def get_initial_params(params_file):
    """Get initial parameters from the file if it exists, otherwise generate new ones."""
    if os.path.exists(params_file):
        # Load parameters if they exist
        return load_params(params_file)
    else:
        # Generate random parameters and save them
        init_params = jnp.array(np.random.randn(15))
        save_params(init_params, params_file)
        return init_params


In [47]:
def create_matrix(params, num_wires):
    """Create a SU(4) unitary matrix from the given parameters."""
    return qml.SpecialUnitary.compute_matrix(theta = params, num_wires=num_wires)



In [48]:
def perform_kak_decomposition(U):
    """Perform the KAK decomposition on a unitary matrix and return the parameters."""
    qubits = [cirq.LineQubit(0), cirq.LineQubit(1)]
    operation = cirq.MatrixGate(U).on(*qubits)
    kak_decomp = cirq.kak_decomposition(operation)
    a0, a1 = kak_decomp.single_qubit_operations_before
    b0, b1 = kak_decomp.single_qubit_operations_after
    A = kak_decomp.interaction_coefficients
    gamma = kak_decomp.global_phase
    K1 = cirq.MatrixGate(a0 @ a1)
    K2 = cirq.MatrixGate(b0 @ b1)
    
    return K1, A, K2, gamma


In [49]:
def main():
    directory = 'data/mapper/'
    params_file = os.path.join(directory, 'init_params.pkl')

    # Ensure the directory exists
    ensure_directory_exists(directory)

    # Get the initial parameters
    init_params = get_initial_params(params_file)

    # Create the SU(4) unitary matrix
    unitary = create_matrix(init_params, 2)

    # Perform the KAK decomposition
    A, B, C, gamma = perform_kak_decomposition(unitary)

    print("KAK Decomposition Parameters:")
    print("A:", A)
    print("B:", B)
    print("C:", C)
    print("Gamma:", gamma)

if __name__ == "__main__":
    main()


KAK Decomposition Parameters:
A: [[-0.537-0.275j  0.364-0.709j]
 [-0.364-0.709j -0.537+0.275j]]
B: (0.6260655968986388, 0.3961425162467157, 0.3252429160052595)
C: [[-0.38 -0.671j -0.167-0.614j]
 [ 0.167-0.614j -0.38 +0.671j]]
Gamma: (3.8285686989269494e-16+1j)
