-
Notifications
You must be signed in to change notification settings - Fork 575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updating qnn.KerasLayer to work in tape mode #869
Changes from 5 commits
3866532
6a1ac95
e0ffcf6
1bb7f7b
2e21a86
a3c860f
e520b58
06df1b0
eedd3b9
92db470
b4bdaa0
ebb90e1
8357265
9667aa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import inspect | ||
from collections.abc import Iterable | ||
from typing import Optional | ||
import pennylane as qml | ||
|
||
try: | ||
import tensorflow as tf | ||
|
@@ -198,7 +199,30 @@ def __init__( | |
"https://www.tensorflow.org/install for detailed instructions." | ||
) | ||
|
||
self.sig = qnode.func.sig | ||
self.weight_shapes = { | ||
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ()) | ||
for weight, size in weight_shapes.items() | ||
} | ||
|
||
if qml.tape_mode_active(): | ||
self._signature_validation_tape_mode(qnode, weight_shapes) | ||
self.qnode = qnode | ||
self.qnode.to_tf(dtype=tf.float32) | ||
else: | ||
self._signature_validation(qnode, weight_shapes) | ||
self.qnode = to_tf(qnode, dtype=tf.keras.backend.floatx()) | ||
|
||
# Allows output_dim to be specified as an int, e.g., 5, or as a length-1 tuple, e.g., (5,) | ||
self.output_dim = output_dim[0] if isinstance(output_dim, Iterable) else output_dim | ||
|
||
self.weight_specs = weight_specs if weight_specs is not None else {} | ||
|
||
self.qnode_weights = {} | ||
|
||
super().__init__(dynamic=True, **kwargs) | ||
|
||
def _signature_validation_tape_mode(self, qnode, weight_shapes): | ||
self.sig = inspect.signature(qnode.func).parameters | ||
|
||
if self.input_arg not in self.sig: | ||
raise TypeError( | ||
|
@@ -213,39 +237,51 @@ def __init__( | |
"weight_shapes".format(self.input_arg) | ||
) | ||
|
||
if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()): | ||
raise ValueError("Must specify a shape for every non-input parameter in the QNode") | ||
param_kinds = [p.kind for p in self.sig.values()] | ||
|
||
if inspect.Parameter.VAR_POSITIONAL in param_kinds: | ||
raise TypeError("Cannot have a variable number of positional arguments") | ||
|
||
if inspect.Parameter.VAR_KEYWORD not in param_kinds: | ||
if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()): | ||
raise ValueError("Must specify a shape for every non-input parameter in the QNode") | ||
|
||
def _signature_validation(self, qnode, weight_shapes): | ||
self.sig = qnode.func.sig | ||
|
||
if self.input_arg not in self.sig: | ||
raise TypeError( | ||
"QNode must include an argument with name {} for inputting data".format( | ||
self.input_arg | ||
) | ||
) | ||
|
||
if self.input_arg in set(weight_shapes.keys()): | ||
raise ValueError( | ||
"{} argument should not have its dimension specified in " | ||
"weight_shapes".format(self.input_arg) | ||
) | ||
|
||
if qnode.func.var_pos: | ||
raise TypeError("Cannot have a variable number of positional arguments") | ||
|
||
if qnode.func.var_keyword: | ||
raise TypeError("Cannot have a variable number of keyword arguments") | ||
|
||
self.qnode = to_tf(qnode, dtype=tf.keras.backend.floatx()) | ||
self.weight_shapes = { | ||
weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ()) | ||
for weight, size in weight_shapes.items() | ||
} | ||
|
||
# Allows output_dim to be specified as an int, e.g., 5, or as a length-1 tuple, e.g., (5,) | ||
self.output_dim = output_dim[0] if isinstance(output_dim, Iterable) else output_dim | ||
if set(weight_shapes.keys()) | {self.input_arg} != set(self.sig.keys()): | ||
raise ValueError("Must specify a shape for every non-input parameter in the QNode") | ||
|
||
defaults = { | ||
name for name, sig in self.sig.items() if sig.par.default != inspect.Parameter.empty | ||
} | ||
|
||
self.input_is_default = self.input_arg in defaults | ||
|
||
if defaults - {self.input_arg} != set(): | ||
raise TypeError( | ||
"Only the argument {} is permitted to have a default".format(self.input_arg) | ||
) | ||
|
||
self.weight_specs = weight_specs if weight_specs is not None else {} | ||
|
||
self.qnode_weights = {} | ||
|
||
super().__init__(dynamic=True, **kwargs) | ||
|
||
def build(self, input_shape): | ||
"""Initializes the QNode weights. | ||
|
||
|
@@ -270,24 +306,40 @@ def call(self, inputs): | |
outputs = [] | ||
for x in inputs: # iterate over batch | ||
|
||
# The QNode can require some passed arguments to be positional and others to be keyword. | ||
# The following loops through input arguments in order and uses functools.partial to | ||
# bind the argument to the QNode. | ||
qnode = self.qnode | ||
|
||
for arg in self.sig: | ||
if arg is not self.input_arg: # Non-input arguments must always be positional | ||
w = self.qnode_weights[arg] | ||
qnode = functools.partial(qnode, w) | ||
else: | ||
if self.input_is_default: # The input argument can be positional or keyword | ||
qnode = functools.partial(qnode, **{self.input_arg: x}) | ||
if qml.tape_mode_active(): | ||
res = self._evaluate_qnode_tape_mode(x) | ||
outputs.append(res) | ||
else: | ||
# The QNode can require some passed arguments to be positional and others to be | ||
# keyword. The following loops through input arguments in order and uses | ||
# functools.partial to bind the argument to the QNode. | ||
qnode = self.qnode | ||
|
||
for arg in self.sig: | ||
if arg is not self.input_arg: # Non-input arguments must always be positional | ||
w = self.qnode_weights[arg] | ||
qnode = functools.partial(qnode, w) | ||
else: | ||
qnode = functools.partial(qnode, x) | ||
outputs.append(qnode()) | ||
if self.input_is_default: # The input argument can be positional or keyword | ||
qnode = functools.partial(qnode, **{self.input_arg: x}) | ||
else: | ||
qnode = functools.partial(qnode, x) | ||
outputs.append(qnode()) | ||
|
||
return tf.stack(outputs) | ||
|
||
def _evaluate_qnode_tape_mode(self, x): | ||
"""Evaluates a tape-mode QNode for a single input datapoint. | ||
|
||
Args: | ||
x (tensor): the datapoint | ||
|
||
Returns: | ||
tensor: output datapoint | ||
""" | ||
kwargs = {**{self.input_arg: x}, **{k: 1.0 * w for k, w in self.qnode_weights.items()}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note the 1.0 here, which converts a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sneaky There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Devil's advocate: I wonder if there is any overhead associated in the Since this is a low-level TensorFlow implementation issue, and not a PL bug/issue, we could simply leave the weights as |
||
return self.qnode(**kwargs) | ||
|
||
def compute_output_shape(self, input_shape): | ||
"""Computes the output shape after passing data of shape ``input_shape`` through the | ||
QNode. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -150,6 +150,7 @@ def __init__( | |
self.func = func | ||
self.device = device | ||
self.qtape = None | ||
self.qfunc_output = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything here is a direct copy of #865, so please defer to that for code review. |
||
|
||
self._tape, self.interface, self.diff_method = self.get_tape(device, interface, diff_method) | ||
self.diff_options = diff_options or {} | ||
|
@@ -375,10 +376,12 @@ def construct(self, args, kwargs): | |
self.qtape = self._tape(caching=self._caching) | ||
|
||
with self.qtape: | ||
measurement_processes = self.func(*args, **kwargs) | ||
self.qfunc_output = self.func(*args, **kwargs) | ||
|
||
if not isinstance(measurement_processes, Sequence): | ||
measurement_processes = (measurement_processes,) | ||
if not isinstance(self.qfunc_output, Sequence): | ||
measurement_processes = (self.qfunc_output,) | ||
else: | ||
measurement_processes = self.qfunc_output | ||
|
||
if not all(isinstance(m, qml.tape.MeasurementProcess) for m in measurement_processes): | ||
raise qml.QuantumFunctionError( | ||
|
@@ -444,8 +447,14 @@ def __call__(self, *args, **kwargs): | |
# execute the tape | ||
res = self.qtape.execute(device=self.device) | ||
|
||
# HOTFIX: to maintain compatibility with core, we squeeze | ||
# all outputs. | ||
if self._caching: | ||
self._cache_execute = self.qtape._cache_execute | ||
|
||
if isinstance(self.qfunc_output, Sequence): | ||
return res | ||
|
||
# HOTFIX: Output is a single measurement function. To maintain compatibility | ||
# with core, we squeeze all outputs. | ||
|
||
# Get the namespace associated with the return type | ||
res_type_namespace = res.__class__.__module__.split(".")[0] | ||
|
@@ -455,9 +464,6 @@ def __call__(self, *args, **kwargs): | |
# 'squeeze' does not exist in the top-level of the namespace | ||
return anp.squeeze(res) | ||
|
||
if self._caching: | ||
self._cache_execute = self.qtape._cache_execute | ||
|
||
return __import__(res_type_namespace).squeeze(res) | ||
|
||
def draw(self, charset="unicode"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all just indenting what we had before.