diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 6d45dad19db..acee6e233e9 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -4,6 +4,9 @@

Improvements

+* ``qml.circuit_drawer.CircuitDrawer`` can accept a string for the ``charset`` keyword, instead of a ``CharSet`` object. + [(#1640)](https://github.com/PennyLaneAI/pennylane/pull/1640) +

Breaking changes

Bug fixes

diff --git a/pennylane/circuit_drawer/circuit_drawer.py b/pennylane/circuit_drawer/circuit_drawer.py index 3b4fb765c14..947beb6be0b 100644 --- a/pennylane/circuit_drawer/circuit_drawer.py +++ b/pennylane/circuit_drawer/circuit_drawer.py @@ -18,7 +18,7 @@ import pennylane as qml from pennylane.wires import Wires -from .charsets import UnicodeCharSet +from .charsets import CHARSETS, CharSet from .representation_resolver import RepresentationResolver from .grid import Grid @@ -46,7 +46,7 @@ class CircuitDrawer: raw_operation_grid (list[list[~.Operation]]): The CircuitGraph's operations raw_observable_grid (list[list[qml.operation.Observable]]): The CircuitGraph's observables wires (Wires): all wires on the device for which the circuit is drawn - charset (pennylane.circuit_drawer.CharSet, optional): The CharSet that shall be used for drawing. + charset (str, pennylane.circuit_drawer.CharSet, optional): The CharSet that shall be used for drawing. show_all_wires (bool): If True, all wires, including empty wires, are printed. """ @@ -55,21 +55,33 @@ def __init__( raw_operation_grid, raw_observable_grid, wires, - charset=UnicodeCharSet, + charset=None, show_all_wires=False, ): self.operation_grid = Grid(raw_operation_grid) self.observable_grid = Grid(raw_observable_grid) self.wires = wires self.active_wires = self.extract_active_wires(raw_operation_grid, raw_observable_grid) - self.charset = charset + + if charset is None: + self.charset = CHARSETS["unicode"]() + elif isinstance(charset, type) and issubclass(charset, CharSet): + self.charset = charset() + else: + if charset not in CHARSETS: + raise ValueError( + "Charset '{}' is not supported. Supported charsets: {}.".format( + charset, ", ".join(CHARSETS.keys()) + ) + ) + self.charset = CHARSETS[charset]() if show_all_wires: # if the provided wires include empty wires, make sure they are included # as active wires self.active_wires = wires.all_wires([wires, self.active_wires]) - self.representation_resolver = RepresentationResolver(charset) + self.representation_resolver = RepresentationResolver(self.charset) self.operation_representation_grid = Grid() self.observable_representation_grid = Grid() self.operation_decoration_indices = [] @@ -91,15 +103,15 @@ def __init__( CircuitDrawer.pad_representation( self.operation_representation_grid, - charset.WIRE, + self.charset.WIRE, "", - 2 * charset.WIRE, + 2 * self.charset.WIRE, self.operation_decoration_indices, ) CircuitDrawer.pad_representation( self.operation_representation_grid, - charset.WIRE, + self.charset.WIRE, "", "", set(range(self.operation_grid.num_layers)) - set(self.operation_decoration_indices), @@ -108,14 +120,14 @@ def __init__( CircuitDrawer.pad_representation( self.observable_representation_grid, " ", - charset.MEASUREMENT + " ", + self.charset.MEASUREMENT + " ", " ", self.observable_decoration_indices, ) CircuitDrawer.pad_representation( self.observable_representation_grid, - charset.WIRE, + self.charset.WIRE, "", "", set(range(self.observable_grid.num_layers)) - set(self.observable_decoration_indices), diff --git a/pennylane/circuit_graph.py b/pennylane/circuit_graph.py index 11c5a37b729..8a03ceede86 100644 --- a/pennylane/circuit_graph.py +++ b/pennylane/circuit_graph.py @@ -24,7 +24,7 @@ import numpy as np from pennylane.wires import Wires -from .circuit_drawer import CHARSETS, CircuitDrawer +from .circuit_drawer import CircuitDrawer def _by_idx(x): @@ -604,18 +604,11 @@ def draw(self, charset="unicode", wire_order=None, show_all_wires=False): grid, obs = self.greedy_layers(wire_order=wire_order, show_all_wires=show_all_wires) - if charset not in CHARSETS: - raise ValueError( - "Charset {} is not supported. Supported charsets: {}.".format( - charset, ", ".join(CHARSETS.keys()) - ) - ) - drawer = CircuitDrawer( grid, obs, wires=wire_order or self.wires, - charset=CHARSETS[charset], + charset=charset, show_all_wires=show_all_wires, ) diff --git a/tests/circuit_drawer/test_circuit_drawer.py b/tests/circuit_drawer/test_circuit_drawer.py index 09c31272c2c..3ad74e01b4c 100644 --- a/tests/circuit_drawer/test_circuit_drawer.py +++ b/tests/circuit_drawer/test_circuit_drawer.py @@ -22,6 +22,7 @@ from pennylane.circuit_drawer import CircuitDrawer from pennylane.circuit_drawer.circuit_drawer import _remove_duplicates from pennylane.circuit_drawer.grid import Grid, _transpose +from pennylane.circuit_drawer.charsets import CHARSETS, UnicodeCharSet, AsciiCharSet from pennylane.wires import Wires from pennylane.measure import state @@ -65,6 +66,41 @@ def test_remove_duplicates(self, input, expected_output): ] +class TestInitialization: + def test_charset_default(self): + + drawer_None = CircuitDrawer( + dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=None + ) + + assert isinstance(drawer_None.charset, UnicodeCharSet) + + @pytest.mark.parametrize("charset", ("unicode", "ascii")) + def test_charset_string(self, charset): + + drawer_str = CircuitDrawer( + dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=charset + ) + + assert isinstance(drawer_str.charset, CHARSETS[charset]) + + @pytest.mark.parametrize("charset", (UnicodeCharSet, AsciiCharSet)) + def test_charset_class(self, charset): + + drawer_class = CircuitDrawer( + dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=charset + ) + + assert isinstance(drawer_class.charset, charset) + + def test_charset_error(self): + + with pytest.raises(ValueError, match=r"Charset 'nope' is not supported."): + CircuitDrawer( + dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset="nope" + ) + + @pytest.fixture def dummy_circuit_drawer(): """A dummy CircuitDrawer instance."""