Skip to content

Commit

Permalink
Added support for output_dim to be a tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Feb 6, 2021
1 parent 6f9e3e5 commit 67f5980
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pennylane/qnn/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,9 @@ def __init__(
self._signature_validation(qnode, weight_shapes)
self.qnode = to_tf(qnode, dtype=tf.keras.backend.floatx())

# Allows output_dim to be specified as an int, e.g., 5, or as a length-1 tuple, e.g., (5,)
self.output_dim = output_dim[0] if isinstance(output_dim, Iterable) else output_dim
# Allows output_dim to be specified as an int, e.g., 5, or as a tuple, e.g., (5, 2)
# However the final output_dim type will always be a tuple, e.g., 5 will become (5,)
self.output_dim = tuple(output_dim) if isinstance(output_dim, Iterable) else (output_dim,)

self.weight_specs = weight_specs if weight_specs is not None else {}

Expand Down

0 comments on commit 67f5980

Please sign in to comment.