Skip to content

Commit

Permalink
Add connection to device spec API (#429)
Browse files Browse the repository at this point in the history
* Implement a get_device method

* Return DeviceSpec fro get_device

* Split get_device method

* Add devicespec class

* Fix formatting

* write out docstrings

* add to documentation

* linting

* update changelog

* add PR link

* done

* blacking

* undo weird string concatenation

* Fix minor error

* Add tests

* Fix failing test

* adding tests

* add initialization test

* Add kwargs and return to docstring

* Update .github/CHANGELOG.md

Co-authored-by: antalszava <antalszava@gmail.com>

* suggested changes

* Update strawberryfields/api/devicespec.py

Co-authored-by: Theodor <theodor@xanadu.ai>

* suggested changes

Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
3 people committed Jul 8, 2020
1 parent f77e879 commit e980b02
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .github/CHANGELOG.md
Expand Up @@ -120,6 +120,11 @@
[(#378)](https://github.com/XanaduAI/strawberryfields/pull/378)
[(#381)](https://github.com/XanaduAI/strawberryfields/pull/381)

* Strawberry Fields can now access the Xanadu Cloud device specifications API.
The ``Connection`` class has a new method ``Connection.get_device``,
which returns a ``DeviceSpec`` class.
[(#429)](https://github.com/XanaduAI/strawberryfields/pull/429)

<h3>Breaking Changes</h3>

* Removes support for Python 3.5.
Expand Down
3 changes: 2 additions & 1 deletion strawberryfields/api/__init__.py
Expand Up @@ -21,7 +21,8 @@
"""

from .connection import Connection, RequestFailedError
from .devicespec import DeviceSpec
from .job import InvalidJobOperationError, Job, JobStatus
from .result import Result

__all__ = ["Connection", "Job", "Result"]
__all__ = ["Connection", "DeviceSpec", "Job", "Result"]
24 changes: 24 additions & 0 deletions strawberryfields/api/connection.py
Expand Up @@ -28,6 +28,7 @@

from .job import Job, JobStatus
from .result import Result
from .devicespec import DeviceSpec

# pylint: disable=bad-continuation,protected-access

Expand Down Expand Up @@ -135,6 +136,29 @@ def use_ssl(self) -> bool:
"""
return self._use_ssl

def get_device_spec(self, target: str) -> DeviceSpec:
"""Gets the device specifications for target.
Args:
target (str): the target device
Returns:
strawberryfields.api.DeviceSpec: the created device specification
"""
device_dict = self._get_device_dict(target)
return DeviceSpec(target=target, spec=device_dict, connection=self)

def _get_device_dict(self, target: str) -> dict:
"""Returns the device specifications as a dictionary"""
path = f"/devices/{target}/specifications"
response = requests.get(self._url(path), headers=self._headers)

if response.status_code == 200:
return response.json()
raise RequestFailedError(
"Failed to get device specifications: {}".format(self._format_error_message(response))
)

def create_job(self, target: str, program: Program, run_options: dict = None) -> Job:
"""Creates a job with the given circuit.
Expand Down
141 changes: 141 additions & 0 deletions strawberryfields/api/devicespec.py
@@ -0,0 +1,141 @@
# Copyright 2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains a class that represents the specifications of
a device available via the API.
"""
from collections.abc import Sequence

import strawberryfields as sf
import blackbird

from strawberryfields.circuitspecs import Ranges


class DeviceSpec:
"""The specifications for a specific hardware device.
Args:
target (str): name of the target hardware device
spec (dict): dictionary representing the raw device specification.
This dictionary should contain the following key-value pairs:
- layout (str): string containing the Blackbird circuit layout
- modes (int): number of modes supported by the target
- compiler (list): list of supported compilers
- gate_parameters (dict): parameters for the circuit gates
connection (strawberryfields.api.Connection): connection over which the
job is managed
"""

def __init__(self, target, spec, connection):
self._target = target
self._connection = connection
self._spec = spec

@property
def target(self):
"""str: The name of the target hardware device."""
return self._target

@property
def layout(self):
"""str: Returns a string containing the Blackbird circuit layout."""
return self._spec["layout"]

@property
def modes(self):
"""int: Number of modes supported by the device."""
return self._spec["modes"]

@property
def compiler(self):
"""list[str]: A list of strings corresponding to Strawberry Fields compilers supported
by the hardware device."""
return self._spec["compiler"]

@property
def gate_parameters(self):
"""dict[str, strawberryfields.circuitspecs.Ranges]: A dictionary of gate parameters
and allowed ranges.
The parameter names correspond to those present in the Blackbird circuit layout.
**Example**
>>> spec.gate_parameters
{'squeezing_amplitude_0': x=0, x=1, 'phase_0': x=0, 0≤x≤6.283185307179586}
"""
gate_parameters = dict()

for gate_name, param_ranges in self._spec["gate_parameters"].items():
# convert gate parameter allowed ranges to Range objects
range_list = [[i] if not isinstance(i, Sequence) else i for i in param_ranges]
gate_parameters[gate_name] = Ranges(*range_list)

return gate_parameters

def create_program(self, **parameters):
"""Create a Strawberry Fields program matching the low-level layout of the
device.
Gate arguments should be passed as keyword arguments, with names
correspond to those present in the Blackbird circuit layout. Parameters not
present will be assumed to have a value of 0.
**Example**
Device specifications can be retrieved from the API by using the
:class:`~.Connection` class:
>>> spec.create_program(squeezing_amplitude_0=0.43)
<strawberryfields.program.Program at 0x7fd37e27ff50>
Keyword Args:
Supported parameter values for the specific device
Returns:
strawberryfields.program.Program: program compiled to the device
"""
bb = blackbird.loads(self.layout)

# check that all provided parameters are valid
for p, v in parameters.items():
if p in self.gate_parameters and v not in self.gate_parameters[p]:
# parameter is present in the device specifications
# but the user has provided a disallowed value
raise ValueError(
f"{p} has invalid value {v}. Only {self.gate_parameters[p]} allowed."
)

if p not in self.gate_parameters:
raise ValueError(f"Parameter {p} not a valid parameter for this device")

# determine parameter value if not provided
extra_params = set(self.gate_parameters) - set(parameters)

for p in extra_params:
# Set parameter value as the first allowed
# value in the gate parameters dictionary.
parameters[p] = self.gate_parameters[p].ranges[0].x

# evaluate the blackbird template
bb = bb(**parameters)
prog = sf.io.to_program(bb)
return prog

def refresh(self):
"""Refreshes the device specifications"""
self._spec = self._connection._get_device_dict(self.target)
4 changes: 2 additions & 2 deletions strawberryfields/circuitspecs/__init__.py
Expand Up @@ -36,7 +36,7 @@
corresponding CircuitSpecs instance with the same short name, used to validate Programs to be
executed on that backend.
"""
from .circuit_specs import CircuitSpecs
from .circuit_specs import CircuitSpecs, Ranges
from .X8 import X8Specs, X8_01
from .X12 import X12Specs, X12_01, X12_02
from .xcov import Xcov
Expand Down Expand Up @@ -66,4 +66,4 @@
"""dict[str, ~strawberryfields.circuitspecs.CircuitSpecs]: Map from circuit
family short name to the corresponding class."""

__all__ = ["circuit_db", "CircuitSpecs"] + [i.__name__ for i in specs]
__all__ = ["circuit_db", "CircuitSpecs", "Ranges"] + [i.__name__ for i in specs]
9 changes: 9 additions & 0 deletions strawberryfields/circuitspecs/circuit_specs.py
Expand Up @@ -345,6 +345,9 @@ def __repr__(self):

return "{}≤{}≤{}".format(self.x, self.name, self.y)

def __eq__(self, other):
return self.x == other.x and self.y == other.y


class Ranges:
"""Lightweight class for representing a set of ranges of floats.
Expand Down Expand Up @@ -380,3 +383,9 @@ def __contains__(self, item):

def __repr__(self):
return ", ".join([str(i) for i in self.ranges])

def __eq__(self, other):
if len(self.ranges) != len(other.ranges):
return False

return all(i == j for i, j in zip(self.ranges, other.ranges))
36 changes: 36 additions & 0 deletions tests/api/test_connection.py
Expand Up @@ -24,6 +24,8 @@

from strawberryfields.api import Connection, JobStatus, RequestFailedError
from strawberryfields import configuration as conf
from strawberryfields.circuitspecs import Ranges

from .conftest import mock_return

# pylint: disable=no-self-use,unused-argument
Expand Down Expand Up @@ -83,6 +85,40 @@ def test_init(self):
# pylint: disable=protected-access
assert connection._url("/abc") == "https://host:123/abc"

def test_get_device_spec(self, prog, connection, monkeypatch):
"""Tests a successful device spec request."""
target = "abc"
layout = ""
modes = 42
compiler = []
gate_parameters = {"param": Ranges([0, 1], variable_name="param")}

monkeypatch.setattr(
requests,
"get",
mock_return(MockResponse(
200,
{"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}}
)),
)

device_spec = connection.get_device_spec(target)

assert device_spec.target == target
assert device_spec.layout == layout
assert device_spec.modes == modes
assert device_spec.compiler == compiler

spec_params = device_spec.gate_parameters
assert gate_parameters == spec_params

def test_get_device_spec_error(self, connection, monkeypatch):
"""Tests a failed device spec request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get device specifications"):
connection.get_device_spec("123")

def test_create_job(self, prog, connection, monkeypatch):
"""Tests a successful job creation flow."""
id_, status = "123", JobStatus.QUEUED
Expand Down

0 comments on commit e980b02

Please sign in to comment.