Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/test_migration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ make install-dev

# Migrate and make sure the diff is empty
pushd $SUSHI_DIR
sqlmesh --gateway $GATEWAY_NAME migrate
sqlmesh --gateway $GATEWAY_NAME diff prod
SQLMESH_DEBUG=1 sqlmesh --gateway $GATEWAY_NAME migrate
SQLMESH_DEBUG=1 sqlmesh --gateway $GATEWAY_NAME diff prod
popd
4 changes: 2 additions & 2 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def call_macro(


def _coerce(
expr: exp.Expression,
expr: t.Any,
typ: t.Any,
dialect: DialectType,
path: Path,
Expand All @@ -1300,7 +1300,7 @@ def _coerce(
"""Coerces the given expression to the specified type on a best-effort basis."""
base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'."
try:
if typ is None or typ is t.Any:
if typ is None or typ is t.Any or not isinstance(expr, exp.Expression):
return expr
base = t.get_origin(typ) or typ

Expand Down
12 changes: 12 additions & 0 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,15 @@ def _render(e: exp.Expression) -> str | int | float | bool:
{k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name
]

def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]:
return {
name: {
k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items()
}
for name, kwargs in self.signals
if name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's an example where name is falsey here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

airflow signals

}

def render_merge_filter(
self,
*,
Expand Down Expand Up @@ -2359,6 +2368,9 @@ def _create_model(

statements.extend(audit.query for audit in audit_definitions.values())

for _, kwargs in model.signals:
statements.extend(kwargs.values())

python_env = python_env or {}

make_python_env(
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,15 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals:
Note that this will handle gaps in the provided intervals. The returned intervals
may introduce new gaps.
"""
signals = self.is_model and self.model.signals
signals = self.is_model and self.model.render_signal_calls()

if not signals:
return intervals

python_env = self.model.python_env
env = prepare_env(python_env)

for signal_name, kwargs in signals:
for signal_name, kwargs in signals.items():
try:
intervals = _check_ready_intervals(
env[signal_name],
Expand Down
28 changes: 28 additions & 0 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytest_mock.plugin import MockerFixture
from sqlglot import exp, to_column

from sqlmesh.core import constants as c
from sqlmesh.core.audit import StandaloneAudit
from sqlmesh.core.config import (
AutoCategorizationMode,
Expand All @@ -35,6 +36,7 @@
)
from sqlmesh.core.model.kind import TimeColumn, ModelKindName
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.signal import signal
from sqlmesh.core.snapshot import (
DeployabilityIndex,
QualifiedViewName,
Expand Down Expand Up @@ -2802,3 +2804,29 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
assert snapshot_b.intervals == [
(to_timestamp("2020-01-01"), to_timestamp("2020-01-06")),
]


def test_render_signal(make_snapshot):
@signal()
def check_types(batch, env: str, default: int = 0):
if env != "in_memory" or not default == 0:
raise
return True

sql_model = load_sql_based_model(
parse(
"""
MODEL (
name test_schema.test_model,
signals check_types(env := @gateway)
);
SELECT a FROM tbl;
"""
),
variables={
c.GATEWAY: "in_memory",
},
signal_definitions=signal.get_registry(),
)
snapshot_a = make_snapshot(sql_model)
assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)]