Skip to content

Commit

Permalink
Updates jupyter magic to display in databricks
Browse files Browse the repository at this point in the history
Databricks doesn’t automatically render. So we
have to render it ourselves. I decided on HTML
over Image for no good reason really.

This should also handle errors if graphviz isn’t installed
and the visualization extra isn’t installed.
  • Loading branch information
skrawcz committed Feb 23, 2024
1 parent 5e7b24f commit 8a00d17
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion hamilton/plugins/jupyter_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import json
import os
import sys
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -154,6 +155,12 @@ def rebuild_drivers(shell, module_name: str, module_object: ModuleType, verbosit
return drivers_rebuilt


def determine_notebook_type() -> str:
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
return "databricks"
return "default"


@magics_class
class HamiltonMagics(Magics):
"""Magics to facilitate Hamilton development in Jupyter notebooks"""
Expand Down Expand Up @@ -189,6 +196,9 @@ def cell_to_module(self, line, cell):
print(" -d, --display: Flag to visualize dataflow.")
print(" -v, --verbosity: of standard output. 0 to hide. 1 is normal, default.")
return # Exit early
if not hasattr(self, "notebook_env"):
# doing this so I don't have to deal with the constructor
self.notebook_env = determine_notebook_type()
# shell.ex() is equivalent to exec(), but in the user namespace (i.e. notebook context).
# This allows imports and functions defined in the magic cell %%cell_to_module to be
# directly accessed from the notebook
Expand Down Expand Up @@ -228,9 +238,17 @@ def cell_to_module(self, line, cell):
dr = driver.Builder().with_modules(module_object).with_config(display_config).build()
self.shell.push({f"{module_name}_dr": dr})
if args.display:
graphviz_obj = dr.display_all_functions()
if self.notebook_env == "databricks" and graphviz_obj:
try:
display(HTML(graphviz_obj.pipe(format="svg").decode("utf-8")))
except Exception as e:
print(f"Failed to display graph: {e}")
print("Please ensure graphviz is installed via `%sh apt install -y graphviz`")
return
# return will go to the output cell. To display multiple elements, use
# IPython.display.display(print("hello"), dr.display_all_functions(), ...)
return dr.display_all_functions()
return graphviz_obj

@line_magic
def insert_module(self, line):
Expand Down

0 comments on commit 8a00d17

Please sign in to comment.