Skip to content

Commit

Permalink
format tests and reword PR template (#627)
Browse files Browse the repository at this point in the history
* format [tests] - black

* impl [.github/workflows] - black tests on CI

* impl [makefile] - black tests on CI

* edit [.github/ PR_TEMPLATE] - code formatting

* edit [.github/PR_TEMPLATE] - fix wording
  • Loading branch information
sduquemesa committed Sep 23, 2021
1 parent a341f81 commit f7b9e70
Show file tree
Hide file tree
Showing 72 changed files with 1,197 additions and 1,016 deletions.
6 changes: 3 additions & 3 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Please complete the following checklist when submitting a PR:

- [ ] Ensure that the test suite passes, by running `make test`.

- [ ] Ensure that code is properly formatted, by running `make format` or `black -l 100
strawberryfields`. You will need to have the Black code format installed: `pip install
black`.
- [ ] Ensure that code and tests are properly formatted, by running `make format` or `black -l 100
<filename>` on any relevant files. You will need to have the Black code format installed:
`pip install black`.

- [ ] Add a new entry to the `.github/CHANGELOG.md` file, summarizing the
change, and including a link back to the PR.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
- uses: actions/checkout@v2

- name: Run Black
run: black -l 100 strawberryfields/ --check
run: black -l 100 strawberryfields/ tests/ --check
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ clean-docs:

.PHONY : format
format:
black -l 100 strawberryfields
black -l 100 strawberryfields tests

test: test-frontend test-gaussian test-fock test-tf batch-test-tf test-apps test-api

Expand Down
34 changes: 24 additions & 10 deletions tests/api/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
test_host = "SomeHost"
test_token = "SomeToken"


class MockResponse:
"""A mock response with a JSON or binary body."""

Expand Down Expand Up @@ -98,10 +99,17 @@ def test_get_device_spec(self, prog, connection, monkeypatch):
monkeypatch.setattr(
requests,
"request",
mock_return(MockResponse(
200,
{"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}}
)),
mock_return(
MockResponse(
200,
{
"layout": "",
"modes": 42,
"compiler": [],
"gate_parameters": {"param": [[0, 1]]},
},
)
),
)

device_spec = connection.get_device_spec(target)
Expand All @@ -126,7 +134,9 @@ def test_create_job(self, prog, connection, monkeypatch):
id_, status = "123", JobStatus.QUEUED

monkeypatch.setattr(
requests, "request", mock_return(MockResponse(201, {"id": id_, "status": status})),
requests,
"request",
mock_return(MockResponse(201, {"id": id_, "status": status})),
)

job = connection.create_job("X8_01", prog, {"shots": 1})
Expand All @@ -153,7 +163,9 @@ def test_get_all_jobs(self, connection, monkeypatch):
for i in range(1, 10)
]
monkeypatch.setattr(
requests, "request", mock_return(MockResponse(200, {"data": jobs})),
requests,
"request",
mock_return(MockResponse(200, {"data": jobs})),
)

jobs = connection.get_all_jobs(after=datetime(2020, 1, 5))
Expand Down Expand Up @@ -233,7 +245,9 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch):
np.save(buf, result_samples)
buf.seek(0)
monkeypatch.setattr(
requests, "request", mock_return(MockResponse(200, binary_body=buf.getvalue())),
requests,
"request",
mock_return(MockResponse(200, binary_body=buf.getvalue())),
)

result = connection.get_job_result("123")
Expand Down Expand Up @@ -285,7 +299,7 @@ def test_refresh_access_token(self, mocker, monkeypatch):
"""Test that the access token is created by passing the expected headers."""
path = "/auth/realms/platform/protocol/openid-connect/token"

data={
data = {
"grant_type": "refresh_token",
"refresh_token": test_token,
"client_id": "public",
Expand All @@ -297,8 +311,8 @@ def test_refresh_access_token(self, mocker, monkeypatch):
conn = Connection(token=test_token, host=test_host)
conn._refresh_access_token()
expected_headers = {
'Accept-Version': conn.api_version,
'User-Agent': conn.user_agent,
"Accept-Version": conn.api_version,
"User-Agent": conn.user_agent,
}
expected_url = f"https://{test_host}:443{path}"
spy.assert_called_once_with(expected_url, headers=expected_headers, data=data)
Expand Down
12 changes: 3 additions & 9 deletions tests/api/test_devicespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,14 @@

device_dict_tdm = {
"layout": mock_layout_tdm,
"modes": {
"concurrent": 2,
"spatial": 1,
"temporal": {
"max": 100
}
},
"modes": {"concurrent": 2, "spatial": 1, "temporal": {"max": 100}},
"compiler": ["TD2"],
"gate_parameters": {
"p0": [0.56],
"p1": [0, [0, 6.28]],
"p2": [0, [0, 3.14], 3.14],
"p3": [0, [0, 6.28]]
}
"p3": [0, [0, 6.28]],
},
}


Expand Down
3 changes: 2 additions & 1 deletion tests/api/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def test_incomplete_job_raises_on_result_access(self, connection):
job = Job("abc", status=JobStatus.QUEUED, connection=connection)

with pytest.raises(
AttributeError, match="The result is undefined for jobs that are not completed",
AttributeError,
match="The result is undefined for jobs that are not completed",
):
job.result

Expand Down
70 changes: 49 additions & 21 deletions tests/api/test_remote_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def test_compilation(self, prog, monkeypatch, caplog):
engine = RemoteEngine("X8")
_, target, res_prog, _ = engine.run_async(prog, shots=10)

assert caplog.records[-1].message == "Compiling program for device X8_01 using compiler fock."
assert (
caplog.records[-1].message == "Compiling program for device X8_01 using compiler fock."
)
assert target == RemoteEngine.DEFAULT_TARGETS["X8"]

# check program is compiled to match the chip template
Expand Down Expand Up @@ -207,10 +209,14 @@ def test_default_compiler(self, prog, monkeypatch, caplog):
_, target, res_prog, _ = engine.run_async(prog, shots=10)

assert engine.device_spec.default_compiler == "Xunitary"
assert caplog.records[-1].message == "Compiling program for device X8_01 using compiler Xunitary."
assert (
caplog.records[-1].message
== "Compiling program for device X8_01 using compiler Xunitary."
)

class MockProgram:
"""A mock program for testing"""

def __init__(self):
self.run_options = {}

Expand All @@ -221,7 +227,9 @@ def test_compile_device_invalid_device_error(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: program
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

Expand All @@ -230,9 +238,7 @@ def test_compile_device_invalid_device_error(self, prog, monkeypatch, caplog):
prog._compile_info = (X8_spec, "dummy_compiler")

engine = sf.RemoteEngine("X8")
with pytest.raises(
ValueError, match="Cannot use program compiled"
):
with pytest.raises(ValueError, match="Cannot use program compiled"):
program = engine.run_async(prog, shots=10)

def test_compile(self, prog, monkeypatch, caplog):
Expand All @@ -242,7 +248,9 @@ def test_compile(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: program
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

Expand All @@ -253,7 +261,10 @@ def test_compile(self, prog, monkeypatch, caplog):
program = engine.run_async(prog, shots=10)

assert isinstance(program, self.MockProgram)
assert caplog.records[-1].message == "Compiling program for device X8_01 using compiler Xunitary."
assert (
caplog.records[-1].message
== "Compiling program for device X8_01 using compiler Xunitary."
)

def test_recompilation_run_async(self, prog, monkeypatch, caplog):
"""Test that recompilation happens when the recompile keyword argument
Expand All @@ -264,7 +275,9 @@ def test_recompilation_run_async(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = compiler

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: program
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)

compile_options = {"compiler": compiler}
Expand All @@ -279,9 +292,11 @@ def test_recompilation_run_async(self, prog, monkeypatch, caplog):
program = engine.run_async(prog, shots=10, compile_options=compile_options, recompile=True)

# No recompilation, original Program
assert caplog.records[-1].message == (f"Recompiling program for device "
f"{device.target} using the specified compiler options: "
f"{compile_options}.")
assert caplog.records[-1].message == (
f"Recompiling program for device "
f"{device.target} using the specified compiler options: "
f"{compile_options}."
)

def test_recompilation_precompiled(self, prog, monkeypatch, caplog):
"""Test that recompilation happens when:
Expand All @@ -295,7 +310,9 @@ def test_recompilation_precompiled(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = []

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: program
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)
monkeypatch.setattr(Program, "compile", lambda *args, **kwargs: self.MockProgram())

Expand All @@ -313,7 +330,10 @@ def test_recompilation_precompiled(self, prog, monkeypatch, caplog):
# Setting recompile in keyword arguments
program = engine.run_async(prog, shots=10, compile_options=compile_options, recompile=True)
assert isinstance(program, self.MockProgram)
assert caplog.records[-1].message == "Recompiling program for device X8_01 using compiler Xunitary."
assert (
caplog.records[-1].message
== "Recompiling program for device X8_01 using compiler Xunitary."
)

def test_recompilation_run(self, prog, monkeypatch, caplog):
"""Test that recompilation happens when the recompile keyword argument
Expand All @@ -324,7 +344,9 @@ def test_recompilation_run(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = compiler

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: MockJob(program))
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: MockJob(program)
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)

class MockJob:
Expand Down Expand Up @@ -353,9 +375,11 @@ def refresh(self):

# No recompilation, original Program
assert caplog.records[-1].message == ("The remote job 0 has been completed.")
assert caplog.records[-2].message == (f"Recompiling program for device "
f"{device.target} using the specified compiler options: "
f"{compile_options}.")
assert caplog.records[-2].message == (
f"Recompiling program for device "
f"{device.target} using the specified compiler options: "
f"{compile_options}."
)

def test_validation(self, prog, monkeypatch, caplog):
"""Test that validation happens (no recompilation) when the target
Expand All @@ -366,7 +390,9 @@ def test_validation(self, prog, monkeypatch, caplog):
test_device_dict = mock_device_dict.copy()
test_device_dict["compiler"] = compiler

monkeypatch.setattr(Connection, "create_job", lambda self, target, program, run_options: program)
monkeypatch.setattr(
Connection, "create_job", lambda self, target, program, run_options: program
)
monkeypatch.setattr(Connection, "_get_device_dict", lambda *args: test_device_dict)

engine = sf.RemoteEngine("X8_01")
Expand All @@ -379,5 +405,7 @@ def test_validation(self, prog, monkeypatch, caplog):
program = engine.run_async(prog, shots=10)

# No recompilation, original Program
assert caplog.records[-1].message == (f"Program previously compiled for {device.target} using {prog.compile_info[1]}. "
f"Validating program against the Xstrict compiler.")
assert caplog.records[-1].message == (
f"Program previously compiled for {device.target} using {prog.compile_info[1]}. "
f"Validating program against the Xstrict compiler."
)
4 changes: 1 addition & 3 deletions tests/api/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class TestResult:
"""Tests for the ``Result`` class."""

def test_stateless_result_raises_on_state_access(self):
"""Tests that `result.state` raises an error for a stateless result.
"""
"""Tests that `result.state` raises an error for a stateless result."""
result = Result(np.array([[1, 2], [3, 4]]), is_stateful=False)

with pytest.raises(
Expand Down Expand Up @@ -79,4 +78,3 @@ def test_unkown_shape_print(self, stateful, capfd):
assert "shots" not in out
assert "timebins" not in out
assert f"contains state={stateful}" in out

10 changes: 5 additions & 5 deletions tests/apps/test_clique.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def test_is_output_clique(self, dim):
assert resized == list(range(dim))

def test_input_clique_then_output_clique(self, dim):
"""Test that if the input is already a clique, then the output is the same clique. """
"""Test that if the input is already a clique, then the output is the same clique."""
graph = nx.lollipop_graph(dim, dim)
subgraph = list(range(dim)) # this is a clique, the "candy" of the lollipop

Expand Down Expand Up @@ -484,7 +484,7 @@ def test_bad_node_select(self, dim):

@pytest.mark.parametrize("dim", range(2, 10))
class TestIsClique:
"""Tests for the function `strawberryfields.apps.clique.is_clique` """
"""Tests for the function `strawberryfields.apps.clique.is_clique`"""

def test_no_false_negatives(self, dim):
"""Tests that cliques are labelled as such"""
Expand Down Expand Up @@ -513,7 +513,7 @@ def test_correct_c_0(self, dim):
assert res | s == set(range(dim))

def test_c_0_comp_graph(self, dim):
""" Tests that the set :math:`c_0` for a node in a clique consists of all remaining nodes"""
"""Tests that the set :math:`c_0` for a node in a clique consists of all remaining nodes"""
A = nx.complete_graph(dim)
S = [dim - 1]
K = clique.c_0(S, A)
Expand All @@ -538,7 +538,7 @@ class TestC1:

def test_c_1_comp_graph(self, dim):
"""Tests that :math:`c_1` set is correctly generated for an almost-complete graph, where
edge (0, 1) is removed """
edge (0, 1) is removed"""
A = nx.complete_graph(dim)
A.remove_edge(0, 1)
S = list(range(1, dim))
Expand All @@ -547,7 +547,7 @@ def test_c_1_comp_graph(self, dim):
assert c1 == [(1, 0)]

def test_c_1_swap_to_clique(self, dim):
"""Tests that :math:`c_1` set gives a valid clique after swapping """
"""Tests that :math:`c_1` set gives a valid clique after swapping"""
A = nx.complete_graph(dim)
A.remove_edge(0, 1)
S = list(range(1, dim))
Expand Down
6 changes: 3 additions & 3 deletions tests/apps/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_invalid_adjacency(self, dim):

def test_invalid_n_samples(self, adj):
"""Test if function raises a ``ValueError`` when a number of samples less than one is
requested """
requested"""
with pytest.raises(ValueError, match="Number of samples must be at least one"):
sample.sample(A=adj, n_mean=1.0, n_samples=0)

Expand Down Expand Up @@ -127,7 +127,7 @@ class TestSampleIntegration:

def test_pnr_integration(self, adj, integration_sample_number):
"""Integration test to check if function returns samples of correct form, i.e., correct
number of samples, correct number of modes, all non-negative integers """
number of samples, correct number of modes, all non-negative integers"""
samples = np.array(
sample.sample(A=adj, n_mean=1.0, n_samples=integration_sample_number, threshold=False)
)
Expand All @@ -141,7 +141,7 @@ def test_pnr_integration(self, adj, integration_sample_number):

def test_threshold_integration(self, adj, integration_sample_number):
"""Integration test to check if function returns samples of correct form, i.e., correct
number of samples, correct number of modes, all integers of zeros and ones """
number of samples, correct number of modes, all integers of zeros and ones"""
samples = np.array(
sample.sample(A=adj, n_mean=1.0, n_samples=integration_sample_number, threshold=True)
)
Expand Down

0 comments on commit f7b9e70

Please sign in to comment.