diff --git a/pyproject.toml b/pyproject.toml index f0ab0ca892..2f57839500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,3 +208,8 @@ module = [ "json_stream.*" ] ignore_missing_imports = true + +[tool.ruff.lint] +select = [ + "RET505", +] diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index fd3e592fb5..38b3cb395c 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -96,7 +96,7 @@ def cli( if ctx.invoked_subcommand in SKIP_CONTEXT_COMMANDS: ctx.obj = path return - elif ctx.invoked_subcommand in SKIP_LOAD_COMMANDS: + if ctx.invoked_subcommand in SKIP_LOAD_COMMANDS: load = False configs = load_configs(config, Context.CONFIG_TYPE, paths) diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index 088e33a9ca..82ab9b294e 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -95,7 +95,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A return dict([_maybe_parse_arg_pair(v.unnest())]) if isinstance(v, (exp.Tuple, exp.Array)): return dict(map(_maybe_parse_arg_pair, v.expressions)) - elif isinstance(v, dict): + if isinstance(v, dict): dialect = get_dialect(values) return { key: value @@ -103,10 +103,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A else d.parse_one(str(value), dialect=dialect) for key, value in v.items() } - else: - raise_config_error( - "Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError - ) + raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError) return {} diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index 9b2250ab1a..bdb8815f37 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -255,12 +255,9 @@ def get_gateway(self, name: t.Optional[str] = None) -> GatewayConfig: raise ConfigError(f"Missing gateway with name '{name}'.") return self.gateways[name] - else: - if name is not None: - raise ConfigError( - "Gateway name is not supported when only one gateway is configured." - ) - return self.gateways + if name is not None: + raise ConfigError("Gateway name is not supported when only one gateway is configured.") + return self.gateways def get_connection(self, gateway_name: t.Optional[str] = None) -> ConnectionConfig: return self.get_gateway(gateway_name).connection or self.default_connection diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 8c22580f1c..5a68e10c5c 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -375,33 +375,32 @@ def replace_query( column_descriptions=column_descriptions, **kwargs, ) - else: - if self_referencing: - with self.temp_table( - self._select_columns(columns_to_types).from_(target_table), - name=target_table, - columns_to_types=columns_to_types, - **kwargs, - ) as temp_table: - for source_query in source_queries: - source_query.add_transform( - lambda node: ( # type: ignore - temp_table # type: ignore - if isinstance(node, exp.Table) - and quote_identifiers(node) == quote_identifiers(target_table) - else node - ) + if self_referencing: + with self.temp_table( + self._select_columns(columns_to_types).from_(target_table), + name=target_table, + columns_to_types=columns_to_types, + **kwargs, + ) as temp_table: + for source_query in source_queries: + source_query.add_transform( + lambda node: ( # type: ignore + temp_table # type: ignore + if isinstance(node, exp.Table) + and quote_identifiers(node) == quote_identifiers(target_table) + else node ) - return self._insert_overwrite_by_condition( - target_table, - source_queries, - columns_to_types, ) - return self._insert_overwrite_by_condition( - target_table, - source_queries, - columns_to_types, - ) + return self._insert_overwrite_by_condition( + target_table, + source_queries, + columns_to_types, + ) + return self._insert_overwrite_by_condition( + target_table, + source_queries, + columns_to_types, + ) def create_index( self, diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 57122396e3..12338d4a52 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -602,9 +602,12 @@ def insert_overwrite_by_partition( raise SQLMeshError( f"The partition expression '{partition_sql}' doesn't contain a column." ) - with self.session({}), self.temp_table( - query_or_df, name=table_name, partitioned_by=partitioned_by - ) as temp_table_name: + with ( + self.session({}), + self.temp_table( + query_or_df, name=table_name, partitioned_by=partitioned_by + ) as temp_table_name, + ): if columns_to_types is None or columns_to_types[ partition_column.name ] == exp.DataType.build("unknown"): @@ -1158,7 +1161,7 @@ def _is_retryable(self, error: BaseException) -> bool: if isinstance(error, self.retryable_errors): return True - elif isinstance(error, Forbidden) and any( + if isinstance(error, Forbidden) and any( e["reason"] == "rateLimitExceeded" for e in error.errors ): return True diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 4928bc0620..beca4ea72d 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -281,7 +281,7 @@ def evaluate_macros( self.parse_one(node.sql(dialect=self.dialect, copy=False)) for node in transformed ] - elif isinstance(transformed, exp.Expression): + if isinstance(transformed, exp.Expression): return self.parse_one(transformed.sql(dialect=self.dialect, copy=False)) return transformed @@ -1265,7 +1265,7 @@ def resolve_template( if mode.lower() == "table": return exp.to_table(result, dialect=evaluator.dialect) return exp.Literal.string(result) - elif evaluator.runtime_stage != RuntimeStage.LOADING.value: + if evaluator.runtime_stage != RuntimeStage.LOADING.value: # only error if we are CREATING, EVALUATING or TESTING and @this_model is not present; this could indicate a bug # otherwise, for LOADING, it's a no-op raise SQLMeshError( @@ -1391,7 +1391,7 @@ def _coerce( return tuple(expr.expressions) if generic[-1] is ...: return tuple(_coerce(expr, generic[0], dialect, path) for expr in expr.expressions) - elif len(generic) == len(expr.expressions): + if len(generic) == len(expr.expressions): return tuple( _coerce(expr, generic[i], dialect, path) for i, expr in enumerate(expr.expressions) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index cbe31294c6..5930d8b835 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2159,19 +2159,18 @@ def load_sql_based_model( time_column_format=time_column_format, **common_kwargs, ) - else: - seed_properties = { - p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions - } - try: - return create_seed_model( - name, - SeedKind(**seed_properties), - **common_kwargs, - ) - except Exception as ex: - raise_config_error(str(ex), path) - raise + seed_properties = { + p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions + } + try: + return create_seed_model( + name, + SeedKind(**seed_properties), + **common_kwargs, + ) + except Exception as ex: + raise_config_error(str(ex), path) + raise def create_sql_model( @@ -2565,7 +2564,7 @@ def _split_sql_model_statements( if not query_positions: return None, sql_statements, [], on_virtual_update, inline_audits - elif len(query_positions) > 1: + if len(query_positions) > 1: raise_config_error("Only one SELECT query is allowed per model", path) query, pos = query_positions[0] diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 70ab187d75..e7d5e2fe55 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -346,8 +346,7 @@ def _expand(node: exp.Expression) -> exp.Expression: alias=node.alias or model.view_name, copy=False, ) - else: - logger.warning("Failed to expand the nested model '%s'", name) + logger.warning("Failed to expand the nested model '%s'", name) return node expression = expression.transform(_expand, copy=False) # type: ignore diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index d1ca567630..d0f45f0d7c 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -90,15 +90,13 @@ def from_struct_kwarg(cls, struct: exp.ColumnDef) -> TableAlterColumn: if kwarg_type.is_type(exp.DataType.Type.STRUCT): return cls.struct(name, quoted=quoted) - elif kwarg_type.is_type(exp.DataType.Type.ARRAY): + if kwarg_type.is_type(exp.DataType.Type.ARRAY): if kwarg_type.expressions and kwarg_type.expressions[0].is_type( exp.DataType.Type.STRUCT ): return cls.array_of_struct(name, quoted=quoted) - else: - return cls.array_of_primitive(name, quoted=quoted) - else: - return cls.primitive(name, quoted=quoted) + return cls.array_of_primitive(name, quoted=quoted) + return cls.primitive(name, quoted=quoted) @property def is_array(self) -> bool: @@ -268,22 +266,21 @@ def expression( ) ], ) - elif self.is_add: + if self.is_add: alter_table = exp.Alter(this=exp.to_table(table_name), kind="TABLE") column = self.column_def(array_element_selector) alter_table.set("actions", [column]) if self.add_position: column.set("position", self.add_position.column_position_node) return alter_table - elif self.is_drop: + if self.is_drop: alter_table = exp.Alter(this=exp.to_table(table_name), kind="TABLE") drop_column = exp.Drop( this=self.column(array_element_selector), kind="COLUMN", cascade=self.cascade ) alter_table.set("actions", [drop_column]) return alter_table - else: - raise ValueError(f"Unknown operation {self.op}") + raise ValueError(f"Unknown operation {self.op}") class SchemaDiffer(PydanticModel): @@ -593,7 +590,7 @@ def _alter_operation( ) if self._is_coerceable_type(current_type, new_type): return [] - elif self._is_compatible_type(current_type, new_type): + if self._is_compatible_type(current_type, new_type): struct.expressions.pop(pos) struct.expressions.insert(pos, new_kwarg) col_pos = ( @@ -610,10 +607,9 @@ def _alter_operation( col_pos, ) ] - else: - return self._drop_operation( - columns, root_struct, pos, root_struct - ) + self._add_operation(columns, pos, new_kwarg, struct, root_struct) + return self._drop_operation(columns, root_struct, pos, root_struct) + self._add_operation( + columns, pos, new_kwarg, struct, root_struct + ) def _resolve_alter_operations( self, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 56b638cb5a..3acb1527d8 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1243,7 +1243,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]: def node_type(self) -> NodeType: if self.node.is_model: return NodeType.MODEL - elif self.node.is_audit: + if self.node.is_audit: return NodeType.AUDIT raise SQLMeshError(f"Snapshot {self.snapshot_id} has an unknown node type.") diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index dd13162710..22246bd875 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -671,8 +671,9 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: physical_properties=rendered_physical_properties, ) - with adapter.transaction(), adapter.session( - snapshot.model.render_session_properties(**render_statements_kwargs) + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), ): wap_id: t.Optional[str] = None if ( @@ -724,7 +725,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: # workaround for that would be to serialize pandas to disk and then read it back with Spark. # Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes # and not SQL expressions. - elif ( + if ( adapter.INSERT_OVERWRITE_STRATEGY in ( InsertOverwriteStrategy.INSERT_OVERWRITE, @@ -772,8 +773,9 @@ def _create_snapshot( deployability_index=deployability_index, ) - with adapter.transaction(), adapter.session( - snapshot.model.render_session_properties(**create_render_kwargs) + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)), ): rendered_physical_properties = snapshot.model.render_physical_properties( **create_render_kwargs @@ -892,8 +894,9 @@ def _migrate_snapshot( runtime_stage=RuntimeStage.CREATING, deployability_index=deployability_index, ) - with adapter.transaction(), adapter.session( - snapshot.model.render_session_properties(**render_kwargs) + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_kwargs)), ): self._execute_create( snapshot=snapshot, diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 83d3df1321..02387c61ca 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -505,8 +505,7 @@ def _refs(node: ManifestNode) -> t.Set[str]: ref_name = f"{ref_name}_v{r.version}" result.add(ref_name) return result - else: - return {".".join(r) for r in node.refs} # type: ignore + return {".".join(r) for r in node.refs} # type: ignore def _sources(node: ManifestNode) -> t.Set[str]: diff --git a/sqlmesh/dbt/profile.py b/sqlmesh/dbt/profile.py index 1c2ffa8726..72634833a6 100644 --- a/sqlmesh/dbt/profile.py +++ b/sqlmesh/dbt/profile.py @@ -73,7 +73,7 @@ def _find_profile(cls, project_root: Path) -> t.Optional[Path]: path = Path(project_root, dir, cls.PROFILE_FILE) if path.exists(): return path - elif dir: + if dir: return None path = Path(Path.home(), ".dbt", cls.PROFILE_FILE) diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 7775888a9e..fe87c14424 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -85,23 +85,23 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig: db_type = data["type"] if db_type == "databricks": return DatabricksConfig(**data) - elif db_type == "duckdb": + if db_type == "duckdb": return DuckDbConfig(**data) - elif db_type == "postgres": + if db_type == "postgres": return PostgresConfig(**data) - elif db_type == "redshift": + if db_type == "redshift": return RedshiftConfig(**data) - elif db_type == "snowflake": + if db_type == "snowflake": return SnowflakeConfig(**data) - elif db_type == "bigquery": + if db_type == "bigquery": return BigQueryConfig(**data) - elif db_type == "sqlserver": + if db_type == "sqlserver": return MSSQLConfig(**data) - elif db_type == "trino": + if db_type == "trino": return TrinoConfig(**data) - elif db_type == "clickhouse": + if db_type == "clickhouse": return ClickhouseConfig(**data) - elif db_type == "athena": + if db_type == "athena": return AthenaConfig(**data) raise ConfigError(f"{db_type} not supported.") @@ -424,8 +424,7 @@ def column_class(cls) -> t.Type[Column]: from dbt.adapters.redshift import RedshiftColumn # type: ignore return RedshiftColumn - else: - return super(RedshiftConfig, cls).column_class + return super(RedshiftConfig, cls).column_class def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return RedshiftConnectionConfig( diff --git a/sqlmesh/integrations/dlt.py b/sqlmesh/integrations/dlt.py index eb76b7c8f7..7ae0881e47 100644 --- a/sqlmesh/integrations/dlt.py +++ b/sqlmesh/integrations/dlt.py @@ -225,5 +225,4 @@ def get_start_date(load_ids: t.List[str]) -> str: if timestamps: start_timestamp = min(timestamps) - timedelta(days=1) return start_timestamp.strftime("%Y-%m-%d") - else: - return yesterday_ds() + return yesterday_ds() diff --git a/sqlmesh/integrations/github/cicd/command.py b/sqlmesh/integrations/github/cicd/command.py index b66e32fb6c..2bf9967548 100644 --- a/sqlmesh/integrations/github/cicd/command.py +++ b/sqlmesh/integrations/github/cicd/command.py @@ -219,7 +219,7 @@ def _run_all(controller: GithubController) -> None: if command.is_invalid: # Probably a comment unrelated to SQLMesh so we do nothing return - elif command.is_deploy_prod: + if command.is_deploy_prod: has_required_approval = True else: raise CICDBotError(f"Unsupported command: {command}") diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index 51624408cd..e54eca5c11 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -177,8 +177,7 @@ def sys_path(*paths: Path) -> t.Iterator[None]: def format_exception(exception: BaseException) -> t.List[str]: if sys.version_info < (3, 10): return traceback.format_exception(type(exception), exception, exception.__traceback__) # type: ignore - else: - return traceback.format_exception(exception) # type: ignore + return traceback.format_exception(exception) # type: ignore def word_characters_only(s: str, replacement_char: str = "_") -> str: diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 010a5f14ff..3ef64b41d8 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -149,8 +149,7 @@ def __eq__(self, other: t.Any) -> bool: if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): if isinstance(other, pydantic.BaseModel): return self.dict() == other.dict() - else: - return self.dict() == other + return self.dict() == other return super().__eq__(other) def __hash__(self) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index ba8e8baeab..56c50a5851 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -181,8 +181,7 @@ def validate( assert list(results["event_date"].values()) == expected_dates return results - else: - raise NotImplementedError(f"Unknown model_name: {model_name}") + raise NotImplementedError(f"Unknown model_name: {model_name}") def pytest_collection_modifyitems(items, *args, **kwargs): diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index 759a2db560..f12e4846da 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -210,7 +210,7 @@ def input_data( batch_end=sys.maxsize, columns_to_types=columns_to_types, ) - elif self.test_type == "pyspark": + if self.test_type == "pyspark": return self.engine_adapter.spark.createDataFrame(data) # type: ignore return self._format_df(data, to_datetime=self.dialect != "trino") @@ -550,7 +550,7 @@ def create_catalog(self, catalog_name: str): def drop_catalog(self, catalog_name: str): if self.dialect == "bigquery": return # bigquery cannot create/drop catalogs - elif self.dialect == "databricks": + if self.dialect == "databricks": self.engine_adapter.execute(f"DROP CATALOG IF EXISTS {catalog_name} CASCADE") else: self.engine_adapter.execute(f'DROP DATABASE IF EXISTS "{catalog_name}"') diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 0e7cf22e2c..fdbe6dac82 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -901,8 +901,7 @@ def _from_structs( def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: if table_name == current_table_name: return {k: exp.DataType.build(v) for k, v in current_table.items()} - else: - return {k: exp.DataType.build(v) for k, v in target_table.items()} + return {k: exp.DataType.build(v) for k, v in target_table.items()} adapter.columns = table_columns diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py index ec4c0992fb..9d63d9400e 100644 --- a/tests/core/engine_adapter/test_clickhouse.py +++ b/tests/core/engine_adapter/test_clickhouse.py @@ -167,10 +167,7 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: return { k: exp.DataType.build(v, dialect=adapter.dialect) for k, v in current_table.items() } - else: - return { - k: exp.DataType.build(v, dialect=adapter.dialect) for k, v in target_table.items() - } + return {k: exp.DataType.build(v, dialect=adapter.dialect) for k, v in target_table.items()} adapter.columns = table_columns # type: ignore diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index 3f628aacc0..f013914c3e 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -153,8 +153,7 @@ def test_alter_table_drop_column_cascade(make_mocked_engine_adapter: t.Callable) def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: if table_name == current_table_name: return {"id": exp.DataType.build("int"), "test_column": exp.DataType.build("int")} - else: - return {"id": exp.DataType.build("int")} + return {"id": exp.DataType.build("int")} adapter.columns = table_columns diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index 6d1bbcf61f..9fdb589bb1 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -355,8 +355,7 @@ def test_alter_table_drop_column_cascade(adapter: t.Callable): def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: if table_name == current_table_name: return {"id": exp.DataType.build("int"), "test_column": exp.DataType.build("int")} - else: - return {"id": exp.DataType.build("int")} + return {"id": exp.DataType.build("int")} adapter.columns = table_columns @@ -376,11 +375,10 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: "id": exp.DataType.build("int"), "test_column": exp.DataType.build("VARCHAR(10)"), } - else: - return { - "id": exp.DataType.build("int"), - "test_column": exp.DataType.build("VARCHAR(20)"), - } + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("VARCHAR(20)"), + } adapter.columns = table_columns @@ -400,11 +398,10 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: "id": exp.DataType.build("int"), "test_column": exp.DataType.build("DECIMAL(10, 10)"), } - else: - return { - "id": exp.DataType.build("int"), - "test_column": exp.DataType.build("DECIMAL(25, 10)"), - } + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("DECIMAL(25, 10)"), + } adapter.columns = table_columns diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index 0556a23e05..f1c658a23a 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -147,13 +147,12 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: "complex": exp.DataType.build("STRUCT"), "ds": exp.DataType.build("STRING"), } - else: - return { - "id": exp.DataType.build("BIGINT"), - "a": exp.DataType.build("STRING"), - "complex": exp.DataType.build("STRUCT"), - "ds": exp.DataType.build("INT"), - } + return { + "id": exp.DataType.build("BIGINT"), + "a": exp.DataType.build("STRING"), + "complex": exp.DataType.build("STRUCT"), + "ds": exp.DataType.build("INT"), + } adapter.columns = table_columns diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 47cbeb5986..5376e18b16 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4855,8 +4855,9 @@ def execute( # monkey-patch PythonModel to default to kind: View again # and ViewKind to allow python models again - with mock.patch.object(ViewKind, "supports_python_models", return_value=True), mock.patch( - "sqlmesh.core.model.definition.PythonModel", OldPythonModel + with ( + mock.patch.object(ViewKind, "supports_python_models", return_value=True), + mock.patch("sqlmesh.core.model.definition.PythonModel", OldPythonModel), ): context.load() diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 59c5bbd965..5cc4364fc5 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -1169,11 +1169,10 @@ def columns(table_name): "c": exp.DataType.build("int"), "b": exp.DataType.build("int"), } - else: - return { - "c": exp.DataType.build("int"), - "a": exp.DataType.build("int"), - } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } adapter.columns = columns # type: ignore adapter.table_exists = lambda _: True # type: ignore @@ -1702,11 +1701,10 @@ def columns(table_name): "c": exp.DataType.build("int"), "b": exp.DataType.build("int"), } - else: - return { - "c": exp.DataType.build("int"), - "a": exp.DataType.build("int"), - } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } adapter.columns = columns # type: ignore @@ -3892,11 +3890,10 @@ def columns(table_name): "c": exp.DataType.build("int"), "b": exp.DataType.build("int"), } - else: - return { - "c": exp.DataType.build("int"), - "a": exp.DataType.build("int"), - } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } adapter.columns = columns # type: ignore adapter_mock.columns = columns # type: ignore