Skip to content

Commit

Permalink
Make global setter for graphviz
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Nov 11, 2023
1 parent 61303d4 commit edcb415
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
26 changes: 25 additions & 1 deletion neuralogic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_initial_seed = _seed
_rnd_generator = None
_max_memory_size = None
_graphviz_path = None

jvm_params = {
"classpath": os.path.join(os.path.abspath(os.path.dirname(__file__)), "jar", "NeuraLogic.jar"),
Expand Down Expand Up @@ -106,8 +107,31 @@ def is_initialized() -> bool:
return _is_initialized


def set_graphviz_path(path: Optional[str]):
"""
Set the default path to Graphviz
Parameters
----------
path : Optional[str]
The Graphviz path
"""
global _graphviz_path
_graphviz_path = path


def get_default_graphviz_path() -> Optional[str]:
"""
Get the default path to Graphviz
"""
return _graphviz_path


def initialize(
debug_mode: bool = False, debug_port: int = 12999, is_debug_server: bool = True, debug_suspend: bool = True
debug_mode: bool = False,
debug_port: int = 12999,
is_debug_server: bool = True,
debug_suspend: bool = True,
):
"""
Initialize the NeuraLogic backend. This function is called implicitly when needed and should be called
Expand Down
15 changes: 13 additions & 2 deletions neuralogic/utils/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@

import jpype

from neuralogic import get_default_graphviz_path
from neuralogic.core.settings import Settings, SettingsProxy


def get_graphviz_path(path: Optional[str] = None) -> str:
"""
Get the path to the Graphviz executable
"""
if path is not None:
return path
return get_default_graphviz_path()


def get_drawing_settings(
img_type: str = "png", value_detail: int = 0, graphviz_path: Optional[str] = None
) -> SettingsProxy:
Expand All @@ -19,8 +29,9 @@ def get_drawing_settings(
"""
settings = Settings().create_proxy()

if graphviz_path is not None:
settings.settings.graphvizPath = graphviz_path
graphviz = get_graphviz_path(graphviz_path)
if graphviz is not None:
settings.settings.graphvizPath = graphviz

settings.settings.drawing = False
settings.settings.storeNotShow = True
Expand Down

0 comments on commit edcb415

Please sign in to comment.