Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion docs/concepts/tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Tests within a suite file contain the following attributes:
* The list of rows that are expected to be returned by the model's query defined as a mapping from a column name to a value associated with it
* [Optional] The list of expected rows per each individual [Common Table Expression](glossary.md#cte) (CTE) defined in the model's query
* [Optional] The dictionary of values for macro variables that will be set during model testing
* There are three special macros that can be overridden, `start`, `end`, and `execution_time`. Overriding each will allow you to override the date macros in your SQL queries. For example, setting execution_time: 2022-01-01 -> execution_ds in your queries.
* There are three special macro variables: `start`, `end`, and `execution_time`. Setting these will allow you to override the date macros in your SQL queries. For example, `@execution_ds` will render to `2022-01-01` if `execution_time` is set to this value. Additionally, SQL expressions like `CURRENT_DATE` and `CURRENT_TIMESTAMP` will result in the same datetime value as `execution_time`, when it is set.

The YAML format is defined as follows:

Expand Down Expand Up @@ -210,6 +210,55 @@ test_example_full_model:
num_orders: 2
```

### Freezing Time

Some models may use SQL expressions that compute datetime values at a given point in time, such as `CURRENT_TIMESTAMP`. Since these expressions are non-deterministic, it's not enough to simply specify an expected output value in order to test them.

Setting the `execution_time` macro variable addresses this problem by mocking out the current time in the context of the test, thus making its value deterministic.

The following example demonstrates how `execution_time` can be used to test a column that is computed using `CURRENT_TIMESTAMP`. The model we're going to test is defined as:

```sql linenums="1"
MODEL (
name colors,
kind FULL
);

SELECT
'Yellow' AS color,
CURRENT_TIMESTAMP AS created_at
```

And the corresponding test is:

```yaml linenums="1"
test_colors:
model: colors
outputs:
query:
- color: "Yellow"
created_at: "2023-01-01 12:05:03"
vars:
execution_time: "2023-01-01 12:05:03"
```

It's also possible to set a time zone for `execution_time`, by including it in the timestamp string.

If a time zone is provided, it is currently required that the test's _expected_ datetime values are timestamps without time zone, meaning that they need to be offset accordingly.

Here's how we would write the above test if we wanted to freeze the time to UTC+2:

```yaml linenums="1"
test_colors:
model: colors
outputs:
query:
- color: "Yellow"
created_at: "2023-01-01 10:05:03"
vars:
execution_time: "2023-01-01 12:05:03+02:00"
```

## Automatic test generation

Creating tests manually can be repetitive and error-prone, which is why SQLMesh also provides a way to automate this process using the [`create_test` command](../reference/cli.md#create_test).
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"croniter",
"duckdb",
"dateparser",
"freezegun",
"hyperscript",
"ipywidgets",
"jinja2",
Expand Down Expand Up @@ -66,7 +67,6 @@
"dbt-core",
"dbt-duckdb>=1.7.1",
"Faker",
"freezegun",
"google-auth",
"google-cloud-bigquery",
"google-cloud-bigquery-storage",
Expand Down
45 changes: 33 additions & 12 deletions sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import typing as t
import unittest
from collections import Counter
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
from sqlglot import exp
from freezegun import freeze_time
from sqlglot import Dialect, exp
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

Expand Down Expand Up @@ -70,6 +73,20 @@ def __init__(
if depends_on not in inputs:
_raise_error(f"Incomplete test, missing input for table {depends_on}", path)

self._engine_adapter_dialect = Dialect.get_or_raise(self.engine_adapter.dialect)
self._transforms = self._engine_adapter_dialect.generator_class.TRANSFORMS

self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
if self._execution_time:
exec_time = exp.Literal.string(self._execution_time)
self._transforms = {
**self._transforms,
exp.CurrentDate: lambda self, _: self.sql(exp.cast(exec_time, "date")),
exp.CurrentDatetime: lambda self, _: self.sql(exp.cast(exec_time, "datetime")),
exp.CurrentTime: lambda self, _: self.sql(exp.cast(exec_time, "time")),
exp.CurrentTimestamp: lambda self, _: self.sql(exp.cast(exec_time, "timestamp")),
}

super().__init__()

def shortDescription(self) -> t.Optional[str]:
Expand Down Expand Up @@ -275,10 +292,6 @@ def _normalize_sources(sources: t.Dict, partial: bool = False) -> t.Dict:


class SqlModelTest(ModelTest):
def _execute(self, query: exp.Expression) -> pd.DataFrame:
"""Executes the query with the engine adapter and returns a DataFrame."""
return self.engine_adapter.fetchdf(query)

def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
"""Run CTE queries and compare output to expected output"""
for cte_name, values in self.body["outputs"].get("ctes", {}).items():
Expand Down Expand Up @@ -331,6 +344,11 @@ def runTest(self) -> None:

self.assert_equal(expected, actual, sort=sort, partial=partial)

def _execute(self, query: exp.Expression) -> pd.DataFrame:
"""Executes the query with the engine adapter and returns a DataFrame."""
with patch.dict(self._engine_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Dialect.get_or_raise(self.engine_adapter.dialect) return a global instance or a new one each time? If it's the latter, can't we just set transformers once and not patch it?

Copy link
Copy Markdown
Collaborator Author

@georgesittas georgesittas Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we can't do this, because TRANSFORMS is a class attribute and since generator classes are singletons we'd end up mutating dialects globally if we didn't patch it.

return self.engine_adapter.fetchdf(query)


class PythonModelTest(ModelTest):
def __init__(
Expand Down Expand Up @@ -368,13 +386,6 @@ def __init__(
default_catalog=default_catalog,
)

def _execute_model(self) -> pd.DataFrame:
"""Executes the python model and returns a DataFrame."""
return t.cast(
pd.DataFrame,
next(self.model.render(context=self.context, **self.body.get("vars", {}))),
)

def runTest(self) -> None:
values = self.body["outputs"].get("query")
if values is not None:
Expand All @@ -387,6 +398,16 @@ def runTest(self) -> None:

self.assert_equal(expected, actual_df, sort=False, partial=partial)

def _execute_model(self) -> pd.DataFrame:
"""Executes the python model and returns a DataFrame."""
time_ctx = freeze_time(self._execution_time) if self._execution_time else nullcontext()
with patch.dict(self._engine_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
with t.cast(AbstractContextManager, time_ctx):
return t.cast(
pd.DataFrame,
next(self.model.render(context=self.context, **self.body.get("vars", {}))),
)


def generate_test(
model: Model,
Expand Down
88 changes: 80 additions & 8 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import typing as t
from pathlib import Path

import pandas as pd
import pytest
from pytest_mock.plugin import MockerFixture
from sqlglot import exp

from sqlmesh.cli.example_project import init_example_project
from sqlmesh.core import constants as c
from sqlmesh.core.config import Config, DuckDBConnectionConfig, ModelDefaultsConfig
from sqlmesh.core.context import Context
from sqlmesh.core.dialect import parse
from sqlmesh.core.model import SqlModel, load_sql_based_model
from sqlmesh.core.test.definition import SqlModelTest
from sqlmesh.core.model import PythonModel, SqlModel, load_sql_based_model, model
from sqlmesh.core.test.definition import PythonModelTest, SqlModelTest
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.yaml import load as load_yaml

Expand All @@ -25,13 +27,21 @@
SUSHI_FOO_META = "MODEL (name sushi.foo, kind FULL)"


@t.overload
def _create_test(
body: t.Dict[str, t.Any],
test_name: str,
model: SqlModel,
context: Context,
) -> SqlModelTest:
return SqlModelTest(
body: t.Dict[str, t.Any], test_name: str, model: SqlModel, context: Context
) -> SqlModelTest: ...


@t.overload
def _create_test(
body: t.Dict[str, t.Any], test_name: str, model: PythonModel, context: Context
) -> PythonModelTest: ...


def _create_test(body, test_name, model, context):
test_type = SqlModelTest if isinstance(model, SqlModel) else PythonModelTest
return test_type(
body=body[test_name],
test_name=test_name,
model=model,
Expand Down Expand Up @@ -841,6 +851,68 @@ def test_nested_data_types() -> None:
_check_successful_or_raise(result)


def test_freeze_time(mocker: MockerFixture) -> None:
test = _create_test(
body=load_yaml(
"""
test_foo:
model: xyz
outputs:
query:
- cur_date: 2023-01-01
cur_time: 12:05:03
cur_timestamp: "2023-01-01 12:05:03"
vars:
execution_time: "2023-01-01 12:05:03+00:00"
"""
),
test_name="test_foo",
model=_create_model(
"SELECT CURRENT_DATE AS cur_date, CURRENT_TIME AS cur_time, CURRENT_TIMESTAMP AS cur_timestamp"
),
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
)

spy_execute = mocker.spy(test.engine_adapter, "_execute")
_check_successful_or_raise(test.run())

spy_execute.assert_called_with(
"SELECT "
"""CAST('2023-01-01 12:05:03+00:00' AS DATE) AS "cur_date", """
"""CAST('2023-01-01 12:05:03+00:00' AS TIME) AS "cur_time", """
'''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''',
)

@model("py_model", columns={"ts1": "timestamptz", "ts2": "timestamptz"})
def execute(context, start, end, execution_time, **kwargs):
datetime_now = datetime.datetime.now()

context.engine_adapter.execute(exp.select("CURRENT_TIMESTAMP"))
current_timestamp = context.engine_adapter.cursor.fetchone()[0]

return pd.DataFrame([{"ts1": datetime_now, "ts2": current_timestamp}])

test = _create_test(
body=load_yaml(
"""
test_py_model:
model: py_model
outputs:
query:
- ts1: "2023-01-01 10:05:03"
ts2: "2023-01-01 10:05:03"
vars:
execution_time: "2023-01-01 12:05:03+02:00"
"""
),
test_name="test_py_model",
model=model.get_registry()["py_model"].model(module_path=Path("."), path=Path(".")),
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
)

_check_successful_or_raise(test.run())


def test_successes(sushi_context: Context) -> None:
results = sushi_context.test()
successful_tests = [success.test_name for success in results.successes] # type: ignore
Expand Down