diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md
index 6d45dad19db..acee6e233e9 100644
--- a/.github/CHANGELOG.md
+++ b/.github/CHANGELOG.md
@@ -4,6 +4,9 @@
Improvements
+* ``qml.circuit_drawer.CircuitDrawer`` can accept a string for the ``charset`` keyword, instead of a ``CharSet`` object.
+ [(#1640)](https://github.com/PennyLaneAI/pennylane/pull/1640)
+
Breaking changes
Bug fixes
diff --git a/pennylane/circuit_drawer/circuit_drawer.py b/pennylane/circuit_drawer/circuit_drawer.py
index 3b4fb765c14..947beb6be0b 100644
--- a/pennylane/circuit_drawer/circuit_drawer.py
+++ b/pennylane/circuit_drawer/circuit_drawer.py
@@ -18,7 +18,7 @@
import pennylane as qml
from pennylane.wires import Wires
-from .charsets import UnicodeCharSet
+from .charsets import CHARSETS, CharSet
from .representation_resolver import RepresentationResolver
from .grid import Grid
@@ -46,7 +46,7 @@ class CircuitDrawer:
raw_operation_grid (list[list[~.Operation]]): The CircuitGraph's operations
raw_observable_grid (list[list[qml.operation.Observable]]): The CircuitGraph's observables
wires (Wires): all wires on the device for which the circuit is drawn
- charset (pennylane.circuit_drawer.CharSet, optional): The CharSet that shall be used for drawing.
+ charset (str, pennylane.circuit_drawer.CharSet, optional): The CharSet that shall be used for drawing.
show_all_wires (bool): If True, all wires, including empty wires, are printed.
"""
@@ -55,21 +55,33 @@ def __init__(
raw_operation_grid,
raw_observable_grid,
wires,
- charset=UnicodeCharSet,
+ charset=None,
show_all_wires=False,
):
self.operation_grid = Grid(raw_operation_grid)
self.observable_grid = Grid(raw_observable_grid)
self.wires = wires
self.active_wires = self.extract_active_wires(raw_operation_grid, raw_observable_grid)
- self.charset = charset
+
+ if charset is None:
+ self.charset = CHARSETS["unicode"]()
+ elif isinstance(charset, type) and issubclass(charset, CharSet):
+ self.charset = charset()
+ else:
+ if charset not in CHARSETS:
+ raise ValueError(
+ "Charset '{}' is not supported. Supported charsets: {}.".format(
+ charset, ", ".join(CHARSETS.keys())
+ )
+ )
+ self.charset = CHARSETS[charset]()
if show_all_wires:
# if the provided wires include empty wires, make sure they are included
# as active wires
self.active_wires = wires.all_wires([wires, self.active_wires])
- self.representation_resolver = RepresentationResolver(charset)
+ self.representation_resolver = RepresentationResolver(self.charset)
self.operation_representation_grid = Grid()
self.observable_representation_grid = Grid()
self.operation_decoration_indices = []
@@ -91,15 +103,15 @@ def __init__(
CircuitDrawer.pad_representation(
self.operation_representation_grid,
- charset.WIRE,
+ self.charset.WIRE,
"",
- 2 * charset.WIRE,
+ 2 * self.charset.WIRE,
self.operation_decoration_indices,
)
CircuitDrawer.pad_representation(
self.operation_representation_grid,
- charset.WIRE,
+ self.charset.WIRE,
"",
"",
set(range(self.operation_grid.num_layers)) - set(self.operation_decoration_indices),
@@ -108,14 +120,14 @@ def __init__(
CircuitDrawer.pad_representation(
self.observable_representation_grid,
" ",
- charset.MEASUREMENT + " ",
+ self.charset.MEASUREMENT + " ",
" ",
self.observable_decoration_indices,
)
CircuitDrawer.pad_representation(
self.observable_representation_grid,
- charset.WIRE,
+ self.charset.WIRE,
"",
"",
set(range(self.observable_grid.num_layers)) - set(self.observable_decoration_indices),
diff --git a/pennylane/circuit_graph.py b/pennylane/circuit_graph.py
index 11c5a37b729..8a03ceede86 100644
--- a/pennylane/circuit_graph.py
+++ b/pennylane/circuit_graph.py
@@ -24,7 +24,7 @@
import numpy as np
from pennylane.wires import Wires
-from .circuit_drawer import CHARSETS, CircuitDrawer
+from .circuit_drawer import CircuitDrawer
def _by_idx(x):
@@ -604,18 +604,11 @@ def draw(self, charset="unicode", wire_order=None, show_all_wires=False):
grid, obs = self.greedy_layers(wire_order=wire_order, show_all_wires=show_all_wires)
- if charset not in CHARSETS:
- raise ValueError(
- "Charset {} is not supported. Supported charsets: {}.".format(
- charset, ", ".join(CHARSETS.keys())
- )
- )
-
drawer = CircuitDrawer(
grid,
obs,
wires=wire_order or self.wires,
- charset=CHARSETS[charset],
+ charset=charset,
show_all_wires=show_all_wires,
)
diff --git a/tests/circuit_drawer/test_circuit_drawer.py b/tests/circuit_drawer/test_circuit_drawer.py
index 09c31272c2c..3ad74e01b4c 100644
--- a/tests/circuit_drawer/test_circuit_drawer.py
+++ b/tests/circuit_drawer/test_circuit_drawer.py
@@ -22,6 +22,7 @@
from pennylane.circuit_drawer import CircuitDrawer
from pennylane.circuit_drawer.circuit_drawer import _remove_duplicates
from pennylane.circuit_drawer.grid import Grid, _transpose
+from pennylane.circuit_drawer.charsets import CHARSETS, UnicodeCharSet, AsciiCharSet
from pennylane.wires import Wires
from pennylane.measure import state
@@ -65,6 +66,41 @@ def test_remove_duplicates(self, input, expected_output):
]
+class TestInitialization:
+ def test_charset_default(self):
+
+ drawer_None = CircuitDrawer(
+ dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=None
+ )
+
+ assert isinstance(drawer_None.charset, UnicodeCharSet)
+
+ @pytest.mark.parametrize("charset", ("unicode", "ascii"))
+ def test_charset_string(self, charset):
+
+ drawer_str = CircuitDrawer(
+ dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=charset
+ )
+
+ assert isinstance(drawer_str.charset, CHARSETS[charset])
+
+ @pytest.mark.parametrize("charset", (UnicodeCharSet, AsciiCharSet))
+ def test_charset_class(self, charset):
+
+ drawer_class = CircuitDrawer(
+ dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset=charset
+ )
+
+ assert isinstance(drawer_class.charset, charset)
+
+ def test_charset_error(self):
+
+ with pytest.raises(ValueError, match=r"Charset 'nope' is not supported."):
+ CircuitDrawer(
+ dummy_raw_operation_grid, dummy_raw_observable_grid, Wires(range(6)), charset="nope"
+ )
+
+
@pytest.fixture
def dummy_circuit_drawer():
"""A dummy CircuitDrawer instance."""