Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get list of supported image formats from dot instead of hardcoding #830

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion rustworkx/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
__all__ = [
"mpl_draw",
"graphviz_draw",
"have_dot",
"is_format_supported",
]

from .matplotlib import mpl_draw
from .graphviz import graphviz_draw
from .graphviz import graphviz_draw, have_dot, is_format_supported
142 changes: 73 additions & 69 deletions rustworkx/visualization/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,53 +10,67 @@
import tempfile
import io

__all__ = ["graphviz_draw", "have_dot", "is_format_supported"]

METHODS = {"twopi", "neato", "circo", "fdp", "sfdp", "dot"}

_NO_PILLOW_MSG = """
Pillow is necessary to use graphviz_draw() it can be installed
with 'pip install pydot pillow.
"""

_NO_DOT_MSG = """
Graphviz could not be found or run. This function requires that
Graphviz is installed. If you need to install Graphviz you can
refer to: https://graphviz.org/download/#executable-packages for
instructions.
"""

try:
from PIL import Image
import PIL

HAS_PILLOW = True
except ImportError:
HAS_PILLOW = False
HAVE_PILLOW = True
except Exception:
HAVE_PILLOW = False

__all__ = ["graphviz_draw"]

METHODS = {"twopi", "neato", "circo", "fdp", "sfdp", "dot"}
IMAGE_TYPES = {
"canon",
"cmap",
"cmapx",
"cmapx_np",
"dia",
"dot",
"fig",
"gd",
"gd2",
"gif",
"hpgl",
"imap",
"imap_np",
"ismap",
"jpe",
"jpeg",
"jpg",
"mif",
"mp",
"pcl",
"pdf",
"pic",
"plain",
"plain-ext",
"png",
"ps",
"ps2",
"svg",
"svgz",
"vml",
"vmlz" "vrml",
"vtx",
"wbmp",
"xdor",
"xlib",
}
# Return True if `dot` is found and executes.
def have_dot():
try:
subprocess.run(
["dot", "-V"],
cwd=tempfile.gettempdir(),
check=True,
capture_output=True,
)
except Exception:
return False
return True


def _capture_support_string():
try:
subprocess.check_output(
["dot", "-T", "bogus_format"],
cwd=tempfile.gettempdir(),
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as exerr:
return exerr.output.decode()


# Return collection of image formats supported by dot, as
# a `set` of `str`.
def _supported_image_formats():
error_string = _capture_support_string()
# 7 is a magic number based error message.
# The words following the first seven are the formats.
return set(error_string.split()[7:])


def is_format_supported(image_format: str):
"""Return true if `image_format` is supported by the installed graphviz."""
return image_format in _supported_image_formats()


def graphviz_draw(
Expand Down Expand Up @@ -141,38 +155,28 @@ def node_attr(node):
graphviz_draw(graph, node_attr_fn=node_attr, method='sfdp')

"""
if not HAS_PILLOW:
raise ImportError(
"Pillow is necessary to use graphviz_draw() "
"it can be installed with 'pip install pydot pillow'"
)
try:
subprocess.run(
["dot", "-V"],
cwd=tempfile.gettempdir(),
check=True,
capture_output=True,
)
except Exception:
raise RuntimeError(
"Graphviz could not be found or run. This function requires that "
"Graphviz is installed. If you need to install Graphviz you can "
"refer to: https://graphviz.org/download/#executable-packages for "
"instructions."
)
_have_dot = have_dot()
if not (HAVE_PILLOW and _have_dot):
raise RuntimeError(_NO_DOT_MSG + _NO_PILLOW_MSG)
if not HAVE_PILLOW:
raise ImportError(_NO_PILLOW_MSG)
if not _have_dot:
raise RuntimeError(_NO_DOT_MSG)

dot_str = graph.to_dot(node_attr_fn, edge_attr_fn, graph_attr)
if image_type is None:
output_format = "png"
else:
if image_type not in IMAGE_TYPES:
raise ValueError(
"The specified value for the image_type argument, "
f"'{image_type}' is not a valid choice. It must be one of: "
f"{IMAGE_TYPES}"
)
output_format = image_type

supported_formats = _supported_image_formats()
if output_format not in supported_formats:
raise ValueError(
"The specified value for the image_type argument, "
f"'{output_format}' is not a valid choice. It must be one of: "
f"{supported_formats}"
)

if method is None:
prog = "dot"
else:
Expand All @@ -193,7 +197,7 @@ def node_attr(node):
text=False,
)
dot_bytes_image = io.BytesIO(dot_result.stdout)
image = Image.open(dot_bytes_image)
image = PIL.Image.open(dot_bytes_image)
return image
else:
subprocess.run(
Expand Down
21 changes: 9 additions & 12 deletions tests/rustworkx_tests/visualization/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,17 @@
# under the License.

import os
import subprocess
import tempfile
import unittest

import rustworkx
from rustworkx.visualization import graphviz_draw
from rustworkx.visualization import graphviz_draw, have_dot, is_format_supported

try:
import PIL

subprocess.run(
["dot", "-V"],
cwd=tempfile.gettempdir(),
check=True,
capture_output=True,
)
HAS_PILLOW = True
HAVE_PILLOW = True
except Exception:
HAS_PILLOW = False
HAVE_PILLOW = False

SAVE_IMAGES = os.getenv("RETWORKX_TEST_PRESERVE_IMAGES", None)

Expand All @@ -39,7 +31,9 @@ def _save_image(image, path):
image.save(path)


@unittest.skipUnless(HAS_PILLOW, "pillow and graphviz are required for running these tests")
@unittest.skipUnless(
HAVE_PILLOW and have_dot(), "pillow and graphviz are required for running these tests"
)
class TestGraphvizDraw(unittest.TestCase):
def test_draw_no_args(self):
graph = rustworkx.generators.star_graph(24)
Expand Down Expand Up @@ -117,6 +111,9 @@ def test_draw_graph_attr(self):
self.assertIsInstance(image, PIL.Image.Image)
_save_image(image, "test_graphviz_draw_graph_attr.png")

@unittest.skipUnless(
is_format_supported("jpg"), "Installed graphviz does not support jpg image format."
)
def test_image_type(self):
graph = rustworkx.directed_gnp_random_graph(50, 0.8)
image = graphviz_draw(graph, image_type="jpg")
Expand Down