Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform dispatcher to programs #4559

Merged
merged 17 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pennylane.typing import Result, ResultBatch
from pennylane.tape import QuantumTape

from .transform_dispatcher import TransformContainer, TransformError
from .transform_dispatcher import TransformContainer, TransformError, TransformDispatcher

PostProcessingFn = Callable[[ResultBatch], Result]
BatchPostProcessingFn = Callable[[ResultBatch], ResultBatch]
Expand Down Expand Up @@ -167,6 +167,61 @@ def insert_front(self, transform_container: TransformContainer):
)
self._transform_program.insert(0, transform_container)

def add_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the end of the program.
rmoyard marked this conversation as resolved.
Show resolved Hide resolved

Args:
transform(TransformDispatcher): The transform program where the transform is added.
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
*targs: Any additional arguments that are passed to the transform.

Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.

"""
if not isinstance(transform, TransformDispatcher):
raise TransformError("Only transform dispatcher can be added to the transform program.")

if transform.expand_transform:
self.push_back(TransformContainer(transform.expand_transform))
self.push_back(
TransformContainer(
transform.transform,
targs,
tkwargs,
transform.classical_cotransform,
transform.is_informative,
)
)

def insert_front_transform(self, transform: TransformDispatcher, *targs, **tkwargs):
"""Add a transform (dispatcher) to the beginning of the program.

Args:
transform(TransformDispatcher): The transform program where the transform is added.
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
*targs: Any additional arguments that are passed to the transform.

Keyword Args:
**tkwargs: Any additional keyword arguments that are passed to the transform.

"""
if transform.is_informative and not self.is_empty():
raise TransformError(
"Informative transforms can only be added at the end of the program."
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
)

self.insert_front(
TransformContainer(
transform.transform,
targs,
tkwargs,
transform.classical_cotransform,
transform.is_informative,
)
)

if transform.expand_transform:
self.insert_front(TransformContainer(transform.expand_transform))

def pop_front(self):
"""Pop the transform container at the beginning of the program.

Expand Down
103 changes: 101 additions & 2 deletions tests/transforms/test_experimental/test_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def first_valid_transform(
return [tape], lambda x: x


def expand_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
"""A valid expand transform."""
return [tape], lambda x: x


def second_valid_transform(
tape: qml.tape.QuantumTape, index: int
) -> (Sequence[qml.tape.QuantumTape], Callable):
Expand Down Expand Up @@ -215,7 +220,7 @@ def test_empty_program(self):
):
program.get_last()

def test_basic_program(self):
def test_push_back(self):
"""Test to push back multiple transforms into a program and also the different methods of a program."""
transform_program = TransformProgram()

Expand Down Expand Up @@ -249,6 +254,55 @@ def test_basic_program(self):
):
transform_program.push_back(10.0)

def test_add_transform(self):
"""Test to add multiple transforms into a program and also the different methods of a program."""
transform_program = TransformProgram()

transform1 = transform(first_valid_transform)
transform_program.add_transform(transform1)

assert not transform_program.is_empty()
assert len(transform_program) == 1
assert isinstance(transform_program[0], TransformContainer)
assert transform_program[0].transform is first_valid_transform

transform2 = transform(second_valid_transform)
transform_program.add_transform(transform2)

assert not transform_program.is_empty()
assert len(transform_program) == 2
assert isinstance(transform_program[1], TransformContainer)
assert transform_program[1].transform is second_valid_transform

transform_program.add_transform(transform1)
transform_program.add_transform(transform2)

sub_program_transforms = transform_program[2:]
assert len(sub_program_transforms) == 2
assert sub_program_transforms[0].transform is first_valid_transform
assert sub_program_transforms[1].transform is second_valid_transform

with pytest.raises(
TransformError,
match="Only transform dispatcher can be added to the transform program.",
):
transform_program.add_transform(10.0)

def test_add_transform_with_expand(self):
"""Test to add a transform with expand into a program."""
transform_program = TransformProgram()

transform1 = transform(first_valid_transform, expand_transform=expand_transform)
transform_program.add_transform(transform1)

assert not transform_program.is_empty()
assert len(transform_program) == 2
assert isinstance(transform_program[0], TransformContainer)
assert transform_program[0].transform is expand_transform

assert isinstance(transform_program[1], TransformContainer)
assert transform_program[1].transform is first_valid_transform

def test_pop_front(self):
"""Test the pop front method of the transform program."""
transform_program = TransformProgram()
Expand All @@ -267,7 +321,7 @@ def test_pop_front(self):
assert transform_container is transform1

def test_insert_front(self):
"""Test to insert a transform at the beginning of a transform program."""
"""Test to insert a transform (container) at the beginning of a transform program."""
transform_program = TransformProgram()

transform1 = TransformContainer(transform=first_valid_transform)
Expand Down Expand Up @@ -296,6 +350,51 @@ def test_insert_front(self):
):
transform_program.insert_front(transform3)

def test_insert_transform(self):
"""Test to insert a transform (dispatcher) at the beginning of a transform program."""
transform_program = TransformProgram()

transform1 = transform(first_valid_transform)
transform_program.insert_front_transform(transform1)

assert not transform_program.is_empty()
assert len(transform_program) == 1
assert isinstance(transform_program[0], TransformContainer)
assert transform_program[0].transform is first_valid_transform

transform2 = transform(second_valid_transform)
transform_program.insert_front_transform(transform2)

assert not transform_program.is_empty()
assert len(transform_program) == 2
assert isinstance(transform_program[0], TransformContainer)
assert transform_program[0].transform is second_valid_transform
assert isinstance(transform_program[1], TransformContainer)
assert transform_program[1].transform is first_valid_transform

transform3 = transform(second_valid_transform, is_informative=True)

with pytest.raises(
TransformError,
match="Informative transforms can only be added at the end of the program.",
):
transform_program.insert_front_transform(transform3)

def test_insert_transform_with_expand(self):
"""Test to insert front a transform with expand into a program."""
transform_program = TransformProgram()

transform1 = transform(first_valid_transform, expand_transform=expand_transform)
transform_program.insert_front_transform(transform1)

assert not transform_program.is_empty()
assert len(transform_program) == 2
assert isinstance(transform_program[0], TransformContainer)
assert transform_program[0].transform is expand_transform

assert isinstance(transform_program[1], TransformContainer)
assert transform_program[1].transform is first_valid_transform

def test_valid_transforms(self):
"""Test that that it is only possible to create valid transforms."""
transform_program = TransformProgram()
Expand Down
Loading