Skip to content

Commit

Permalink
Change slice ids in the position json during dashboard import. (#1380)
Browse files Browse the repository at this point in the history
* Change slice ids in the position json during dashboard import.

* Update slice ids in the dashboard json metadata.
  • Loading branch information
bkyryliuk committed Oct 20, 2016
1 parent ece69fb commit c198535
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 12 deletions.
57 changes: 54 additions & 3 deletions caravel/models.py
Expand Up @@ -449,6 +449,12 @@ def params(self):
def params(self, value):
self.json_metadata = value

@property
def position_array(self):
if self.position_json:
return json.loads(self.position_json)
return []

@classmethod
def import_obj(cls, dashboard_to_import, import_time=None):
"""Imports the dashboard from the object to the database.
Expand All @@ -460,6 +466,28 @@ def import_obj(cls, dashboard_to_import, import_time=None):
to import/export dashboards between multiple caravel instances.
Audit metadata isn't copies over.
"""
def alter_positions(dashboard, old_to_new_slc_id_dict):
""" Updates slice_ids in the position json.
Sample position json:
[{
"col": 5,
"row": 10,
"size_x": 4,
"size_y": 2,
"slice_id": "3610"
}]
"""
position_array = dashboard.position_array
for position in position_array:
if 'slice_id' not in position:
continue
old_slice_id = int(position['slice_id'])
if old_slice_id in old_to_new_slc_id_dict:
position['slice_id'] = '{}'.format(
old_to_new_slc_id_dict[old_slice_id])
dashboard.position_json = json.dumps(position_array)

logging.info('Started import of the dashboard: {}'
.format(dashboard_to_import.to_json()))
session = db.session
Expand All @@ -468,11 +496,25 @@ def import_obj(cls, dashboard_to_import, import_time=None):
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
slices = copy(dashboard_to_import.slices)
slice_ids = set()
old_to_new_slc_id_dict = {}
new_filter_immune_slices = []
new_expanded_slices = {}
i_params_dict = dashboard_to_import.params_dict
for slc in slices:
logging.info('Importing slice {} from the dashboard: {}'.format(
slc.to_json(), dashboard_to_import.dashboard_title))
slice_ids.add(Slice.import_obj(slc, import_time=import_time))
new_slc_id = Slice.import_obj(slc, import_time=import_time)
old_to_new_slc_id_dict[slc.id] = new_slc_id
# update json metadata that deals with slice ids
if ('filter_immune_slices' in i_params_dict and
slc.id in i_params_dict['filter_immune_slices']):
new_filter_immune_slices.append(new_slc_id)
new_slc_id_str = '{}'.format(new_slc_id)
old_slc_id_str = '{}'.format(slc.id)
if ('expanded_slices' in i_params_dict and
old_slc_id_str in i_params_dict['expanded_slices']):
new_expanded_slices[new_slc_id_str] = (
i_params_dict['expanded_slices'][old_slc_id_str])

# override the dashboard
existing_dashboard = None
Expand All @@ -483,8 +525,17 @@ def import_obj(cls, dashboard_to_import, import_time=None):
existing_dashboard = dash

dashboard_to_import.id = None
alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
dashboard_to_import.alter_params(import_time=import_time)
new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
if new_expanded_slices:
dashboard_to_import.alter_params(
expanded_slices=new_expanded_slices)
if new_filter_immune_slices:
dashboard_to_import.alter_params(
filter_immune_slices=new_filter_immune_slices)

new_slices = session.query(Slice).filter(
Slice.id.in_(old_to_new_slc_id_dict.values())).all()

if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
Expand Down
58 changes: 49 additions & 9 deletions tests/import_export_tests.py
Expand Up @@ -97,6 +97,10 @@ def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]):
def get_slice(self, slc_id):
return db.session.query(models.Slice).filter_by(id=slc_id).first()

def get_slice_by_name(self, name):
return db.session.query(models.Slice).filter_by(
slice_name=name).first()

def get_dash(self, dash_id):
return db.session.query(models.Dashboard).filter_by(
id=dash_id).first()
Expand All @@ -113,12 +117,11 @@ def get_table_by_name(self, name):
return db.session.query(models.SqlaTable).filter_by(
table_name=name).first()

def assert_dash_equals(self, expected_dash, actual_dash):
def assert_dash_equals(self, expected_dash, actual_dash,
check_position=True):
self.assertEquals(expected_dash.slug, actual_dash.slug)
self.assertEquals(
expected_dash.dashboard_title, actual_dash.dashboard_title)
self.assertEquals(
expected_dash.position_json, actual_dash.position_json)
self.assertEquals(
len(expected_dash.slices), len(actual_dash.slices))
expected_slices = sorted(
Expand All @@ -127,6 +130,9 @@ def assert_dash_equals(self, expected_dash, actual_dash):
actual_dash.slices, key=lambda s: s.slice_name)
for e_slc, a_slc in zip(expected_slices, actual_slices):
self.assert_slice_equals(e_slc, a_slc)
if check_position:
self.assertEquals(
expected_dash.position_json, actual_dash.position_json)

def assert_table_equals(self, expected_ds, actual_ds):
self.assertEquals(expected_ds.table_name, actual_ds.table_name)
Expand Down Expand Up @@ -221,7 +227,6 @@ def test_import_2_slices_for_same_table(self):
self.assert_slice_equals(slc_1, imported_slc_1)
self.assertEquals(imported_slc_1.datasource.perm, imported_slc_1.perm)


self.assertEquals(table_id, imported_slc_2.datasource_id)
self.assert_slice_equals(slc_2, imported_slc_2)
self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm)
Expand All @@ -246,37 +251,71 @@ def test_import_empty_dashboard(self):
imported_dash_id = models.Dashboard.import_obj(
empty_dash, import_time=1989)
imported_dash = self.get_dash(imported_dash_id)
self.assert_dash_equals(empty_dash, imported_dash)
self.assert_dash_equals(
empty_dash, imported_dash, check_position=False)

def test_import_dashboard_1_slice(self):
slc = self.create_slice('health_slc', id=10006)
dash_with_1_slice = self.create_dashboard(
'dash_with_1_slice', slcs=[slc], id=10002)
dash_with_1_slice.position_json = """
[{{
"col": 5,
"row": 10,
"size_x": 4,
"size_y": 2,
"slice_id": "{}"
}}]
""".format(slc.id)
imported_dash_id = models.Dashboard.import_obj(
dash_with_1_slice, import_time=1990)
imported_dash = self.get_dash(imported_dash_id)

expected_dash = self.create_dashboard(
'dash_with_1_slice', slcs=[slc], id=10002)
make_transient(expected_dash)
self.assert_dash_equals(expected_dash, imported_dash)
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False)
self.assertEquals({"remote_id": 10002, "import_time": 1990},
json.loads(imported_dash.json_metadata))

expected_position = dash_with_1_slice.position_array
expected_position[0]['slice_id'] = '{}'.format(
imported_dash.slices[0].id)
self.assertEquals(expected_position, imported_dash.position_array)

def test_import_dashboard_2_slices(self):
e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage')
b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names')
dash_with_2_slices = self.create_dashboard(
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
dash_with_2_slices.json_metadata = json.dumps({
"remote_id": 10003,
"filter_immune_slices": [e_slc.id],
"expanded_slices": {e_slc.id: True, b_slc.id: False}
})

imported_dash_id = models.Dashboard.import_obj(
dash_with_2_slices, import_time=1991)
imported_dash = self.get_dash(imported_dash_id)

expected_dash = self.create_dashboard(
'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003)
make_transient(expected_dash)
self.assert_dash_equals(imported_dash, expected_dash)
self.assertEquals({"remote_id": 10003, "import_time": 1991},
self.assert_dash_equals(
imported_dash, expected_dash, check_position=False)
i_e_slc = self.get_slice_by_name('e_slc')
i_b_slc = self.get_slice_by_name('b_slc')
expected_json_metadata = {
"remote_id": 10003,
"import_time": 1991,
"filter_immune_slices": [i_e_slc.id],
"expanded_slices": {
'{}'.format(i_e_slc.id): True,
'{}'.format(i_b_slc.id): False
}
}
self.assertEquals(expected_json_metadata,
json.loads(imported_dash.json_metadata))

def test_import_override_dashboard_2_slices(self):
Expand Down Expand Up @@ -304,7 +343,8 @@ def test_import_override_dashboard_2_slices(self):
'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004)
make_transient(expected_dash)
imported_dash = self.get_dash(imported_dash_id_2)
self.assert_dash_equals(expected_dash, imported_dash)
self.assert_dash_equals(
expected_dash, imported_dash, check_position=False)
self.assertEquals({"remote_id": 10004, "import_time": 1992},
json.loads(imported_dash.json_metadata))

Expand Down

0 comments on commit c198535

Please sign in to comment.