Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message in DAGCircuit.draw for invalid filenames #7447

Merged
merged 6 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions qiskit/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def __init__(
def __str__(self) -> str:
"""Return the message."""
return repr(self.message)


class InvalidFileError(QiskitError):
"""Raised when the file provided is not valid for the specific task."""
67 changes: 66 additions & 1 deletion qiskit/visualization/dag_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile

from qiskit.dagcircuit.dagnode import DAGOpNode, DAGInNode, DAGOutNode
from qiskit.exceptions import MissingOptionalLibraryError
from qiskit.exceptions import MissingOptionalLibraryError, InvalidFileError
from .exceptions import VisualizationError

try:
Expand All @@ -31,6 +31,64 @@
except ImportError:
HAS_PIL = False

FILENAME_EXTENSIONS = {
"bmp",
"canon",
"cgimage",
"cmap",
"cmapx",
"cmapx_np",
"dot",
"dot_json",
"eps",
"exr",
"fig",
"gd",
"gd2",
"gif",
"gv",
"icns",
"ico",
"imap",
"imap_np",
"ismap",
"jp2",
"jpe",
"jpeg",
"jpg",
"json",
"json0",
"mp",
"pct",
"pdf",
"pic",
"pict",
"plain",
"plain-ext",
"png",
"pov",
"ps",
"ps2",
"psd",
"sgi",
"svg",
"svgz",
"tga",
"tif",
"tiff",
"tk",
"vdx",
"vml",
"vmlz",
"vrml",
"wbmp",
"webp",
"xdot",
"xdot1.2",
"xdot1.4",
"xdot_json",
}


def dag_drawer(dag, scale=0.7, filename=None, style="color"):
"""Plot the directed acyclic graph (dag) to represent operation dependencies
Expand Down Expand Up @@ -59,6 +117,7 @@ def dag_drawer(dag, scale=0.7, filename=None, style="color"):
Raises:
VisualizationError: when style is not recognized.
MissingOptionalLibraryError: when pydot or pillow are not installed.
InvalidFileError: when filename provided is not valid

Example:
.. jupyter-execute::
Expand Down Expand Up @@ -166,7 +225,13 @@ def edge_attr_func(edge):
dot = pydot.graph_from_dot_data(dot_str)[0]

if filename:
if "." not in filename:
raise InvalidFileError("Parameter 'filename' must be in format 'name.extension'")
extension = filename.split(".")[-1]
if extension not in FILENAME_EXTENSIONS:
jakelishman marked this conversation as resolved.
Show resolved Hide resolved
raise InvalidFileError(
"Filename extension must be one of: " + " ".join(FILENAME_EXTENSIONS)
)
dot.write(filename, format=extension)
return None
elif ("ipykernel" in sys.modules) and ("spyder" not in sys.modules):
Expand Down
13 changes: 13 additions & 0 deletions test/python/visualization/test_dag_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from qiskit import QuantumRegister, QuantumCircuit
from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import dag_drawer
from qiskit.exceptions import InvalidFileError
from qiskit.visualization.exceptions import VisualizationError
from qiskit.converters import circuit_to_dag

Expand All @@ -36,6 +37,18 @@ def test_dag_drawer_invalid_style(self):
"""Test dag draw with invalid style."""
self.assertRaises(VisualizationError, dag_drawer, self.dag, style="multicolor")

def test_dag_drawer_checks_filename_correct_format(self):
"""filename must contain name and extension"""
with self.assertRaisesRegex(
InvalidFileError, "Parameter 'filename' must be in format 'name.extension'"
):
dag_drawer(self.dag, filename="aaabc")

def test_dag_drawer_checks_filename_extension(self):
"""filename must have a valid extension"""
with self.assertRaisesRegex(InvalidFileError, "Filename extension must be one of: .*"):
dag_drawer(self.dag, filename="aa.abc")


if __name__ == "__main__":
unittest.main(verbosity=2)