Skip to content

Commit

Permalink
Ensure interactive visualization can be run iteratively (#541)
Browse files Browse the repository at this point in the history
* Ensure interactive visualization can be run iteratively

* Handle case where numeric types are passed in as renamed clusters

* Remove extraneous comment

* Add updated notebook process

* Add tests for metaclustering visualization with rename column passed

* Remove extraneous comment from previous inline version

* Remove another extraneous cell notebook comment

* Turn off overlay axes and grid lines

* Turn off label for color axis of heatmap

* Need to add conditional check if row_colors or col_colors are specified

* Add test for row_ and col_colors

* Need to explicitly check for not None
  • Loading branch information
alex-l-kong committed May 19, 2022
1 parent 1b78fc1 commit ad9fd61
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 99 deletions.
31 changes: 28 additions & 3 deletions ark/analysis/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def draw_boxplot(cell_data, col_name, col_split=None,

def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None, min_val=None, max_val=None,
cbar_ticks=None, colormap="vlag", row_colors=None, row_cluster=True,
col_colors=None, col_cluster=True, save_dir=None, save_file=None):
col_colors=None, col_cluster=True, left_start=None, right_start=None,
w_spacing=None, h_spacing=None, save_dir=None, save_file=None):
"""Plots the z scores between all phenotypes as a clustermap.
Args:
Expand Down Expand Up @@ -100,6 +101,14 @@ def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None, min_val=No
Include these values as an additional color-coded cluster bar for column values
col_cluster (bool):
Whether to include dendrogram clustering for the columns
left_start (float):
The position to set the left edge of the figure to (from 0-1)
right_start (float):
The position to set the right edge of the figure to (from 0-1)
w_spacing (float):
The amount of spacing to put between the subplots width-wise (from 0-1)
h_spacing (float):
The amount of spacing to put between the subplots height-wise (from 0-1)
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
Expand All @@ -115,12 +124,28 @@ def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None, min_val=No
data_df = pd.DataFrame(data, index=x_labels, columns=y_labels)
sns.set(font_scale=.7)

sns.clustermap(
heatmap = sns.clustermap(
data_df, cmap=colormap, center=center_val,
vmin=min_val, vmax=max_val, row_colors=row_colors, row_cluster=row_cluster,
col_colors=col_colors, col_cluster=col_cluster, cbar_kws={'ticks': cbar_ticks}
col_colors=col_colors, col_cluster=col_cluster,
cbar_kws={'ticks': cbar_ticks}
)

# ensure the row color axis doesn't have a label attacked to it
if row_colors is not None:
_ = heatmap.ax_row_colors.xaxis.set_visible(False)

if col_colors is not None:
_ = heatmap.ax_col_colors.yaxis.set_visible(False)

# update the figure dimensions to accommodate Jupyter widget backend
_ = heatmap.gs.update(
left=left_start, right=right_start, wspace=w_spacing, hspace=h_spacing
)

# ensure the y-axis labels are horizontal, will be misaligned if vertical
_ = plt.setp(heatmap.ax_heatmap.get_yticklabels(), rotation=0)

if save_dir is not None:
misc_utils.save_figure(save_dir, save_file, dpi=dpi)

Expand Down
21 changes: 21 additions & 0 deletions ark/analysis/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ def test_draw_heatmap():
save_dir=temp_dir, save_file="z_score_viz.png")
assert os.path.exists(os.path.join(temp_dir, "z_score_viz.png"))

# test row_colors drawing functionality
row_colors = [(0.0, 0.0, 0.0, 0.0) for i in np.arange(26)]
visualize.draw_heatmap(
z, pheno_titles, pheno_titles, row_colors=row_colors, save_file="z_score_viz.png"
)
assert os.path.exists(os.path.join(temp_dir, "z_score_viz.png"))

# test col_colors drawing functionality
col_colors = [(0.0, 0.0, 0.0, 0.0) for i in np.arange(26)]
visualize.draw_heatmap(
z, pheno_titles, pheno_titles, col_colors=col_colors, save_file="z_score_viz.png"
)
assert os.path.exists(os.path.join(temp_dir, "z_score_viz.png"))

# test row_colors and col_colors
visualize.draw_heatmap(
z, pheno_titles, pheno_titles, row_colors=row_colors,
col_colors=col_colors, save_file="z_score_viz.png"
)
assert os.path.exists(os.path.join(temp_dir, "z_score_viz.png"))


def test_draw_boxplot():
# trim random data so we don't have to visualize as many facets
Expand Down
17 changes: 12 additions & 5 deletions ark/phenotyping/som_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cl
Path to the file containing the average weighted channel expression per cell cluster
cell_cluster_col (str):
The name of the cell cluster col,
needs to be either 'cell_som_cluster' or 'cell_meta_cluster'
needs to be either 'cell_som_cluster' or 'cell_meta_cluster_rename'
channels (str):
The list of channels to visualize
raw_cmap (dict):
Expand All @@ -1883,7 +1883,7 @@ def generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cl
# verify the cell_cluster_col provided is valid
misc_utils.verify_in_list(
provided_cluster_col=[cell_cluster_col],
valid_cluster_cols=['cell_som_cluster', 'cell_meta_cluster']
valid_cluster_cols=['cell_som_cluster', 'cell_meta_cluster_rename']
)

# read the channel average path
Expand All @@ -1897,11 +1897,15 @@ def generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cl

# sort the data by the meta cluster value
# this ensures the meta clusters are grouped together when the colormap is displayed
cell_cluster_channel_avgs = cell_cluster_channel_avgs.sort_values(by='cell_meta_cluster')
cell_cluster_channel_avgs = cell_cluster_channel_avgs.sort_values(
by='cell_meta_cluster_rename'
)

# map raw_cmap onto cell_cluster_channel_avgs for the heatmap to display the side color bar
meta_cluster_index = cell_cluster_channel_avgs[cell_cluster_col].values
meta_cluster_mapping = pd.Series(cell_cluster_channel_avgs['cell_meta_cluster']).map(raw_cmap)
meta_cluster_mapping = pd.Series(
cell_cluster_channel_avgs['cell_meta_cluster_rename']
).map(renamed_cmap)
meta_cluster_mapping.index = meta_cluster_index

# draw the heatmap
Expand All @@ -1915,6 +1919,9 @@ def generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cl
cbar_ticks=np.arange(-3, 4),
row_colors=meta_cluster_mapping,
row_cluster=False,
left_start=0.0,
right_start=0.85,
w_spacing=0.2,
colormap='vlag'
)

Expand All @@ -1924,7 +1931,7 @@ def generate_weighted_channel_avg_heatmap(cell_cluster_channel_avg_path, cell_cl
handles,
renamed_cmap,
title='Meta cluster',
bbox_to_anchor=(1.1, 1),
bbox_to_anchor=(1, 1),
bbox_transform=plt.gcf().transFigure,
loc='upper right'
)
2 changes: 1 addition & 1 deletion ark/phenotyping/som_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,5 +2367,5 @@ def test_generate_weighted_channel_avg_heatmap():
# assert visualization runs
som_utils.generate_weighted_channel_avg_heatmap(
os.path.join(temp_dir, 'sample_channel_avg.csv'),
'cell_meta_cluster', ['chan1', 'chan2'], raw_cmap, renamed_cmap
'cell_meta_cluster_rename', ['chan1', 'chan2'], raw_cmap, renamed_cmap
)
24 changes: 22 additions & 2 deletions ark/utils/metacluster_remap_gui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,22 @@ def simple_clusters_df():
(0.1, 0.1, 0.3, 2, 2),
(0.5, 0.1, 0.1, 3, 3),
(0.7, 0.2, 0.1, 4, 3),
]
]
return pd.DataFrame(data=clusters_data, columns=clusters_headers)


@pytest.fixture()
def simple_clusters_meta_rename_df():
"""Minimal example data for cluster data"""
clusters_headers = [
'CD163', 'CD206', 'CD31', 'cluster', 'metacluster', 'metacluster_rename'
]
clusters_data = [
(0.1, 0.2, 0.1, 1, 1, 'cluster_1'),
(0.1, 0.1, 0.3, 2, 2, 'cluster_2'),
(0.5, 0.1, 0.1, 3, 3, 'cluster_3'),
(0.7, 0.2, 0.1, 4, 3, 'cluster_3'),
]
return pd.DataFrame(data=clusters_data, columns=clusters_headers)


Expand All @@ -27,7 +42,7 @@ def simple_pixelcount_df():
(2, 10),
(3, 50),
(4, 77),
]
]
return pd.DataFrame(data=pixelcount_data, columns=pixelcount_headers)


Expand All @@ -49,6 +64,11 @@ def simple_metaclusterdata(simple_clusters_df, simple_pixelcount_df):
return MetaClusterData(simple_clusters_df, simple_pixelcount_df)


@pytest.fixture()
def simple_metaclusterdata_rename(simple_clusters_meta_rename_df, simple_pixelcount_df):
return MetaClusterData(simple_clusters_meta_rename_df, simple_pixelcount_df)


@pytest.fixture(autouse=True)
def test_plot_fn(monkeypatch):
"""Make plt.show impotent for all tests"""
Expand Down
3 changes: 2 additions & 1 deletion ark/utils/metacluster_remap_gui/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def metaclusterdata_from_files(cluster_path, cluster_type='pixel', prefix_trim=N
# with {cluster_type}_{som/meta}_cluster, not high priority
cluster_data = cluster_data.rename(columns={
'%s_som_cluster' % cluster_type: 'cluster',
'%s_meta_cluster' % cluster_type: 'metacluster'
'%s_meta_cluster' % cluster_type: 'metacluster',
'%s_meta_cluster_rename' % cluster_type: 'metacluster_rename'
})

if 'cluster' not in cluster_data.columns:
Expand Down
30 changes: 28 additions & 2 deletions ark/utils/metacluster_remap_gui/metaclusterdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def __init__(self, raw_clusters_df, raw_pixelcounts_df):
self._clusters = sorted_clusters_df.set_index('cluster').drop(columns='metacluster')
self.mapping = sorted_clusters_df[['cluster', 'metacluster']].set_index('cluster')
self._metacluster_displaynames_map = {}

# need to prefill the displaynames_map with the already renamed meta clusters
# on subsequent runs after the first to prevent automatic incremental rewriting
if 'metacluster_rename' in sorted_clusters_df.columns:
unique_mappings = sorted_clusters_df[
['metacluster', 'metacluster_rename']
].drop_duplicates()

self._metacluster_displaynames_map = {
mc['metacluster']: str(mc['metacluster_rename'])
for _, mc in unique_mappings.iterrows()
}

self._marker_order = list(range(len(self._clusters.columns)))
self._output_mapping_filename = None
self._cached_metaclusters = None
Expand All @@ -37,11 +50,24 @@ def output_mapping_filename(self, filepath):
@property
def clusters_with_metaclusters(self):
df = self._clusters.join(self.mapping).sort_values(by='metacluster')
return df.iloc[:, self._marker_order + [max(self._marker_order) + 1]]

# NOTE: this method takes into account both the initial run (without _rename column)
# and subsequent runs (with _rename columns)
return df.iloc[:, self._marker_order + list(
range(max(self._marker_order) + 1, len(df.columns.values))
)]

@property
def clusters(self):
return self.clusters_with_metaclusters.drop(columns='metacluster')
# maintain old clusters_with_metaclusters
clusters_data = self.clusters_with_metaclusters.copy()

# we need to drop the rename column on subsequent runs after the first
if 'metacluster_rename' in self.clusters_with_metaclusters.columns:
clusters_data = clusters_data.drop(columns='metacluster_rename')

# metacluster column needs to be dropped regardless of run
return clusters_data.drop(columns='metacluster')

@property
def metacluster_displaynames(self):
Expand Down
92 changes: 79 additions & 13 deletions ark/utils/metacluster_remap_gui/metaclusterdata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,42 @@
from .metaclusterdata import MetaClusterData


def test_can_get_mapping(simple_metaclusterdata: MetaClusterData):
def test_can_get_mapping(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
np.testing.assert_array_equal(
simple_metaclusterdata.mapping['metacluster'].values,
np.array((1, 2, 3, 3)))
np.array((1, 2, 3, 3))
)

np.testing.assert_array_equal(
simple_metaclusterdata_rename.mapping['metacluster'].values,
np.array((1, 2, 3, 3))
)


def test_can_remap(simple_metaclusterdata: MetaClusterData):
def test_can_remap(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
simple_metaclusterdata.remap(4, 1)
assert simple_metaclusterdata.mapping.loc[4, 'metacluster'] == 1

simple_metaclusterdata_rename.remap(4, 1)
assert simple_metaclusterdata_rename.mapping.loc[4, 'metacluster'] == 1


def test_can_create_new_metacluster(simple_metaclusterdata: MetaClusterData):
def test_can_create_new_metacluster(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
new_mc = simple_metaclusterdata.new_metacluster()

simple_metaclusterdata.remap(4, new_mc)
assert simple_metaclusterdata.mapping.loc[4, 'metacluster'] == 4

simple_metaclusterdata_rename.remap(4, new_mc)
assert simple_metaclusterdata_rename.mapping.loc[4, 'metacluster'] == 4


def test_can_save_mapping(simple_metaclusterdata: MetaClusterData, tmp_path):
def test_can_save_mapping(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData,
tmp_path):
simple_metaclusterdata.output_mapping_filename = tmp_path / 'output_mapping.csv'
simple_metaclusterdata.save_output_mapping()
with open(tmp_path / 'output_mapping.csv', 'r') as f:
Expand All @@ -31,40 +49,79 @@ def test_can_save_mapping(simple_metaclusterdata: MetaClusterData, tmp_path):
"2,2,2",
"3,3,3",
"4,3,3",
]
]

simple_metaclusterdata_rename.output_mapping_filename = tmp_path / 'output_mapping.csv'
simple_metaclusterdata_rename.save_output_mapping()
with open(tmp_path / 'output_mapping.csv', 'r') as f:
output = [ll.strip() for ll in f.readlines()]
assert output == [
"cluster,metacluster,mc_name",
"1,1,cluster_1",
"2,2,cluster_2",
"3,3,cluster_3",
"4,3,cluster_3",
]


def test_metaclusters_can_have_displaynames(simple_metaclusterdata: MetaClusterData):
def test_metaclusters_can_have_displaynames(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
assert simple_metaclusterdata.metacluster_displaynames == ['1', '2', '3']
assert simple_metaclusterdata_rename.metacluster_displaynames == \
['cluster_1', 'cluster_2', 'cluster_3']


def test_metaclusters_can_change_displaynames(simple_metaclusterdata: MetaClusterData):
def test_metaclusters_can_change_displaynames(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
simple_metaclusterdata.change_displayname(1, 'y2k')
assert simple_metaclusterdata.metacluster_displaynames == ['y2k', '2', '3']

simple_metaclusterdata_rename.change_displayname(1, 'y2k')
assert simple_metaclusterdata_rename.metacluster_displaynames == \
['y2k', 'cluster_2', 'cluster_3']


def test_can_find_which_metacluster_a_cluster_belongs_to(simple_metaclusterdata: MetaClusterData):
def test_can_match_cluster_to_metacluster(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
assert simple_metaclusterdata.which_metacluster(4) == 3

assert simple_metaclusterdata_rename.which_metacluster(4) == 3


def test_can_average_clusters_by_metacluster(simple_metaclusterdata: MetaClusterData):
def test_can_average_clusters_by_metacluster(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
simple_metaclusterdata.remap(4, 3)
clusters_data = np.array([
(0.1, 0.2, 0.1),
(0.1, 0.1, 0.3),
((0.5 * 50 + 0.7 * 77) / (50 + 77),
(0.1 * 50 + 0.2 * 77) / (50 + 77),
(0.1 * 50 + 0.1 * 77) / (50 + 77)),
])
])
np.testing.assert_equal(simple_metaclusterdata.metaclusters.values, clusters_data)

simple_metaclusterdata_rename.remap(4, 3)
clusters_data = np.array([
(0.1, 0.2, 0.1),
(0.1, 0.1, 0.3),
((0.5 * 50 + 0.7 * 77) / (50 + 77),
(0.1 * 50 + 0.2 * 77) / (50 + 77),
(0.1 * 50 + 0.1 * 77) / (50 + 77)),
])
np.testing.assert_equal(simple_metaclusterdata_rename.metaclusters.values, clusters_data)

def test_can_reorder_markers(simple_metaclusterdata: MetaClusterData):

def test_can_reorder_markers(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
simple_metaclusterdata.set_marker_order([0, 2, 1])
assert list(simple_metaclusterdata.marker_names) == ['CD163', 'CD31', 'CD206']

simple_metaclusterdata_rename.set_marker_order([0, 2, 1])
assert list(simple_metaclusterdata_rename.marker_names) == ['CD163', 'CD31', 'CD206']


def test_marker_orders_match(simple_metaclusterdata: MetaClusterData):
def test_marker_orders_match(simple_metaclusterdata: MetaClusterData,
simple_metaclusterdata_rename: MetaClusterData):
# access the properties first to reproduce a cache invalidation bug
_ = simple_metaclusterdata.clusters
_ = simple_metaclusterdata.metaclusters
Expand All @@ -73,3 +130,12 @@ def test_marker_orders_match(simple_metaclusterdata: MetaClusterData):
c_marks = list(simple_metaclusterdata.clusters.columns[0:3])
m_marks = list(simple_metaclusterdata.metaclusters.columns[0:3])
assert c_marks == m_marks

# access the properties first to reproduce a cache invalidation bug
_ = simple_metaclusterdata_rename.clusters
_ = simple_metaclusterdata_rename.metaclusters
_ = simple_metaclusterdata_rename.clusters_with_metaclusters
simple_metaclusterdata_rename.set_marker_order([0, 2, 1])
c_marks = list(simple_metaclusterdata_rename.clusters.columns[0:3])
m_marks = list(simple_metaclusterdata_rename.metaclusters.columns[0:3])
assert c_marks == m_marks
Loading

0 comments on commit ad9fd61

Please sign in to comment.