Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,8 @@ module = [
"json_stream.*"
]
ignore_missing_imports = true

[tool.ruff.lint]
select = [
"RET505",
]
2 changes: 1 addition & 1 deletion sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def cli(
if ctx.invoked_subcommand in SKIP_CONTEXT_COMMANDS:
ctx.obj = path
return
elif ctx.invoked_subcommand in SKIP_LOAD_COMMANDS:
if ctx.invoked_subcommand in SKIP_LOAD_COMMANDS:
load = False

configs = load_configs(config, Context.CONFIG_TYPE, paths)
Expand Down
7 changes: 2 additions & 5 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,15 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
return dict([_maybe_parse_arg_pair(v.unnest())])
if isinstance(v, (exp.Tuple, exp.Array)):
return dict(map(_maybe_parse_arg_pair, v.expressions))
elif isinstance(v, dict):
if isinstance(v, dict):
dialect = get_dialect(values)
return {
key: value
if isinstance(value, exp.Expression)
else d.parse_one(str(value), dialect=dialect)
for key, value in v.items()
}
else:
raise_config_error(
"Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError
)
raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError)
return {}


Expand Down
9 changes: 3 additions & 6 deletions sqlmesh/core/config/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,9 @@ def get_gateway(self, name: t.Optional[str] = None) -> GatewayConfig:
raise ConfigError(f"Missing gateway with name '{name}'.")

return self.gateways[name]
else:
if name is not None:
raise ConfigError(
"Gateway name is not supported when only one gateway is configured."
)
return self.gateways
if name is not None:
raise ConfigError("Gateway name is not supported when only one gateway is configured.")
return self.gateways

def get_connection(self, gateway_name: t.Optional[str] = None) -> ConnectionConfig:
return self.get_gateway(gateway_name).connection or self.default_connection
Expand Down
49 changes: 24 additions & 25 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,33 +375,32 @@ def replace_query(
column_descriptions=column_descriptions,
**kwargs,
)
else:
if self_referencing:
with self.temp_table(
self._select_columns(columns_to_types).from_(target_table),
name=target_table,
columns_to_types=columns_to_types,
**kwargs,
) as temp_table:
for source_query in source_queries:
source_query.add_transform(
lambda node: ( # type: ignore
temp_table # type: ignore
if isinstance(node, exp.Table)
and quote_identifiers(node) == quote_identifiers(target_table)
else node
)
if self_referencing:
with self.temp_table(
self._select_columns(columns_to_types).from_(target_table),
name=target_table,
columns_to_types=columns_to_types,
**kwargs,
) as temp_table:
for source_query in source_queries:
source_query.add_transform(
lambda node: ( # type: ignore
temp_table # type: ignore
if isinstance(node, exp.Table)
and quote_identifiers(node) == quote_identifiers(target_table)
else node
)
return self._insert_overwrite_by_condition(
target_table,
source_queries,
columns_to_types,
)
return self._insert_overwrite_by_condition(
target_table,
source_queries,
columns_to_types,
)
return self._insert_overwrite_by_condition(
target_table,
source_queries,
columns_to_types,
)
return self._insert_overwrite_by_condition(
target_table,
source_queries,
columns_to_types,
)

def create_index(
self,
Expand Down
11 changes: 7 additions & 4 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,12 @@ def insert_overwrite_by_partition(
raise SQLMeshError(
f"The partition expression '{partition_sql}' doesn't contain a column."
)
with self.session({}), self.temp_table(
query_or_df, name=table_name, partitioned_by=partitioned_by
) as temp_table_name:
with (
self.session({}),
self.temp_table(
query_or_df, name=table_name, partitioned_by=partitioned_by
) as temp_table_name,
):
if columns_to_types is None or columns_to_types[
partition_column.name
] == exp.DataType.build("unknown"):
Expand Down Expand Up @@ -1158,7 +1161,7 @@ def _is_retryable(self, error: BaseException) -> bool:

if isinstance(error, self.retryable_errors):
return True
elif isinstance(error, Forbidden) and any(
if isinstance(error, Forbidden) and any(
e["reason"] == "rateLimitExceeded" for e in error.errors
):
return True
Expand Down
6 changes: 3 additions & 3 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def evaluate_macros(
self.parse_one(node.sql(dialect=self.dialect, copy=False))
for node in transformed
]
elif isinstance(transformed, exp.Expression):
if isinstance(transformed, exp.Expression):
return self.parse_one(transformed.sql(dialect=self.dialect, copy=False))

return transformed
Expand Down Expand Up @@ -1265,7 +1265,7 @@ def resolve_template(
if mode.lower() == "table":
return exp.to_table(result, dialect=evaluator.dialect)
return exp.Literal.string(result)
elif evaluator.runtime_stage != RuntimeStage.LOADING.value:
if evaluator.runtime_stage != RuntimeStage.LOADING.value:
# only error if we are CREATING, EVALUATING or TESTING and @this_model is not present; this could indicate a bug
# otherwise, for LOADING, it's a no-op
raise SQLMeshError(
Expand Down Expand Up @@ -1391,7 +1391,7 @@ def _coerce(
return tuple(expr.expressions)
if generic[-1] is ...:
return tuple(_coerce(expr, generic[0], dialect, path) for expr in expr.expressions)
elif len(generic) == len(expr.expressions):
if len(generic) == len(expr.expressions):
return tuple(
_coerce(expr, generic[i], dialect, path)
for i, expr in enumerate(expr.expressions)
Expand Down
27 changes: 13 additions & 14 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,19 +2159,18 @@ def load_sql_based_model(
time_column_format=time_column_format,
**common_kwargs,
)
else:
seed_properties = {
p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions
}
try:
return create_seed_model(
name,
SeedKind(**seed_properties),
**common_kwargs,
)
except Exception as ex:
raise_config_error(str(ex), path)
raise
seed_properties = {
p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions
}
try:
return create_seed_model(
name,
SeedKind(**seed_properties),
**common_kwargs,
)
except Exception as ex:
raise_config_error(str(ex), path)
raise


def create_sql_model(
Expand Down Expand Up @@ -2565,7 +2564,7 @@ def _split_sql_model_statements(
if not query_positions:
return None, sql_statements, [], on_virtual_update, inline_audits

elif len(query_positions) > 1:
if len(query_positions) > 1:
raise_config_error("Only one SELECT query is allowed per model", path)

query, pos = query_positions[0]
Expand Down
3 changes: 1 addition & 2 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,7 @@ def _expand(node: exp.Expression) -> exp.Expression:
alias=node.alias or model.view_name,
copy=False,
)
else:
logger.warning("Failed to expand the nested model '%s'", name)
logger.warning("Failed to expand the nested model '%s'", name)
return node

expression = expression.transform(_expand, copy=False) # type: ignore
Expand Down
24 changes: 10 additions & 14 deletions sqlmesh/core/schema_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def from_struct_kwarg(cls, struct: exp.ColumnDef) -> TableAlterColumn:

if kwarg_type.is_type(exp.DataType.Type.STRUCT):
return cls.struct(name, quoted=quoted)
elif kwarg_type.is_type(exp.DataType.Type.ARRAY):
if kwarg_type.is_type(exp.DataType.Type.ARRAY):
if kwarg_type.expressions and kwarg_type.expressions[0].is_type(
exp.DataType.Type.STRUCT
):
return cls.array_of_struct(name, quoted=quoted)
else:
return cls.array_of_primitive(name, quoted=quoted)
else:
return cls.primitive(name, quoted=quoted)
return cls.array_of_primitive(name, quoted=quoted)
return cls.primitive(name, quoted=quoted)

@property
def is_array(self) -> bool:
Expand Down Expand Up @@ -268,22 +266,21 @@ def expression(
)
],
)
elif self.is_add:
if self.is_add:
alter_table = exp.Alter(this=exp.to_table(table_name), kind="TABLE")
column = self.column_def(array_element_selector)
alter_table.set("actions", [column])
if self.add_position:
column.set("position", self.add_position.column_position_node)
return alter_table
elif self.is_drop:
if self.is_drop:
alter_table = exp.Alter(this=exp.to_table(table_name), kind="TABLE")
drop_column = exp.Drop(
this=self.column(array_element_selector), kind="COLUMN", cascade=self.cascade
)
alter_table.set("actions", [drop_column])
return alter_table
else:
raise ValueError(f"Unknown operation {self.op}")
raise ValueError(f"Unknown operation {self.op}")


class SchemaDiffer(PydanticModel):
Expand Down Expand Up @@ -593,7 +590,7 @@ def _alter_operation(
)
if self._is_coerceable_type(current_type, new_type):
return []
elif self._is_compatible_type(current_type, new_type):
if self._is_compatible_type(current_type, new_type):
struct.expressions.pop(pos)
struct.expressions.insert(pos, new_kwarg)
col_pos = (
Expand All @@ -610,10 +607,9 @@ def _alter_operation(
col_pos,
)
]
else:
return self._drop_operation(
columns, root_struct, pos, root_struct
) + self._add_operation(columns, pos, new_kwarg, struct, root_struct)
return self._drop_operation(columns, root_struct, pos, root_struct) + self._add_operation(
columns, pos, new_kwarg, struct, root_struct
)

def _resolve_alter_operations(
self,
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]:
def node_type(self) -> NodeType:
if self.node.is_model:
return NodeType.MODEL
elif self.node.is_audit:
if self.node.is_audit:
return NodeType.AUDIT
raise SQLMeshError(f"Snapshot {self.snapshot_id} has an unknown node type.")

Expand Down
17 changes: 10 additions & 7 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,9 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
physical_properties=rendered_physical_properties,
)

with adapter.transaction(), adapter.session(
snapshot.model.render_session_properties(**render_statements_kwargs)
with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
):
wap_id: t.Optional[str] = None
if (
Expand Down Expand Up @@ -724,7 +725,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
# workaround for that would be to serialize pandas to disk and then read it back with Spark.
# Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes
# and not SQL expressions.
elif (
if (
adapter.INSERT_OVERWRITE_STRATEGY
in (
InsertOverwriteStrategy.INSERT_OVERWRITE,
Expand Down Expand Up @@ -772,8 +773,9 @@ def _create_snapshot(
deployability_index=deployability_index,
)

with adapter.transaction(), adapter.session(
snapshot.model.render_session_properties(**create_render_kwargs)
with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)),
):
rendered_physical_properties = snapshot.model.render_physical_properties(
**create_render_kwargs
Expand Down Expand Up @@ -892,8 +894,9 @@ def _migrate_snapshot(
runtime_stage=RuntimeStage.CREATING,
deployability_index=deployability_index,
)
with adapter.transaction(), adapter.session(
snapshot.model.render_session_properties(**render_kwargs)
with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
):
self._execute_create(
snapshot=snapshot,
Expand Down
3 changes: 1 addition & 2 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,7 @@ def _refs(node: ManifestNode) -> t.Set[str]:
ref_name = f"{ref_name}_v{r.version}"
result.add(ref_name)
return result
else:
return {".".join(r) for r in node.refs} # type: ignore
return {".".join(r) for r in node.refs} # type: ignore


def _sources(node: ManifestNode) -> t.Set[str]:
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _find_profile(cls, project_root: Path) -> t.Optional[Path]:
path = Path(project_root, dir, cls.PROFILE_FILE)
if path.exists():
return path
elif dir:
if dir:
return None

path = Path(Path.home(), ".dbt", cls.PROFILE_FILE)
Expand Down
21 changes: 10 additions & 11 deletions sqlmesh/dbt/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,23 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig:
db_type = data["type"]
if db_type == "databricks":
return DatabricksConfig(**data)
elif db_type == "duckdb":
if db_type == "duckdb":
return DuckDbConfig(**data)
elif db_type == "postgres":
if db_type == "postgres":
return PostgresConfig(**data)
elif db_type == "redshift":
if db_type == "redshift":
return RedshiftConfig(**data)
elif db_type == "snowflake":
if db_type == "snowflake":
return SnowflakeConfig(**data)
elif db_type == "bigquery":
if db_type == "bigquery":
return BigQueryConfig(**data)
elif db_type == "sqlserver":
if db_type == "sqlserver":
return MSSQLConfig(**data)
elif db_type == "trino":
if db_type == "trino":
return TrinoConfig(**data)
elif db_type == "clickhouse":
if db_type == "clickhouse":
return ClickhouseConfig(**data)
elif db_type == "athena":
if db_type == "athena":
return AthenaConfig(**data)

raise ConfigError(f"{db_type} not supported.")
Expand Down Expand Up @@ -424,8 +424,7 @@ def column_class(cls) -> t.Type[Column]:
from dbt.adapters.redshift import RedshiftColumn # type: ignore

return RedshiftColumn
else:
return super(RedshiftConfig, cls).column_class
return super(RedshiftConfig, cls).column_class

def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
return RedshiftConnectionConfig(
Expand Down
Loading