diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 13796a513e..f314f704b6 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -60,6 +60,7 @@ class RedshiftEngineAdapter( exp.DataType.build("CHAR", dialect=DIALECT).this: 4096, exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535, }, + precision_increase_allowed_types={exp.DataType.build("VARCHAR", dialect=DIALECT).this}, drop_cascade=True, ) VARIABLE_LENGTH_DATA_TYPES = { diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index 674b1df76f..d1ca567630 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -341,6 +341,7 @@ class SchemaDiffer(PydanticModel): coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field( default_factory=dict, alias="coerceable_types" ) + precision_increase_allowed_types: t.Optional[t.Set[exp.DataType.Type]] = None support_coercing_compatible_types: bool = False drop_cascade: bool = False parameterized_type_defaults: t.Dict[ @@ -367,7 +368,10 @@ def coerceable_types(self) -> t.Dict[exp.DataType, t.Set[exp.DataType]]: def _is_compatible_type(self, current_type: exp.DataType, new_type: exp.DataType) -> bool: # types are identical or both types are parameterized and new has higher precision # - default parameter values are automatically provided if not present - if current_type == new_type or self._is_precision_increase(current_type, new_type): + if current_type == new_type or ( + self._is_precision_increase_allowed(current_type) + and self._is_precision_increase(current_type, new_type) + ): return True # types are un-parameterized and compatible if current_type in self.compatible_types: @@ -390,6 +394,12 @@ def _is_coerceable_type(self, current_type: exp.DataType, new_type: exp.DataType return is_coerceable return False + def _is_precision_increase_allowed(self, current_type: exp.DataType) -> bool: + return ( + self.precision_increase_allowed_types is None + or current_type.this in self.precision_increase_allowed_types + ) + def _is_precision_increase(self, current_type: exp.DataType, new_type: exp.DataType) -> bool: if current_type.this == new_type.this and not current_type.is_type( *exp.DataType.NESTED_TYPES diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index 930b0a607c..6d1bbcf61f 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -366,6 +366,55 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: ] +def test_alter_table_precision_increase_varchar(adapter: t.Callable): + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("VARCHAR(10)"), + } + else: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("VARCHAR(20)"), + } + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" ALTER COLUMN "test_column" TYPE VARCHAR(20)', + ] + + +def test_alter_table_precision_increase_decimal(adapter: t.Callable): + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("DECIMAL(10, 10)"), + } + else: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("DECIMAL(25, 10)"), + } + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE', + 'ALTER TABLE "test_table" ADD COLUMN "test_column" DECIMAL(25, 10)', + ] + + def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) mocker.patch( diff --git a/tests/core/test_schema_diff.py b/tests/core/test_schema_diff.py index 3a921598a1..85f4d424a3 100644 --- a/tests/core/test_schema_diff.py +++ b/tests/core/test_schema_diff.py @@ -884,6 +884,42 @@ def test_schema_diff_calculate_type_transitions(): ], {}, ), + # Increase the precision of a type is ALTER when the type is supported + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.alter_type( + TableAlterColumn.primitive("address"), + "VARCHAR(121)", + current_type="VARCHAR(120)", + expected_table_struct="STRUCT", + ) + ], + dict( + precision_increase_allowed_types={exp.DataType.build("VARCHAR").this}, + ), + ), + # Increase the precision of a type is DROP/ADD when the type is not supported + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.drop( + TableAlterColumn.primitive("address"), + "STRUCT", + "VARCHAR(120)", + ), + TableAlterOperation.add( + TableAlterColumn.primitive("address"), + "VARCHAR(121)", + expected_table_struct="STRUCT", + ), + ], + dict( + precision_increase_allowed_types={exp.DataType.build("DECIMAL").this}, + ), + ), # Decrease the precision of a type is DROP/ADD ( "STRUCT",