## Challenge code
 
 You must complete the `linear_combination` function to build the above circuit that implements the linear combination
 
 $$ 
 \alpha U + \beta V
 $$ 
 
 of two single-qubit unitaries U and V, and returns the probabilities on the auxiliary register. For simplicity, we take $\alpha$ and $\beta$ to be positive real numbers.
 
 As a helper function, you are also asked to complete the `W` function, which returns the unitary $W(\alpha,\beta).$

 ![img](images/spaceship_2.png)
 
 ### Input
 
 As input to this problem, you are given:
 
 - `U` (`list(list(float))`): A $2\times 2$ matrix representing the single-qubit unitary operator $U$.
 - `V` (`list(list(float))`): A $2\times 2$ matrix representing the single-qubit unitary operator $V$
 - `alpha` (`float`): The prefactor $\alpha$ of $U$ in the linear combination, as above.
 - `beta` (`float`): The prefactor $\beta$ of $V$ in the linear combination, as above.
  
 ### Output
 
 The output used to test your solution is a `float` corresponding to the probability of measuring $\vert 0 \rangle$ on the main register. This is the first element of your output of `linear_combination`. We will extract this element for you in our testing functions!
 
 If your solution matches the correct one within the given tolerance specified in `check` (in this case it's an absolute tolerance of `0.001`), the output will be `"Correct!"` Otherwise, you will receive a `"Wrong answer"` prompt.
 
 Good luck!
 ### Imports
 The cell below specifies the libraries you should use in this challenge. Run the cell to import the libraries. ***Do not modify the cell.***

In [2]:
import json
import pennylane as qml
from pennylane import numpy as np

### Code
 Complete the code below. Note that during QHack, some sections were not editable. We've marked those sections accordingly here, but you can still edit them if you wish.

In [7]:
def W(alpha, beta):
    """ This function returns the matrix W in terms of
    the coefficients alpha and beta

    Args:
        - alpha (float): The prefactor alpha of U in the linear combination, as in the
        challenge statement.
        - beta (float): The prefactor beta of V in the linear combination, as in the
        challenge statement.
    Returns 
        -(numpy.ndarray): A 2x2 matrix representing the operator W,
        as defined in the challenge statement
    """
    # Return the real matrix of the unitary W, in terms of the coefficients.
    return np.array([[np.sqrt(alpha), -np.sqrt(beta)], 
                     [np.sqrt(beta), np.sqrt(alpha)]])/ np.sqrt(alpha + beta)


dev = qml.device('default.qubit', wires = 2)
@qml.qnode(dev)
def linear_combination(U, V,  alpha, beta):
    """This circuit implements the circuit that probabilistically calculates the linear combination 
    of the unitaries.

    Args:
        - U (list(list(float))): A 2x2 matrix representing the single-qubit unitary operator U.
        - V (list(list(float))): A 2x2 matrix representing the single-qubit unitary operator U.
        - alpha (float): The prefactor alpha of U in the linear combination, as above.
        - beta (float): The prefactor beta of V in the linear combination, as above.

    Returns:
        -(numpy.tensor): Probabilities of measuring the computational
        basis states on the auxiliary wire. 
    """
    qml.QubitUnitary(W(alpha, beta), wires=[0])
    qml.ControlledQubitUnitary(U, control_wires=[0], wires=[1], control_values=[0])
    qml.ControlledQubitUnitary(V, control_wires=[0], wires=[1], control_values=[1])
    qml.adjoint(qml.QubitUnitary)(W(alpha, beta), wires=[0])
    # Return the probabilities on the first wire
    return qml.probs(wires=[0])


U = [[ 0.70710678,  0.70710678], 
     [ 0.70710678, -0.70710678]]
V = [[1, 0], 
     [0, -1]]
alpha = 1
beta = 3

print(qml.draw(linear_combination)(U, V, alpha, beta), "\n")
print(linear_combination(U, V, alpha, beta))

0: ──U(M0)─╭○─────╭●──────U(M0)†─┤  Probs
1: ────────╰U(M1)─╰U(M2)─────────┤       

M0 = 
[[ 0.5       -0.8660254]
 [ 0.8660254  0.5      ]]
M1 = 
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
M2 = 
[[ 1  0]
 [ 0 -1]] 

[0.89016504 0.10983496]


These functions are responsible for testing the solution. You will need to run the cell below. ***Do not modify the cell.***

In [4]:
def run(test_case_input: str) -> str:
    dev = qml.device('default.qubit', wires = 2)
    ins = json.loads(test_case_input)
    output = linear_combination(*ins)[0].numpy()

    return str(output)

def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    assert np.allclose(
        solution_output, expected_output, rtol=1e-4
    ), "Your circuit doesn't look quite right "

### Test cases
 Running the cell below will load the test cases. ***Do not modify the cell***.
 - input: [[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]],[[1, 0], [0, -1]], 1, 3]
 	+ expected output: 0.8901650422902458
 - input: [[[0, 1],[1, 0]],[[1, 0], [0, -1]], 1, 2]
 	+ expected output: 0.5555555555555559
 - input: [[[ 0.98877108, -0.14943813], [ 0.14943813,  0.98877108]],[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]], 2, 1]
 	+ expected output: 0.9132602008678633

In [5]:
test_cases = [['[[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]],[[1, 0], [0, -1]], 1, 3]', '0.8901650422902458'], ['[[[0, 1],[1, 0]],[[1, 0], [0, -1]], 1, 2]', '0.5555555555555559'], ['[[[ 0.98877108, -0.14943813], [ 0.14943813,  0.98877108]],[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]], 2, 1]', '0.9132602008678633']]

### Solution testing
 Once you have run every cell above, including the one with your code, the cell below will test your solution. Run the cell. If you are correct for all of the test cases, it means your solutions is correct. Otherwise, you need to double check your work. ***Do not modify the cell below.***

In [6]:
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]],[[1, 0], [0, -1]], 1, 3]'...
Correct!
Running test case 1 with input '[[[0, 1],[1, 0]],[[1, 0], [0, -1]], 1, 2]'...
Correct!
Running test case 2 with input '[[[ 0.98877108, -0.14943813], [ 0.14943813,  0.98877108]],[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]], 2, 1]'...
Correct!
