From 7c8db817e57c5ba295fc1b6cecedab421c111db5 Mon Sep 17 00:00:00 2001 From: Zhuangfang Yi Date: Tue, 14 Aug 2018 19:16:43 -0400 Subject: [PATCH] clean up scripts --- pixel_decoder/predict.py | 6 ++---- pixel_decoder/train.py | 4 ++-- pixel_decoder/utils.py | 4 ++-- setup.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pixel_decoder/predict.py b/pixel_decoder/predict.py index f6326fe..9ad3f9a 100644 --- a/pixel_decoder/predict.py +++ b/pixel_decoder/predict.py @@ -17,7 +17,7 @@ from pixel_decoder.utils import stats_data, open_image, preprocess_inputs_std, cache_stats from pixel_decoder.resnet_unet import get_resnet_unet -def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_no=256, border_no=32, channel_no=3, model_id='resnet_unet'): +def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_no, border_no, channel_no, model_id): origin_shape = (origin_shape_no, origin_shape_no) rgb_index = [0, 1, 2] border = (border_no, border_no) @@ -25,8 +25,6 @@ def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_ means, stds = cache_stats(imgs_folder) if not path.isdir(pred_folder):mkdir(os.path.join(os.getcwd(),pred_folder)) if not path.isdir(path.join(pred_folder, model_id)):mkdir(path.join(pred_folder, model_id)) - if not path.isdir(path.join(pred_folder, model_id)):mkdir(path.join(pred_folder, model_id)) - if not path.isdir(path.join(path.join(pred_folder, model_name))):mkdir(path.join(path.join(pred_folder, model_id))) if model_id == 'resnet_unet': model = get_resnet_unet(input_shape, channel_no) else: @@ -64,4 +62,4 @@ def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_ mask = mask[mask_index1:mask_index2, mask_index1:mask_index2, ...] mask = mask * 255 mask = mask.astype('uint8') - cv2.imwrite(path.join(pred_folder, model_name,'{}.png'.format(img_id)), mask, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + cv2.imwrite(path.join(pred_folder, model_id,'{}.png'.format(img_id)), mask, [cv2.IMWRITE_PNG_COMPRESSION, 9]) diff --git a/pixel_decoder/train.py b/pixel_decoder/train.py index 5eafc8c..e47aa86 100644 --- a/pixel_decoder/train.py +++ b/pixel_decoder/train.py @@ -15,6 +15,7 @@ from keras import metrics from keras.callbacks import ModelCheckpoint from pixel_decoder.loss import dice_coef, dice_logloss2, dice_logloss3, dice_coef_rounded, dice_logloss +from pixel_decoder.resnet_unet import get_resnet_unet # import skimage.io import keras.backend as K @@ -28,14 +29,13 @@ # masks_folder = sys.argv[3] # models_folder =sys.argv[4] -def train(batch_size, imgs_folder, masks_folder, models_folder, model_id='resnet_unet', origin_shape_no=256, border_no=32, channel_no = 3): +def train(batch_size, imgs_folder, masks_folder, models_folder, model_id, origin_shape_no, border_no, channel_no): origin_shape = (origin_shape_no, origin_shape_no) border = (border_no, border_no) all_files, all_masks = datafiles(imgs_folder, masks_folder, models_folder) means, stds = cache_stats(imgs_folder) input_shape = (origin_shape[0] + border[0] + border[1] , origin_shape[1] + border[0] + border[1]) if model_id == 'resnet_unet': - from pixel_decoder.resnet_unet import get_resnet_unet model = get_resnet_unet(input_shape, channel_no) # elif model_id == 'inception_unet': # from pixel_decoder.inception_unet import get_inception_resnet_v2_unet diff --git a/pixel_decoder/utils.py b/pixel_decoder/utils.py index 76a3e0c..b74632b 100644 --- a/pixel_decoder/utils.py +++ b/pixel_decoder/utils.py @@ -101,7 +101,7 @@ def rotate_image(image, angle, scale, imgs_folder, masks_folder, models_folder): # = cache_stats(imgs_folder) -def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_folder, models_folder, channel_no = 3, border_no=32, origin_shape_no = 256): +def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_folder, models_folder, channel_no, border_no, origin_shape_no): # all_files, all_masks = datafiles() origin_shape = (origin_shape_no, origin_shape_no) border = (border_no, border_no) @@ -160,7 +160,7 @@ def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_ inputs = [] outputs = [] -def val_data_generator(val_idx, batch_size, validation_steps, means, stds, imgs_folder, masks_folder, models_folder, channel_no = 3, border_no=32, origin_shape_no = 256): +def val_data_generator(val_idx, batch_size, validation_steps, means, stds, imgs_folder, masks_folder, models_folder, channel_no, border_no, origin_shape_no): origin_shape = (origin_shape_no, origin_shape_no) border = (border_no, border_no) all_files, all_masks = datafiles(imgs_folder, masks_folder, models_folder) diff --git a/setup.py b/setup.py index 28b7c02..88387ce 100755 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ #readme with open('README.md') as f: readme = f.read() - + setup( name='pixel_decoder', author='Zhuangfang NaNa Yi',