Skip to content

Commit

Permalink
Update deepcell processing pipeline (#261)
Browse files Browse the repository at this point in the history
* Updated processing pipeline

* notebook description

* docstring description

* six>=1.15.0

* updated notebook (last commit was wrong)

* right parameters for load_imgs_from_dir

* fix for channels_first

* PEP 8

* PEP 8

* PEP 8

* forgot to commit changes in notebooks..

* notebooks

* notebooks

* PYCODESTYLE fix

* fixes

* validate xr_channel_names length

* fixed typo + added test

* remove old notebooks + refactoring

Co-authored-by: Noah Greenwald <noahfgreenwald@gmail.com>
  • Loading branch information
omerbt and ngreenwald committed Oct 13, 2020
1 parent 526ea6d commit 91610fa
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 563 deletions.
20 changes: 11 additions & 9 deletions ark/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import math

import skimage.io as io
import numpy as np
import xarray as xr
Expand All @@ -9,8 +8,9 @@
# TODO: Add metadata for channel name (eliminates need for fixed-order channels)
def generate_deepcell_input(data_xr, data_dir, nuc_channels, mem_channels):
"""Saves nuclear and membrane channels into deepcell input format.
Either nuc_channels or mem_channels should be specified.
Writes summed channel images out as multitiffs
Writes summed channel images out as multitiffs (channels first)
Args:
data_xr (xr.DataArray):
Expand All @@ -21,19 +21,21 @@ def generate_deepcell_input(data_xr, data_dir, nuc_channels, mem_channels):
nuclear channels to be summed over
mem_channels (list):
membrane channels to be summed over
Raises:
ValueError:
Raised if nuc_channels and mem_channels are both None or empty
"""
if not nuc_channels and not mem_channels:
raise ValueError('Either nuc_channels or mem_channels should be non-empty.')

for fov in data_xr.fovs.values:
out = np.zeros((data_xr.shape[1], data_xr.shape[2], 2), dtype=data_xr.dtype)
out = np.zeros((2, data_xr.shape[1], data_xr.shape[2]), dtype=data_xr.dtype)

# sum over channels and add to output
if nuc_channels:
out[:, :, 0] = \
np.sum(data_xr.loc[fov, :, :, nuc_channels].values,
axis=2)
out[0] = np.sum(data_xr.loc[fov, :, :, nuc_channels].values, axis=2)
if mem_channels:
out[:, :, 1] = \
np.sum(data_xr.loc[fov, :, :, mem_channels].values,
axis=2)
out[1] = np.sum(data_xr.loc[fov, :, :, mem_channels].values, axis=2)

save_path = os.path.join(data_dir, f'{fov}.tif')
io.imsave(save_path, out, plugin='tifffile')
Expand Down
26 changes: 15 additions & 11 deletions ark/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tempfile
from shutil import rmtree
import pytest

from ark.utils import data_utils, test_utils
import skimage.io as io
Expand All @@ -12,8 +13,8 @@ def test_generate_deepcell_input():
fovs = ['fov1', 'fov2']
chans = ['nuc1', 'nuc2', 'mem1', 'mem2']

data_xr = test_utils.make_images_xarray(tif_data=None, fov_ids=fovs, channel_names=chans,
dtype='int16')
data_xr = test_utils.make_images_xarray(tif_data=None, fov_ids=fovs,
channel_names=chans, dtype='int16')

fov1path = os.path.join(temp_dir, 'fov1.tif')
fov2path = os.path.join(temp_dir, 'fov2.tif')
Expand All @@ -23,8 +24,8 @@ def test_generate_deepcell_input():
mems = ['mem2']

data_utils.generate_deepcell_input(data_xr, temp_dir, nucs, mems)
fov1 = io.imread(fov1path)
fov2 = io.imread(fov2path)
fov1 = np.moveaxis(io.imread(fov1path), 0, -1)
fov2 = np.moveaxis(io.imread(fov2path), 0, -1)

assert np.array_equal(fov1, data_xr.loc['fov1', :, :, ['nuc2', 'mem2']].values)
assert np.array_equal(fov2, data_xr.loc['fov2', :, :, ['nuc2', 'mem2']].values)
Expand All @@ -34,8 +35,8 @@ def test_generate_deepcell_input():
mems = ['mem1', 'mem2']

data_utils.generate_deepcell_input(data_xr, temp_dir, nucs, mems)
fov1 = io.imread(fov1path)
fov2 = io.imread(fov2path)
fov1 = np.moveaxis(io.imread(fov1path), 0, -1)
fov2 = np.moveaxis(io.imread(fov2path), 0, -1)

nuc_sums = data_xr.loc[:, :, :, nucs].sum(dim='channels').values
mem_sums = data_xr.loc[:, :, :, mems].sum(dim='channels').values
Expand All @@ -49,8 +50,8 @@ def test_generate_deepcell_input():
nucs = None

data_utils.generate_deepcell_input(data_xr, temp_dir, nucs, mems)
fov1 = io.imread(fov1path)
fov2 = io.imread(fov2path)
fov1 = np.moveaxis(io.imread(fov1path), 0, -1)
fov2 = np.moveaxis(io.imread(fov2path), 0, -1)

assert np.all(fov1[:, :, 0] == 0)
assert np.array_equal(fov1[:, :, 1], mem_sums[0, :, :])
Expand All @@ -62,14 +63,18 @@ def test_generate_deepcell_input():
mems = None

data_utils.generate_deepcell_input(data_xr, temp_dir, nucs, mems)
fov1 = io.imread(fov1path)
fov2 = io.imread(fov2path)
fov1 = np.moveaxis(io.imread(fov1path), 0, -1)
fov2 = np.moveaxis(io.imread(fov2path), 0, -1)

assert np.all(fov1[:, :, 1] == 0)
assert np.array_equal(fov1[:, :, 0], data_xr.loc['fov1', :, :, 'nuc2'].values)
assert np.all(fov2[:, :, 1] == 0)
assert np.array_equal(fov2[:, :, 0], data_xr.loc['fov2', :, :, 'nuc2'].values)

# test nuc None and mem None
with pytest.raises(ValueError):
data_utils.generate_deepcell_input(data_xr, temp_dir, None, None)


def test_stitch_images():
fovs, chans = test_utils.gen_fov_chan_names(num_fovs=40, num_chans=4)
Expand All @@ -84,7 +89,6 @@ def test_stitch_images():

def test_split_img_stack():
with tempfile.TemporaryDirectory() as temp_dir:

fovs = ['stack_sample']
_, chans, names = test_utils.gen_fov_chan_names(num_fovs=0, num_chans=10, return_imgs=True)

Expand Down
189 changes: 82 additions & 107 deletions ark/utils/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,78 +77,6 @@ def load_imgs_from_mibitiff(data_dir, mibitiff_files=None, channels=None, delimi
return img_xr


def load_imgs_from_multitiff(data_dir, multitiff_files=None, channels=None, delimiter=None,
dtype='int16'):
"""Load images from a series of multi-channel tiff files.
This function takes a set of multi-channel tiff files and loads the images into an xarray.
The type used to store the images will be the same as that of the images stored in the
multi-channel tiff files.
This function differs from `load_imgs_from_mibitiff` in that proprietary metadata is unneeded,
which is usefull loading in more general multi-channel tiff files.
Args:
data_dir (str):
directory containing multitiffs
multitiff_files (list):
list of multi-channel tiff files to load. If None, all multitiff files in data_dir
are loaded.
channels (list):
optional list of channels to load. Unlike MIBItiff, this must be given as a numeric
list of indices, since there is no metadata containing channel names.
delimiter (str):
optional delimiter-character/string which separate fov names from the rest of the file
name. Default is None.
dtype (str/type):
optional specifier of image type. Overwritten with warning for float images
Returns:
xarray.DataArray:
xarray with shape [fovs, x_dim, y_dim, channels]
"""

if not multitiff_files:
multitiff_files = iou.list_files(data_dir, substrs=['.tif'])

# extract fov names w/ delimiter agnosticism
fovs = iou.extract_delimited_names(multitiff_files, delimiter=delimiter)

multitiff_files = [os.path.join(data_dir, mt_file)
for mt_file in multitiff_files]

test_img = io.imread(multitiff_files[0], plugin='tifffile')

# check to make sure that float dtype was supplied if image data is float
data_dtype = test_img.dtype
if np.issubdtype(data_dtype, np.floating):
if not np.issubdtype(dtype, np.floating):
warnings.warn(f"The supplied non-float dtype {dtype} was overwritten to {data_dtype}, "
f"because the loaded images are floats")
dtype = data_dtype

# extract data
img_data = []
for multitiff_file in multitiff_files:
img_data.append(io.imread(multitiff_file, plugin='tifffile'))
img_data = np.stack(img_data, axis=0)
img_data = img_data.astype(dtype)

if channels:
img_data = img_data[:, :, :, channels]

# create xarray with image data
img_xr = xr.DataArray(img_data,
coords=[fovs, range(img_data.shape[1]),
range(img_data.shape[2]),
channels if channels else range(img_data.shape[3])],
dims=["fovs", "rows", "cols", "channels"])

img_xr = img_xr.sortby('fovs').sortby('channels')

return img_xr


def load_imgs_from_tree(data_dir, img_sub_folder=None, fovs=None, channels=None,
dtype="int16", variable_sizes=False):
"""Takes a set of imgs from a directory structure and loads them into an xarray.
Expand Down Expand Up @@ -249,85 +177,132 @@ def load_imgs_from_tree(data_dir, img_sub_folder=None, fovs=None, channels=None,
return img_xr


def load_imgs_from_dir(data_dir, imgdim_name='compartments', image_name='img_data', delimiter=None,
dtype="int16", variable_sizes=False, force_ints=False):
"""Takes a set of images from a directory and loads them into an xarray based on filename
prefixes.
def load_imgs_from_dir(data_dir, files=None, delimiter=None, xr_dim_name='compartments',
xr_channel_names=None, dtype="int16", force_ints=False,
channel_indices=None):
"""Takes a set of images (possibly multitiffs) from a directory and loads them
into an xarray.
Args:
data_dir (str):
directory containing images
imgdim_name (str):
sets the name of the last dimension of the output xarray
image_name (str):
sets name of the last coordinate in the output xarray
files (list):
list of files (e.g. ['fov1.tif'. 'fov2.tif'] to load.
If None, all (.tif, .jpg, .png) files in data_dir are loaded.
delimiter (str):
character used to determine the file-prefix containging the fov name. Default is None.
character used to determine the file-prefix containging the fov name.
Default is None.
xr_dim_name (str):
sets the name of the last dimension of the output xarray.
Default: 'compartments'
xr_channel_names (list):
sets the name of the coordinates in the last dimension of the output xarray.
dtype (str/type):
data type to load/store
variable_sizes (bool):
Dynamically determine image sizes and pad smaller imgs w/ zeros
force_ints (bool):
If dtype is an integer, forcefully convert float imgs to ints. Default is False.
channel_indices (list):
optional list of indices specifying which channels to load (by their indices).
if None or empty, the function loads all channels.
(Ignored if data is not multitiff).
Returns:
xarray.DataArray:
xarray with shape [fovs, x_dim, y_dim, 1]
xarray with shape [fovs, x_dim, y_dim, tifs]
Raises:
ValueError:
Raised in the following cases:
* data_dir is not a directory, <data_dir>/img is
not a file for some img in the input 'files' list, or no images are found.
* channels_indices are invalid according to the shape of the images.
* the provided dtype is too small to represent the data.
* The length of xr_channel_names (if provided) does not match the number
of channels in the input.
"""
if not os.path.isdir(data_dir):
raise ValueError(f"Invalid value for data_dir. {data_dir} is not a directory.")

imgs = iou.list_files(data_dir, substrs=['.tif', '.jpg', '.png'])

imgs.sort()
if files is None:
imgs = iou.list_files(data_dir, substrs=['.tif', '.jpg', '.png'])
else:
imgs = files
for img in imgs:
if not os.path.isfile(os.path.join(data_dir, img)):
raise ValueError(f"Invalid value for {img}. "
f"{os.path.join(data_dir, img)} is not a file.")

if len(imgs) == 0:
raise ValueError(f"No images found in directory, {data_dir}")

test_img = io.imread(os.path.join(data_dir, imgs[0]))

# check data format
multitiff = test_img.ndim == 3
channels_first = multitiff and test_img.shape[0] == min(test_img.shape)

# check to make sure all channel indices are valid given the shape of the image
n_channels = 1
if multitiff:
n_channels = test_img.shape[0] if channels_first else test_img.shape[2]
if channel_indices:
if max(channel_indices) >= n_channels or min(channel_indices) < 0:
raise ValueError(f'Invalid value for channel_indices. Indices should be'
f' between 0-{n_channels-1} for the given data.')
# make sure channels_names has the same length as the number of channels in the image
if xr_channel_names and n_channels != len(xr_channel_names):
raise ValueError(f'Invalid value for xr_channel_names. xr_channel_names'
f' length should be {n_channels}, as the number of channels'
f' in the input data.')

# check to make sure that float dtype was supplied if image data is float
data_dtype = test_img.dtype
if force_ints and np.issubdtype(dtype, np.integer):
if not np.issubdtype(data_dtype, np.integer):
warnings.warn(f"The the loaded {data_dtype} images were forcefully "
warnings.warn(f"The loaded {data_dtype} images were forcefully "
f"overwritten with the supplied integer dtype {dtype}")
elif np.issubdtype(data_dtype, np.floating):
if not np.issubdtype(dtype, np.floating):
warnings.warn(f"The supplied non-float dtype {dtype} was overwritten to {data_dtype}, "
f"because the loaded images are floats")
dtype = data_dtype

if variable_sizes:
img_data = np.zeros((len(imgs), 1024, 1024, 1), dtype=dtype)
else:
img_data = np.zeros((len(imgs), test_img.shape[0], test_img.shape[1], 1),
dtype=dtype)
# extract data
img_data = []
for img in imgs:
v = io.imread(os.path.join(data_dir, img))
if not multitiff:
v = np.expand_dims(v, axis=2)
elif channels_first:
# covert channels_first to be channels_last
v = np.moveaxis(v, 0, -1)
img_data.append(v)
img_data = np.stack(img_data, axis=0)

img_data = img_data.astype(dtype)

for img in range(len(imgs)):
if variable_sizes:
temp_img = io.imread(os.path.join(data_dir, imgs[img]))
img_data[img, :temp_img.shape[0], :temp_img.shape[1], 0] = temp_img.astype(dtype)
else:
img_data[img, :, :, 0] = io.imread(os.path.join(data_dir, imgs[img])).astype(dtype)
if channel_indices and multitiff:
img_data = img_data[:, :, :, channel_indices]

# check to make sure that dtype wasn't too small for range of data
if np.min(img_data) < 0:
raise ValueError("Integer overflow from loading TIF image, try a larger dtype")

if variable_sizes:
row_coords, col_coords = range(1024), range(1024)
if channels_first:
row_coords, col_coords = range(test_img.shape[1]), range(test_img.shape[2])
else:
row_coords, col_coords = range(test_img.shape[0]), range(test_img.shape[1])

# get fov name from imgs
fovs = iou.extract_delimited_names(imgs, delimiter=delimiter)

img_xr = xr.DataArray(img_data.astype(dtype),
coords=[fovs, row_coords, col_coords, [image_name]],
dims=["fovs", "rows", "cols",
imgdim_name])
# create xarray with image data
img_xr = xr.DataArray(img_data,
coords=[fovs, row_coords, col_coords,
xr_channel_names if xr_channel_names
else range(img_data.shape[3])],
dims=["fovs", "rows", "cols", xr_dim_name])

# sort for deterministic fov names
img_xr = img_xr.sortby('fovs')
img_xr = img_xr.sortby('fovs').sortby(xr_dim_name)

return img_xr
Loading

0 comments on commit 91610fa

Please sign in to comment.