In [1]:
import json
import pennylane as qml
import pennylane.numpy as np

In [2]:
def is_product(state, subsystem, wires):
    """Determines if a pure quantum state can be written as a product state between 
    a subsystem of wires and their compliment.

    Args:
        state (numpy.array): The quantum state of interest.
        subsystem (list(int)): The subsystem used to determine if the state is a product state.
        wires (list(int)): The wire/qubit labels for the state. Use these for creating a QNode if you wish!

    Returns:
        (str): "yes" if the state is a product state or "no" if it isn't.
    """
        
    # Find the reduced density matrix of any subsystem
    @qml.qnode(device=qml.device("default.qubit", wires=len(wires)))
    def subset_circuit(w):
        qml.QubitStateVector(np.array(state), wires=wires)
        return qml.density_matrix(wires=w)

    # Compute the indices of the remaining subsystem
    subsystem_bar = [w for w in wires if w not in subsystem]

    # Compute the reduced density matrices of the two subsystems of interest
    density_matrix_subsystem = subset_circuit(subsystem)
    density_matrix_subsystem_bar = subset_circuit(subsystem_bar)

    # Reconstruct a density matrix assuming a product state
    @qml.qnode(device=qml.device("default.mixed", wires=len(wires)))
    def product_circuit():
        qml.QubitDensityMatrix(density_matrix_subsystem, wires=subsystem)
        qml.QubitDensityMatrix(density_matrix_subsystem_bar, wires=subsystem_bar)
        return qml.density_matrix(wires)

    product_state = product_circuit()
    initial_state = subset_circuit(wires)
        
    # Compare the product state of the partial traces and the input density matrix
    if np.allclose(product_state, initial_state, rtol=1e-5, atol=1e-8):
        is_product_str = "yes"
    else:
        is_product_str = "no"
    
    return is_product_str

In [3]:
# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    state, subsystem, wires = ins
    state = np.array(state)
    output = is_product(state, subsystem, wires)
    return output

def check(solution_output: str, expected_output: str) -> None:
    assert solution_output == expected_output

In [4]:
test_cases = [['[[0.707107, 0, 0, 0.707107], [0], [0, 1]]', 'no'], 
              ['[[1, 0, 0, 0], [0], [0, 1]]', 'yes'],
              ['[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [1, 2], [0, 1, 2]]', 'yes'],
              ['[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [0, 2], [0, 1, 2]]', 'no'],
              ['[[0.5, 0, 0, 0.5, 0, 0, 0, 0, 0.5, 0, 0, 0.5, 0, 0, 0, 0], [0, 1], [0, 1, 2, 3]]', 'yes']]

In [5]:
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[0.707107, 0, 0, 0.707107], [0], [0, 1]]'...
Correct!
Running test case 1 with input '[[1, 0, 0, 0], [0], [0, 1]]'...
Correct!
Running test case 2 with input '[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [1, 2], [0, 1, 2]]'...
Correct!
Running test case 3 with input '[[0.707107, 0, 0, 0.707107, 0, 0, 0, 0], [0, 2], [0, 1, 2]]'...
Correct!
Running test case 4 with input '[[0.5, 0, 0, 0.5, 0, 0, 0, 0, 0.5, 0, 0, 0.5, 0, 0, 0, 0], [0, 1], [0, 1, 2, 3]]'...
Correct!
