In [1]:
import json
import pennylane as qml
import pennylane.numpy as np

def W(params):
    num_wires = len(params[0])
    wires = list(range(num_wires))

    for layer_params in params:
        # Apply rotations
        for i, param in enumerate(layer_params):
            qml.RY(param, wires=wires[i])
        # Apply entangling gates (CNOT)
        for i in range(num_wires):
            qml.CNOT(wires=[wires[i], wires[(i+1) % num_wires]])

def S(g, x, num_wires):
    for wire in range(num_wires):
        # Get the operator G for the current wire
        G = g(wire)
        # Apply the exponential of the operator i x G
        qml.exp(qml.dot([1], [1j * x * G]))
        

# Create a device
dev = qml.device("default.qubit", wires=[0, 1, 2, 3])

@qml.qnode(dev)
def quantum_model(param_set, g, x):
    num_wires = len(param_set[0][0])

    # Apply alternating trainable and encoding blocks
    for i in range(len(param_set)):
        W(param_set[i])
        if i == len(param_set) - 1:
            pass
        else:
            S(g, x, num_wires)

    # Measure the probabilities in the computational basis on the first wire
    return qml.probs(wires=0)

# These functions are used to test the solution
def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    params = np.array(ins[0])
    #print(params[0])
    g = getattr(qml, ins[1])
    x = ins[2]
    outs = quantum_model(params, g, x).tolist()
    return str(outs)

def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    
    assert np.allclose(solution_output, expected_output, atol=1e-3), "Not the correct probabilities for the quantum model."

# Public test cases
test_cases = [
    ('[[[[1.0472, 0.7854, 3.1416, 0.3927],[1.0472, 0.7854, 3.1416, 0.5236]],[[1.0472, 0.7854, 1.5708, 0.3927],[0.7854, 0.7854, 1.5708, 0.7854]]],"PauliX", 0.7854]', '[0.46653, 0.53347]'),
    ('[[[[0.62832, 0.3927, 1.0472, 0.7854],[0.7854, 0.31416, 0.62832, 0.5236]],[[0.31416, 0.7854, 0.7854, 0.3927],[0.31416, 0.3927, 0.31416, 0.3927]]],"PauliY", 0.5236]', '[0.68594, 0.31406]')
]

# Run the public test cases
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")
    try:
        output = run(input_)
    except Exception as exc:
        print(f"Runtime Error: {exc}")
    else:
        check(output, expected_output)
        print("Correct!")


Running test case 0 with input '[[[[1.0472, 0.7854, 3.1416, 0.3927],[1.0472, 0.7854, 3.1416, 0.5236]],[[1.0472, 0.7854, 1.5708, 0.3927],[0.7854, 0.7854, 1.5708, 0.7854]]],"PauliX", 0.7854]'...
Correct!
Running test case 1 with input '[[[[0.62832, 0.3927, 1.0472, 0.7854],[0.7854, 0.31416, 0.62832, 0.5236]],[[0.31416, 0.7854, 0.7854, 0.3927],[0.31416, 0.3927, 0.31416, 0.3927]]],"PauliY", 0.5236]'...
Correct!


