Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/engine_adapter/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class RedshiftEngineAdapter(
exp.DataType.build("CHAR", dialect=DIALECT).this: 4096,
exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535,
},
precision_increase_allowed_types={exp.DataType.build("VARCHAR", dialect=DIALECT).this},
drop_cascade=True,
)
VARIABLE_LENGTH_DATA_TYPES = {
Expand Down
12 changes: 11 additions & 1 deletion sqlmesh/core/schema_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ class SchemaDiffer(PydanticModel):
coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field(
default_factory=dict, alias="coerceable_types"
)
precision_increase_allowed_types: t.Optional[t.Set[exp.DataType.Type]] = None
support_coercing_compatible_types: bool = False
drop_cascade: bool = False
parameterized_type_defaults: t.Dict[
Expand All @@ -367,7 +368,10 @@ def coerceable_types(self) -> t.Dict[exp.DataType, t.Set[exp.DataType]]:
def _is_compatible_type(self, current_type: exp.DataType, new_type: exp.DataType) -> bool:
# types are identical or both types are parameterized and new has higher precision
# - default parameter values are automatically provided if not present
if current_type == new_type or self._is_precision_increase(current_type, new_type):
if current_type == new_type or (
self._is_precision_increase_allowed(current_type)
and self._is_precision_increase(current_type, new_type)
):
return True
# types are un-parameterized and compatible
if current_type in self.compatible_types:
Expand All @@ -390,6 +394,12 @@ def _is_coerceable_type(self, current_type: exp.DataType, new_type: exp.DataType
return is_coerceable
return False

def _is_precision_increase_allowed(self, current_type: exp.DataType) -> bool:
return (
self.precision_increase_allowed_types is None
or current_type.this in self.precision_increase_allowed_types
)

def _is_precision_increase(self, current_type: exp.DataType, new_type: exp.DataType) -> bool:
if current_type.this == new_type.this and not current_type.is_type(
*exp.DataType.NESTED_TYPES
Expand Down
49 changes: 49 additions & 0 deletions tests/core/engine_adapter/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,55 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
]


def test_alter_table_precision_increase_varchar(adapter: t.Callable):
current_table_name = "test_table"
target_table_name = "target_table"

def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
if table_name == current_table_name:
return {
"id": exp.DataType.build("int"),
"test_column": exp.DataType.build("VARCHAR(10)"),
}
else:
return {
"id": exp.DataType.build("int"),
"test_column": exp.DataType.build("VARCHAR(20)"),
}

adapter.columns = table_columns

adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name))
assert to_sql_calls(adapter) == [
'ALTER TABLE "test_table" ALTER COLUMN "test_column" TYPE VARCHAR(20)',
]


def test_alter_table_precision_increase_decimal(adapter: t.Callable):
current_table_name = "test_table"
target_table_name = "target_table"

def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
if table_name == current_table_name:
return {
"id": exp.DataType.build("int"),
"test_column": exp.DataType.build("DECIMAL(10, 10)"),
}
else:
return {
"id": exp.DataType.build("int"),
"test_column": exp.DataType.build("DECIMAL(25, 10)"),
}

adapter.columns = table_columns

adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name))
assert to_sql_calls(adapter) == [
'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE',
'ALTER TABLE "test_table" ADD COLUMN "test_column" DECIMAL(25, 10)',
]


def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)
mocker.patch(
Expand Down
36 changes: 36 additions & 0 deletions tests/core/test_schema_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,42 @@ def test_schema_diff_calculate_type_transitions():
],
{},
),
# Increase the precision of a type is ALTER when the type is supported
(
"STRUCT<id INT, address VARCHAR(120)>",
"STRUCT<id INT, address VARCHAR(121)>",
[
TableAlterOperation.alter_type(
TableAlterColumn.primitive("address"),
"VARCHAR(121)",
current_type="VARCHAR(120)",
expected_table_struct="STRUCT<id INT, address VARCHAR(121)>",
)
],
dict(
precision_increase_allowed_types={exp.DataType.build("VARCHAR").this},
),
),
# Increase the precision of a type is DROP/ADD when the type is not supported
(
"STRUCT<id INT, address VARCHAR(120)>",
"STRUCT<id INT, address VARCHAR(121)>",
[
TableAlterOperation.drop(
TableAlterColumn.primitive("address"),
"STRUCT<id INT>",
"VARCHAR(120)",
),
TableAlterOperation.add(
TableAlterColumn.primitive("address"),
"VARCHAR(121)",
expected_table_struct="STRUCT<id INT, address VARCHAR(121)>",
),
],
dict(
precision_increase_allowed_types={exp.DataType.build("DECIMAL").this},
),
),
# Decrease the precision of a type is DROP/ADD
(
"STRUCT<id INT, address VARCHAR(120)>",
Expand Down