Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 26, 2020
1 parent 05254a9 commit 2581513
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 57 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ install:
- conda env create -f environment.yml
- source activate cellpose
- pip install .
- pip install matplotlib
- pip install coveralls
script:
- coverage run --source=cellpose setup.py test
Expand Down
8 changes: 6 additions & 2 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,12 @@ def main():
nimg = len(image_names)
labels = [io.imread(label_names[n]) for n in range(nimg)]
if flow_names is not None and not args.unet:
labels = [np.concatenate((labels[n][np.newaxis,:,:], io.imread(flow_names[n])), axis=0)
for n in range(nimg)]
for n in range(nimg):
flows = io.imread(flow_names[n])
if flows.shape[0]<4:
labels[n] = np.concatenate((labels[n][np.newaxis,:,:], flows), axis=0)
else:
labels[n] = flows

# testing data
test_images, test_labels, image_names_test = None, None, None
Expand Down
13 changes: 10 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import os, sys
import os, sys, shutil
from cellpose import utils
from urllib.request import urlopen
from urllib.parse import urlparse
Expand All @@ -22,7 +22,7 @@ def data_dir(image_names):
data_dir_3D = data_dir.joinpath('3D')
data_dir_3D.mkdir(exist_ok=True)

for image_name in image_names:
for i,image_name in enumerate(image_names):
url = 'http://www.cellpose.org/static/data/' + image_name
if '2D' in image_name:
cached_file = str(data_dir_2D.joinpath(image_name))
Expand All @@ -34,14 +34,21 @@ def data_dir(image_names):
utils.download_url_to_file(url, cached_file)

# check if mask downloaded (and clear potential previous test data)
if i<2:
train_dir = data_dir_2D.joinpath('train')
train_dir.mkdir(exist_ok=True)
shutil.copyfile(cached_file, train_dir.joinpath(image_name))
name = os.path.splitext(cached_file)[0]
mask_file = name + '_cp_masks' + ext
if os.path.exists(mask_file):
os.remove(mask_file)
cached_mask_files = [name + '_cyto_masks' + ext, name + '_nuclei_masks' + ext]
for cached_mask_file in cached_mask_files:
for c,cached_mask_file in enumerate(cached_mask_files):
url = 'http://www.cellpose.org/static/data/' + os.path.split(cached_mask_file)[-1]
if not os.path.exists(cached_mask_file):
print(cached_mask_file)
utils.download_url_to_file(url, cached_mask_file, progress=True)
if i<2 and c==0:
shutil.copyfile(cached_mask_file,
train_dir.joinpath(os.path.splitext(image_name)[0] + '_cyto_masks' + ext))
return data_dir
114 changes: 63 additions & 51 deletions tests/test_output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from cellpose import io, models, metrics, plot
from pathlib import Path
import subprocess, shlex
from subprocess import check_output, STDOUT
import os
import numpy as np
import matplotlib.pyplot as plt
try:
import matplotlib.pyplot as plt
MATPLOTLIB = True
except:
MATPLOTLIB = False

r_tol, a_tol = 1e-2, 1e-2

Expand All @@ -25,78 +30,90 @@ def clear_output(data_dir, image_names):
def test_class_2D(data_dir, image_names):
clear_output(data_dir, image_names)
img = io.imread(str(data_dir.joinpath('2D').joinpath('rgb_2D.png')))
model_types = ['cyto', 'nuclei']
chan = [2,1]
chan2 = [1,0]
model_types = ['nuclei']
chan = [1]
chan2 = [0]
for m,model_type in enumerate(model_types):
model = models.Cellpose(model_type=model_type)
masks, flows, _, _ = model.eval(img, diameter=0, channels=[chan[m],chan2[m]])
masks, flows, _, _ = model.eval(img, diameter=0, channels=[chan[m],chan2[m]], net_avg=False)
io.imsave(str(data_dir.joinpath('2D').joinpath('rgb_2D_cp_masks.png')), masks)
check_output(data_dir, image_names, '2D', model_type)
fig = plt.figure(figsize=(8,3))
plot.show_segmentation(fig, img, masks, flows[0], channels=[chan[m],chan2[m]])
compare_masks(data_dir, ['rgb_2D.png'], '2D', model_type)
clear_output(data_dir, image_names)
if MATPLOTLIB:
fig = plt.figure(figsize=(8,3))
plot.show_segmentation(fig, img, masks, flows[0], channels=[chan[m],chan2[m]])

def test_class_3D(data_dir, image_names):
clear_output(data_dir, image_names)
model = models.Cellpose()
img = io.imread(str(data_dir.joinpath('3D').joinpath('rgb_3D.tif')))
model_types = ['cyto', 'nuclei']
chan = [2,1]
chan2 = [1,0]
model_types = ['nuclei']
chan = [1]
chan2 = [0]
for m,model_type in enumerate(model_types):
masks = model.eval(img, diameter=0, channels=[chan[m],chan2[m]])[0]
model = models.Cellpose(model_type='nuclei')
masks = model.eval(img, do_3D=True, diameter=25, channels=[chan[m],chan2[m]], net_avg=False)[0]
io.imsave(str(data_dir.joinpath('3D').joinpath('rgb_3D_cp_masks.tif')), masks)
check_output(data_dir, image_names, '3D', model_type)
compare_masks(data_dir, ['rgb_3D.tif'], '3D', model_type)
clear_output(data_dir, image_names)

def test_cli_2D(data_dir, image_names):
clear_output(data_dir, image_names)
model_types = ['cyto', 'nuclei']
chan = [2,1]
chan2 = [1,0]
model_types = ['cyto']
chan = [2]
chan2 = [1]
for m,model_type in enumerate(model_types):
cmd = 'python -m cellpose --dir %s --pretrained_model %s --fast_mode --chan %d --chan2 %d --diameter 0 --save_png'%(str(data_dir.joinpath('2D')), model_type, chan[m], chan2[m])
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
stdout, stderr = process.communicate()
print(stdout)
print(stderr)
check_output(data_dir, image_names, '2D', model_type)
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
print(cmd_stdout)
except Exception as e:
print(e)
raise ValueError(e)
compare_masks(data_dir, image_names, '2D', model_type)
clear_output(data_dir, image_names)

def test_cli_3D(data_dir, image_names):
clear_output(data_dir, image_names)
model_types = ['cyto', 'nuclei']
chan = [2,1]
chan2 = [1,0]
model_types = ['cyto']
chan = [2]
chan2 = [1]
for m,model_type in enumerate(model_types):
cmd = 'python -m cellpose --dir %s --do_3D --pretrained_model %s --fast_mode --chan %d --chan2 %d --diameter 25 --save_tif'%(str(data_dir.joinpath('3D')), model_type, chan[m], chan2[m])
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
stdout, stderr = process.communicate()
print(stdout)
print(stderr)
check_output(data_dir, image_names, '3D', model_type)
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
except Exception as e:
print(e)
raise ValueError(e)
compare_masks(data_dir, image_names, '3D', model_type)
clear_output(data_dir, image_names)

def check_output(data_dir, image_names, runtype, model_type):
def test_cli_train(data_dir, image_names):
train_dir = str(data_dir.joinpath('2D').joinpath('train'))
cmd = 'python -m cellpose --train --train_size --n_epochs 10 --dir %s --mask_filter _cyto_masks --pretrained_model cyto --chan 2 --chan2 1 --diameter 30'%train_dir
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
except Exception as e:
print(e)
raise ValueError(e)

def compare_masks(data_dir, image_names, runtype, model_type):
"""
Helper function to check if outputs given by a test are exactly the same
as the ground truth outputs.
"""
data_dir_2D = data_dir.joinpath('2D')
data_dir_3D = data_dir.joinpath('3D')
for image_name in image_names:
check=False
if '2D' in runtype and '2D' in image_name:
image_file = str(data_dir_2D.joinpath(image_name))
name, ext = os.path.splitext(image_file)
name = os.path.splitext(image_file)[0]
output_test = name + '_cp_masks.png'
output_true = name + '_%s_masks.png'%model_type
check = True
elif '3D' in runtype and '3D' in image_name:
image_file = str(data_dir_3D.joinpath(image_name))
name, ext = os.path.splitext(image_file)
name = os.path.splitext(image_file)[0]
output_test = name + '_cp_masks.tif'
output_true = name + '_%s_masks.tif'%model_type
check = True
Expand All @@ -106,21 +123,16 @@ def check_output(data_dir, image_names, runtype, model_type):
print('checking output %s'%output_test)
masks_test = io.imread(output_test)
masks_true = io.imread(output_true)
ap = metrics.average_precision(masks_true, masks_test)

ap = metrics.average_precision(masks_true, masks_test)[0]
print('average precision of [%0.3f %0.3f %0.3f]'%(ap[0],ap[1],ap[2]))
yield np.allclose(ap, np.ones(3), rtol=r_tol, atol=a_tol)
ap_precision = np.allclose(ap, np.ones(3), rtol=r_tol, atol=a_tol)

matching_pix = np.logical_and(masks_test>0, masks_true>0).mean()
all_pix = (masks_test>0).mean()
yield np.allclose(all_pix, matching_pix, rtol=r_tol, atol=a_tol)
pix_precision = np.allclose(all_pix, matching_pix, rtol=r_tol, atol=a_tol)

assert all([ap_precision, pix_precision])
else:
print('ERROR: no file of name %s found'%output_test)
print('ERROR: no output file of name %s found'%output_test)
assert False
clear_output(data_dir, image_names)

#def test_cli_3D(data_dir):
# os.system('python -m cellpose --dir %s'%str(data_dir.join('3D').resolve()))

#def test_gray_2D(data_dir):
# os.system('python -m cellpose ')
# data_dir.join('2D').

0 comments on commit 2581513

Please sign in to comment.