In [1]:
import json
import pennylane as qml
import pennylane.numpy as np

In [129]:
np.random.seed(1967)

def get_matrix(params):
    """
    Args:
        - params (array): The four parameters of the model.
        
    Returns:
        - (matrix): The associated matrix to these parameters.
    """
    alpha, beta, gamma, phi = params
    # Put your code here #
    
    def circuit(alpha, beta, gamma, phi):
        qml.RZ(alpha, wires=0)
        qml.RX(beta, wires=0)
        qml.RZ(gamma, wires=0)
        qml.PhaseShift(phi, wires=0)
    
    matrix = qml.matrix(circuit,wire_order=[0])
    matrix_val = matrix(alpha, beta, gamma, phi)
    # Return the matrix
    return matrix_val

In [141]:
def error(U, params):
    """
    This function determines the similarity between your generated matrix and
    the target unitary.

    Args:
        - U (np.array): Goal matrix that we want to approach.
        - params (array): The four parameters of the model.

    Returns:
        - (float): Error associated with the quality of the solution.
    """

    matrix = get_matrix(params)
    # Put your code here #
    error = np.linalg.norm(matrix - U)
    print(error)
    # Return the error
    return error

In [142]:
def train_parameters(U):
    epochs = 100
    lr = 0.01

    grad = qml.grad(error, argnum=1)
    params = np.random.rand(4) * np.pi

    for epoch in range(epochs):
        params -= lr * grad(U, params)

    return params

In [143]:
# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:
    matrix = json.loads(test_case_input)
    params = [float(p) for p in train_parameters(matrix)]
    return json.dumps(params)


def check(solution_output: str, expected_output: str) -> None:
    matrix1 = get_matrix(json.loads(solution_output))
    matrix2 = json.loads(expected_output)
    assert not np.allclose(get_matrix(np.random.rand(4)), get_matrix(np.random.rand(4)))
    assert np.allclose(matrix1, matrix2, atol=0.2)

In [144]:
# These are the public test cases
test_cases = [
    ('[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]', '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'),
    ('[[ 1,  0], [ 0, -1]]', '[[ 1,  0], [ 0, -1]]')
]

In [145]:
# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'...
Autograd ArrayBox with value 1.3500460105646268
Autograd ArrayBox with value 1.3509160978425234
Autograd ArrayBox with value 1.3517911602435262
Autograd ArrayBox with value 1.3526710141046043
Autograd ArrayBox with value 1.3535554728408286
Autograd ArrayBox with value 1.354444347165621
Autograd ArrayBox with value 1.355337445313746
Autograd ArrayBox with value 1.3562345732664893
Autograd ArrayBox with value 1.357135534978481
Autograd ArrayBox with value 1.3580401326056095
Autograd ArrayBox with value 1.358948166733489
Autograd ArrayBox with value 1.3598594366059458
Autograd ArrayBox with value 1.3607737403529965
Autograd ArrayBox with value 1.361690875217807
Autograd ArrayBox with value 1.3626106377821303
Autograd ArrayBox with value 1.3635328241897362
Autograd ArrayBox with value 1.364457230367365
Autograd ArrayBox with value 1.3653836522427523
Autograd ArrayBox with value 1.3663118859592907
A

AssertionError: 

In [136]:
def circuit(alpha, beta, gamma, phi):
    qml.RZ(alpha, wires=0)
    qml.RX(beta, wires=0)
    qml.RZ(gamma, wires=0)
    qml.PhaseShift(phi, wires=0)

matrix_fn = qml.matrix(circuit, wire_order=[0])
alpha, beta, gamma, phi = 2.27803033, 1.80051172, 2.0258912,  1.75304041
matrix_fn(alpha, beta, gamma, phi)

array([[-0.34115232-0.51938896j,  0.09851229-0.77726799j],
       [ 0.78225008+0.0439888j , -0.44895819-0.42963512j]])