In [6]:
import json
import pennylane as qml
import pennylane.numpy as np

In [7]:
np.random.seed(1967)

def get_matrix(params):
    alpha, beta, gamma, phi = params

    # Define the matrices for single-qubit rotations
    RZ_alpha = np.array([ [np.exp(-0.5j * alpha), 0], 
                          [0, np.exp(0.5j * alpha)] ], 
                                                        dtype=complex)
    
    RX_beta = np.array([[np.cos(0.5 * beta), -1j * np.sin(0.5 * beta)],
                        [-1j * np.sin(0.5 * beta), np.cos(0.5 * beta)] ], 
                                                                            dtype=complex)
    
    RZ_gamma = np.array([ [np.exp(-0.5j * gamma), 0], 
                          [0, np.exp(0.5j * gamma)] ], 
                                                        dtype=complex)
    
    Phase_phi = np.array([ [np.exp(1j * phi), 0], 
                           [0, 1] ], 
                                     dtype=complex)

    # Construct the unitary matrix
    matrix = Phase_phi @ RZ_gamma @ RX_beta @ RZ_alpha

    return matrix

In [8]:
def error(U, params):
    matrix = get_matrix(params)
    diff = U - matrix
    error = np.sum(np.abs(diff) ** 2)
    return error

In [9]:
def train_parameters(U):
    epochs = 1000
    lr = 0.01

    grad = qml.grad(error, argnum=1)
    params = np.random.rand(4) * np.pi

    for epoch in range(epochs):
        params -= lr * grad(U, params)

    return params

In [10]:
# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:
    matrix = json.loads(test_case_input)
    params = [float(p) for p in train_parameters(matrix)]
    return json.dumps(params)


def check(solution_output: str, expected_output: str) -> None:
    matrix1 = get_matrix(json.loads(solution_output))
    matrix2 = json.loads(expected_output)
    assert not np.allclose(get_matrix(np.random.rand(4)), get_matrix(np.random.rand(4)))
    assert np.allclose(matrix1, matrix2, atol=0.2)

# These are the public test cases
test_cases = [
    ('[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]', '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'),
    ('[[ 1,  0], [ 0, -1]]', '[[ 1,  0], [ 0, -1]]')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'...
Correct!
Running test case 1 with input '[[ 1,  0], [ 0, -1]]'...
Correct!
