diff --git a/setup.py b/setup.py index 0d39b7b183..9e2aa83f3e 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ "requests", "rich", "ruamel.yaml", - "sqlglot==11.1.2", + "sqlglot>=11.2.0", ], extras_require={ "dev": [ diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index db9357a035..14eadc7aab 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -408,8 +408,14 @@ def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression if i < total - 1: chunks.append(([], False)) else: - if token.token_type == TokenType.BLOCK_START or ( - token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text) + if ( + token.token_type == TokenType.BLOCK_START + or ( + i < total - 1 + and token.token_type == TokenType.L_BRACE + and tokens[i + 1].token_type == TokenType.L_BRACE + ) + or (token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text)) ): chunks[-1] = (chunks[-1][0], True) chunks[-1][0].append(token) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 5d1425f299..70a845c224 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -10,6 +10,7 @@ from sqlglot.optimizer import optimize from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.expand_laterals import expand_laterals +from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.simplify import simplify @@ -30,6 +31,8 @@ qualify_tables, qualify_columns, expand_laterals, + pushdown_projections, + annotate_types, ) @@ -151,8 +154,6 @@ def render( except SqlglotError as ex: raise_config_error(f"Invalid model query. {ex}", self._path) - self._query_cache[cache_key] = annotate_types(self._query_cache[cache_key]) - query = self._query_cache[cache_key] if expand: diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 5e519a7174..c71261c272 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -27,7 +27,7 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: A dictionary of macro name to macro definition. """ self.reset() - self.sql = jinja or "" + self.sql = jinja self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja) self._index = -1 self._advance() @@ -35,17 +35,26 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: macros: t.Dict[str, MacroInfo] = {} while self._curr: - if self._curr.token_type == TokenType.BLOCK_START: + if self._at_block_start(): + if self._prev and self._prev.token_type == TokenType.L_BRACE: + self._advance() macro_start = self._curr elif self._tag == "MACRO" and self._next: name = self._next.text - while self._curr and self._curr.token_type != TokenType.BLOCK_END: + while self._curr and not self._at_block_end(): self._advance() + else: + if self._prev and self._prev.token_type == TokenType.R_BRACE: + self._advance() + body_start = self._next while self._curr and self._tag != "ENDMACRO": - if self._curr.token_type == TokenType.BLOCK_START: + if self._at_block_start(): body_end = self._prev + if self._prev and self._prev.token_type == TokenType.L_BRACE: + self._advance() + self._advance() calls = capture_jinja(self._find_sql(body_start, body_end)).calls @@ -55,11 +64,30 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: return macros + def _at_block_start(self) -> bool: + return self._curr.token_type == TokenType.BLOCK_START or self._match_pair( + TokenType.L_BRACE, TokenType.L_BRACE, advance=False + ) + + def _at_block_end(self) -> bool: + return self._curr.token_type == TokenType.BLOCK_END or self._match_pair( + TokenType.R_BRACE, TokenType.R_BRACE, advance=False + ) + def _advance(self, times: int = 1) -> None: super()._advance(times) self._tag = ( self._curr.text.upper() - if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START + if self._curr + and self._prev + and ( + self._prev.token_type == TokenType.BLOCK_START + or ( + self._index > 1 + and self._tokens[self._index - 1].token_type == TokenType.L_BRACE + and self._tokens[self._index - 2].token_type == TokenType.L_BRACE + ) + ) else "" ) diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 5c8afe1e74..116bf536df 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -140,7 +140,7 @@ def test_no_query(): def test_macro(model: Model): - expected_query = "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL" + expected_query = "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL" audit = Audit( name="test_audit", @@ -163,7 +163,7 @@ def test_not_null_audit(model: Model): ) assert ( rendered_query_a.sql() - == "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL" + == "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL" ) rendered_query_a_and_b = builtin.not_null_audit.render_query( @@ -172,7 +172,7 @@ def test_not_null_audit(model: Model): ) assert ( rendered_query_a_and_b.sql() - == "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL OR b IS NULL" + == "SELECT * FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL OR _q_0.b IS NULL" ) @@ -180,7 +180,7 @@ def test_unique_values_audit(model: Model): rendered_query_a = builtin.unique_values_audit.render_query(model, columns=[exp.to_column("a")]) assert ( rendered_query_a.sql() - == "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1" + == "SELECT _q_1.a_rank AS a_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1" ) rendered_query_a_and_b = builtin.unique_values_audit.render_query( @@ -188,7 +188,7 @@ def test_unique_values_audit(model: Model): ) assert ( rendered_query_a_and_b.sql() - == "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY b ORDER BY 1) AS b_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1 OR b_rank > 1" + == "SELECT _q_1.a_rank AS a_rank, _q_1.b_rank AS b_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY _q_0.b ORDER BY 1) AS b_rank FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1 OR _q_1.b_rank > 1" ) @@ -200,7 +200,7 @@ def test_accepted_values_audit(model: Model): ) assert ( rendered_query.sql() - == "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE NOT a IN ('value_a', 'value_b')" + == "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE NOT _q_0.a IN ('value_a', 'value_b')" ) @@ -211,5 +211,5 @@ def test_number_of_rows_audit(model: Model): ) assert ( rendered_query.sql() - == "SELECT 1 FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') HAVING COUNT(*) <= 0 LIMIT 0 + 1" + == """SELECT 1 AS "1" FROM (SELECT 1 AS _ FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 HAVING COUNT(*) <= 0 LIMIT 0 + 1""" )