diff --git a/rustworkx/visualization/__init__.py b/rustworkx/visualization/__init__.py index 615932f9b..f2d91a5ca 100644 --- a/rustworkx/visualization/__init__.py +++ b/rustworkx/visualization/__init__.py @@ -11,7 +11,9 @@ __all__ = [ "mpl_draw", "graphviz_draw", + "have_dot", + "is_format_supported", ] from .matplotlib import mpl_draw -from .graphviz import graphviz_draw +from .graphviz import graphviz_draw, have_dot, is_format_supported diff --git a/rustworkx/visualization/graphviz.py b/rustworkx/visualization/graphviz.py index 4d5198814..8205f58a7 100644 --- a/rustworkx/visualization/graphviz.py +++ b/rustworkx/visualization/graphviz.py @@ -10,53 +10,67 @@ import tempfile import io +__all__ = ["graphviz_draw", "have_dot", "is_format_supported"] + +METHODS = {"twopi", "neato", "circo", "fdp", "sfdp", "dot"} + +_NO_PILLOW_MSG = """ +Pillow is necessary to use graphviz_draw() it can be installed +with 'pip install pydot pillow. +""" + +_NO_DOT_MSG = """ +Graphviz could not be found or run. This function requires that +Graphviz is installed. If you need to install Graphviz you can +refer to: https://graphviz.org/download/#executable-packages for +instructions. +""" + try: - from PIL import Image + import PIL - HAS_PILLOW = True -except ImportError: - HAS_PILLOW = False + HAVE_PILLOW = True +except Exception: + HAVE_PILLOW = False -__all__ = ["graphviz_draw"] -METHODS = {"twopi", "neato", "circo", "fdp", "sfdp", "dot"} -IMAGE_TYPES = { - "canon", - "cmap", - "cmapx", - "cmapx_np", - "dia", - "dot", - "fig", - "gd", - "gd2", - "gif", - "hpgl", - "imap", - "imap_np", - "ismap", - "jpe", - "jpeg", - "jpg", - "mif", - "mp", - "pcl", - "pdf", - "pic", - "plain", - "plain-ext", - "png", - "ps", - "ps2", - "svg", - "svgz", - "vml", - "vmlz" "vrml", - "vtx", - "wbmp", - "xdor", - "xlib", -} +# Return True if `dot` is found and executes. +def have_dot(): + try: + subprocess.run( + ["dot", "-V"], + cwd=tempfile.gettempdir(), + check=True, + capture_output=True, + ) + except Exception: + return False + return True + + +def _capture_support_string(): + try: + subprocess.check_output( + ["dot", "-T", "bogus_format"], + cwd=tempfile.gettempdir(), + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as exerr: + return exerr.output.decode() + + +# Return collection of image formats supported by dot, as +# a `set` of `str`. +def _supported_image_formats(): + error_string = _capture_support_string() + # 7 is a magic number based error message. + # The words following the first seven are the formats. + return set(error_string.split()[7:]) + + +def is_format_supported(image_format: str): + """Return true if `image_format` is supported by the installed graphviz.""" + return image_format in _supported_image_formats() def graphviz_draw( @@ -141,38 +155,28 @@ def node_attr(node): graphviz_draw(graph, node_attr_fn=node_attr, method='sfdp') """ - if not HAS_PILLOW: - raise ImportError( - "Pillow is necessary to use graphviz_draw() " - "it can be installed with 'pip install pydot pillow'" - ) - try: - subprocess.run( - ["dot", "-V"], - cwd=tempfile.gettempdir(), - check=True, - capture_output=True, - ) - except Exception: - raise RuntimeError( - "Graphviz could not be found or run. This function requires that " - "Graphviz is installed. If you need to install Graphviz you can " - "refer to: https://graphviz.org/download/#executable-packages for " - "instructions." - ) + _have_dot = have_dot() + if not (HAVE_PILLOW and _have_dot): + raise RuntimeError(_NO_DOT_MSG + _NO_PILLOW_MSG) + if not HAVE_PILLOW: + raise ImportError(_NO_PILLOW_MSG) + if not _have_dot: + raise RuntimeError(_NO_DOT_MSG) dot_str = graph.to_dot(node_attr_fn, edge_attr_fn, graph_attr) if image_type is None: output_format = "png" else: - if image_type not in IMAGE_TYPES: - raise ValueError( - "The specified value for the image_type argument, " - f"'{image_type}' is not a valid choice. It must be one of: " - f"{IMAGE_TYPES}" - ) output_format = image_type + supported_formats = _supported_image_formats() + if output_format not in supported_formats: + raise ValueError( + "The specified value for the image_type argument, " + f"'{output_format}' is not a valid choice. It must be one of: " + f"{supported_formats}" + ) + if method is None: prog = "dot" else: @@ -193,7 +197,7 @@ def node_attr(node): text=False, ) dot_bytes_image = io.BytesIO(dot_result.stdout) - image = Image.open(dot_bytes_image) + image = PIL.Image.open(dot_bytes_image) return image else: subprocess.run( diff --git a/tests/rustworkx_tests/visualization/test_graphviz.py b/tests/rustworkx_tests/visualization/test_graphviz.py index af2733141..1152eed62 100644 --- a/tests/rustworkx_tests/visualization/test_graphviz.py +++ b/tests/rustworkx_tests/visualization/test_graphviz.py @@ -11,25 +11,17 @@ # under the License. import os -import subprocess -import tempfile import unittest import rustworkx -from rustworkx.visualization import graphviz_draw +from rustworkx.visualization import graphviz_draw, have_dot, is_format_supported try: import PIL - subprocess.run( - ["dot", "-V"], - cwd=tempfile.gettempdir(), - check=True, - capture_output=True, - ) - HAS_PILLOW = True + HAVE_PILLOW = True except Exception: - HAS_PILLOW = False + HAVE_PILLOW = False SAVE_IMAGES = os.getenv("RETWORKX_TEST_PRESERVE_IMAGES", None) @@ -39,7 +31,9 @@ def _save_image(image, path): image.save(path) -@unittest.skipUnless(HAS_PILLOW, "pillow and graphviz are required for running these tests") +@unittest.skipUnless( + HAVE_PILLOW and have_dot(), "pillow and graphviz are required for running these tests" +) class TestGraphvizDraw(unittest.TestCase): def test_draw_no_args(self): graph = rustworkx.generators.star_graph(24) @@ -117,6 +111,9 @@ def test_draw_graph_attr(self): self.assertIsInstance(image, PIL.Image.Image) _save_image(image, "test_graphviz_draw_graph_attr.png") + @unittest.skipUnless( + is_format_supported("jpg"), "Installed graphviz does not support jpg image format." + ) def test_image_type(self): graph = rustworkx.directed_gnp_random_graph(50, 0.8) image = graphviz_draw(graph, image_type="jpg")