Skip to content

Commit

Permalink
Merge pull request #254 from saskra/lold_branch
Browse files Browse the repository at this point in the history
Enable look_one_level_down during training as well
  • Loading branch information
carsen-stringer committed May 7, 2021
2 parents babf2ef + f88b15a commit 35c16c9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
4 changes: 3 additions & 1 deletion cellpose/__main__.py
@@ -1,5 +1,6 @@
import sys, os, argparse, glob, pathlib, time
import subprocess

import numpy as np
from natsort import natsorted
from tqdm import tqdm
Expand All @@ -22,6 +23,7 @@
import logging
logger = logging.getLogger(__name__)


def main():

parser = argparse.ArgumentParser(description='cellpose parameters')
Expand Down Expand Up @@ -213,7 +215,7 @@ def main():
szmean = 30.

test_dir = None if len(args.test_dir)==0 else args.test_dir
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, args.unet)
output = io.load_train_test_data(args.dir, test_dir, imf, args.mask_filter, args.unet, args.look_one_level_down)
images, labels, image_names, test_images, test_labels, image_names_test = output

# training with all channels
Expand Down
6 changes: 3 additions & 3 deletions cellpose/io.py
Expand Up @@ -145,8 +145,8 @@ def get_label_files(image_names, mask_filter, imf=None):
return label_names, flow_names


def load_train_test_data(train_dir, test_dir=None, image_filter=None, mask_filter='_masks', unet=False):
image_names = get_image_files(train_dir, mask_filter, imf=image_filter)
def load_train_test_data(train_dir, test_dir=None, image_filter=None, mask_filter='_masks', unet=False, look_one_level_down=True):
image_names = get_image_files(train_dir, mask_filter, image_filter, look_one_level_down)
nimg = len(image_names)
images = [imread(image_names[n]) for n in range(nimg)]

Expand All @@ -165,7 +165,7 @@ def load_train_test_data(train_dir, test_dir=None, image_filter=None, mask_filte
# testing data
test_images, test_labels, image_names_test = None, None, None
if test_dir is not None:
image_names_test = get_image_files(test_dir, mask_filter, imf=image_filter)
image_names_test = get_image_files(test_dir, mask_filter, image_filter, look_one_level_down)
label_names_test, flow_names_test = get_label_files(image_names_test, mask_filter, imf=image_filter)
nimg = len(image_names_test)
test_images = [imread(image_names_test[n]) for n in range(nimg)]
Expand Down
12 changes: 12 additions & 0 deletions requirements.txt
@@ -0,0 +1,12 @@
numpy~=1.19.3
matplotlib~=3.3.2
mxnet~=1.5.0
natsort~=7.1.1
scipy~=1.6.0
pytest~=6.2.2
tifffile~=2021.2.1
pyqtgraph~=0.11.1
tqdm~=4.56.0
torch~=1.7.1
numba~=0.52.0
setuptools~=52.0.0

0 comments on commit 35c16c9

Please sign in to comment.