Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sushi/models/waiter_names.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ MODEL (
path '../seeds/waiter_names.csv',
batch_size 5
),
dialect duckdb,
owner jen
)
)
6 changes: 5 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from sqlmesh.core.config import Config, load_config_from_paths
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.core.dialect import format_model_expressions, parse_model
from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse_model
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.environment import Environment
from sqlmesh.core.hooks import hook
Expand Down Expand Up @@ -503,6 +503,10 @@ def render(

expand = self.dag.upstream(model.name) if expand is True else expand or []

if model.is_seed:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this logic be a part of the render_query overridden in SeedModel? I'd rather avoid ifs like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it's a different purpose. render_query is used for CTAS and other things. the context render is purely for the user

df = next(model.render(self, start=start, end=end, latest=latest, **kwargs))
return next(pandas_to_sql(df, model.columns_to_types))

return model.render_query(
start=start,
end=end,
Expand Down
18 changes: 12 additions & 6 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def _create_parser(
parser_type: t.Type[exp.Expression], table_keys: t.List[str]
) -> t.Callable:
def parse(self: Parser) -> t.Optional[exp.Expression]:
from sqlmesh.core.model.kind import ModelKindName

expressions = []

while True:
Expand All @@ -249,12 +251,13 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
if not id_var:
value = None
else:
id_var = id_var.name.lower()
index = self._index
if id_var in (
"incremental_by_time_range",
"incremental_by_unique_key",
"seed",
kind = ModelKindName[id_var.name.upper()]

if kind in (
ModelKindName.INCREMENTAL_BY_TIME_RANGE,
ModelKindName.INCREMENTAL_BY_UNIQUE_KEY,
ModelKindName.SEED,
) and self._match(TokenType.L_PAREN):
self._retreat(index)
props = self._parse_wrapped_csv(
Expand All @@ -264,7 +267,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
props = None
value = self.expression(
ModelKind,
this=id_var,
this=kind.value,
expressions=props,
)
else:
Expand Down Expand Up @@ -336,6 +339,9 @@ def format_model_expressions(
Returns:
A string with the formatted model.
"""
if len(expressions) == 1:
return expressions[0].sql(pretty=True, dialect=dialect)

*statements, query = expressions
query = query.copy()
selects = []
Expand Down
8 changes: 1 addition & 7 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def render_definition(self) -> t.List[exp.Expression]:
expressions.append(
exp.Property(
this="kind",
value=field_value.to_expression(self.dialect),
value=field_value.to_expression(dialect=self.dialect),
)
)
else:
Expand Down Expand Up @@ -667,11 +667,6 @@ def render(
) -> t.Generator[QueryOrDF, None, None]:
yield from self.seed.read(batch_size=self.kind.batch_size)

def render_definition(self) -> t.List[exp.Expression]:
result = super().render_definition()
result.append(exp.Literal.string(self.seed.content))
return result

def text_diff(self, other: Model) -> str:
if not isinstance(other, SeedModel):
return super().text_diff(other)
Expand Down Expand Up @@ -986,7 +981,6 @@ def create_seed_model(
name,
defaults=defaults,
path=path,
depends_on=set(),
seed=seed,
kind=seed_kind,
**kwargs,
Expand Down
42 changes: 27 additions & 15 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
class ModelKindName(str, Enum):
"""The kind of model, determining how this data is computed and stored in the warehouse."""

INCREMENTAL_BY_TIME_RANGE = "incremental_by_time_range"
INCREMENTAL_BY_UNIQUE_KEY = "incremental_by_unique_key"
FULL = "full"
SNAPSHOT = "snapshot"
VIEW = "view"
EMBEDDED = "embedded"
SEED = "seed"
INCREMENTAL_BY_TIME_RANGE = "INCREMENTAL_BY_TIME_RANGE"
INCREMENTAL_BY_UNIQUE_KEY = "INCREMENTAL_BY_UNIQUE_KEY"
FULL = "FULL"
SNAPSHOT = "SNAPSHOT"
VIEW = "VIEW"
EMBEDDED = "EMBEDDED"
SEED = "SEED"


class ModelKind(PydanticModel):
Expand Down Expand Up @@ -65,10 +65,8 @@ def only_latest(self) -> bool:
"""Whether or not this model only cares about latest date to render."""
return self.name in (ModelKindName.VIEW, ModelKindName.FULL)

def to_expression(self, *args: t.Any, **kwargs: t.Any) -> d.ModelKind:
return d.ModelKind(
this=self.name.value,
)
def to_expression(self, **kwargs: t.Any) -> d.ModelKind:
return d.ModelKind(this=self.name.value.upper(), **kwargs)


class TimeColumn(PydanticModel):
Expand Down Expand Up @@ -126,9 +124,8 @@ def _parse_time_column(cls, v: t.Any) -> TimeColumn:
return TimeColumn(column=v)
return v

def to_expression(self, dialect: str) -> d.ModelKind:
return d.ModelKind(
this=self.name.value,
def to_expression(self, dialect: str = "", **kwargs: t.Any) -> d.ModelKind:
return super().to_expression(
expressions=[
exp.Property(
this="time_column", value=self.time_column.to_expression(dialect)
Expand Down Expand Up @@ -171,6 +168,20 @@ def _parse_path(cls, v: t.Any) -> str:
return v.this
return str(v)

def to_expression(self, **kwargs: t.Any) -> d.ModelKind:
"""Convert the seed kind into a SQLGlot expression."""
return super().to_expression(
expressions=[
exp.Property(
this=exp.Var(this="path"), value=exp.Literal.string(self.path)
),
exp.Property(
this=exp.Var(this="batch_size"),
value=exp.Literal.number(self.batch_size),
),
],
)


def _model_kind_validator(v: t.Any) -> ModelKind:
if isinstance(v, ModelKind):
Expand Down Expand Up @@ -201,7 +212,8 @@ def _model_kind_validator(v: t.Any) -> ModelKind:
klass = ModelKind
return klass(**v)

name = (v.name if isinstance(v, exp.Expression) else str(v)).lower()
name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()

try:
return ModelKind(name=ModelKindName(name))
except ValueError:
Expand Down
9 changes: 5 additions & 4 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,14 @@ def validate_model_kind_change(
"sushi.customer_revenue_by_day",
"sushi.top_waiters",
]
kind: ModelKind = ModelKind(name=kind_name)
if kind_name == "incremental_by_time_range":
kind = IncrementalByTimeRangeKind(
if kind_name == ModelKindName.INCREMENTAL_BY_TIME_RANGE:
kind: ModelKind = IncrementalByTimeRangeKind(
time_column=TimeColumn(column="ds", format="%Y-%m-%d")
)
elif kind_name == "incremental_by_unique_key":
elif kind_name == ModelKindName.INCREMENTAL_BY_UNIQUE_KEY:
kind = IncrementalByUniqueKeyKind(unique_key="id")
else:
kind = ModelKind(name=kind_name)

def _validate_plan(context, plan):
validate_plan_changes(plan, modified=directly_modified + indirectly_modified)
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_json(snapshot: Snapshot):
"cron": "1 0 * * *",
"batch_size": 30,
"kind": {
"name": "incremental_by_time_range",
"name": "INCREMENTAL_BY_TIME_RANGE",
"time_column": {"column": "ds"},
},
"start": "2020-01-01",
Expand All @@ -86,7 +86,7 @@ def test_json(snapshot: Snapshot):
"previous_versions": [],
"indirect_versions": {},
"updated_ts": 1663891973000,
"version": snapshot.version,
"version": snapshot.fingerprint.dict(),
}


Expand Down Expand Up @@ -256,7 +256,7 @@ def test_fingerprint(model: Model, parent_model: Model):
fingerprint = fingerprint_from_model(model, models={})

original_fingerprint = SnapshotFingerprint(
data_hash="713628577",
data_hash="457513203",
metadata_hash="3589467163",
)

Expand Down
20 changes: 12 additions & 8 deletions tests/schedulers/airflow/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from urllib.parse import urlencode

import pytest
import requests
Expand All @@ -9,7 +10,7 @@
from sqlmesh.core.model import SqlModel
from sqlmesh.core.snapshot import Snapshot, SnapshotNameVersion
from sqlmesh.schedulers.airflow import common
from sqlmesh.schedulers.airflow.client import AirflowClient
from sqlmesh.schedulers.airflow.client import AirflowClient, _list_to_json
from sqlmesh.utils.date import to_datetime


Expand Down Expand Up @@ -81,7 +82,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
"pre": [],
"post": [],
"kind": {
"name": "incremental_by_time_range",
"name": "INCREMENTAL_BY_TIME_RANGE",
"time_column": {"column": "ds"},
},
"name": "test_model",
Expand Down Expand Up @@ -130,6 +131,10 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
}


def snapshot_url(snapshot_ids, key="ids") -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks 👍

return urlencode({key: _list_to_json(snapshot_ids)})


def test_get_snapshots(mocker: MockerFixture, snapshot: Snapshot):
snapshots = common.SnapshotsResponse(snapshots=[snapshot])

Expand All @@ -147,7 +152,7 @@ def test_get_snapshots(mocker: MockerFixture, snapshot: Snapshot):
assert result == [snapshot]

get_snapshots_mock.assert_called_once_with(
"http://localhost:8080/sqlmesh/api/v1/snapshots?ids=%5B%7B%22name%22%3A%22test_model%22%2C%22identifier%22%3A%223654063500%22%7D%5D"
f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url([snapshot.snapshot_id])}"
)


Expand All @@ -168,7 +173,7 @@ def test_snapshots_exist(mocker: MockerFixture, snapshot: Snapshot):
assert result == {snapshot.snapshot_id}

snapshots_exist_mock.assert_called_once_with(
"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&ids=%5B%7B%22name%22%3A%22test_model%22%2C%22identifier%22%3A%223654063500%22%7D%5D"
f"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&{snapshot_url([snapshot.snapshot_id])}"
)


Expand All @@ -184,14 +189,13 @@ def test_get_snapshots_with_same_version(mocker: MockerFixture, snapshot: Snapsh
client = AirflowClient(
airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()
)
result = client.get_snapshots_with_same_version(
[SnapshotNameVersion(name=snapshot.name, version=snapshot.version)]
)
versions = [SnapshotNameVersion(name=snapshot.name, version=snapshot.version)]
result = client.get_snapshots_with_same_version(versions)

assert result == [snapshot]

get_snapshots_mock.assert_called_once_with(
"http://localhost:8080/sqlmesh/api/v1/snapshots?versions=%5B%7B%22name%22%3A%22test_model%22%2C%22version%22%3A%222710441016%22%7D%5D"
f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url(versions, 'versions')}"
)


Expand Down