Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ class StandaloneAudit(_Node, AuditMixin):
"""
Args:
depends_on: A list of tables this audit depends on.
hash_raw_query: Whether to hash the raw query or the rendered query.
python_env: Dictionary containing all global variables needed to render the audit's macros.
"""

Expand All @@ -281,7 +280,6 @@ class StandaloneAudit(_Node, AuditMixin):
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
default_catalog: t.Optional[str] = None
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
hash_raw_query: bool = False
python_env_: t.Optional[t.Dict[str, Executable]] = Field(default=None, alias="python_env")

source_type: Literal["audit"] = "audit"
Expand Down Expand Up @@ -353,7 +351,7 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str:
self.stamp,
]

query = self.query if self.hash_raw_query else self.render_query(self) or self.query
query = self.render_query(self) or self.query
data.append(query.sql(comments=False))

return hash_data(data)
Expand Down Expand Up @@ -602,6 +600,5 @@ def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
"standalone": exp.convert,
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
"tags": _single_value_or_tuple,
"hash_raw_query": exp.convert,
"default_catalog": exp.to_identifier,
}
42 changes: 20 additions & 22 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,9 @@ def _create_renderer(expression: exp.Expression) -> ExpressionRenderer:
)

def _render(e: exp.Expression) -> str | int | float | bool:
rendered_exprs = _create_renderer(e).render(
start=start, end=end, execution_time=execution_time
rendered_exprs = (
_create_renderer(e).render(start=start, end=end, execution_time=execution_time)
or []
)
if len(rendered_exprs) != 1:
raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}")
Expand Down Expand Up @@ -770,9 +771,7 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str:
elif audit_name in audits:
audit = audits[audit_name]
query = (
audit.query
if self.hash_raw_query
else audit.render_query(self, **t.cast(t.Dict[str, t.Any], audit_args))
audit.render_query(self, **t.cast(t.Dict[str, t.Any], audit_args))
or audit.query
)
metadata.extend(
Expand Down Expand Up @@ -884,7 +883,7 @@ def _render_statements(
for statement in statements
if not isinstance(statement, d.MacroDef)
)
return [r for expressions in rendered for r in expressions]
return [r for expressions in rendered if expressions for r in expressions]

def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
expression_key = id(expression)
Expand All @@ -903,21 +902,21 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:

@property
def _data_hash_values(self) -> t.List[str]:
statements = (
self._additional_metadata
if self.hash_raw_query
else [gen(e) for e in (*self.render_pre_statements(), *self.render_post_statements())]
)
return [
*super()._data_hash_values,
*statements,
]
data_hash_values = super()._data_hash_values

@property
def _additional_metadata(self) -> t.List[str]:
return [
gen(s) for s in (*self.pre_statements, *self.post_statements, *self.macro_definitions)
]
for statement in (*self.pre_statements, *self.post_statements):
statement_exprs: t.List[exp.Expression] = []
if isinstance(statement, d.MacroDef):
statement_exprs = [statement]
else:
rendered = self._statement_renderer(statement).render()
if rendered is not None:
statement_exprs = rendered
else:
statement_exprs = [statement]
data_hash_values.extend(gen(e) for e in statement_exprs)

return data_hash_values


class SqlModel(_SqlBasedModel):
Expand Down Expand Up @@ -1115,7 +1114,7 @@ def _query_renderer(self) -> QueryRenderer:
def _data_hash_values(self) -> t.List[str]:
data = super()._data_hash_values

query = self.query if self.hash_raw_query else self.render_query() or self.query
query = self.render_query() or self.query
data.append(gen(query))
data.extend(self.jinja_macros.data_hash_values)
return data
Expand Down Expand Up @@ -1935,7 +1934,6 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
"tags": _single_value_or_tuple,
"grains": _refs_to_sql,
"references": _refs_to_sql,
"hash_raw_query": exp.convert,
"table_properties_": lambda value: value,
"session_properties_": lambda value: value,
"allow_partials": exp.convert,
Expand Down
1 change: 0 additions & 1 deletion sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class ModelMeta(_Node):
audits: t.List[AuditReference] = []
grains: t.List[exp.Expression] = []
references: t.List[exp.Expression] = []
hash_raw_query: bool = False
physical_schema_override: t.Optional[str] = None
table_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="table_properties")
session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties")
Expand Down
21 changes: 12 additions & 9 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,18 @@ def render(
deployability_index: t.Optional[DeployabilityIndex] = None,
expand: t.Iterable[str] = tuple(),
**kwargs: t.Any,
) -> t.List[exp.Expression]:
expressions = super()._render(
start=start,
end=end,
execution_time=execution_time,
snapshots=snapshots,
deployability_index=deployability_index,
**kwargs,
)
) -> t.Optional[t.List[exp.Expression]]:
try:
expressions = super()._render(
start=start,
end=end,
execution_time=execution_time,
snapshots=snapshots,
deployability_index=deployability_index,
**kwargs,
)
except ParsetimeAdapterCallError:
return None

return [
self._resolve_tables(
Expand Down
1 change: 0 additions & 1 deletion sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
}.union({source.canonical_name(context) for source in model_context.sources.values()}),
"jinja_macros": jinja_macros,
"path": self.path,
"hash_raw_query": True,
"pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook],
"post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook],
"tags": self.tags,
Expand Down
1 change: 0 additions & 1 deletion sqlmesh/dbt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def to_sqlmesh(self, context: DbtContext) -> Audit:
{source.canonical_name(context) for source in test_context.sources.values()}
),
tags=self.tags,
hash_raw_query=True,
default_catalog=context.target.database,
**self.sqlmesh_config_kwargs,
)
Expand Down
57 changes: 57 additions & 0 deletions sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Remove hash_raw_query from existing snapshots."""

import json

import pandas as pd
from sqlglot import exp

from sqlmesh.utils.migration import index_text_type


def migrate(state_sync, **kwargs): # type: ignore
engine_adapter = state_sync.engine_adapter
schema = state_sync.schema
snapshots_table = "_snapshots"
if schema:
snapshots_table = f"{schema}.{snapshots_table}"

new_snapshots = []

for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall(
exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_(
snapshots_table
),
quote_identifiers=True,
):
parsed_snapshot = json.loads(snapshot)
parsed_snapshot["node"].pop("hash_raw_query", None)

new_snapshots.append(
{
"name": name,
"identifier": identifier,
"version": version,
"snapshot": json.dumps(parsed_snapshot),
"kind_name": kind_name,
"expiration_ts": expiration_ts,
}
)

if new_snapshots:
engine_adapter.delete_from(snapshots_table, "TRUE")

index_type = index_text_type(engine_adapter.dialect)

engine_adapter.insert_append(
snapshots_table,
pd.DataFrame(new_snapshots),
columns_to_types={
"name": exp.DataType.build(index_type),
"identifier": exp.DataType.build(index_type),
"version": exp.DataType.build(index_type),
"snapshot": exp.DataType.build("text"),
"kind_name": exp.DataType.build(index_type),
"expiration_ts": exp.DataType.build("bigint"),
},
contains_json=True,
)
26 changes: 26 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlglot.expressions import DataType

from sqlmesh.core import constants as c
from sqlmesh.core import dialect as d
from sqlmesh.core.config import AutoCategorizationMode
from sqlmesh.core.console import Console
from sqlmesh.core.context import Context
Expand Down Expand Up @@ -1040,6 +1041,31 @@ def test_select_models_for_backfill(init_and_plan_context: t.Callable):
)


@freeze_time("2023-01-08 15:00:00")
def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context):
context = sushi_test_dbt_context

model = context.get_model("sushi.simple_model_a")
context.upsert_model(
SqlModel.parse_obj(
{
**model.dict(),
"query": d.parse_one("SELECT 1 AS a, 2 AS b"),
}
)
)

snapshot_a_id = context.get_snapshot("sushi.simple_model_a").snapshot_id # type: ignore
snapshot_b_id = context.get_snapshot("sushi.simple_model_b").snapshot_id # type: ignore

plan = context.plan_builder("dev", skip_tests=True).build()
assert plan.directly_modified == {snapshot_a_id, snapshot_b_id}
assert {i.snapshot_id for i in plan.missing_intervals} == {snapshot_a_id, snapshot_b_id}

assert plan.snapshots[snapshot_a_id].change_category == SnapshotChangeCategory.NON_BREAKING
assert plan.snapshots[snapshot_b_id].change_category == SnapshotChangeCategory.NON_BREAKING


@pytest.mark.parametrize(
"context_fixture",
["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context", "sushi_no_default_catalog"],
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_json(snapshot: Snapshot):
"tags": [],
"grains": [],
"references": [],
"hash_raw_query": False,
"allow_partials": False,
"signals": [],
},
Expand Down Expand Up @@ -628,13 +627,13 @@ def test_fingerprint(model: Model, parent_model: Model):
fingerprint = fingerprint_from_node(model, nodes={})
assert new_fingerprint != fingerprint
assert new_fingerprint.data_hash != fingerprint.data_hash
assert new_fingerprint.metadata_hash != fingerprint.metadata_hash
assert new_fingerprint.metadata_hash == fingerprint.metadata_hash

model = SqlModel(**{**original_model.dict(), "post_statements": [parse_one("DROP TABLE test")]})
fingerprint = fingerprint_from_node(model, nodes={})
assert new_fingerprint != fingerprint
assert new_fingerprint.data_hash != fingerprint.data_hash
assert new_fingerprint.metadata_hash != fingerprint.metadata_hash
assert new_fingerprint.metadata_hash == fingerprint.metadata_hash


def test_fingerprint_seed_model():
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/dbt/sushi_test/models/simple_model_a.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

SELECT 1 AS a
2 changes: 2 additions & 0 deletions tests/fixtures/dbt/sushi_test/models/simple_model_b.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

SELECT * FROM {{ ref("simple_model_a") }}
1 change: 0 additions & 1 deletion tests/schedulers/airflow/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
"source_type": "sql",
"tags": [],
"grains": [],
"hash_raw_query": False,
"allow_partials": False,
"signals": [],
},
Expand Down