Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 77 additions & 40 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,21 @@ def loading_stop(self, id: uuid.UUID) -> None:
class TerminalConsole(Console):
"""A rich based implementation of the console"""

def __init__(self, console: t.Optional[RichConsole] = None) -> None:
def __init__(self, console: t.Optional[RichConsole] = None, **kwargs: t.Any) -> None:
self.console: RichConsole = console or srich.console
self.progress: t.Optional[Progress] = None
self.tasks: t.Dict[str, t.Tuple[TaskID, int]] = {}
self.loading_status: t.Dict[uuid.UUID, Status] = {}

def _print(self, value: t.Any, **kwargs: t.Any) -> None:
self.console.print(value)

def _prompt(self, message: str, **kwargs: t.Any) -> t.Any:
return Prompt.ask(message, console=self.console, **kwargs)

def _confirm(self, message: str, **kwargs: t.Any) -> bool:
return Confirm.ask(message, console=self.console, **kwargs)

def start_snapshot_progress(self, snapshot_name: str, total_batches: int) -> None:
"""Indicates that a new load progress has begun."""
if not self.progress:
Expand Down Expand Up @@ -178,9 +187,7 @@ def show_model_difference_summary(
detailed: Show the actual SQL differences if True.
"""
if not context_diff.has_differences:
self.console.print(
Tree(f"[bold]No differences when compared to `{context_diff.environment}`")
)
self._print(Tree(f"[bold]No differences when compared to `{context_diff.environment}`"))
return
tree = Tree(f"[bold]Summary of differences against `{context_diff.environment}`:")

Expand Down Expand Up @@ -217,7 +224,7 @@ def show_model_difference_summary(
tree.add(indirect)
if metadata.children:
tree.add(metadata)
self.console.print(tree)
self._print(tree)

def plan(self, plan: Plan, auto_apply: bool) -> None:
"""The main plan flow.
Expand Down Expand Up @@ -249,7 +256,7 @@ def _prompt_categorize(self, plan: Plan, auto_apply: bool) -> None:
self._show_categorized_snapshots(plan)

for snapshot in plan.uncategorized:
self.console.print(Syntax(plan.context_diff.text_diff(snapshot.name), "sql"))
self._print(Syntax(plan.context_diff.text_diff(snapshot.name), "sql"))
tree = Tree(f"[bold][direct]Directly Modified: {snapshot.name}")
indirect_tree = None

Expand All @@ -258,7 +265,7 @@ def _prompt_categorize(self, plan: Plan, auto_apply: bool) -> None:
indirect_tree = Tree(f"[indirect]Indirectly Modified Children:")
tree.add(indirect_tree)
indirect_tree.add(f"[indirect]{child}")
self.console.print(tree)
self._print(tree)
self._get_snapshot_change_category(snapshot, plan, auto_apply)

def _show_categorized_snapshots(self, plan: Plan) -> None:
Expand All @@ -276,8 +283,8 @@ def _show_categorized_snapshots(self, plan: Plan) -> None:
indirect_tree = Tree(f"[indirect]Indirectly Modified Children:")
tree.add(indirect_tree)
indirect_tree.add(f"[indirect]{child}")
self.console.print(syntax_dff)
self.console.print(tree)
self._print(syntax_dff)
self._print(tree)

def _show_missing_dates(self, plan: Plan) -> None:
"""Displays the models with missing dates"""
Expand All @@ -286,7 +293,7 @@ def _show_missing_dates(self, plan: Plan) -> None:
backfill = Tree("[bold]Models needing backfill (missing dates):")
for missing in plan.missing_intervals:
backfill.add(f"{missing.snapshot_name}: {missing.format_missing_range()}")
self.console.print(backfill)
self._print(backfill)

def _prompt_backfill(self, plan: Plan, auto_apply: bool) -> None:
is_forward_only_dev = plan.is_dev and plan.forward_only
Expand All @@ -298,28 +305,25 @@ def _prompt_backfill(self, plan: Plan, auto_apply: bool) -> None:
if is_forward_only_dev
else "for the beginning of history"
)
start = Prompt.ask(
start = self._prompt(
f"Enter the {backfill_or_preview} start date (eg. '1 year', '2020-01-01') or blank {blank_meaning}",
console=self.console,
)
if start:
plan.start = start

if plan.is_end_allowed and not plan.override_end:
end = Prompt.ask(
end = self._prompt(
f"Enter the {backfill_or_preview} end date (eg. '1 month ago', '2020-01-01') or blank to {backfill_or_preview} up until now",
console=self.console,
)
if end:
plan.end = end

if not auto_apply and Confirm.ask(f"Apply - {backfill_or_preview.capitalize()} Tables"):
if not auto_apply and self._confirm(f"Apply - {backfill_or_preview.capitalize()} Tables"):
plan.apply()

def _prompt_promote(self, plan: Plan) -> None:
if Confirm.ask(
if self._confirm(
f"Apply - Logical Update",
console=self.console,
):
plan.apply()

Expand All @@ -328,36 +332,36 @@ def log_test_results(
) -> None:
divider_length = 70
if result.wasSuccessful():
self.console.print("=" * divider_length)
self.console.print(
self._print("=" * divider_length)
self._print(
f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}",
style="green",
)
self.console.print("-" * divider_length)
self._print("-" * divider_length)
else:
self.console.print("-" * divider_length)
self.console.print("Test Failure Summary")
self.console.print("=" * divider_length)
self.console.print(
self._print("-" * divider_length)
self._print("Test Failure Summary")
self._print("=" * divider_length)
self._print(
f"Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}"
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self.console.print(f"Failure Test: {test.model_name} {test.test_name}")
self.console.print("=" * divider_length)
self.console.print(output)
self._print(f"Failure Test: {test.model_name} {test.test_name}")
self._print("=" * divider_length)
self._print(output)

def show_sql(self, sql: str) -> None:
self.console.print(Syntax(sql, "sql"))
self._print(Syntax(sql, "sql"))

def log_status_update(self, message: str) -> None:
self.console.print(message)
self._print(message)

def log_error(self, message: str) -> None:
self.console.print(f"[red]{message}[/red]")
self._print(f"[red]{message}[/red]")

def log_success(self, message: str) -> None:
self.console.print(f"\n[green]{message}[/green]\n")
self._print(f"\n[green]{message}[/green]\n")

def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
id = uuid.uuid4()
Expand All @@ -373,9 +377,8 @@ def _get_snapshot_change_category(
self, snapshot: Snapshot, plan: Plan, auto_apply: bool
) -> None:
choices = self._snapshot_change_choices(snapshot)
response = Prompt.ask(
response = self._prompt(
"\n".join([f"[{i+1}] {choice}" for i, choice in enumerate(choices.values())]),
console=self.console,
show_choices=False,
choices=[f"{i+1}" for i in range(len(choices))],
)
Expand Down Expand Up @@ -432,10 +435,12 @@ class NotebookMagicConsole(TerminalConsole):
or capturing it and converting it into a widget.
"""

def __init__(self, display: t.Callable, console: t.Optional[RichConsole] = None) -> None:
def __init__(
self, display: t.Callable, console: t.Optional[RichConsole] = None, **kwargs: t.Any
) -> None:
import ipywidgets as widgets

super().__init__(console)
super().__init__(console, **kwargs)
self.display = display
self.missing_dates_output = widgets.Output()
self.dynamic_options_after_categorization_output = widgets.VBox()
Expand Down Expand Up @@ -637,10 +642,42 @@ def log_test_results(
self.display(widgets.VBox(children=[test_info, error_output], layout={"width": "100%"}))


def get_console() -> TerminalConsole:
class DatabricksMagicConsole(TerminalConsole):
"""
Note: Databricks Magic Console currently does not support progress bars while a plan is being applied. The
NotebookMagicConsole does support progress bars, but they will time out after 5 minutes of execution
and it makes it difficult to see the progress of the plan.
"""
Currently we only return TerminalConsole since the MagicConsole is only referenced in the magics and
called directly. Seems reasonable we will want dynamic consoles in the future based on runtime environment
so going to leave this for now.

def _print(self, value: t.Any, **kwargs: t.Any) -> None:
with self.console.capture() as capture:
self.console.print(value, **kwargs)
output = capture.get()
print(output)

def _prompt(self, message: str, **kwargs: t.Any) -> t.Any:
self._print(message, **kwargs)
return super()._prompt("", **kwargs)

def _confirm(self, message: str, **kwargs: t.Any) -> bool:
message = f"{message} \[y/n]"
self._print(message, **kwargs)
return super()._confirm("", **kwargs)


def get_console(**kwargs: t.Any) -> TerminalConsole | DatabricksMagicConsole | NotebookMagicConsole:
"""
Returns the console that is appropriate for the current runtime environment.

Note: Google Colab environment is untested and currently assumes is compatible with the base
NotebookMagicConsole.
"""
return TerminalConsole()
from sqlmesh import RuntimeEnv, runtime_env

runtime_env_mapping = {
RuntimeEnv.DATABRICKS: DatabricksMagicConsole,
RuntimeEnv.JUPYTER: NotebookMagicConsole,
RuntimeEnv.TERMINAL: TerminalConsole,
RuntimeEnv.GOOGLE_COLAB: NotebookMagicConsole,
}
return runtime_env_mapping[runtime_env](**kwargs)
4 changes: 2 additions & 2 deletions sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from IPython.core.magic import Magics, line_cell_magic, line_magic, magics_class
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring

from sqlmesh.core.console import NotebookMagicConsole
from sqlmesh.core.console import get_console
from sqlmesh.core.context import Context
from sqlmesh.core.dialect import format_model_expressions, parse_model
from sqlmesh.core.model import load_model
Expand Down Expand Up @@ -226,7 +226,7 @@ def plan(self, line: str) -> None:

# Since the magics share a context we want to clear out any state before generating a new plan
console = self._context.console
self._context.console = NotebookMagicConsole(self.display)
self._context.console = get_console(display=self.display)

self._context.plan(
args.environment,
Expand Down