Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heatmap visualizations in for SOM results #410

Merged
merged 15 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions ark/analysis/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from ark.utils import misc_utils


def draw_boxplot(cell_data, col_name, col_split=None, split_vals=None, dpi=None, save_dir=None):
def draw_boxplot(cell_data, col_name, col_split=None,
split_vals=None, dpi=None, save_dir=None, save_file=None):
"""Draws a boxplot for a given column, optionally with help from a split column

Args:
Expand All @@ -23,6 +24,9 @@ def draw_boxplot(cell_data, col_name, col_split=None, split_vals=None, dpi=None,
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""

# the col_name must be valid
Expand Down Expand Up @@ -61,11 +65,11 @@ def draw_boxplot(cell_data, col_name, col_split=None, split_vals=None, dpi=None,

# save visualization to a directory if specified
if save_dir is not None:
misc_utils.save_figure(save_dir, "boxplot_viz.png", dpi=dpi)
misc_utils.save_figure(save_dir, save_file, dpi=dpi)


def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None,
overlay_values=False, colormap="vlag", save_dir=None):
overlay_values=False, colormap="vlag", save_dir=None, save_file=None):
"""Plots the z scores between all phenotypes as a clustermap.

Args:
Expand All @@ -85,6 +89,9 @@ def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None,
color scheme for visualization
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""

# Replace the NA's and inf values with 0s
Expand All @@ -101,7 +108,7 @@ def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None,
sns.clustermap(data_df, cmap=colormap, center=center_val)

if save_dir is not None:
misc_utils.save_figure(save_dir, "z_score_viz.png", dpi=dpi)
misc_utils.save_figure(save_dir, save_file, dpi=dpi)


def get_sorted_data(cell_data, sort_by_first, sort_by_second, is_normalized=False):
Expand Down
4 changes: 2 additions & 2 deletions ark/analysis/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_draw_heatmap():

# test that with save_dir, we do save
visualize.draw_heatmap(z, pheno_titles, pheno_titles,
save_dir=temp_dir)
save_dir=temp_dir, save_file="z_score_viz.png")
assert os.path.exists(os.path.join(temp_dir, "z_score_viz.png"))


Expand Down Expand Up @@ -66,7 +66,7 @@ def test_draw_boxplot():
with tempfile.TemporaryDirectory() as temp_dir:
visualize.draw_boxplot(cell_data=random_data, col_name="A",
col_split=settings.PATIENT_ID, split_vals=[1, 2],
save_dir=temp_dir)
save_dir=temp_dir, save_file="boxplot_viz.png")
assert os.path.exists(os.path.join(temp_dir, "boxplot_viz.png"))


Expand Down
6 changes: 5 additions & 1 deletion ark/phenotyping/consensus_cluster.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ clusterAvgPath <- args[6]
# get consensus clustered write path
pixelMatConsensus <- args[7]

# set the random seed
seed <- strtoi(args[8])
set.seed(seed)

# read cluster averaged data
print("Reading cluster averaged data")
clusterAvgs <- arrow::read_feather(clusterAvgPath)
Expand All @@ -46,7 +50,7 @@ clusterAvgsScale <- pmin(scale(clusterAvgs[markers]), cap)

# run the consensus clustering
print("Running consensus clustering")
consensusClusterResults <- ConsensusClusterPlus(t(clusterAvgsScale), maxK=maxK)
consensusClusterResults <- ConsensusClusterPlus(t(clusterAvgsScale), maxK=maxK, seed=seed)
hClust <- consensusClusterResults[[maxK]]$consensusClass
names(hClust) <- clusterAvgs$cluster

Expand Down
6 changes: 5 additions & 1 deletion ark/phenotyping/create_som_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ normValsPath <- args[5]
# get the weights write path
pixelWeightsPath <- args[6]

# set the random seed
seed <- strtoi(args[7])
set.seed(seed)

# read the subsetted pixel mat data for training
print("Reading the subsetted pixel matrix data for SOM training")
pixelSubsetData <- NULL
Expand Down Expand Up @@ -76,7 +80,7 @@ arrow::write_feather(as.data.table(normVals), normValsPath)

# run the SOM training step
print("Run the SOM training")
somResults <- SOM(data=pixelSubsetData, rlen=numPasses)
somResults <- SOM(data=pixelSubsetData, rlen=numPasses, alpha=c(0.05, 0.01))

# write the weights to HDF5
print("Save trained weights")
Expand Down
127 changes: 94 additions & 33 deletions ark/phenotyping/som_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,31 @@
from skimage.io import imread

import ark.settings as settings
from ark.analysis import visualize
from ark.utils import io_utils
from ark.utils import load_utils
from ark.utils import misc_utils


def compute_cluster_avg(fovs, channels, base_dir,
cluster_dir='pixel_mat_clustered',
cluster_avg_name='pixel_cluster_avg.feather'):
"""Averages channel values across all fovs in pixel_mat_clustered
def compute_cluster_avg(fovs, channels, base_dir, cluster_col,
cluster_dir='pixel_mat_clustered'):
"""For each fov, compute the average channel values across each SOM cluster

Args:
fovs (list):
The list of fovs to subset on
channels (list):
The list of channels to subset on
base_dir (str):
Name of the directory to save the pixel files to
The path to the data directories
cluster_col (str):
Name of the column to group by
cluster_dir (str):
Name of the file containing the pixel data with cluster labels
cluster_avg_name (str):
Name of file to save the averaged results to

Returns:
pandas.DataFrame:
Contains the average channel values for each SOM cluster
"""

# define the cluster averages DataFrame
Expand All @@ -42,8 +46,8 @@ def compute_cluster_avg(fovs, channels, base_dir,
)

# aggregate the sums and counts
sum_by_cluster = fov_pixel_data.groupby('cluster')[channels].sum()
count_by_cluster = fov_pixel_data.groupby('cluster')[channels].size().to_frame('count')
sum_by_cluster = fov_pixel_data.groupby(cluster_col)[channels].sum()
count_by_cluster = fov_pixel_data.groupby(cluster_col)[channels].size().to_frame('count')

# concat the results together
agg_results = pd.merge(
Expand All @@ -52,18 +56,15 @@ def compute_cluster_avg(fovs, channels, base_dir,
cluster_avgs = pd.concat([cluster_avgs, agg_results])

# sum the counts and the channel sums
sum_count_totals = cluster_avgs.groupby('cluster')[channels + ['count']].sum().reset_index()
sum_count_totals = cluster_avgs.groupby(cluster_col)[channels + ['count']].sum().reset_index()

# now compute the means using the count column
sum_count_totals[channels] = sum_count_totals[channels].div(sum_count_totals['count'], axis=0)

# drop the count column
sum_count_totals = sum_count_totals.drop('count', axis=1)

# save the DataFrame
feather.write_dataframe(sum_count_totals,
os.path.join(base_dir, cluster_avg_name),
compression='uncompressed')
return sum_count_totals


def create_fov_pixel_data(fov, channels, img_data, seg_labels,
Expand All @@ -90,10 +91,10 @@ def create_fov_pixel_data(fov, channels, img_data, seg_labels,

Returns:
tuple:
A tuple containing two pd.Dataframes:
Contains the following:

- The full preprocessed pixel dataset for a fov
- The subsetted pixel dataset for a fov
- pandas.DataFrame: Gaussian blurred and channel sum normalized pixel data for a fov
- pandas.DataFrame: subset of the preprocessed pixel dataset for a fov
"""

# for each marker, compute the Gaussian blur
Expand Down Expand Up @@ -129,21 +130,19 @@ def create_fov_pixel_data(fov, channels, img_data, seg_labels,
return pixel_mat, pixel_mat_subset


def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,
def create_pixel_matrix(fovs, base_dir, tiff_dir, seg_dir,
pre_dir='pixel_mat_preprocessed',
sub_dir='pixel_mat_subsetted', is_mibitiff=False,
blur_factor=2, subset_proportion=0.1, seed=42):
"""Preprocess the images for FlowSOM clustering and creates a pixel-level matrix
"""For each fov, add a Gaussian blur to each channel and normalize channel sums for each pixel

Saves preprocessed data to pre_dir and subsetted data to sub_dir
Saves data to pre_dir and subsetted data to sub_dir

Args:
fovs (list):
List of fovs to subset over
channels (list):
List of channels to subset over
base_dir (str):
Name of the directory to save the pixel files to
The path to the data directories
tiff_dir (str):
Name of the directory containing the tiff files
seg_dir (str):
Expand Down Expand Up @@ -186,11 +185,11 @@ def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,
# load img_xr from MIBITiff or directory with the fov
if is_mibitiff:
img_xr = load_utils.load_imgs_from_mibitiff(
tiff_dir, mibitiff_files=[fov], channels=channels, dtype="int16"
tiff_dir, mibitiff_files=[fov], dtype="int16"
)
else:
img_xr = load_utils.load_imgs_from_tree(
tiff_dir, fovs=[fov], channels=channels, dtype="int16"
tiff_dir, fovs=[fov], dtype="int16"
)

# load segmentation labels in for fov
Expand All @@ -201,7 +200,7 @@ def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,

# create the full and subsetted fov matrices
pixel_mat, pixel_mat_subset = create_fov_pixel_data(
fov=fov, channels=channels, img_data=img_data, seg_labels=seg_labels,
fov=fov, channels=img_xr.channels.values, img_data=img_data, seg_labels=seg_labels,
blur_factor=blur_factor, subset_proportion=subset_proportion, seed=seed
)

Expand All @@ -222,7 +221,7 @@ def create_pixel_matrix(fovs, channels, base_dir, tiff_dir, seg_dir,

def train_som(fovs, channels, base_dir,
sub_dir='pixel_mat_subsetted', norm_vals_name='norm_vals.feather',
weights_name='weights.feather', num_passes=1):
weights_name='weights.feather', num_passes=1, seed=42):
"""Run the SOM training on the subsetted pixel data.

Saves weights to base_dir/weights_name.
Expand All @@ -233,7 +232,7 @@ def train_som(fovs, channels, base_dir,
channels (list):
The list of markers to subset on
base_dir (str):
The path to the data directory
The path to the data directories
sub_dir (str):
The name of the subsetted data directory
norm_vals_name (str):
Expand All @@ -242,6 +241,8 @@ def train_som(fovs, channels, base_dir,
The name of the weights file
num_passes (int):
The number of training passes to make through the dataset
seed (int):
The random seed to set for training
"""

# define the paths to the data
Expand All @@ -266,7 +267,8 @@ def train_som(fovs, channels, base_dir,

# run the SOM training process
process_args = ['Rscript', '/create_som_matrix.R', ','.join(fovs), ','.join(channels),
str(num_passes), subsetted_path, norm_vals_path, weights_path]
str(num_passes), subsetted_path, norm_vals_path, weights_path, str(seed)]

process = subprocess.Popen(process_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

# continuously poll the process for output/error to display in Jupyter notebook
Expand Down Expand Up @@ -365,16 +367,18 @@ def cluster_pixels(fovs, base_dir, pre_dir='pixel_mat_preprocessed',
def consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
cluster_dir='pixel_mat_clustered',
cluster_avg_name='pixel_cluster_avg.feather',
consensus_dir='pixel_mat_consensus'):
consensus_dir='pixel_mat_consensus', seed=42):
"""Run consensus clustering algorithm on summed data across channels

Saves data with consensus cluster labels to consensus_dir

Args:
fovs (list):
The list of fovs to subset on
channels (list):
The list of channels to subset on
base_dir (str):
Name of the directory to save the pixel files to
The path to the data directory
max_k (int):
The number of consensus clusters
cap (int):
Expand All @@ -385,6 +389,8 @@ def consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
Name of file to save the channel-averaged results to
consensus_dir (str):
Name of directory to save the consensus clustered results
seed (int):
The random seed to set for consensus clustering
"""

clustered_path = os.path.join(base_dir, cluster_dir)
Expand All @@ -396,15 +402,22 @@ def consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
(base_dir, clustered_path))

# compute and write the averaged cluster results
compute_cluster_avg(fovs, channels, base_dir, cluster_dir)
cluster_avgs = compute_cluster_avg(fovs, channels, base_dir,
cluster_col='cluster', cluster_dir=cluster_dir)

# save the DataFrame
feather.write_dataframe(cluster_avgs,
os.path.join(base_dir, cluster_avg_name),
compression='uncompressed')

# make consensus_dir if it doesn't exist
if not os.path.exists(consensus_path):
os.mkdir(consensus_path)

# run the consensus clustering process
process_args = ['Rscript', '/consensus_cluster.R', ','.join(fovs), ','.join(channels),
str(max_k), str(cap), clustered_path, cluster_avg_path, consensus_path]
str(max_k), str(cap), clustered_path, cluster_avg_path, consensus_path,
str(seed)]

process = subprocess.Popen(process_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

Expand All @@ -418,3 +431,51 @@ def consensus_cluster(fovs, channels, base_dir, max_k=20, cap=3,
break
if output:
print(output.strip())


def visualize_cluster_data(fovs, channels, base_dir, cluster_dir, cluster_col='cluster',
dpi=None, center_val=None, overlay_values=False, colormap="vlag",
save_dir=None, save_file=None):
"""Visualize the average cluster results for each cluster

Args:
fovs (list):
The list of fovs to subset on
channels (list):
The list of channels to subset on
base_dir (str):
The path to the data directories
cluster_dir (str):
Name of the directory containing the data to visualize
cluster_col (str):
Name of the column to group values by
dpi (float):
The resolution of the image to save, ignored if save_dir is None
center_val (float):
value at which to center the heatmap
overlay_values (bool):
whether to overlay the raw heatmap values on top
colormap (str):
color scheme for visualization
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""

# average the channel values across the cluster column
cluster_avgs = compute_cluster_avg(fovs, channels, base_dir, cluster_col, cluster_dir)

# convert cluster column to integer type
cluster_avgs[cluster_col] = cluster_avgs[cluster_col].astype(int)

# sort cluster col in ascending order
cluster_avgs = cluster_avgs.sort_values(by=cluster_col)

# draw the heatmap
visualize.draw_heatmap(
data=cluster_avgs[channels].values, x_labels=cluster_avgs[cluster_col], y_labels=channels,
dpi=dpi, center_val=center_val, overlay_values=overlay_values,
colormap=colormap, save_dir=save_dir, save_file=save_file
)
Loading