From 83d2657cd7eb458a1d6771e6515d24fd0e5ad33c Mon Sep 17 00:00:00 2001 From: tobymao Date: Tue, 7 Feb 2023 16:21:23 -0800 Subject: [PATCH 1/2] model magic for seeds fixes #225 --- examples/sushi/models/waiter_names.sql | 3 +- sqlmesh/core/context.py | 6 +++- sqlmesh/core/dialect.py | 18 +++++++---- sqlmesh/core/model/definition.py | 8 +---- sqlmesh/core/model/kind.py | 42 ++++++++++++++++--------- tests/core/test_snapshot.py | 6 ++-- tests/schedulers/airflow/test_client.py | 20 +++++++----- 7 files changed, 62 insertions(+), 41 deletions(-) diff --git a/examples/sushi/models/waiter_names.sql b/examples/sushi/models/waiter_names.sql index e40ed784f7..6479713239 100644 --- a/examples/sushi/models/waiter_names.sql +++ b/examples/sushi/models/waiter_names.sql @@ -4,5 +4,6 @@ MODEL ( path '../seeds/waiter_names.csv', batch_size 5 ), + dialect duckdb, owner jen -) +) \ No newline at end of file diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 7cd517df8f..60a7fc955a 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -49,7 +49,7 @@ from sqlmesh.core.config import Config, load_config_from_paths from sqlmesh.core.console import Console, get_console from sqlmesh.core.context_diff import ContextDiff -from sqlmesh.core.dialect import format_model_expressions, parse_model +from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse_model from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.environment import Environment from sqlmesh.core.hooks import hook @@ -503,6 +503,10 @@ def render( expand = self.dag.upstream(model.name) if expand is True else expand or [] + if model.is_seed: + df = next(model.render(self, start=start, end=end, latest=latest, **kwargs)) + return next(pandas_to_sql(df, model.columns_to_types)) + return model.render_query( start=start, end=end, diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 2dad872fc3..af68ab605f 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -228,6 +228,8 @@ def _create_parser( parser_type: t.Type[exp.Expression], table_keys: t.List[str] ) -> t.Callable: def parse(self: Parser) -> t.Optional[exp.Expression]: + from sqlmesh.core.model.kind import ModelKindName + expressions = [] while True: @@ -249,12 +251,13 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: if not id_var: value = None else: - id_var = id_var.name.lower() index = self._index - if id_var in ( - "incremental_by_time_range", - "incremental_by_unique_key", - "seed", + kind = ModelKindName[id_var.name.upper()] + + if kind in ( + ModelKindName.INCREMENTAL_BY_TIME_RANGE, + ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, + ModelKindName.SEED, ) and self._match(TokenType.L_PAREN): self._retreat(index) props = self._parse_wrapped_csv( @@ -264,7 +267,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: props = None value = self.expression( ModelKind, - this=id_var, + this=kind.value, expressions=props, ) else: @@ -336,6 +339,9 @@ def format_model_expressions( Returns: A string with the formatted model. """ + if len(expressions) == 1: + return expressions[0].sql(pretty=True, dialect=dialect) + *statements, query = expressions query = query.copy() selects = [] diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index b7732bcb5d..a332dc59cd 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -162,7 +162,7 @@ def render_definition(self) -> t.List[exp.Expression]: expressions.append( exp.Property( this="kind", - value=field_value.to_expression(self.dialect), + value=field_value.to_expression(dialect=self.dialect), ) ) else: @@ -667,11 +667,6 @@ def render( ) -> t.Generator[QueryOrDF, None, None]: yield from self.seed.read(batch_size=self.kind.batch_size) - def render_definition(self) -> t.List[exp.Expression]: - result = super().render_definition() - result.append(exp.Literal.string(self.seed.content)) - return result - def text_diff(self, other: Model) -> str: if not isinstance(other, SeedModel): return super().text_diff(other) @@ -986,7 +981,6 @@ def create_seed_model( name, defaults=defaults, path=path, - depends_on=set(), seed=seed, kind=seed_kind, **kwargs, diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 5b4e6216a8..b26e0f65ba 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -16,13 +16,13 @@ class ModelKindName(str, Enum): """The kind of model, determining how this data is computed and stored in the warehouse.""" - INCREMENTAL_BY_TIME_RANGE = "incremental_by_time_range" - INCREMENTAL_BY_UNIQUE_KEY = "incremental_by_unique_key" - FULL = "full" - SNAPSHOT = "snapshot" - VIEW = "view" - EMBEDDED = "embedded" - SEED = "seed" + INCREMENTAL_BY_TIME_RANGE = "INCREMENTAL_BY_TIME_RANGE" + INCREMENTAL_BY_UNIQUE_KEY = "INCREMENTAL_BY_UNIQUE_KEY" + FULL = "FULL" + SNAPSHOT = "SNAPSHOT" + VIEW = "VIEW" + EMBEDDED = "EMBEDDED" + SEED = "SEED" class ModelKind(PydanticModel): @@ -65,10 +65,8 @@ def only_latest(self) -> bool: """Whether or not this model only cares about latest date to render.""" return self.name in (ModelKindName.VIEW, ModelKindName.FULL) - def to_expression(self, *args: t.Any, **kwargs: t.Any) -> d.ModelKind: - return d.ModelKind( - this=self.name.value, - ) + def to_expression(self, **kwargs: t.Any) -> d.ModelKind: + return d.ModelKind(this=self.name.value.upper(), **kwargs) class TimeColumn(PydanticModel): @@ -126,9 +124,8 @@ def _parse_time_column(cls, v: t.Any) -> TimeColumn: return TimeColumn(column=v) return v - def to_expression(self, dialect: str) -> d.ModelKind: - return d.ModelKind( - this=self.name.value, + def to_expression(self, dialect: str = "", **kwargs: t.Any) -> d.ModelKind: + return super().to_expression( expressions=[ exp.Property( this="time_column", value=self.time_column.to_expression(dialect) @@ -171,6 +168,20 @@ def _parse_path(cls, v: t.Any) -> str: return v.this return str(v) + def to_expression(self, **kwargs: t.Any) -> d.ModelKind: + """Convert the seed kind into a SQLGlot expression.""" + return super().to_expression( + expressions=[ + exp.Property( + this=exp.Var(this="path"), value=exp.Literal.string(self.path) + ), + exp.Property( + this=exp.Var(this="batch_size"), + value=exp.Literal.number(self.batch_size), + ), + ], + ) + def _model_kind_validator(v: t.Any) -> ModelKind: if isinstance(v, ModelKind): @@ -201,7 +212,8 @@ def _model_kind_validator(v: t.Any) -> ModelKind: klass = ModelKind return klass(**v) - name = (v.name if isinstance(v, exp.Expression) else str(v)).lower() + name = (v.name if isinstance(v, exp.Expression) else str(v)).upper() + try: return ModelKind(name=ModelKindName(name)) except ValueError: diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 6a10af2818..d5c9d48624 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -65,7 +65,7 @@ def test_json(snapshot: Snapshot): "cron": "1 0 * * *", "batch_size": 30, "kind": { - "name": "incremental_by_time_range", + "name": "INCREMENTAL_BY_TIME_RANGE", "time_column": {"column": "ds"}, }, "start": "2020-01-01", @@ -86,7 +86,7 @@ def test_json(snapshot: Snapshot): "previous_versions": [], "indirect_versions": {}, "updated_ts": 1663891973000, - "version": snapshot.version, + "version": snapshot.fingerprint.dict(), } @@ -256,7 +256,7 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_model(model, models={}) original_fingerprint = SnapshotFingerprint( - data_hash="713628577", + data_hash="457513203", metadata_hash="3589467163", ) diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index c02040d0f8..a168ae0cff 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -1,4 +1,5 @@ import json +from urllib.parse import urlencode import pytest import requests @@ -9,7 +10,7 @@ from sqlmesh.core.model import SqlModel from sqlmesh.core.snapshot import Snapshot, SnapshotNameVersion from sqlmesh.schedulers.airflow import common -from sqlmesh.schedulers.airflow.client import AirflowClient +from sqlmesh.schedulers.airflow.client import AirflowClient, _list_to_json from sqlmesh.utils.date import to_datetime @@ -81,7 +82,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): "pre": [], "post": [], "kind": { - "name": "incremental_by_time_range", + "name": "INCREMENTAL_BY_TIME_RANGE", "time_column": {"column": "ds"}, }, "name": "test_model", @@ -130,6 +131,10 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): } +def snapshot_url(snapshot_ids, key="ids") -> str: + return urlencode({key: _list_to_json(snapshot_ids)}) + + def test_get_snapshots(mocker: MockerFixture, snapshot: Snapshot): snapshots = common.SnapshotsResponse(snapshots=[snapshot]) @@ -147,7 +152,7 @@ def test_get_snapshots(mocker: MockerFixture, snapshot: Snapshot): assert result == [snapshot] get_snapshots_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/snapshots?ids=%5B%7B%22name%22%3A%22test_model%22%2C%22identifier%22%3A%223654063500%22%7D%5D" + f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url([snapshot.snapshot_id])}" ) @@ -168,7 +173,7 @@ def test_snapshots_exist(mocker: MockerFixture, snapshot: Snapshot): assert result == {snapshot.snapshot_id} snapshots_exist_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&ids=%5B%7B%22name%22%3A%22test_model%22%2C%22identifier%22%3A%223654063500%22%7D%5D" + f"http://localhost:8080/sqlmesh/api/v1/snapshots?return_ids&{snapshot_url([snapshot.snapshot_id])}" ) @@ -184,14 +189,13 @@ def test_get_snapshots_with_same_version(mocker: MockerFixture, snapshot: Snapsh client = AirflowClient( airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session() ) - result = client.get_snapshots_with_same_version( - [SnapshotNameVersion(name=snapshot.name, version=snapshot.version)] - ) + versions = [SnapshotNameVersion(name=snapshot.name, version=snapshot.version)] + result = client.get_snapshots_with_same_version(versions) assert result == [snapshot] get_snapshots_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/snapshots?versions=%5B%7B%22name%22%3A%22test_model%22%2C%22version%22%3A%222710441016%22%7D%5D" + f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url(versions, 'versions')}" ) From d14387a74c62e88562c90a29edca08c78841edde Mon Sep 17 00:00:00 2001 From: tobymao Date: Wed, 8 Feb 2023 09:30:14 -0800 Subject: [PATCH 2/2] fix tests --- tests/core/test_integration.py | 9 +++++---- tests/schedulers/airflow/test_client.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 509554de56..256646336f 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -258,13 +258,14 @@ def validate_model_kind_change( "sushi.customer_revenue_by_day", "sushi.top_waiters", ] - kind: ModelKind = ModelKind(name=kind_name) - if kind_name == "incremental_by_time_range": - kind = IncrementalByTimeRangeKind( + if kind_name == ModelKindName.INCREMENTAL_BY_TIME_RANGE: + kind: ModelKind = IncrementalByTimeRangeKind( time_column=TimeColumn(column="ds", format="%Y-%m-%d") ) - elif kind_name == "incremental_by_unique_key": + elif kind_name == ModelKindName.INCREMENTAL_BY_UNIQUE_KEY: kind = IncrementalByUniqueKeyKind(unique_key="id") + else: + kind = ModelKind(name=kind_name) def _validate_plan(context, plan): validate_plan_changes(plan, modified=directly_modified + indirectly_modified) diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index a168ae0cff..9970af67f5 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -173,7 +173,7 @@ def test_snapshots_exist(mocker: MockerFixture, snapshot: Snapshot): assert result == {snapshot.snapshot_id} snapshots_exist_mock.assert_called_once_with( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?return_ids&{snapshot_url([snapshot.snapshot_id])}" + f"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&{snapshot_url([snapshot.snapshot_id])}" )