Skip to content

Commit

Permalink
Remove TensorFlow dependency from frontend tests (#639)
Browse files Browse the repository at this point in the history
* Lift TF installation assumption from frontend tests

* Update changelog

* Refactor mock TensorFlow imports

* Run black formatter

* Remove isolated .backend artefact from merge

* Use sf.LocalEngine() in frontend tests and gate TF version tests

* Restore test coverage to incorrect TF version check

* Restore TFBackend import to bottom of file

* Fix TF imports in frontend tests following #599

* Remove release note about TF test fixes
  • Loading branch information
Mandrenkov committed Nov 3, 2021
1 parent 243d5d7 commit d4c11b9
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

This release contains contributions from (in alphabetical order):

Sebastián Duque Mesa, Filippo Miatto, Nicolás Quesada, Antal Száva, Yuan Yao.
Mikhail Andrenkov, Sebastián Duque Mesa, Filippo Miatto, Nicolás Quesada, Antal Száva, Yuan Yao.

# Release 0.19.0 (current release)

Expand Down
3 changes: 2 additions & 1 deletion strawberryfields/backends/tfbackend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@
"""
import sys
from .backend import TFBackend

try:
import tensorflow
Expand Down Expand Up @@ -170,4 +169,6 @@ def excepthook(type, value, traceback):
raise ImportError(tf_info)


# The modules inside the tfbackend package assume TensorFlow is importable.
from .backend import TFBackend
from .ops import update_symplectic
5 changes: 3 additions & 2 deletions tests/backend/test_gaussian_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import pytest
import numpy as np
import tensorflow as tf
from scipy.linalg import expm
from scipy.stats import unitary_group
from thewalrus.quantum.fock_tensors import fock_tensor
from thewalrus.symplectic import sympmat

tf = pytest.importorskip("tensorflow")

from strawberryfields.backends.tfbackend.ops import (
choi_trick,
n_mode_gate,
Expand Down
9 changes: 6 additions & 3 deletions tests/backend/test_states_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import pytest

import numpy as np
import tensorflow as tf
from scipy.special import factorial as fac

from strawberryfields import backends
from strawberryfields import utils
try:
import tensorflow as tf
except ImportError:
import unittest.mock as mock

tf = mock.Mock()


MAG_ALPHAS = np.linspace(0, 0.8, 3)
Expand Down
49 changes: 12 additions & 37 deletions tests/backend/test_tf_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Unit tests for TensorFlow 2.x version checking"""
from importlib import reload
import sys
from unittest.mock import MagicMock

import pytest

import strawberryfields as sf


try:
import tensorflow
except (ImportError, ModuleNotFoundError):

except ImportError:
tf_available = False
import mock

tensorflow = mock.Mock(__version__="1.12.2")
else:
tf_available = True


class TestBackendImport:
"""Test importing the backend directly"""

def test_incorrect_tf_version(self, monkeypatch):
"""Test that an exception is raised if the version
of TensorFlow installed is not version 2.x"""
with monkeypatch.context() as m:
# force Python check to pass
m.setattr("sys.version_info", (3, 6, 3))
m.setattr(tensorflow, "__version__", "1.12.2")

with pytest.raises(ImportError, match="version 2.x of TensorFlow is required"):
reload(sf.backends.tfbackend)

@pytest.mark.skipif(tf_available, reason="Test only works if TF not installed")
def test_tensorflow_not_installed(self, monkeypatch):
"""Test that an exception is raised if TensorFlow is not installed"""
with monkeypatch.context() as m:
# force Python check to pass
m.setattr("sys.version_info", (3, 6, 3))

with pytest.raises(ImportError, match="version 2.x of TensorFlow is required"):
reload(sf.backends.tfbackend)


@pytest.mark.frontend
class TestFrontendImport:
"""Test importing via the frontend"""
Expand All @@ -64,21 +37,23 @@ def test_incorrect_tf_version(self, monkeypatch):
"""Test that an exception is raised if the version
of TensorFlow installed is not version 2.x"""
with monkeypatch.context() as m:
# force Python check to pass
m.setattr("sys.version_info", (3, 6, 3))
m.setattr(tensorflow, "__version__", "1.12.2")
# Force the Python check to pass.
m.setattr(sys, "version_info", (3, 6, 3))

# Unload the TF backend to ensure sf.LocalEngine() will run __init__.py.
m.delitem(sys.modules, "strawberryfields.backends.tfbackend", raising=False)
# Set the TF version in case the existing version is valid.
m.setitem(sys.modules, "tensorflow", MagicMock(__version__="1.2.3"))

with pytest.raises(ImportError, match="version 2.x of TensorFlow is required"):
reload(sf.backends.tfbackend)
sf.LocalEngine("tf")

@pytest.mark.skipif(tf_available, reason="Test only works if TF not installed")
def test_tensorflow_not_installed(self, monkeypatch):
"""Test that an exception is raised if TensorFlow is not installed"""
with monkeypatch.context() as m:
# force Python check to pass
m.setattr("sys.version_info", (3, 6, 3))
# Force the Python check to pass.
m.setattr(sys, "version_info", (3, 6, 3))

with pytest.raises(ImportError, match="version 2.x of TensorFlow is required"):
reload(sf.backends.tfbackend)
sf.LocalEngine("tf")
7 changes: 3 additions & 4 deletions tests/integration/test_engine_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@
import pytest

import numpy as np
import tensorflow as tf

import strawberryfields as sf
from strawberryfields import ops
from strawberryfields.backends import BaseGaussian, BaseFock
from strawberryfields.backends import GaussianBackend, FockBackend
from strawberryfields.backends import BaseFock, FockBackend, GaussianBackend
from strawberryfields.backends.states import BaseState


try:
from strawberryfields.backends.tfbackend import TFBackend
except (ImportError, ModuleNotFoundError, ValueError) as e:
import tensorflow as tf
except (ImportError, ValueError):
eng_backend_params = [("gaussian", GaussianBackend), ("fock", FockBackend)]
else:
eng_backend_params = [
Expand Down
31 changes: 17 additions & 14 deletions tests/integration/test_ops_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
# limitations under the License.
r"""Integration tests for frontend operations applied to the backend"""
import pytest

import numpy as np
import tensorflow as tf
from thewalrus.random import random_symplectic
from scipy.stats import unitary_group

import strawberryfields as sf
from strawberryfields import ops

from strawberryfields.backends import BaseGaussian
from strawberryfields.backends.states import BaseFockState, BaseGaussianState

from thewalrus.quantum import is_valid_cov
from thewalrus.random import random_symplectic

from scipy.stats import unitary_group
try:
import tensorflow as tf
except:
backends = ["fock", "tf"]
else:
backends = ["fock"]

# make test deterministic
np.random.seed(42)
Expand Down Expand Up @@ -193,7 +195,7 @@ def test_passive_channel(self, M, setup_eng, tol):
class TestPreparationApplication:
"""Tests that involve state preparation application"""

@pytest.mark.backends("tf", "fock")
@pytest.mark.backends(*backends)
def test_ket_state_object(self, setup_eng, pure):
"""Test loading a ket from a prior state object"""
if not pure:
Expand All @@ -217,7 +219,7 @@ def test_ket_state_object(self, setup_eng, pure):
# verify it is the same state
assert state1 == state2

@pytest.mark.backends("tf", "fock")
@pytest.mark.backends(*backends)
def test_ket_gaussian_state_object(self, setup_eng):
"""Test exception if loading a ket from a Gaussian state object"""
eng = sf.Engine("gaussian")
Expand All @@ -231,7 +233,7 @@ def test_ket_gaussian_state_object(self, setup_eng):
with pytest.raises(ValueError, match="Gaussian states are not supported"):
ops.Ket(state) | q[0]

@pytest.mark.backends("tf", "fock")
@pytest.mark.backends(*backends)
def test_ket_mixed_state_object(self, setup_eng, pure):
"""Test exception if loading a ket from a prior mixed state object"""
if pure:
Expand All @@ -251,7 +253,7 @@ def test_ket_mixed_state_object(self, setup_eng, pure):
with pytest.raises(ValueError, match="Fock state is not pure"):
ops.Ket(state1) | q[0]

@pytest.mark.backends("tf", "fock")
@pytest.mark.backends(*backends)
def test_dm_state_object(self, setup_eng, tol):
"""Test loading a density matrix from a prior state object"""
eng, prog = setup_eng(1)
Expand All @@ -272,7 +274,7 @@ def test_dm_state_object(self, setup_eng, tol):
# verify it is the same state
assert np.allclose(state1.dm(), state2.dm(), atol=tol, rtol=0)

@pytest.mark.backends("tf", "fock")
@pytest.mark.backends(*backends)
def test_dm_gaussian_state_object(self, setup_eng):
"""Test exception if loading a ket from a Gaussian state object"""
eng = sf.Engine("gaussian")
Expand All @@ -287,7 +289,7 @@ def test_dm_gaussian_state_object(self, setup_eng):
ops.DensityMatrix(state) | q[0]


@pytest.mark.backends("fock", "tf")
@pytest.mark.backends(*backends)
class TestKetDensityMatrixIntegration:
"""Tests for the frontend Fock multi-mode state preparations"""

Expand Down Expand Up @@ -407,7 +409,8 @@ def test_dm_two_mode(self, setup_eng, hbar, cutoff, tol):
assert np.allclose(state1.dm(), state2.dm(), atol=tol, rtol=0)


@pytest.mark.backends("tf", "fock")
@pytest.mark.skipif("tf" not in backends, reason="Tests require TF")
@pytest.mark.backends(*backends)
class TestGaussianGateApplication:
def test_multimode_gaussian_gate(self, setup_backend, pure):
"""Test applying gaussian gate on multiple modes"""
Expand Down
7 changes: 3 additions & 4 deletions tests/integration/test_utils_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@

try:
import tensorflow as tf
except (ImportError, ModuleNotFoundError) as e:
import mock
except ImportError:
import unittest.mock as mock

tf = mock.MagicMock()
tf.Tensor = int
tf = mock.MagicMock(Tensor=int)

import strawberryfields as sf
import strawberryfields.ops as ops
Expand Down

0 comments on commit d4c11b9

Please sign in to comment.