Skip to content

Commit

Permalink
Merge branch 'master' into move-CY-to-controlled-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
frederikwilde committed May 25, 2023
2 parents a1291fb + f7cc39d commit 722ce6a
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 58 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@

<h3>Bug fixes 🐛</h3>

* Fixes a bug where the wire ordering of the `wires` argument to `qml.density_matrix`
was not taken into account.
[(#4072)](https://github.com/PennyLaneAI/pennylane/pull/4072)

* Removes a patch in `interfaces/autograd.py` that checks for the `strawberryfields.gbs` device. That device
is pinned to PennyLane <= v0.29.0, so that patch is no longer necessary.

Expand Down
34 changes: 17 additions & 17 deletions pennylane/math/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import single_dispatch # pylint:disable=unused-import
from .multi_dispatch import diag, dot, scatter_element_add, einsum, get_interface
from .utils import is_abstract, allclose, cast, convert_like, cast_like
from .matrix_manipulation import _permute_dense_matrix

ABC_ARRAY = np.array(list(ABC))

Expand Down Expand Up @@ -208,21 +209,23 @@ def _density_matrix_from_matrix(density_matrix, indices, check_state=False):
"""
shape = density_matrix.shape[0]
num_indices = int(np.log2(shape))
dim = density_matrix.shape[0]
num_indices = int(np.log2(dim))

if check_state:
_check_density_matrix(density_matrix)

consecutive_indices = list(range(0, num_indices))
consecutive_indices = list(range(num_indices))

# Return the full density matrix if all the wires are given
if tuple(indices) == tuple(consecutive_indices):
return density_matrix
# Return the full density matrix if all the wires are given, potentially permuted
if len(indices) == num_indices:
return _permute_dense_matrix(density_matrix, consecutive_indices, indices, None)

# Compute the partial trace
traced_wires = [x for x in consecutive_indices if x not in indices]
density_matrix = _partial_trace(density_matrix, traced_wires)
return density_matrix
# Permute the remaining indices of the density matrix
return _permute_dense_matrix(density_matrix, sorted(indices), indices, None)


def _partial_trace(density_matrix, indices):
Expand Down Expand Up @@ -374,16 +377,16 @@ def _density_matrix_from_state_vector(state, indices, check_state=False):
[0.+0.j 0.+0.j]], shape=(2, 2), dtype=complex128)
"""
len_state = np.shape(state)[0]
dim = np.shape(state)[0]

# Check the format and norm of the state vector
if check_state:
_check_state_vector(state)

# Get dimension of the quantum system and reshape
num_indices = int(np.log2(len_state))
consecutive_wires = list(range(num_indices))
state = np.reshape(state, [2] * num_indices)
num_wires = int(np.log2(dim))
consecutive_wires = list(range(num_wires))
state = np.reshape(state, [2] * num_wires)

# Get the system to be traced
traced_system = [x for x in consecutive_wires if x not in indices]
Expand All @@ -392,7 +395,7 @@ def _density_matrix_from_state_vector(state, indices, check_state=False):
density_matrix = np.tensordot(state, np.conj(state), axes=(traced_system, traced_system))
density_matrix = np.reshape(density_matrix, (2 ** len(indices), 2 ** len(indices)))

return density_matrix
return _permute_dense_matrix(density_matrix, sorted(indices), indices, None)


def reduced_dm(state, indices, check_state=False, c_dtype="complex128"):
Expand Down Expand Up @@ -447,12 +450,9 @@ def reduced_dm(state, indices, check_state=False, c_dtype="complex128"):
len_state = state.shape[0]
# State vector
if state.shape == (len_state,):
density_matrix = _density_matrix_from_state_vector(state, indices, check_state)
return density_matrix

density_matrix = _density_matrix_from_matrix(state, indices, check_state)
return _density_matrix_from_state_vector(state, indices, check_state)

return density_matrix
return _density_matrix_from_matrix(state, indices, check_state)


def purity(state, indices, check_state=False, c_dtype="complex128"):
Expand Down
4 changes: 2 additions & 2 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class TestGradientTransformIntegration:
@pytest.mark.parametrize("slicing", [False, True])
def test_acting_on_qnodes_single_param(self, shots, slicing, atol):
"""Test that a gradient transform acts on QNodes with a single parameter correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
Expand All @@ -214,8 +215,6 @@ def circuit(weights):
res = grad_fn(w)
assert circuit.interface == "auto"
expected = np.array([-np.sin(w[0] if slicing else w), 0])
print(expected)
print(res)
if isinstance(shots, list):
assert all(np.allclose(r, expected, atol=atol, rtol=0) for r in res)
else:
Expand All @@ -224,6 +223,7 @@ def circuit(weights):
@pytest.mark.parametrize("shots, atol", [(None, 1e-6), (1000, 1e-1), ([1000, 100], 2e-1)])
def test_acting_on_qnodes_multi_param(self, shots, atol):
"""Test that a gradient transform acts on QNodes with multiple parameters correctly"""
np.random.seed(412)
dev = qml.device("default.qubit", wires=2, shots=shots)

@qml.qnode(dev)
Expand Down
71 changes: 43 additions & 28 deletions tests/math/test_density_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Unit tests for density matrices functions.
"""
# pylint: disable=import-outside-toplevel

import numpy as onp
import pytest
Expand Down Expand Up @@ -84,13 +85,19 @@
]

multiple_wires_list = [
[0, 1]
[0, 1],
[1, 0],
]
# fmt: on

c_dtypes = ["complex64", "complex128"]


def permute_two_qubit_dm(dm):
"""Permute the two qubits of a density matrix by transposing."""
return fn.reshape(fn.transpose(fn.reshape(dm, [2] * 4), [1, 0, 3, 2]), (4, 4))


class TestDensityMatrixFromStateVectors:
"""Tests for creating a density matrix from state vectors."""

Expand All @@ -101,6 +108,7 @@ def test_density_matrix_from_state_vector_single_wires(
self, state_vector, wires, expected_density_matrix, array_func
):
"""Test the density matrix from state vectors for single wires."""
# pylint: disable=protected-access
state_vector = array_func(state_vector)
density_matrix = fn.quantum._density_matrix_from_state_vector(state_vector, indices=wires)
assert np.allclose(density_matrix, expected_density_matrix[wires[0]])
Expand All @@ -112,9 +120,13 @@ def test_density_matrix_from_state_vector_full_wires(
self, state_vector, wires, expected_density_matrix, array_func
):
"""Test the density matrix from state vectors for full wires."""
# pylint: disable=protected-access
state_vector = array_func(state_vector)
density_matrix = fn.quantum._density_matrix_from_state_vector(state_vector, indices=wires)
assert np.allclose(density_matrix, expected_density_matrix[2])
expected = expected_density_matrix[2]
if wires == [1, 0]:
expected = permute_two_qubit_dm(expected)
assert np.allclose(density_matrix, expected)

@pytest.mark.parametrize("array_func", array_funcs)
@pytest.mark.parametrize("state_vector, expected_density_matrix", state_vectors)
Expand All @@ -136,7 +148,10 @@ def test_reduced_dm_with_state_vector_full_wires(
"""Test the reduced_dm with state vectors for full wires."""
state_vector = array_func(state_vector)
density_matrix = fn.reduced_dm(state_vector, indices=wires)
assert np.allclose(density_matrix, expected_density_matrix[2])
expected = expected_density_matrix[2]
if wires == [1, 0]:
expected = permute_two_qubit_dm(expected)
assert np.allclose(density_matrix, expected)

@pytest.mark.parametrize("array_func", array_funcs)
@pytest.mark.parametrize("state_vector, expected_density_matrix", state_vectors)
Expand All @@ -147,8 +162,11 @@ def test_density_matrix_from_state_vector_check_state(
"""Test the density matrix from state vectors for single wires with state checking"""
state_vector = array_func(state_vector)
density_matrix = fn.quantum.reduced_dm(state_vector, indices=wires, check_state=True)
expected = expected_density_matrix[2]
if wires == [1, 0]:
expected = permute_two_qubit_dm(expected)

assert np.allclose(density_matrix, expected_density_matrix[2])
assert np.allclose(density_matrix, expected)

def test_state_vector_wrong_shape(self):
"""Test that wrong shaped state vector raises an error with check_state=True"""
Expand All @@ -166,8 +184,8 @@ def test_state_vector_wrong_norm(self):

def test_density_matrix_from_state_vector_jax_jit(self):
"""Test jitting the density matrix from state vector function."""
# pylint: disable=protected-access
from jax import jit
import jax.numpy as jnp

state_vector = jnp.array([1, 0, 0, 0])

Expand All @@ -180,8 +198,8 @@ def test_density_matrix_from_state_vector_jax_jit(self):

def test_wrong_shape_jax_jit(self):
"""Test jitting the density matrix from state vector with wrong shape."""
# pylint: disable=protected-access
from jax import jit
import jax.numpy as jnp

state_vector = jnp.array([1, 0, 0])

Expand All @@ -194,7 +212,6 @@ def test_wrong_shape_jax_jit(self):

def test_density_matrix_tf_jit(self):
"""Test jitting the density matrix from state vector function with Tf."""
import tensorflow as tf
from functools import partial

state_vector = tf.Variable([1, 0, 0, 0], dtype=tf.complex128)
Expand All @@ -211,11 +228,9 @@ def test_density_matrix_tf_jit(self):

@pytest.mark.parametrize("c_dtype", c_dtypes)
@pytest.mark.parametrize("array_func", array_funcs)
@pytest.mark.parametrize("state_vector, expected_density_matrix", state_vectors)
@pytest.mark.parametrize("state_vector", list(zip(*state_vectors))[0])
@pytest.mark.parametrize("wires", single_wires_list)
def test_density_matrix_c_dtype(
self, array_func, state_vector, wires, c_dtype, expected_density_matrix
):
def test_density_matrix_c_dtype(self, array_func, state_vector, wires, c_dtype):
"""Test different complex dtype."""
state_vector = array_func(state_vector)
if fn.get_interface(state_vector) == "jax" and c_dtype == "complex128":
Expand Down Expand Up @@ -278,24 +293,29 @@ def test_reduced_dm_with_matrix_single_wires(
density_matrix = fn.reduced_dm(density_matrix, indices=wires)
assert np.allclose(density_matrix, expected_density_matrix[wires[0]])

@pytest.mark.parametrize("density_matrix, expected_density_matrix", density_matrices)
@pytest.mark.parametrize("density_matrix", list(zip(*density_matrices))[0])
@pytest.mark.parametrize("wires", multiple_wires_list)
def test_reduced_dm_with_matrix_full_wires(
self, density_matrix, wires, expected_density_matrix
):
def test_reduced_dm_with_matrix_full_wires(self, density_matrix, wires):
"""Test the reduced_dm with matrix for full wires."""
returned_density_matrix = fn.reduced_dm(density_matrix, indices=wires)
expected = density_matrix
if wires == [1, 0]:
expected = permute_two_qubit_dm(expected)
assert np.allclose(returned_density_matrix, expected)

assert np.allclose(density_matrix, returned_density_matrix)

@pytest.mark.parametrize("density_matrix, expected_density_matrix", density_matrices)
@pytest.mark.parametrize("density_matrix", list(zip(*density_matrices))[0])
@pytest.mark.parametrize("wires", multiple_wires_list)
def test_density_matrix_from_matrix_check(self, density_matrix, wires, expected_density_matrix):
def test_density_matrix_from_matrix_check(self, density_matrix, wires):
"""Test the density matrix from matrices for single wires with state checking"""
# pylint: disable=protected-access
returned_density_matrix = fn.quantum._density_matrix_from_matrix(
density_matrix, indices=wires, check_state=True
)
assert np.allclose(density_matrix, returned_density_matrix)
expected = density_matrix
if wires == [1, 0]:
expected = permute_two_qubit_dm(expected)

assert np.allclose(returned_density_matrix, expected)

def test_matrix_wrong_shape(self):
"""Test that wrong shaped state vector raises an error with check_state=True"""
Expand Down Expand Up @@ -328,7 +348,6 @@ def test_matrix_not_positive_definite(self):
def test_density_matrix_from_state_vector_jax_jit(self):
"""Test jitting the density matrix from state vector function."""
from jax import jit
import jax.numpy as jnp

state_vector = jnp.array([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])

Expand All @@ -339,8 +358,8 @@ def test_density_matrix_from_state_vector_jax_jit(self):

def test_wrong_shape_jax_jit(self):
"""Test jitting the density matrix from state vector with wrong shape."""
# pylint: disable=protected-access
from jax import jit
import jax.numpy as jnp

state_vector = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]])

Expand All @@ -351,7 +370,6 @@ def test_wrong_shape_jax_jit(self):

def test_density_matrix_tf_jit(self):
"""Test jitting the density matrix from density matrix function with Tf."""
import tensorflow as tf
from functools import partial

d_mat = tf.Variable(
Expand All @@ -374,12 +392,9 @@ def test_density_matrix_tf_jit(self):
assert np.allclose(density_matrix, [[1, 0], [0, 0]])

@pytest.mark.parametrize("c_dtype", c_dtypes)
@pytest.mark.parametrize("array_func", array_funcs)
@pytest.mark.parametrize("density_matrix, expected_density_matrix", state_vectors)
@pytest.mark.parametrize("density_matrix", list(zip(*density_matrices))[0])
@pytest.mark.parametrize("wires", single_wires_list)
def test_density_matrix_c_dtype(
self, array_func, density_matrix, wires, c_dtype, expected_density_matrix
):
def test_density_matrix_c_dtype(self, density_matrix, wires, c_dtype):
"""Test different complex dtype."""
if fn.get_interface(density_matrix) == "jax" and c_dtype == "complex128":
pytest.skip("Jax does not support complex 128")
Expand Down

0 comments on commit 722ce6a

Please sign in to comment.