diff --git a/caravel/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py b/caravel/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py index af96538849aa..9b1d8018c341 100644 --- a/caravel/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py +++ b/caravel/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py @@ -12,27 +12,16 @@ from alembic import op import sqlalchemy as sa +from caravel.utils import generic_find_constraint_name naming_convention = { "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", } -def find_constraint_name(upgrade = True): - __table = 'columns' - __cols = {'column_name'} if upgrade else {'datasource_name'} - __referenced = 'datasources' - __ref_cols = {'datasource_name'} if upgrade else {'column_name'} - - engine = op.get_bind().engine - m = sa.MetaData({}) - t=sa.Table(__table,m, autoload=True, autoload_with=engine) - - for fk in t.foreign_key_constraints: - if fk.referred_table.name == __referenced and \ - set(fk.column_keys) == __cols: - return fk.name - return None +def find_constraint_name(upgrade=True): + cols = {'column_name'} if upgrade else {'datasource_name'} + return generic_find_constraint_name(table='columns', columns=cols, referenced='datasources') def upgrade(): constraint = find_constraint_name() or 'fk_columns_column_name_datasources' @@ -47,4 +36,3 @@ def downgrade(): naming_convention=naming_convention) as batch_op: batch_op.drop_constraint(constraint, type_="foreignkey") batch_op.create_foreign_key('fk_columns_column_name_datasources', 'datasources', ['column_name'], ['datasource_name']) - \ No newline at end of file diff --git a/caravel/utils.py b/caravel/utils.py index 15c0d75cd7ff..9a48b42a6736 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -11,7 +11,9 @@ from datetime import datetime import parsedatetime +import sqlalchemy as sa from dateutil.parser import parse +from alembic import op from flask import flash, Markup from flask_appbuilder.security.sqla import models as ab_models from markdown import markdown as md @@ -255,3 +257,18 @@ def readfile(filepath): with open(filepath) as f: content = f.read() return content + + +def generic_find_constraint_name(table, columns, referenced): + """ + Utility to find a constraint name in alembic migrations + """ + engine = op.get_bind().engine + m = sa.MetaData({}) + t = sa.Table(table, m, autoload=True, autoload_with=engine) + + for fk in t.foreign_key_constraints: + if fk.referred_table.name == referenced and \ + set(fk.column_keys) == columns: + return fk.name + return None