In [32]:
import json
import pennylane as qml
import pennylane.numpy as np

In [179]:
def is_product(state, subsystem, wires):
    """Determines if a pure quantum state can be written as a product state between 
    a subsystem of wires and their compliment.

    Args:
        state (numpy.array): The quantum state of interest.
        subsystem (list(int)): The subsystem used to determine if the state is a product state.
        wires (list(int)): The wire/qubit labels for the state. Use these for creating a QNode if you wish!

    Returns:
        (str): "yes" if the state is a product state or "no" if it isn't.
    """
    dev = qml.device("default.qubit", wires=len(wires))

    # Compute the original density matrix
    @qml.qnode(device=dev)
    def total_circuit():
        qml.QubitStateVector(np.array(state), wires=wires)
        return qml.density_matrix(wires)

    # Compute the partial trace of the two subsystems of interest
    subsystembar = [w for w in wires if w not in subsystem]

    ## Subsystem
    @qml.qnode(device=dev)
    def sub_circuit():
        qml.QubitStateVector(np.array(state), wires=wires)
        return qml.density_matrix(subsystem)
    subsystem_res = sub_circuit()

    ## Residual subsystem
    @qml.qnode(device=dev)
    def sub_circuit_bar():
        qml.QubitStateVector(np.array(state), wires=wires)
        return qml.density_matrix(subsystembar)  
    subsystembar_res = sub_circuit_bar()

    @qml.qnode(device=dev)
    def product_circuit():
        qml.QubitStateVector(subsystem_res, wires=subsystem)
        qml.QubitStateVector(subsystembar_res, wires=subsystembar)
        return qml.density_matrix(wires)

    print(qml.draw(product_circuit)())
        
    # Compare the product state of the partial traces and the input density matrix
    # if np.allclose(np.kron(sub_circuit(), sub_circuit_bar()), state_dm, rtol=1e-6, atol=1e-8):
    if np.allclose(product_circuit(), total_circuit(), rtol=1e-5, atol=1e-8):
        print('hmmm')
        is_product_str = "yes"
    else:
        is_product_str = "no"
    
    return is_product_str

In [180]:
# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    state, subsystem, wires = ins
    state = np.array(state)
    output = is_product(state, subsystem, wires)
    return output

def check(solution_output: str, expected_output: str) -> None:
    assert solution_output == expected_output

In [181]:
test_cases = [['[[0.707107, 0, 0, 0.707107], [0], [0, 1]]', 'no'], 
              ['[[1, 0, 0, 0], [0], [0, 1]]', 'yes'],
              ['[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [1, 2], [0, 1, 2]]', 'yes'],
              ['[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [0, 2], [0, 1, 2]]', 'no'],
              ['[[0.5, 0, 0, 0.5, 0, 0, 0, 0, 0.5, 0, 0, 0.5, 0, 0, 0, 0], [0, 1], [0, 1, 2, 3]]', 'yes']]

In [182]:
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[0.707107, 0, 0, 0.707107], [0], [0, 1]]'...
0: ──QubitStateVector(M0)─┤ ╭State
1: ──QubitStateVector(M0)─┤ ╰State
Runtime Error. Sum of amplitudes-squared does not equal one.
Running test case 1 with input '[[1, 0, 0, 0], [0], [0, 1]]'...
0: ──QubitStateVector(M0)─┤ ╭State
1: ──QubitStateVector(M0)─┤ ╰State
Runtime Error. Sum of amplitudes-squared does not equal one.
Running test case 2 with input '[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [1, 2], [0, 1, 2]]'...
Runtime Error. The batch sizes of the quantum script operations do not match, they include 4 and 2.
Running test case 3 with input '[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [0, 2], [0, 1, 2]]'...
Runtime Error. The batch sizes of the quantum script operations do not match, they include 4 and 2.
Running test case 4 with input '[[0.5, 0, 0, 0.5, 0, 0, 0, 0, 0.5, 0, 0, 0.5, 0, 0, 0, 0], [0, 1], [0, 1, 2, 3]]'...
0: ─╭QubitStateVector(M1)─┤ ╭State
1: ─╰QubitStateVector(M1)─┤ ├State
2: ─╭QubitStateVe

In [104]:
np.kron([0.707107, 0, 0, 0.707107], [0.707107, 0, 0, 0.707107])

tensor([0.50000031, 0.        , 0.        , 0.50000031, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.50000031, 0.        , 0.        ,
        0.50000031], requires_grad=True)

In [186]:
wires = [0, 1, 2]
state = [0.707107, 0, 0, 0.707107, 0, 0, 0, 0]

dev = qml.device("default.qubit", wires=len(wires))

# Compute the original density matrix
@qml.qnode(device=dev)
def total_circuit(w):
    qml.QubitStateVector(np.array(state), wires=wires)
    return qml.state()
print(total_circuit(w=[0]))
print(total_circuit(w=[1,2]))

# print(np.kron(total_circuit(w=[0]), total_circuit(w=[1,2])))
# print(total_circuit(wires))

[0.707107+0.j 0.      +0.j 0.      +0.j 0.707107+0.j 0.      +0.j
 0.      +0.j 0.      +0.j 0.      +0.j]
[0.707107+0.j 0.      +0.j 0.      +0.j 0.707107+0.j 0.      +0.j
 0.      +0.j 0.      +0.j 0.      +0.j]
