Skip to content

Commit

Permalink
Merge branch 'master' into tiff_ext
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-l-kong committed Oct 27, 2020
2 parents ea341f8 + 05b956e commit 354bb0d
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 98 deletions.
49 changes: 25 additions & 24 deletions ark/segmentation/marker_quantification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from ark.segmentation import signal_extraction


def compute_marker_counts(input_images, segmentation_masks, nuclear_counts=False,
def compute_marker_counts(input_images, segmentation_labels, nuclear_counts=False,
regionprops_features=None, split_large_nuclei=False):
"""Extract single cell protein expression data from channel TIFs for a single fov
Args:
input_images (xarray.DataArray):
rows x columns x channels matrix of imaging data
segmentation_masks (numpy.ndarray):
segmentation_labels (numpy.ndarray):
rows x columns x compartment matrix of masks
nuclear_counts (bool):
boolean flag to determine whether nuclear counts are returned
Expand Down Expand Up @@ -48,37 +48,38 @@ def compute_marker_counts(input_images, segmentation_masks, nuclear_counts=False
regionprops_names.remove('centroid')
regionprops_names += ['centroid-0', 'centroid-1']

unique_cell_ids = np.unique(segmentation_masks[..., 0].values)
unique_cell_ids = np.unique(segmentation_labels[..., 0].values)
unique_cell_ids = unique_cell_ids[np.nonzero(unique_cell_ids)]

# create labels for array holding channel counts and morphology metrics
feature_names = np.concatenate((np.array('cell_size'), input_images.channels,
regionprops_names), axis=None)

# create np.array to hold compartment x cell x feature info
marker_counts_array = np.zeros((len(segmentation_masks.compartments), len(unique_cell_ids),
marker_counts_array = np.zeros((len(segmentation_labels.compartments), len(unique_cell_ids),
len(feature_names)))

marker_counts = xr.DataArray(copy.copy(marker_counts_array),
coords=[segmentation_masks.compartments,
coords=[segmentation_labels.compartments,
unique_cell_ids.astype('int'),
feature_names],
dims=['compartments', 'cell_id', 'features'])

# get regionprops for each cell
cell_props = pd.DataFrame(regionprops_table(segmentation_masks.loc[:, :, 'whole_cell'].values,
cell_props = pd.DataFrame(regionprops_table(segmentation_labels.loc[:, :, 'whole_cell'].values,
properties=regionprops_features))

if nuclear_counts:
nuc_mask = segmentation_masks.loc[:, :, 'nuclear'].values
nuc_labels = segmentation_labels.loc[:, :, 'nuclear'].values

if split_large_nuclei:
cell_mask = segmentation_masks.loc[:, :, 'whole_cell'].values
nuc_mask = segmentation_utils.split_large_nuclei(cell_segmentation_mask=cell_mask,
nuc_segmentation_mask=nuc_mask,
cell_ids=unique_cell_ids)
cell_labels = segmentation_labels.loc[:, :, 'whole_cell'].values
nuc_labels = \
segmentation_utils.split_large_nuclei(cell_segmentation_labels=cell_labels,
nuc_segmentation_labels=nuc_labels,
cell_ids=unique_cell_ids)

nuc_props = pd.DataFrame(regionprops_table(nuc_mask, properties=regionprops_features))
nuc_props = pd.DataFrame(regionprops_table(nuc_labels, properties=regionprops_features))

# TODO: There's some repeated code here, maybe worth refactoring? Maybe not
# loop through each cell in mask
Expand All @@ -103,8 +104,8 @@ def compute_marker_counts(input_images, segmentation_masks, nuclear_counts=False

if nuclear_counts:
# get id of corresponding nucleus
nuc_id = segmentation_utils.find_nuclear_mask_id(nuc_segmentation_mask=nuc_mask,
cell_coords=cell_coords)
nuc_id = segmentation_utils.find_nuclear_label_id(nuc_segmentation_labels=nuc_labels,
cell_coords=cell_coords)

if nuc_id is not None:
# get coordinates of corresponding nucleus
Expand Down Expand Up @@ -225,8 +226,8 @@ def create_marker_count_matrices(segmentation_labels, image_data, nuclear_counts
return normalized_data, arcsinh_data


def generate_cell_data(segmentation_labels, tiff_dir, img_sub_folder,
is_mibitiff=False, fovs=None, batch_size=5):
def generate_cell_table(segmentation_labels, tiff_dir, img_sub_folder,
is_mibitiff=False, fovs=None, batch_size=5):
"""
This function takes the segmented data and computes the expression matrices batch-wise
while also validating inputs
Expand Down Expand Up @@ -279,8 +280,8 @@ def generate_cell_data(segmentation_labels, tiff_dir, img_sub_folder,
cohort_len = len(fovs)

# create the final dfs to store the processed data
combined_cell_size_normalized_data = pd.DataFrame()
combined_arcsinh_transformed_data = pd.DataFrame()
combined_cell_table_size_normalized = pd.DataFrame()
combined_cell_table_arcsinh_transformed = pd.DataFrame()

# iterate over all the batches
for batch_names, batch_files in zip(
Expand All @@ -300,17 +301,17 @@ def generate_cell_data(segmentation_labels, tiff_dir, img_sub_folder,
current_labels = segmentation_labels.loc[batch_names, :, :, :]

# segment the imaging data
cell_size_normalized_data, arcsinh_transformed_data = create_marker_count_matrices(
cell_table_size_normalized, cell_table_arcsinh_transformed = create_marker_count_matrices(
segmentation_labels=current_labels,
image_data=image_data
)

# now append to the final dfs to return
combined_cell_size_normalized_data = combined_cell_size_normalized_data.append(
cell_size_normalized_data
combined_cell_table_size_normalized = combined_cell_table_size_normalized.append(
cell_table_size_normalized
)
combined_arcsinh_transformed_data = combined_arcsinh_transformed_data.append(
arcsinh_transformed_data
combined_cell_table_arcsinh_transformed = combined_cell_table_arcsinh_transformed.append(
cell_table_arcsinh_transformed
)

return combined_cell_size_normalized_data, combined_arcsinh_transformed_data
return combined_cell_table_size_normalized, combined_cell_table_arcsinh_transformed
57 changes: 26 additions & 31 deletions ark/segmentation/marker_quantification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
def test_compute_marker_counts_base():
cell_mask, channel_data = test_utils.create_test_extraction_data()

segmentation_masks = test_utils.make_labels_xarray(label_data=cell_mask,
compartment_names=['whole_cell'])
segmentation_labels = test_utils.make_labels_xarray(label_data=cell_mask,
compartment_names=['whole_cell'])

input_images = test_utils.make_images_xarray(channel_data)

# test utils output is 4D but tests require 3D
segmentation_masks, input_images = segmentation_masks[0], input_images[0]
segmentation_labels, input_images = segmentation_labels[0], input_images[0]

segmentation_output = \
marker_quantification.compute_marker_counts(input_images=input_images,
segmentation_masks=segmentation_masks)
segmentation_labels=segmentation_labels)

# check that channel 0 counts are same as cell size
assert np.array_equal(segmentation_output.loc['whole_cell', :, 'cell_size'].values,
Expand Down Expand Up @@ -58,23 +58,20 @@ def test_compute_marker_counts_base():
def test_compute_marker_counts_equal_masks():
cell_mask, channel_data = test_utils.create_test_extraction_data()

segmentation_masks = test_utils.make_labels_xarray(label_data=cell_mask,
compartment_names=['whole_cell'])

# test whole_cell and nuclear compartments with same data
segmentation_masks_equal = test_utils.make_labels_xarray(
segmentation_labels_equal = test_utils.make_labels_xarray(
label_data=np.concatenate((cell_mask, cell_mask), axis=-1),
compartment_names=['whole_cell', 'nuclear']
)

input_images = test_utils.make_images_xarray(channel_data)

# test utils output is 4D but tests require 3D
segmentation_masks_equal, input_images = segmentation_masks_equal[0], input_images[0]
segmentation_labels_equal, input_images = segmentation_labels_equal[0], input_images[0]

segmentation_output_equal = \
marker_quantification.compute_marker_counts(input_images=input_images,
segmentation_masks=segmentation_masks_equal,
segmentation_labels=segmentation_labels_equal,
nuclear_counts=True)

assert np.all(segmentation_output_equal[0].values == segmentation_output_equal[1].values)
Expand All @@ -89,20 +86,21 @@ def test_compute_marker_counts_nuc_whole_cell_diff():
nuc_mask = np.expand_dims(nuc_mask, axis=-1)

unequal_masks = np.concatenate((cell_mask, nuc_mask), axis=-1)
segmentation_masks_unequal = test_utils.make_labels_xarray(
segmentation_labels_unequal = test_utils.make_labels_xarray(
label_data=unequal_masks,
compartment_names=['whole_cell', 'nuclear']
)

input_images = test_utils.make_images_xarray(channel_data)

# test utils output is 4D but tests require 3D
segmentation_masks_unequal, input_images = segmentation_masks_unequal[0], input_images[0]
segmentation_labels_unequal, input_images = segmentation_labels_unequal[0], input_images[0]

segmentation_output_unequal = \
marker_quantification.compute_marker_counts(input_images=input_images,
segmentation_masks=segmentation_masks_unequal,
nuclear_counts=True)
marker_quantification.compute_marker_counts(
input_images=input_images,
segmentation_labels=segmentation_labels_unequal,
nuclear_counts=True)

# make sure nuclear segmentations are smaller
assert np.all(segmentation_output_unequal.loc['nuclear', :, 'cell_size'].values <
Expand Down Expand Up @@ -140,26 +138,23 @@ def test_compute_marker_counts_nuc_whole_cell_diff():
def test_compute_marker_counts_diff_props():
cell_mask, channel_data = test_utils.create_test_extraction_data()

segmentation_masks = test_utils.make_labels_xarray(label_data=cell_mask,
compartment_names=['whole_cell'])

# test whole_cell and nuclear compartments with same data
segmentation_masks_equal = test_utils.make_labels_xarray(
segmentation_labels_equal = test_utils.make_labels_xarray(
label_data=np.concatenate((cell_mask, cell_mask), axis=-1),
compartment_names=['whole_cell', 'nuclear']
)

input_images = test_utils.make_images_xarray(channel_data)

segmentation_masks_equal, input_images = segmentation_masks_equal[0], input_images[0]
segmentation_labels_equal, input_images = segmentation_labels_equal[0], input_images[0]

# different object properties can be supplied
regionprops_features = ['label', 'area']
excluded_defaults = ['eccentricity']

segmentation_output_specified = \
marker_quantification.compute_marker_counts(input_images=input_images,
segmentation_masks=segmentation_masks_equal,
segmentation_labels=segmentation_labels_equal,
nuclear_counts=True,
regionprops_features=regionprops_features)

Expand All @@ -170,7 +165,7 @@ def test_compute_marker_counts_diff_props():
# these nuclei are all smaller than the cells, so we should get same result
segmentation_output_specified_split = \
marker_quantification.compute_marker_counts(input_images=input_images,
segmentation_masks=segmentation_masks_equal,
segmentation_labels=segmentation_labels_equal,
nuclear_counts=True,
regionprops_features=regionprops_features,
split_large_nuclei=True)
Expand All @@ -191,14 +186,14 @@ def test_create_marker_count_matrices_base():
tif_data[0, :, :, :] = channel_data[0, :, :, :]
tif_data[1, 5:, 5:, :] = channel_data[0, :-5, :-5, :]

segmentation_masks = test_utils.make_labels_xarray(
segmentation_labels = test_utils.make_labels_xarray(
label_data=cell_masks,
compartment_names=['whole_cell']
)

channel_data = test_utils.make_images_xarray(tif_data)

normalized, _ = marker_quantification.create_marker_count_matrices(segmentation_masks,
normalized, _ = marker_quantification.create_marker_count_matrices(segmentation_labels,
channel_data)

assert normalized.shape[0] == 7
Expand Down Expand Up @@ -233,15 +228,15 @@ def test_create_marker_count_matrices_multiple_compartments():

unequal_masks = np.concatenate((cell_masks, nuc_masks), axis=-1)

segmentation_masks_unequal = test_utils.make_labels_xarray(
segmentation_labels_unequal = test_utils.make_labels_xarray(
label_data=unequal_masks,
compartment_names=['whole_cell', 'nuclear']
)

channel_data = test_utils.make_images_xarray(channel_datas)

normalized, arcsinh = marker_quantification.create_marker_count_matrices(
segmentation_masks_unequal,
segmentation_labels_unequal,
channel_data,
nuclear_counts=True
)
Expand Down Expand Up @@ -304,21 +299,21 @@ def test_generate_cell_data_tree_loading():

with pytest.raises(ValueError):
# specifying fovs not in the original segmentation mask
marker_quantification.generate_cell_data(
marker_quantification.generate_cell_table(
segmentation_labels=segmentation_masks.loc[["fov1"]], tiff_dir=tiff_dir,
img_sub_folder=img_sub_folder, is_mibitiff=False, fovs=["fov1", "fov2"],
batch_size=5)

# generate sample norm and arcsinh data for all fovs
norm_data, arcsinh_data = marker_quantification.generate_cell_data(
norm_data, arcsinh_data = marker_quantification.generate_cell_table(
segmentation_labels=segmentation_masks, tiff_dir=tiff_dir,
img_sub_folder=img_sub_folder, is_mibitiff=False, fovs=None, batch_size=2)

assert norm_data.shape[0] > 0 and norm_data.shape[1] > 0
assert arcsinh_data.shape[0] > 0 and arcsinh_data.shape[1] > 0

# generate sample norm and arcsinh data for a subset of fovs
norm_data, arcsinh_data = marker_quantification.generate_cell_data(
norm_data, arcsinh_data = marker_quantification.generate_cell_table(
segmentation_labels=segmentation_masks, tiff_dir=tiff_dir,
img_sub_folder=img_sub_folder, is_mibitiff=False, fovs=fovs_subset, batch_size=2)

Expand Down Expand Up @@ -359,15 +354,15 @@ def test_generate_cell_data_mibitiff_loading():
)

# generate sample norm and arcsinh data for all fovs
norm_data, arcsinh_data = marker_quantification.generate_cell_data(
norm_data, arcsinh_data = marker_quantification.generate_cell_table(
segmentation_labels=segmentation_masks, tiff_dir=tiff_dir,
img_sub_folder=tiff_dir, is_mibitiff=True, fovs=None, batch_size=2)

assert norm_data.shape[0] > 0 and norm_data.shape[1] > 0
assert arcsinh_data.shape[0] > 0 and arcsinh_data.shape[1] > 0

# generate sample norm and arcsinh data for a subset of fovs
norm_data, arcsinh_data = marker_quantification.generate_cell_data(
norm_data, arcsinh_data = marker_quantification.generate_cell_table(
segmentation_labels=segmentation_masks, tiff_dir=tiff_dir,
img_sub_folder=tiff_dir, is_mibitiff=True, fovs=fovs_subset, batch_size=2)

Expand Down
Loading

0 comments on commit 354bb0d

Please sign in to comment.