Skip to content

Commit

Permalink
fix: overwrite update override columns on PUT /dataset (#20862)
Browse files Browse the repository at this point in the history
* update override columns

* save

* fix overwrite with session.flush

* write test

* write test

* layup

* address concerns

* address concerns
  • Loading branch information
hughhhh committed Jul 30, 2022
1 parent 67e3dc7 commit bc435e0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 15 deletions.
3 changes: 2 additions & 1 deletion superset/datasets/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def __init__(
self,
model_id: int,
data: Dict[str, Any],
override_columns: bool = False,
override_columns: Optional[bool] = False,
):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[SqlaTable] = None
self.override_columns = override_columns
self._properties["override_columns"] = override_columns

def run(self) -> Model:
self.validate()
Expand Down
50 changes: 36 additions & 14 deletions superset/datasets/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,22 @@ def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bo

@classmethod
def update(
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
cls,
model: SqlaTable,
properties: Dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
"""

if "columns" in properties:
cls.update_columns(model, properties.pop("columns"), commit=commit)
cls.update_columns(
model,
properties.pop("columns"),
commit=commit,
override_columns=bool(properties.get("override_columns")),
)

if "metrics" in properties:
cls.update_metrics(model, properties.pop("metrics"), commit=commit)
Expand All @@ -167,6 +175,7 @@ def update_columns(
model: SqlaTable,
property_columns: List[Dict[str, Any]],
commit: bool = True,
override_columns: bool = False,
) -> None:
"""
Creates/updates and/or deletes a list of columns, based on a
Expand All @@ -180,24 +189,37 @@ def update_columns(

column_by_id = {column.id: column for column in model.columns}
seen = set()
original_cols = {obj.id for obj in model.columns}

for properties in property_columns:
if "id" in properties:
seen.add(properties["id"])
if override_columns:
for id_ in original_cols:
DatasetDAO.delete_column(column_by_id[id_], commit=False)

DatasetDAO.update_column(
column_by_id[properties["id"]],
properties,
commit=False,
)
else:
db.session.flush()

for properties in property_columns:
DatasetDAO.create_column(
{**properties, "table_id": model.id},
commit=False,
)

for id_ in {obj.id for obj in model.columns} - seen:
DatasetDAO.delete_column(column_by_id[id_], commit=False)
else:
for properties in property_columns:
if "id" in properties:
seen.add(properties["id"])

DatasetDAO.update_column(
column_by_id[properties["id"]],
properties,
commit=False,
)
else:
DatasetDAO.create_column(
{**properties, "table_id": model.id},
commit=False,
)

for id_ in {obj.id for obj in model.columns} - seen:
DatasetDAO.delete_column(column_by_id[id_], commit=False)

if commit:
db.session.commit()
Expand Down
50 changes: 50 additions & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,56 @@ def test_update_dataset_item_w_override_columns(self):
db.session.delete(dataset)
db.session.commit()

def test_update_dataset_item_w_override_columns_same_columns(self):
"""
Dataset API: Test update dataset with override columns
"""
if backend() == "sqlite":
return

# Add default dataset
main_db = get_main_database()
dataset = self.insert_default_dataset()
prev_col_len = len(dataset.columns)

cols = [
{
"column_name": c.column_name,
"description": c.description,
"expression": c.expression,
"type": c.type,
"advanced_data_type": c.advanced_data_type,
"verbose_name": c.verbose_name,
}
for c in dataset.columns
]

cols.append(
{
"column_name": "new_col",
"description": "description",
"expression": "expression",
"type": "INTEGER",
"advanced_data_type": "ADVANCED_DATA_TYPE",
"verbose_name": "New Col",
}
)

self.login(username="admin")
dataset_data = {
"columns": cols,
}
uri = f"api/v1/dataset/{dataset.id}?override_columns=true"
rv = self.put_assert_metric(uri, dataset_data, "put")

assert rv.status_code == 200

columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all()
assert len(columns) != prev_col_len
assert len(columns) == 3
db.session.delete(dataset)
db.session.commit()

def test_update_dataset_create_column_and_metric(self):
"""
Dataset API: Test update dataset create column
Expand Down

0 comments on commit bc435e0

Please sign in to comment.