Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import abc
import collections
import contextlib
import gc
import traceback
import typing as t
Expand Down Expand Up @@ -947,6 +946,7 @@ def test(
match_patterns: t.Optional[t.List[str]] = None,
tests: t.Optional[t.List[str]] = None,
verbose: bool = False,
stream: t.Optional[t.TextIO] = None,
) -> unittest.result.TestResult:
"""Discover and run model tests"""
verbosity = 2 if verbose else 1
Expand Down Expand Up @@ -977,6 +977,7 @@ def test(
models=self._models,
engine_adapter=self._test_engine_adapter,
verbosity=verbosity,
stream=stream,
)
finally:
self._test_engine_adapter.close()
Expand Down Expand Up @@ -1097,8 +1098,7 @@ def close(self) -> None:

def _run_tests(self) -> t.Tuple[unittest.result.TestResult, str]:
test_output_io = StringIO()
with contextlib.redirect_stderr(test_output_io):
result = self.test()
result = self.test(stream=test_output_io)
return result, test_output_io.getvalue()

def _run_plan_tests(
Expand Down
9 changes: 7 additions & 2 deletions sqlmesh/core/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pathlib
import typing as t
import unittest

from sqlmesh.core.engine_adapter import EngineAdapter
Expand All @@ -20,6 +21,7 @@ def run_tests(
models: dict[str, Model],
engine_adapter: EngineAdapter,
verbosity: int = 1,
stream: t.TextIO | None = None,
) -> unittest.result.TestResult:
"""Create a test suite of ModelTest objects and run it.

Expand All @@ -39,7 +41,9 @@ def run_tests(
)
for metadata in model_test_metadata
)
return unittest.TextTestRunner(verbosity=verbosity, resultclass=ModelTextTestResult).run(suite)
return unittest.TextTestRunner(
stream=stream, verbosity=verbosity, resultclass=ModelTextTestResult
).run(suite)


def run_model_tests(
Expand All @@ -48,6 +52,7 @@ def run_model_tests(
engine_adapter: EngineAdapter,
verbosity: int = 1,
patterns: list[str] | None = None,
stream: t.TextIO | None = None,
) -> unittest.result.TestResult:
"""Load and run tests.

Expand All @@ -68,4 +73,4 @@ def run_model_tests(
loaded_tests.extend(load_model_test_file(path).values())
if patterns:
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)
return run_tests(loaded_tests, models, engine_adapter, verbosity)
return run_tests(loaded_tests, models, engine_adapter, verbosity, stream)
4 changes: 4 additions & 0 deletions sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def assert_equal(self, df1: pd.DataFrame, df2: pd.DataFrame) -> None:
def runTest(self) -> None:
raise NotImplementedError

def path_relative_to(self, other: pathlib.Path) -> pathlib.Path | None:
"""Compute a version of this test's path relative to the `other` path"""
return self.path.relative_to(other) if self.path else None

@staticmethod
def create_test(
body: dict[str, t.Any],
Expand Down
57 changes: 57 additions & 0 deletions tests/web/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,60 @@ def test_table_diff(web_sushi_context: Context) -> None:
assert response.status_code == 200
assert "schema_diff" in response.json()
assert "row_diff" in response.json()


def test_test(web_sushi_context: Context) -> None:
response = client.get("/api/commands/test")
assert response.status_code == 200
response_json = response.json()
assert response_json["tests_run"] == 2
assert response_json["failures"] == []

# Single test
response = client.get("/api/commands/test", params={"test": "tests/test_order_items.yaml"})
assert response.status_code == 200
response_json = response.json()
assert response_json["tests_run"] == 1
assert response_json["failures"] == []


def test_test_failure(project_context: Context) -> None:
models_dir = project_context.path / "models"
models_dir.mkdir()
sql_file = models_dir / "foo.sql"
sql_file.write_text("MODEL (name foo); SELECT 1 ds;")

tests_dir = project_context.path / "tests"
tests_dir.mkdir()
test_file = tests_dir / "test_foo.yaml"
test_file.write_text(
"""test_foo:
model: foo
outputs:
query:
- ds: 2
vars:
start: 2022-01-01
end: 2022-01-01
latest: 2022-01-01"""
)

project_context.load()
response = client.get("/api/commands/test")
assert response.status_code == 200
response_json = response.json()
assert response_json["tests_run"] == 1
assert response_json["failures"] == [
{
"name": "test_foo",
"path": "tests/test_foo.yaml",
"tb": """AssertionError: Data differs
- {'ds': 2}
? ^

+ {'ds': 1}
? ^

""",
}
]
53 changes: 53 additions & 0 deletions web/server/api/endpoints/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import asyncio
import functools
import io
import typing as t

import pandas as pd
from fastapi import APIRouter, Body, Depends, Request

from sqlmesh.core.context import Context
from sqlmesh.core.snapshot.definition import SnapshotChangeCategory
from sqlmesh.core.test import ModelTest
from sqlmesh.utils.errors import PlanError
from web.server import models
from web.server.exceptions import ApiException
Expand Down Expand Up @@ -155,3 +157,54 @@ async def render(
dialect = options.dialect or context.config.dialect

return models.Query(sql=rendered.sql(pretty=options.pretty, dialect=dialect))


@router.get("/test")
async def test(
test: t.Optional[str] = None,
verbose: bool = False,
context: Context = Depends(get_loaded_context),
) -> models.TestResult:
"""Run one or all model tests"""
test_output = io.StringIO()
try:
result = context.test(
tests=[str(context.path / test)] if test else None, verbose=verbose, stream=test_output
)
except Exception:
raise ApiException(
message="Unable to run tests",
origin="API -> commands -> test",
)
context.console.log_test_results(
result, test_output.getvalue(), context._test_engine_adapter.dialect
)
return models.TestResult(
errors=[
models.TestErrorOrFailure(
name=test.test_name,
path=test.path_relative_to(context.path),
tb=tb,
)
for test, tb in ((t.cast(ModelTest, test), tb) for test, tb in result.errors)
],
failures=[
models.TestErrorOrFailure(
name=test.test_name,
path=test.path_relative_to(context.path),
tb=tb,
)
for test, tb in ((t.cast(ModelTest, test), tb) for test, tb in result.failures)
],
skipped=[
models.TestSkipped(
name=test.test_name,
path=test.path_relative_to(context.path),
reason=reason,
)
for test, reason in (
(t.cast(ModelTest, test), reason) for test, reason in result.skipped
)
],
tests_run=result.testsRun,
)
20 changes: 20 additions & 0 deletions web/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,23 @@ class TableDiff(BaseModel):
schema_diff: SchemaDiff
row_diff: RowDiff
on: t.List[t.Tuple[str, str]]


class TestCase(BaseModel):
name: str
path: pathlib.Path


class TestErrorOrFailure(TestCase):
tb: str


class TestSkipped(TestCase):
reason: str


class TestResult(BaseModel):
tests_run: int
failures: t.List[TestErrorOrFailure]
errors: t.List[TestErrorOrFailure]
skipped: t.List[TestSkipped]