diff --git a/hamilton/plugins/jupyter_magic.py b/hamilton/plugins/jupyter_magic.py index ed0fab93f..2e6f0ffc2 100644 --- a/hamilton/plugins/jupyter_magic.py +++ b/hamilton/plugins/jupyter_magic.py @@ -15,6 +15,7 @@ """ import json +import os import sys from pathlib import Path from types import ModuleType @@ -154,6 +155,12 @@ def rebuild_drivers(shell, module_name: str, module_object: ModuleType, verbosit return drivers_rebuilt +def determine_notebook_type() -> str: + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + return "databricks" + return "default" + + @magics_class class HamiltonMagics(Magics): """Magics to facilitate Hamilton development in Jupyter notebooks""" @@ -189,6 +196,9 @@ def cell_to_module(self, line, cell): print(" -d, --display: Flag to visualize dataflow.") print(" -v, --verbosity: of standard output. 0 to hide. 1 is normal, default.") return # Exit early + if not hasattr(self, "notebook_env"): + # doing this so I don't have to deal with the constructor + self.notebook_env = determine_notebook_type() # shell.ex() is equivalent to exec(), but in the user namespace (i.e. notebook context). # This allows imports and functions defined in the magic cell %%cell_to_module to be # directly accessed from the notebook @@ -228,9 +238,17 @@ def cell_to_module(self, line, cell): dr = driver.Builder().with_modules(module_object).with_config(display_config).build() self.shell.push({f"{module_name}_dr": dr}) if args.display: + graphviz_obj = dr.display_all_functions() + if self.notebook_env == "databricks" and graphviz_obj: + try: + display(HTML(graphviz_obj.pipe(format="svg").decode("utf-8"))) + except Exception as e: + print(f"Failed to display graph: {e}") + print("Please ensure graphviz is installed via `%sh apt install -y graphviz`") + return # return will go to the output cell. To display multiple elements, use # IPython.display.display(print("hello"), dr.display_all_functions(), ...) - return dr.display_all_functions() + return graphviz_obj @line_magic def insert_module(self, line):