Skip to content

Commit

Permalink
fix: Fix migration for removing time_range_endpoints 3 (#19767)
Browse files Browse the repository at this point in the history
* fix migration

* so dumb

* update test

* add code change

* retest
  • Loading branch information
hughhhh committed Apr 19, 2022
1 parent 4ba62ca commit 7e92340
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""rm_time_range_endpoints_from_qc_3
Revision ID: ad07e4fdbaba
Revises: cecc6bf46990
Create Date: 2022-04-18 11:20:47.390901
"""

# revision identifiers, used by Alembic.
revision = "ad07e4fdbaba"
down_revision = "cecc6bf46990"

import json

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base

from superset import db

Base = declarative_base()


class Slice(Base):
__tablename__ = "slices"
id = sa.Column(sa.Integer, primary_key=True)
query_context = sa.Column(sa.Text)
slice_name = sa.Column(sa.String(250))


def upgrade_slice(slc: Slice):
try:
query_context = json.loads(slc.query_context)
except json.decoder.JSONDecodeError:
return

query_context.get("form_data", {}).pop("time_range_endpoints", None)

if query_context.get("queries"):
queries = query_context["queries"]
for query in queries:
query.get("extras", {}).pop("time_range_endpoints", None)

slc.query_context = json.dumps(query_context)

return slc


def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
slices_updated = 0
for slc in (
session.query(Slice)
.filter(Slice.query_context.like("%time_range_endpoints%"))
.all()
):
updated_slice = upgrade_slice(slc)
if updated_slice:
slices_updated += 1

print(f"slices updated with no time_range_endpoints: {slices_updated}")
session.commit()
session.close()


def downgrade():
pass
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,9 @@
revision = "cecc6bf46990"
down_revision = "9d8a8d575284"

import json

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base

from superset import db

Base = declarative_base()


class Slice(Base):
__tablename__ = "slices"
id = sa.Column(sa.Integer, primary_key=True)
query_context = sa.Column(sa.Text)
slice_name = sa.Column(sa.String(250))


def upgrade_slice(slc: Slice):
try:
query_context = json.loads(slc.query_context)
except json.decoder.JSONDecodeError:
return

queries = query_context.get("queries")

for query in queries:
query.get("extras", {}).pop("time_range_endpoints", None)

slc.query_context = json.dumps(query_context)


def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)
for slc in session.query(Slice).filter(
Slice.query_context.like("%time_range_endpoints%")
):
upgrade_slice(slc)

session.commit()
session.close()
pass


def downgrade():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import json

from superset.migrations.versions.cecc6bf46990_rm_time_range_endpoints_2 import (
from superset.migrations.versions.ad07e4fdbaba_rm_time_range_endpoints_from_qc_3 import (
Slice,
upgrade_slice,
)
Expand Down Expand Up @@ -106,7 +106,9 @@
"post_processing": [],
}
],
"form_data": {},
"form_data": {
"time_range_endpoints": ["inclusive", "exclusive"],
},
"result_format": "json",
"result_type": "full",
}
Expand All @@ -123,6 +125,9 @@ def test_upgrade():
extras = q.get("extras", {})
assert "time_range_endpoints" not in extras

form_data = query_context.get("form_data", {})
assert "time_range_endpoints" not in form_data


def test_upgrade_bad_json():
slc = Slice(slice_name="FOO", query_context="abc")
Expand Down

0 comments on commit 7e92340

Please sign in to comment.