Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating qnn.KerasLayer to work in tape mode #869

Merged
114 changes: 83 additions & 31 deletions pennylane/qnn/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
from collections.abc import Iterable
from typing import Optional
import pennylane as qml

try:
import tensorflow as tf
Expand Down Expand Up @@ -198,7 +199,30 @@ def __init__(
"https://www.tensorflow.org/install for detailed instructions."
)

self.sig = qnode.func.sig
self.weight_shapes = {
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ())
for weight, size in weight_shapes.items()
}

if qml.tape_mode_active():
self._signature_validation_tape_mode(qnode, weight_shapes)
self.qnode = qnode
self.qnode.to_tf(dtype=tf.float32)
else:
self._signature_validation(qnode, weight_shapes)
self.qnode = to_tf(qnode, dtype=tf.keras.backend.floatx())

# Allows output_dim to be specified as an int, e.g., 5, or as a length-1 tuple, e.g., (5,)
self.output_dim = output_dim[0] if isinstance(output_dim, Iterable) else output_dim

self.weight_specs = weight_specs if weight_specs is not None else {}

self.qnode_weights = {}

super().__init__(dynamic=True, **kwargs)

def _signature_validation_tape_mode(self, qnode, weight_shapes):
self.sig = inspect.signature(qnode.func).parameters

if self.input_arg not in self.sig:
raise TypeError(
Expand All @@ -213,39 +237,51 @@ def __init__(
"weight_shapes".format(self.input_arg)
)

if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()):
raise ValueError("Must specify a shape for every non-input parameter in the QNode")
param_kinds = [p.kind for p in self.sig.values()]

if inspect.Parameter.VAR_POSITIONAL in param_kinds:
raise TypeError("Cannot have a variable number of positional arguments")

if inspect.Parameter.VAR_KEYWORD not in param_kinds:
if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()):
raise ValueError("Must specify a shape for every non-input parameter in the QNode")

def _signature_validation(self, qnode, weight_shapes):
self.sig = qnode.func.sig

if self.input_arg not in self.sig:
raise TypeError(
"QNode must include an argument with name {} for inputting data".format(
self.input_arg
)
)

if self.input_arg in set(weight_shapes.keys()):
raise ValueError(
"{} argument should not have its dimension specified in "
"weight_shapes".format(self.input_arg)
)

if qnode.func.var_pos:
raise TypeError("Cannot have a variable number of positional arguments")

if qnode.func.var_keyword:
raise TypeError("Cannot have a variable number of keyword arguments")

self.qnode = to_tf(qnode, dtype=tf.keras.backend.floatx())
self.weight_shapes = {
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ())
for weight, size in weight_shapes.items()
}

# Allows output_dim to be specified as an int, e.g., 5, or as a length-1 tuple, e.g., (5,)
self.output_dim = output_dim[0] if isinstance(output_dim, Iterable) else output_dim
if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()):
raise ValueError("Must specify a shape for every non-input parameter in the QNode")

defaults = {
name for name, sig in self.sig.items() if sig.par.default != inspect.Parameter.empty
}

self.input_is_default = self.input_arg in defaults

if defaults - {self.input_arg} != set():
raise TypeError(
"Only the argument {} is permitted to have a default".format(self.input_arg)
)

self.weight_specs = weight_specs if weight_specs is not None else {}

self.qnode_weights = {}

super().__init__(dynamic=True, **kwargs)

def build(self, input_shape):
"""Initializes the QNode weights.

Expand All @@ -270,24 +306,40 @@ def call(self, inputs):
outputs = []
for x in inputs: # iterate over batch

# The QNode can require some passed arguments to be positional and others to be keyword.
# The following loops through input arguments in order and uses functools.partial to
# bind the argument to the QNode.
qnode = self.qnode

for arg in self.sig:
if arg is not self.input_arg: # Non-input arguments must always be positional
w = self.qnode_weights[arg]
qnode = functools.partial(qnode, w)
else:
if self.input_is_default: # The input argument can be positional or keyword
qnode = functools.partial(qnode, **{self.input_arg: x})
if qml.tape_mode_active():
res = self._evaluate_qnode_tape_mode(x)
outputs.append(res)
else:
# The QNode can require some passed arguments to be positional and others to be
# keyword. The following loops through input arguments in order and uses
# functools.partial to bind the argument to the QNode.
qnode = self.qnode

for arg in self.sig:
if arg is not self.input_arg: # Non-input arguments must always be positional
w = self.qnode_weights[arg]
qnode = functools.partial(qnode, w)
else:
qnode = functools.partial(qnode, x)
outputs.append(qnode())
if self.input_is_default: # The input argument can be positional or keyword
qnode = functools.partial(qnode, **{self.input_arg: x})
else:
qnode = functools.partial(qnode, x)
outputs.append(qnode())
Comment on lines +317 to +331
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all just indenting what we had before.


return tf.stack(outputs)

def _evaluate_qnode_tape_mode(self, x):
"""Evaluates a tape-mode QNode for a single input datapoint.

Args:
x (tensor): the datapoint

Returns:
tensor: output datapoint
"""
kwargs = {**{self.input_arg: x}, **{k: 1.0 * w for k, w in self.qnode_weights.items()}}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the 1.0 here, which converts a Variable into a Tensor, allowing us to do things like qml.Rot(*weight).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sneaky

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devil's advocate: I wonder if there is any overhead associated in the Variable to Tensor conversion --- attempting to iterate over Variable objects is what causes the massive slowdown in qml.broadcast. Is it worth doing the conversion just to support *tf.Variable(...)?

Since this is a low-level TensorFlow implementation issue, and not a PL bug/issue, we could simply leave the weights as Variable objects, and expect users to know that the pattern qml.Rot(*tf.Variable(..)) is not supported by TF.

return self.qnode(**kwargs)

def compute_output_shape(self, input_shape):
"""Computes the output shape after passing data of shape ``input_shape`` through the
QNode.
Expand Down
22 changes: 14 additions & 8 deletions pennylane/tape/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
self.func = func
self.device = device
self.qtape = None
self.qfunc_output = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything here is a direct copy of #865, so please defer to that for code review.


self._tape, self.interface, self.diff_method = self.get_tape(device, interface, diff_method)
self.diff_options = diff_options or {}
Expand Down Expand Up @@ -375,10 +376,12 @@ def construct(self, args, kwargs):
self.qtape = self._tape(caching=self._caching)

with self.qtape:
measurement_processes = self.func(*args, **kwargs)
self.qfunc_output = self.func(*args, **kwargs)

if not isinstance(measurement_processes, Sequence):
measurement_processes = (measurement_processes,)
if not isinstance(self.qfunc_output, Sequence):
measurement_processes = (self.qfunc_output,)
else:
measurement_processes = self.qfunc_output

if not all(isinstance(m, qml.tape.MeasurementProcess) for m in measurement_processes):
raise qml.QuantumFunctionError(
Expand Down Expand Up @@ -444,8 +447,14 @@ def __call__(self, *args, **kwargs):
# execute the tape
res = self.qtape.execute(device=self.device)

# HOTFIX: to maintain compatibility with core, we squeeze
# all outputs.
if self._caching:
self._cache_execute = self.qtape._cache_execute

if isinstance(self.qfunc_output, Sequence):
return res

# HOTFIX: Output is a single measurement function. To maintain compatibility
# with core, we squeeze all outputs.

# Get the namespace associated with the return type
res_type_namespace = res.__class__.__module__.split(".")[0]
Expand All @@ -455,9 +464,6 @@ def __call__(self, *args, **kwargs):
# 'squeeze' does not exist in the top-level of the namespace
return anp.squeeze(res)

if self._caching:
self._cache_execute = self.qtape._cache_execute

return __import__(res_type_namespace).squeeze(res)

def draw(self, charset="unicode"):
Expand Down
Loading