Skip to content

Commit

Permalink
clean up scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Geoyi committed Aug 14, 2018
1 parent 4472e05 commit 7c8db81
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 9 deletions.
6 changes: 2 additions & 4 deletions pixel_decoder/predict.py
Expand Up @@ -17,16 +17,14 @@
from pixel_decoder.utils import stats_data, open_image, preprocess_inputs_std, cache_stats
from pixel_decoder.resnet_unet import get_resnet_unet

def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_no=256, border_no=32, channel_no=3, model_id='resnet_unet'):
def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_no, border_no, channel_no, model_id):
origin_shape = (origin_shape_no, origin_shape_no)
rgb_index = [0, 1, 2]
border = (border_no, border_no)
input_shape = (origin_shape[0] + border[0] + border[1] , origin_shape[1] + border[0] + border[1])
means, stds = cache_stats(imgs_folder)
if not path.isdir(pred_folder):mkdir(os.path.join(os.getcwd(),pred_folder))
if not path.isdir(path.join(pred_folder, model_id)):mkdir(path.join(pred_folder, model_id))
if not path.isdir(path.join(pred_folder, model_id)):mkdir(path.join(pred_folder, model_id))
if not path.isdir(path.join(path.join(pred_folder, model_name))):mkdir(path.join(path.join(pred_folder, model_id)))
if model_id == 'resnet_unet':
model = get_resnet_unet(input_shape, channel_no)
else:
Expand Down Expand Up @@ -64,4 +62,4 @@ def predict(origin_shape, imgs_folder, models_folder, pred_folder, origin_shape_
mask = mask[mask_index1:mask_index2, mask_index1:mask_index2, ...]
mask = mask * 255
mask = mask.astype('uint8')
cv2.imwrite(path.join(pred_folder, model_name,'{}.png'.format(img_id)), mask, [cv2.IMWRITE_PNG_COMPRESSION, 9])
cv2.imwrite(path.join(pred_folder, model_id,'{}.png'.format(img_id)), mask, [cv2.IMWRITE_PNG_COMPRESSION, 9])
4 changes: 2 additions & 2 deletions pixel_decoder/train.py
Expand Up @@ -15,6 +15,7 @@
from keras import metrics
from keras.callbacks import ModelCheckpoint
from pixel_decoder.loss import dice_coef, dice_logloss2, dice_logloss3, dice_coef_rounded, dice_logloss
from pixel_decoder.resnet_unet import get_resnet_unet
# import skimage.io
import keras.backend as K

Expand All @@ -28,14 +29,13 @@
# masks_folder = sys.argv[3]
# models_folder =sys.argv[4]

def train(batch_size, imgs_folder, masks_folder, models_folder, model_id='resnet_unet', origin_shape_no=256, border_no=32, channel_no = 3):
def train(batch_size, imgs_folder, masks_folder, models_folder, model_id, origin_shape_no, border_no, channel_no):
origin_shape = (origin_shape_no, origin_shape_no)
border = (border_no, border_no)
all_files, all_masks = datafiles(imgs_folder, masks_folder, models_folder)
means, stds = cache_stats(imgs_folder)
input_shape = (origin_shape[0] + border[0] + border[1] , origin_shape[1] + border[0] + border[1])
if model_id == 'resnet_unet':
from pixel_decoder.resnet_unet import get_resnet_unet
model = get_resnet_unet(input_shape, channel_no)
# elif model_id == 'inception_unet':
# from pixel_decoder.inception_unet import get_inception_resnet_v2_unet
Expand Down
4 changes: 2 additions & 2 deletions pixel_decoder/utils.py
Expand Up @@ -101,7 +101,7 @@ def rotate_image(image, angle, scale, imgs_folder, masks_folder, models_folder):

# = cache_stats(imgs_folder)

def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_folder, models_folder, channel_no = 3, border_no=32, origin_shape_no = 256):
def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_folder, models_folder, channel_no, border_no, origin_shape_no):
# all_files, all_masks = datafiles()
origin_shape = (origin_shape_no, origin_shape_no)
border = (border_no, border_no)
Expand Down Expand Up @@ -160,7 +160,7 @@ def batch_data_generator(train_idx, batch_size, means, stds, imgs_folder, masks_
inputs = []
outputs = []

def val_data_generator(val_idx, batch_size, validation_steps, means, stds, imgs_folder, masks_folder, models_folder, channel_no = 3, border_no=32, origin_shape_no = 256):
def val_data_generator(val_idx, batch_size, validation_steps, means, stds, imgs_folder, masks_folder, models_folder, channel_no, border_no, origin_shape_no):
origin_shape = (origin_shape_no, origin_shape_no)
border = (border_no, border_no)
all_files, all_masks = datafiles(imgs_folder, masks_folder, models_folder)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -18,7 +18,7 @@
#readme
with open('README.md') as f:
readme = f.read()

setup(
name='pixel_decoder',
author='Zhuangfang NaNa Yi',
Expand Down

0 comments on commit 7c8db81

Please sign in to comment.