Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 185 additions & 156 deletions sqlmesh/core/selector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import fnmatch
import logging
import typing as t
from collections import defaultdict
from pathlib import Path

from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import seq_get

from sqlmesh.core.dialect import normalize_model_name
from sqlmesh.core.environment import Environment
Expand All @@ -15,10 +17,9 @@
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.git import GitClient

logger = logging.getLogger(__name__)


if t.TYPE_CHECKING:
from typing_extensions import Literal as Lit # noqa
from sqlmesh.core.model import Model
from sqlmesh.core.state_sync import StateReader

Expand Down Expand Up @@ -162,163 +163,191 @@ def expand_model_selections(
Returns:
A set of model names.
"""
results: t.Set[str] = set()
models = models or self._models
models_by_tags: t.Optional[t.Dict[str, t.Set[str]]] = None

for selection in model_selections:
sub_results: t.Optional[t.Set[str]] = None

def add_sub_results(sr: t.Set[str]) -> None:
nonlocal sub_results
if sub_results is None:
sub_results = sr
else:
sub_results &= sr

sub_selections = [s.strip() for s in selection.split("&")]
for sub_selection in sub_selections:
if not sub_selection:
continue

if sub_selection.startswith("tag:"):
if models_by_tags is None:
models_by_tag = defaultdict(set)
for model in models.values():
for tag in model.tags:
models_by_tag[tag.lower()].add(model.fqn)
add_sub_results(
self._expand_model_tag(sub_selection[4:], models, models_by_tag)
)
elif sub_selection.startswith(("git:", "+git:")):
sub_selection = sub_selection.replace("git:", "")
add_sub_results(self._expand_git(sub_selection, models))
else:
add_sub_results(self._expand_model_name(sub_selection, models))

if sub_results:
results.update(sub_results)
else:
logger.warning(f"Expression '{selection}' doesn't match any models.")

return results

def _expand_git(self, target_branch: str, models: t.Dict[str, Model]) -> t.Set[str]:
results: t.Set[str] = set()

(
target_branch,
include_upstream,
include_downstream,
) = self._get_value_and_dependency_inclusion(target_branch)

git_modified_files = {
*self._git_client.list_untracked_files(),
*self._git_client.list_uncommitted_changed_files(),
*self._git_client.list_committed_changed_files(target_branch=target_branch),
}
matched_models = {m.fqn for m in self._models.values() if m._path in git_modified_files}

if not matched_models:
logger.warning(f"Expression 'git:{target_branch}' doesn't match any models.")
return matched_models

for model_fqn in matched_models:
results.update(
self._get_models(model_fqn, include_upstream, include_downstream, models)
)

return results

def _expand_model_name(self, selection: str, models: t.Dict[str, Model]) -> t.Set[str]:
results = set()

(
selection,
include_upstream,
include_downstream,
) = self._get_value_and_dependency_inclusion(selection)
node = parse(" | ".join(f"({s})" for s in model_selections))

matched_models = set()

if "*" in selection:
for model in models.values():
if fnmatch.fnmatchcase(model.name, selection):
matched_models.add(model.fqn)
models = models or self._models
models_by_tags: t.Dict[str, t.Set[str]] = {}

for fqn, model in models.items():
for tag in model.tags:
tag = tag.lower()
models_by_tags.setdefault(tag, set())
models_by_tags[tag].add(model.fqn)

def evaluate(node: exp.Expression) -> t.Set[str]:
if isinstance(node, exp.Var):
pattern = node.this
if "*" in pattern:
return {
fqn
for fqn, model in models.items()
if fnmatch.fnmatchcase(model.name, node.this)
}
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
return {fqn} if fqn in models else set()
if isinstance(node, exp.And):
return evaluate(node.left) & evaluate(node.right)
if isinstance(node, exp.Or):
return evaluate(node.left) | evaluate(node.right)
if isinstance(node, exp.Paren):
return evaluate(node.this)
if isinstance(node, exp.Not):
return set(models) - evaluate(node.this)
if isinstance(node, Git):
target_branch = node.name
git_modified_files = {
*self._git_client.list_untracked_files(),
*self._git_client.list_uncommitted_changed_files(),
*self._git_client.list_committed_changed_files(target_branch=target_branch),
}
return {m.fqn for m in self._models.values() if m._path in git_modified_files}
if isinstance(node, Tag):
pattern = node.name.lower()

if "*" in pattern:
return {
model
for tag, models in models_by_tags.items()
for model in models
if fnmatch.fnmatchcase(tag, pattern)
}
return models_by_tags.get(pattern, set())
if isinstance(node, Direction):
selected = set()

for model_name in evaluate(node.this):
selected.add(model_name)
if node.args.get("up"):
for u in self._dag.upstream(model_name):
if u in models:
selected.add(u)
if node.args.get("down"):
selected.update(self._dag.downstream(model_name))
return selected
raise ParseError(f"Unexpected node {node}")

return evaluate(node)


class SelectorTokenizer(Tokenizer):
SINGLE_TOKENS = {
"(": TokenType.L_PAREN,
")": TokenType.R_PAREN,
"&": TokenType.AMP,
"|": TokenType.PIPE,
"^": TokenType.CARET,
"+": TokenType.PLUS,
"*": TokenType.STAR,
":": TokenType.COLON,
}

KEYWORDS = {}
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = []


class Git(exp.Expression):
pass


class Tag(exp.Expression):
pass


class Direction(exp.Expression):
pass


def parse(selector: str, dialect: DialectType = None) -> exp.Expression:
tokens = SelectorTokenizer().tokenize(selector)
i = 0

def _curr() -> t.Optional[Token]:
return seq_get(tokens, i)

def _prev() -> Token:
return tokens[i - 1]

def _advance(num: int = 1) -> Token:
nonlocal i
i += num
return _prev()

def _next() -> t.Optional[Token]:
return seq_get(tokens, i + 1)

def _error(msg: str) -> str:
return f"{msg} at index {i}: {selector}"

def _match(token_type: TokenType, raise_unmatched: bool = False) -> t.Optional[Token]:
token = _curr()
if token and token.token_type == token_type:
return _advance()
if raise_unmatched:
raise ParseError(_error(f"Expected {token_type}"))
return None

def _parse_kind(kind: str) -> bool:
token = _curr()
next_token = _next()

if (
token
and token.token_type == TokenType.VAR
and token.text.lower() == kind
and next_token
and next_token.token_type == TokenType.COLON
):
_advance(2)
return True
return False

def _parse_var() -> exp.Expression:
upstream = _match(TokenType.PLUS)
tag = _parse_kind("tag")
git = False if tag else _parse_kind("git")
lstar = "*" if _match(TokenType.STAR) else ""
directions = {}

if _match(TokenType.VAR):
name = _prev().text
rstar = "*" if _match(TokenType.STAR) else ""
downstream = _match(TokenType.PLUS)
this: exp.Expression = exp.Var(this=f"{lstar}{name}{rstar}")

if upstream:
directions["up"] = True
if downstream:
directions["down"] = True
elif _match(TokenType.L_PAREN):
this = exp.Paren(this=_parse_conjunction())
_match(TokenType.R_PAREN, True)
elif lstar:
this = exp.var("*")
else:
model_fqn = normalize_model_name(selection, self._default_catalog, self._dialect)
if model_fqn in models:
matched_models.add(model_fqn)
raise ParseError(_error("Expected model name."))

if not matched_models:
logger.warning(f"Expression '{selection}' doesn't match any models.")
if tag:
this = Tag(this=this)
if git:
this = Git(this=this)
if directions:
this = Direction(this=this, **directions)
return this

for model_fqn in matched_models:
results.update(
self._get_models(model_fqn, include_upstream, include_downstream, models)
)
return results
def _parse_unary() -> exp.Expression:
if _match(TokenType.CARET):
return exp.Not(this=_parse_unary())
return _parse_var()

def _expand_model_tag(
self, tag_selection: str, models: t.Dict[str, Model], models_by_tag: t.Dict[str, t.Set[str]]
) -> t.Set[str]:
"""
Expands a set of model tags into a set of model names.
The tag matching is case-insensitive and supports wildcards and + prefix and suffix to
include upstream and downstream models.
def _parse_conjunction() -> exp.Expression:
this = _parse_unary()

Args:
tag_selection: A tag to match models against.
if _match(TokenType.AMP):
this = exp.And(this=this, expression=_parse_unary())
if _match(TokenType.PIPE):
this = exp.Or(this=this, expression=_parse_conjunction())

Returns:
A set of model names.
"""
result = set()
matched_tags = set()
(
selection,
include_upstream,
include_downstream,
) = self._get_value_and_dependency_inclusion(tag_selection.lower())

if "*" in selection:
for model_tag in models_by_tag:
if fnmatch.fnmatchcase(model_tag, selection):
matched_tags.add(model_tag)
elif selection in models_by_tag:
matched_tags.add(selection)

if not matched_tags:
logger.warning(f"Expression 'tag:{tag_selection}' doesn't match any models.")

for tag in matched_tags:
for model in models_by_tag[tag]:
result.update(self._get_models(model, include_upstream, include_downstream, models))

return result

def _get_models(
self,
model_name: str,
include_upstream: bool,
include_downstream: bool,
models: t.Dict[str, Model],
) -> t.Set[str]:
result = {model_name}
if include_upstream:
result.update([u for u in self._dag.upstream(model_name) if u in models])
if include_downstream:
result.update(self._dag.downstream(model_name))
return result

@staticmethod
def _get_value_and_dependency_inclusion(value: str) -> t.Tuple[str, bool, bool]:
include_upstream = False
include_downstream = False
if value[0] == "+":
value = value[1:]
include_upstream = True
if value[-1] == "+":
value = value[:-1]
include_downstream = True
return value, include_upstream, include_downstream
return this

return _parse_conjunction()
Loading