Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"requests",
"rich",
"ruamel.yaml",
"sqlglot==11.1.2",
"sqlglot>=11.2.0",
],
extras_require={
"dev": [
Expand Down
10 changes: 8 additions & 2 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,14 @@ def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression
if i < total - 1:
chunks.append(([], False))
else:
if token.token_type == TokenType.BLOCK_START or (
token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text)
if (
token.token_type == TokenType.BLOCK_START
or (
i < total - 1
and token.token_type == TokenType.L_BRACE
and tokens[i + 1].token_type == TokenType.L_BRACE
)
or (token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text))
):
chunks[-1] = (chunks[-1][0], True)
chunks[-1][0].append(token)
Expand Down
5 changes: 3 additions & 2 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlglot.optimizer import optimize
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.simplify import simplify
Expand All @@ -30,6 +31,8 @@
qualify_tables,
qualify_columns,
expand_laterals,
pushdown_projections,
annotate_types,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to do pushdown projections here instead

and keep annotate types outside in case it doesn’t work right

Copy link
Contributor Author

@georgesittas georgesittas Feb 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotate types requires at least qualified columns to work properly, because it needs col.table not to be empty. I don't think it'll crash, since it just skips columns for which this isn't true:

for col in scope.columns:  # line 282 in annotate_types.py
    if not col.table:
        continue
...

However, if optimize does fail for some reason and we end up not having qualified columns, I'm not sure if it's worth running just annotate_types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added pushdown projections.

)


Expand Down Expand Up @@ -151,8 +154,6 @@ def render(
except SqlglotError as ex:
raise_config_error(f"Invalid model query. {ex}", self._path)

self._query_cache[cache_key] = annotate_types(self._query_cache[cache_key])

query = self._query_cache[cache_key]

if expand:
Expand Down
38 changes: 33 additions & 5 deletions sqlmesh/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,34 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
A dictionary of macro name to macro definition.
"""
self.reset()
self.sql = jinja or ""
self.sql = jinja
self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja)
self._index = -1
self._advance()

macros: t.Dict[str, MacroInfo] = {}

while self._curr:
if self._curr.token_type == TokenType.BLOCK_START:
if self._at_block_start():
if self._prev and self._prev.token_type == TokenType.L_BRACE:
self._advance()
macro_start = self._curr
elif self._tag == "MACRO" and self._next:
name = self._next.text
while self._curr and self._curr.token_type != TokenType.BLOCK_END:
while self._curr and not self._at_block_end():
self._advance()
else:
if self._prev and self._prev.token_type == TokenType.R_BRACE:
self._advance()

body_start = self._next

while self._curr and self._tag != "ENDMACRO":
if self._curr.token_type == TokenType.BLOCK_START:
if self._at_block_start():
body_end = self._prev
if self._prev and self._prev.token_type == TokenType.L_BRACE:
self._advance()

self._advance()

calls = capture_jinja(self._find_sql(body_start, body_end)).calls
Expand All @@ -55,11 +64,30 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:

return macros

def _at_block_start(self) -> bool:
return self._curr.token_type == TokenType.BLOCK_START or self._match_pair(
TokenType.L_BRACE, TokenType.L_BRACE, advance=False
)

def _at_block_end(self) -> bool:
return self._curr.token_type == TokenType.BLOCK_END or self._match_pair(
TokenType.R_BRACE, TokenType.R_BRACE, advance=False
)

def _advance(self, times: int = 1) -> None:
super()._advance(times)
self._tag = (
self._curr.text.upper()
if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
if self._curr
and self._prev
and (
self._prev.token_type == TokenType.BLOCK_START
or (
self._index > 1
and self._tokens[self._index - 1].token_type == TokenType.L_BRACE
and self._tokens[self._index - 2].token_type == TokenType.L_BRACE
)
)
else ""
)

Expand Down
14 changes: 7 additions & 7 deletions tests/core/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_no_query():


def test_macro(model: Model):
expected_query = "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL"
expected_query = "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL"

audit = Audit(
name="test_audit",
Expand All @@ -163,7 +163,7 @@ def test_not_null_audit(model: Model):
)
assert (
rendered_query_a.sql()
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL"
== "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL"
)

rendered_query_a_and_b = builtin.not_null_audit.render_query(
Expand All @@ -172,23 +172,23 @@ def test_not_null_audit(model: Model):
)
assert (
rendered_query_a_and_b.sql()
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL OR b IS NULL"
== "SELECT * FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL OR _q_0.b IS NULL"
)


def test_unique_values_audit(model: Model):
rendered_query_a = builtin.unique_values_audit.render_query(model, columns=[exp.to_column("a")])
assert (
rendered_query_a.sql()
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1"
== "SELECT _q_1.a_rank AS a_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1"
)

rendered_query_a_and_b = builtin.unique_values_audit.render_query(
model, columns=[exp.to_column("a"), exp.to_column("b")]
)
assert (
rendered_query_a_and_b.sql()
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY b ORDER BY 1) AS b_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1 OR b_rank > 1"
== "SELECT _q_1.a_rank AS a_rank, _q_1.b_rank AS b_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY _q_0.b ORDER BY 1) AS b_rank FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1 OR _q_1.b_rank > 1"
)


Expand All @@ -200,7 +200,7 @@ def test_accepted_values_audit(model: Model):
)
assert (
rendered_query.sql()
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE NOT a IN ('value_a', 'value_b')"
== "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE NOT _q_0.a IN ('value_a', 'value_b')"
)


Expand All @@ -211,5 +211,5 @@ def test_number_of_rows_audit(model: Model):
)
assert (
rendered_query.sql()
== "SELECT 1 FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') HAVING COUNT(*) <= 0 LIMIT 0 + 1"
== """SELECT 1 AS "1" FROM (SELECT 1 AS _ FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 HAVING COUNT(*) <= 0 LIMIT 0 + 1"""
)