Skip to content

Commit

Permalink
Removed rec_field attribute and wrapper.outputs, fixed bug of prev co…
Browse files Browse the repository at this point in the history
…mmit.
  • Loading branch information
Kamnitsask committed Jan 19, 2020
1 parent 14d8a31 commit 77bd7c6
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 45 deletions.
14 changes: 7 additions & 7 deletions deepmedic/dataManagement/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,21 @@ def calc_pad_per_axis(pad_input_imgs, dims_img, dims_rec_field, dims_highres_seg
# These pad/unpad should have their own class, and an instance should be created per subject.
# So that unpad gets how much to unpad from the pad.
def pad_imgs_of_case(channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat,
pad_input_imgs, dims_rec_field, dims_highres_segment):
pad_input_imgs, unpred_margin):
# channels: np.array of dimensions [n_channels, x-dim, y-dim, z-dim]
# gt_lbl_img: np.array
# roi_mask: np.array
# wmaps_to_sample_per_cat: np.array of dimensions [num_categories, x-dim, y-dim, z-dim]
# dims_highres_segment: list [x,y,z] of dimensions of the normal-resolution samples for cnn.
# pad_input_imgs: Boolean, do padding or not.
# unpred_margin: [[pre-x, post-x], [pre-y, post-y], [pre-z, post-z]], number voxels not predicted
# Returns:
# pad_left_right_axes: Padding added before and after each axis. All 0s if no padding.

# Padding added before and after each axis. ((0, 0), (0, 0), (0, 0)) if no pad.
pad_left_right_per_axis = calc_pad_per_axis(pad_input_imgs,
channels[0].shape, dims_rec_field, dims_highres_segment)
if not pad_input_imgs:
return channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat, pad_left_right_axes

# Padding added before and after each axis. ((0, 0), (0, 0), (0, 0)) if no pad.
pad_left_right_per_axis = unpred_margin

channels = pad_4d_arr(channels, pad_left_right_per_axis)

if gt_lbl_img is not None:
Expand All @@ -85,7 +85,7 @@ def pad_imgs_of_case(channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat,

def pad_4d_arr(arr_4d, pad_left_right_per_axis_3d):
# Do not pad first dimension. E.g. for channels or weightmaps, [n_chans,x,y,z]
pad_left_right_per_axis_4d = ((0,0),) + pad_left_right_per_axis_3d
pad_left_right_per_axis_4d = [[0,0],] + pad_left_right_per_axis_3d
return np.lib.pad(arr_4d, pad_left_right_per_axis_4d, 'reflect')

def pad_3d_img(img, pad_left_right_per_axis):
Expand Down
17 changes: 9 additions & 8 deletions deepmedic/dataManagement/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_samples_for_subepoch(log,
sampling_type,
inp_shapes_per_path,
outp_pred_dims,
unpred_margin,
# Paths to input files
paths_per_chan_per_subj,
paths_to_lbls_per_subj,
Expand Down Expand Up @@ -109,7 +110,8 @@ def get_samples_for_subepoch(log,
idxs_of_subjs_for_subep,
n_samples_per_subj,
inp_shapes_per_path,
outp_pred_dims
outp_pred_dims,
unpred_margin
]

log.print3(sampler_id + " Will sample from [" + str(n_subjs_for_subep) +
Expand Down Expand Up @@ -261,7 +263,8 @@ def load_subj_and_sample(job_idx,
idxs_of_subjs_for_subep,
n_samples_per_subj,
inp_shapes_per_path,
outp_pred_dims):
outp_pred_dims,
unpred_margin):
# train_val_or_test: 'train', 'val' or 'test'
# paths_per_chan_per_subj: [[ for chan-0 [ one path per subj ]], ..., [for chan-n [ one path per subj ] ]]
# n_samples_per_cat_per_subj: np arr, shape [num sampling categories, num subjects in subepoch]
Expand Down Expand Up @@ -301,7 +304,7 @@ def load_subj_and_sample(job_idx,
pad_left_right_per_axis) = preproc_imgs_of_subj(log, job_id,
channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat,
run_input_checks, cnn3d.num_classes, # checks
pad_input_imgs, cnn3d.receptive_field, dims_hres_segment, # pad
pad_input_imgs, unpred_margin,
norm_prms)
time_prep = time.time() - time_prep_0

Expand Down Expand Up @@ -466,9 +469,7 @@ def load_imgs_of_subject(log,


def preproc_imgs_of_subj(log, job_id, channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat,
run_input_checks, n_classes,
pad_input_imgs, dims_rec_field, dims_hres_segment,
norm_prms):
run_input_checks, n_classes, pad_input_imgs, unpred_margin, norm_prms):
# job_id: Should be "" in testing.

if run_input_checks:
Expand All @@ -479,7 +480,7 @@ def preproc_imgs_of_subj(log, job_id, channels, gt_lbl_img, roi_mask, wmaps_to_s
roi_mask,
wmaps_to_sample_per_cat,
pad_left_right_per_axis) = pad_imgs_of_case(channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat,
pad_input_imgs, dims_rec_field, dims_hres_segment)
pad_input_imgs, unpred_margin)

channels = normalize_int_of_subj(log, channels, roi_mask, norm_prms, job_id)

Expand Down Expand Up @@ -713,7 +714,7 @@ def extractSegmentGivenSliceCoords(train_val_or_test,
channs_of_sample_per_path.append(channsForThisSubsampledPartAndPathway)

# Get ground truth labels for training.
numOfCentralVoxelsClassifRcz = cnn3d.finalTargetLayer_outputShape[train_val_or_test][2:]
numOfCentralVoxelsClassifRcz = outp_pred_dims
leftBoundaryRcz = [coord_center[d] - (numOfCentralVoxelsClassifRcz[d] - 1) // 2 for d in range (3)]
rightBoundaryRcz = [leftBoundaryRcz[d] + numOfCentralVoxelsClassifRcz[d] for d in range(3)]
lbls_predicted_part_of_sample = gt_lbl_img[leftBoundaryRcz[0]: rightBoundaryRcz[0],
Expand Down
11 changes: 3 additions & 8 deletions deepmedic/neuralnet/cnn3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def __init__(self):

self.num_classes = None

self.receptive_field = ""

#======= Output tensors Y_GT ========
# For each targetLayer, I should be placing a y_gt placeholder/feed.
self._output_gt_tensor_feeds = {'train': {},
Expand Down Expand Up @@ -370,10 +368,6 @@ def make_cnn_model( self,
self._output_gt_tensor_feeds['train']['y_gt'] = tf.compat.v1.placeholder(dtype="int32", shape=[None, None, None, None], name="y_train")
self._output_gt_tensor_feeds['val']['y_gt'] = tf.compat.v1.placeholder(dtype="int32", shape=[None, None, None, None], name="y_val")

# ======== Calculated Attributes =========
#This recField CNN should in future be calculated with all non-secondary pathways, ie normal+fc. Use another variable for pathway.recField.
self.receptive_field = self._calc_rec_field_cnn_wrt_hr_inp()

log.print3("Finished building the CNN's model.")


Expand Down Expand Up @@ -432,7 +426,7 @@ def calc_inp_dims_of_paths_from_hr_inp(self, inp_hr_dims):
# [pathFc-in-dim-x, pathFc-in-dim-y, pathFc-in-dim-z] ]
return inp_shape_per_path

def _calc_rec_field_cnn_wrt_hr_inp(self):
def _calc_receptive_field_cnn_wrt_hr_inp(self):
rec_field_hr_path, strides_rf_at_end_of_hr_path = self.pathways[0].rec_field()
cnn_rf, _ = self.pathways[-1].rec_field(rec_field_hr_path, strides_rf_at_end_of_hr_path)
return cnn_rf
Expand All @@ -442,9 +436,10 @@ def calc_outp_dims_given_inp(self, inp_dims_hr_path):
return self.pathways[-1].calc_outp_dims_given_inp(outp_dims_hr_path)

def calc_unpredicted_margin(self, inp_dims_hr_path):
# unpred_margin: [[before-x, after-x], [before-y, after-y], [before-z, after-z]]
outp_dims = self.calc_outp_dims_given_inp(inp_dims_hr_path)
n_unpred_vox = [inp_dims_hr_path[d] - outp_dims[d] for d in range(3)]
unpred_margin = [n_unpred_vox[d] // 2 for d in range(3)]
unpred_margin = [[n_unpred_vox[d]//2, n_unpred_vox[d]-n_unpred_vox[d]//2] for d in range(3)]
return unpred_margin


4 changes: 0 additions & 4 deletions deepmedic/neuralnet/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class CnnWrapperForSampling(object):
def __init__(self, cnn3d) :
# Cnn
self.num_classes = cnn3d.num_classes
self.receptive_field = cnn3d.receptive_field
self.finalTargetLayer_outputShape = {"train": cnn3d.finalTargetLayer.output["train"].shape,
"val": cnn3d.finalTargetLayer.output["val"].shape,
"test": cnn3d.finalTargetLayer.output["test"].shape}
# Pathways related
self._numPathwaysThatRequireInput = cnn3d.getNumPathwaysThatRequireInput()
self.numSubsPaths = cnn3d.numSubsPaths
Expand Down
33 changes: 15 additions & 18 deletions deepmedic/routines/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def stitch_predicted_to_prob_maps(prob_maps_per_class, idx_next_tile_in_pred_vol
slice_coords_tile = slice_coords[idx_next_tile_in_pred_vols]
top_left = [slice_coords_tile[0][0], slice_coords_tile[1][0], slice_coords_tile[2][0]]
prob_maps_per_class[:,
top_left[0] + unpred_margin[0]: top_left[0] + unpred_margin[0] + stride[0],
top_left[1] + unpred_margin[1]: top_left[1] + unpred_margin[1] + stride[1],
top_left[2] + unpred_margin[2]: top_left[2] + unpred_margin[2] + stride[2]
top_left[0] + unpred_margin[0][0]: top_left[0] + unpred_margin[0][0] + stride[0],
top_left[1] + unpred_margin[1][0]: top_left[1] + unpred_margin[1][0] + stride[1],
top_left[2] + unpred_margin[2][0]: top_left[2] + unpred_margin[2][0] + stride[2]
] = prob_maps_batch[tile_i]
idx_next_tile_in_pred_vols += 1

Expand Down Expand Up @@ -180,14 +180,14 @@ def stitch_predicted_to_fms(array_fms_to_save, idx_next_tile_in_fm_vols,
# newly created images all at once.
fm_to_reconstruct[:, # last dimension is the number-of-Fms, I create an image for each.

coords_top_left_voxel[0] + unpred_margin[0]:
coords_top_left_voxel[0] + unpred_margin[0] + stride[0],
coords_top_left_voxel[0] + unpred_margin[0][0]:
coords_top_left_voxel[0] + unpred_margin[0][0] + stride[0],

coords_top_left_voxel[1] + unpred_margin[1]:
coords_top_left_voxel[1] + unpred_margin[1] + stride[1],
coords_top_left_voxel[1] + unpred_margin[1][0]:
coords_top_left_voxel[1] + unpred_margin[1][0] + stride[1],

coords_top_left_voxel[2] + unpred_margin[2]:
coords_top_left_voxel[2] + unpred_margin[2] + stride[2]
coords_top_left_voxel[2] + unpred_margin[2][0]:
coords_top_left_voxel[2] + unpred_margin[2][0] + stride[2]

] = central_voxels_all_fms_batch[tile_batch_idx]

Expand Down Expand Up @@ -220,12 +220,10 @@ def prepare_feeds_dict(feeds, channs_of_tiles_per_path):


def predict_whole_volume_by_tiling(log, sessionTf, cnn3d,
channels, roi_mask, inp_shapes_per_path, batchsize,
save_fms_flag, idxs_fms_to_save ):
channels, roi_mask, inp_shapes_per_path, unpred_margin,
batchsize, save_fms_flag, idxs_fms_to_save):
# One of the main routines. Segment whole volume tile-by-tile.

# Receptive field is list [size-x, size-y, size-z]. -1 to exclude the central voxel.
unpred_margin = cnn3d.calc_unpredicted_margin(inp_shapes_per_path[0]) # Non pred voxels left.
# For tiling the volume: Stride is how much I move in each dimension to get the next tile.
# I stride exactly the number of voxels that are predicted per forward pass.
outp_pred_dims = cnn3d.calc_outp_dims_given_inp(inp_shapes_per_path[0])
Expand Down Expand Up @@ -526,8 +524,7 @@ def inference_on_whole_volumes(sessionTf,
NA_PATTERN = AccuracyMonitorForEpSegm.NA_PATTERN
n_classes = cnn3d.num_classes
n_subjects = len(paths_per_chan_per_subj)
dims_hres_segment = inp_shapes_per_path[0] # pathway [0] is the high-res path.

unpred_margin = cnn3d.calc_unpredicted_margin(inp_shapes_per_path[0])
# One dice score for whole foreground (0) AND one for each actual class
# Dice1 - AllpredictedLes/AllLesions
# Dice2 - predictedInsideRoiMask/AllLesions
Expand Down Expand Up @@ -559,7 +556,7 @@ def inference_on_whole_volumes(sessionTf,
pad_left_right_per_axis) = preproc_imgs_of_subj(log, "",
channels, gt_lbl_img, roi_mask, None,
run_input_checks, n_classes, # checks
pad_input, cnn3d.receptive_field, dims_hres_segment, # pad
pad_input, unpred_margin,
norm_prms)

# ============== Augmentation ==================
Expand All @@ -569,8 +566,8 @@ def inference_on_whole_volumes(sessionTf,
# array_fms_to_save will be None if not saving them.
(prob_maps_vols,
array_fms_to_save) = predict_whole_volume_by_tiling(log, sessionTf, cnn3d,
channels, roi_mask, inp_shapes_per_path, batchsize,
save_fms_flag, idxs_fms_to_save )
channels, roi_mask, inp_shapes_per_path, unpred_margin,
batchsize, save_fms_flag, idxs_fms_to_save )

# ========================== Post-Processing =========================
pred_seg = np.argmax(prob_maps_vols, axis=0) # The segmentation.
Expand Down
2 changes: 2 additions & 0 deletions deepmedic/routines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def do_training(sessionTf,
sampling_type_inst_tr,
inp_shapes_per_path_train,
cnn3d.calc_outp_dims_given_inp(inp_shapes_per_path_train[0]),
cnn3d.calc_unpredicted_margin(inp_shapes_per_path_train[0]),
paths_per_chan_per_subj_train,
paths_to_lbls_per_subj_train,
paths_to_masks_per_subj_train,
Expand All @@ -214,6 +215,7 @@ def do_training(sessionTf,
sampling_type_inst_val,
inp_shapes_per_path_val,
cnn3d.calc_outp_dims_given_inp(inp_shapes_per_path_val[0]),
cnn3d.calc_unpredicted_margin(inp_shapes_per_path_val[0]),
paths_per_chan_per_subj_val,
paths_to_lbls_per_subj_val,
paths_to_masks_per_subj_val,
Expand Down

0 comments on commit 77bd7c6

Please sign in to comment.