Skip to content

Commit

Permalink
removed get_shape and _inp_shape from cnn3d and sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamnitsask committed Jan 18, 2020
1 parent cba7a6a commit 0cea7b5
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 58 deletions.
28 changes: 17 additions & 11 deletions deepmedic/dataManagement/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_samples_for_subepoch(log,
max_n_cases_per_subep,
n_samples_per_subep,
sampling_type,
inp_shapes_per_path,
# Paths to input files
paths_per_chan_per_subj,
paths_to_lbls_per_subj,
Expand Down Expand Up @@ -105,7 +106,8 @@ def get_samples_for_subepoch(log,

n_subjs_for_subep,
idxs_of_subjs_for_subep,
n_samples_per_subj
n_samples_per_subj,
inp_shapes_per_path
]

log.print3(sampler_id + " Will sample from [" + str(n_subjs_for_subep) +
Expand Down Expand Up @@ -255,7 +257,8 @@ def load_subj_and_sample(job_idx,
augm_sample_prms,
n_subjs_for_subep,
idxs_of_subjs_for_subep,
n_samples_per_subj):
n_samples_per_subj,
inp_shapes_per_path):
# train_val_or_test: 'train', 'val' or 'test'
# paths_per_chan_per_subj: [[ for chan-0 [ one path per subj ]], ..., [for chan-n [ one path per subj ] ]]
# n_samples_per_cat_per_subj: np arr, shape [num sampling categories, num subjects in subepoch]
Expand All @@ -271,7 +274,7 @@ def load_subj_and_sample(job_idx,
channs_of_samples_per_path = [[] for i in range(cnn3d.getNumPathwaysThatRequireInput())]
lbls_predicted_part_of_samples = [] # Labels only for the central/predicted part of segments.

dims_hres_segment = cnn3d.get_inp_shape_of_path(0, train_val_or_test)
dims_hres_segment = inp_shapes_per_path[0]

# Load images of subject
time_load_0 = time.time()
Expand Down Expand Up @@ -354,7 +357,8 @@ def load_subj_and_sample(job_idx,
cnn3d,
coord_center,
channels,
gt_lbl_img)
gt_lbl_img,
inp_shapes_per_path)

# Augmentation of segments
time_augm_sample_0 = time.time()
Expand Down Expand Up @@ -724,7 +728,8 @@ def extractSegmentGivenSliceCoords(train_val_or_test,
cnn3d,
coord_center,
channels,
gt_lbl_img):
gt_lbl_img,
inp_shapes_per_path):
# channels: numpy array [ n_channels, x, y, z ]
# coord_center: indeces of the central voxel for the patch to be extracted.

Expand All @@ -739,7 +744,7 @@ def extractSegmentGivenSliceCoords(train_val_or_test,
if cnn3d.pathways[path_idx].pType() == pt.FC:
continue
subSamplingFactor = cnn3d.pathways[path_idx].subs_factor()
pathwayInputShapeRcz = cnn3d.get_inp_shape_of_path(path_idx, train_val_or_test)
pathwayInputShapeRcz = inp_shapes_per_path[path_idx]
leftBoundaryRcz = [coord_center[0] - subSamplingFactor[0] * (pathwayInputShapeRcz[0] - 1) // 2,
coord_center[1] - subSamplingFactor[1] * (pathwayInputShapeRcz[1] - 1) // 2,
coord_center[2] - subSamplingFactor[2] * (pathwayInputShapeRcz[2] - 1) // 2]
Expand All @@ -760,7 +765,7 @@ def extractSegmentGivenSliceCoords(train_val_or_test,
if cnn3d.pathways[pathway_i].pType() == pt.FC or cnn3d.pathways[pathway_i].pType() == pt.NORM:
continue
# this datastructure is similar to channelsForThisImagePart, but contains voxels from the subsampled image.
dimsOfPrimarySegment = cnn3d.get_inp_shape_of_path(pathway_i, train_val_or_test)
dimsOfPrimarySegment = inp_shapes_per_path[pathway_i]

# rightmost are placeholders here.
slicesCoordsOfSegmForPrimaryPathway = [[leftBoundaryRcz[0], rightBoundaryRcz[0] - 1],
Expand All @@ -772,7 +777,7 @@ def extractSegmentGivenSliceCoords(train_val_or_test,
subsampledImageChannels=channels,
image_part_slices_coords=slicesCoordsOfSegmForPrimaryPathway,
subSamplingFactor=cnn3d.pathways[pathway_i].subs_factor(),
subsampledImagePartDimensions=cnn3d.get_inp_shape_of_path(pathway_i, train_val_or_test)
subsampledImagePartDimensions=inp_shapes_per_path[pathway_i]
)

channs_of_sample_per_path.append(channsForThisSubsampledPartAndPathway)
Expand Down Expand Up @@ -879,13 +884,14 @@ def get_slice_coords_of_all_img_tiles(log,
def extractSegmentsGivenSliceCoords(cnn3d,
sliceCoordsOfSegmentsToExtract,
channelsOfImageNpArray,
recFieldCnn):
recFieldCnn,
inp_shapes_per_path):
# channelsOfImageNpArray: numpy array [ n_channels, x, y, z ]
numberOfSegmentsToExtract = len(sliceCoordsOfSegmentsToExtract)
channsForSegmentsPerPathToReturn = [[] for i in range(
cnn3d.getNumPathwaysThatRequireInput())] # [pathway, image parts, channels, r, c, z]
# RCZ dims of input to primary pathway (NORMAL). Which should be the first one in .pathways.
dimsOfPrimarySegment = cnn3d.get_inp_shape_of_path(0, 'test')
dimsOfPrimarySegment = inp_shapes_per_path[0]

for segment_i in range(numberOfSegmentsToExtract):
rLowBoundary = sliceCoordsOfSegmentsToExtract[segment_i][0][0]
Expand Down Expand Up @@ -916,7 +922,7 @@ def extractSegmentsGivenSliceCoords(cnn3d,
subsampledImageChannels=channelsOfImageNpArray,
image_part_slices_coords=slicesCoordsOfSegmForPrimaryPathway,
subSamplingFactor=cnn3d.pathways[pathway_i].subs_factor(),
subsampledImagePartDimensions=cnn3d.get_inp_shape_of_path(pathway_i, 'test')
subsampledImagePartDimensions=inp_shapes_per_path[pathway_i]
)
channsForSegmentsPerPathToReturn[pathway_i].append(channsForThisSubsPathForThisSegm)

Expand Down
12 changes: 6 additions & 6 deletions deepmedic/frontEnd/testSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def run_session(self, *args):
cnn3d = Cnn3d()
with tf.compat.v1.variable_scope("net"):
cnn3d.make_cnn_model( *model_params.get_args_for_arch() ) # Creates the network's graph (without optimizer).
inp_plchldrs_test = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('test'), 'test')
p_y_given_x_test = cnn3d.apply(inp_plchldrs_test, 'infer', 'test', verbose=True, log=self._log)
inp_plchldrs, inp_shapes_per_path = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('test'), 'test')
p_y_given_x = cnn3d.apply(inp_plchldrs, 'infer', 'test', verbose=True, log=self._log)

self._log.print3("=========== Compiling the Testing Function ============")
self._log.print3("=======================================================\n")

cnn3d.setup_ops_n_feeds_to_test( self._log,
inp_plchldrs_test,
self._params.indices_fms_per_pathtype_per_layer_to_save )
cnn3d.setup_ops_n_feeds_to_test(self._log, inp_plchldrs, self._params.indices_fms_per_pathtype_per_layer_to_save )
# Create the saver
collection_vars_net = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope="net")
saver_net = tf.compat.v1.train.Saver(var_list=collection_vars_net) # saver_net would suffice
Expand Down Expand Up @@ -131,7 +129,9 @@ def run_session(self, *args):
self._log.print3("=========== Testing with the CNN model ===============")
self._log.print3("======================================================")

res_code = inference_on_whole_volumes( *( [sessionTf, cnn3d] + self._params.get_args_for_testing() ) )
res_code = inference_on_whole_volumes(*([sessionTf, cnn3d] +\
self._params.get_args_for_testing() +\
[inp_shapes_per_path]))

self._log.print3("")
self._log.print3("======================================================")
Expand Down
14 changes: 7 additions & 7 deletions deepmedic/frontEnd/trainSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def run_session(self, *args):
with tf.compat.v1.variable_scope("net"):
cnn3d.make_cnn_model(*model_params.get_args_for_arch())
# I have now created the CNN graph. But not yet the Optimizer's graph.
inp_plchldrs_train = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('train'), 'train')
inp_plchldrs_val = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('val'), 'val')
inp_plchldrs_test = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('test'), 'test')
inp_plchldrs_train, inp_shapes_per_path_train = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('train'), 'train')
inp_plchldrs_val, inp_shapes_per_path_val = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('val'), 'val')
inp_plchldrs_test, inp_shapes_per_path_test = cnn3d.create_inp_plchldrs(model_params.get_inp_dims_hr_path('test'), 'test')
p_y_given_x_train = cnn3d.apply(inp_plchldrs_train, 'train', 'train', verbose=True, log=self._log)
p_y_given_x_val = cnn3d.apply(inp_plchldrs_val, 'infer', 'val', verbose=True, log=self._log)
p_y_given_x_test = cnn3d.apply(inp_plchldrs_test, 'infer', 'test', verbose=True, log=self._log)
Expand Down Expand Up @@ -141,9 +141,7 @@ def run_session(self, *args):

self._log.print3("=========== Compiling the Testing Function ============")
# For validation with full segmentation
cnn3d.setup_ops_n_feeds_to_test(self._log,
inp_plchldrs_test,
self._params.indices_fms_per_pathtype_per_layer_to_save)
cnn3d.setup_ops_n_feeds_to_test(self._log, inp_plchldrs_test, self._params.indices_fms_per_pathtype_per_layer_to_save)

# Create the savers
saver_all = tf.compat.v1.train.Saver() # Will be used during training for saving everything.
Expand Down Expand Up @@ -217,7 +215,9 @@ def run_session(self, *args):
self._log.print3("============== Training the CNN model =================")
self._log.print3("=======================================================")

do_training(*([sessionTf, saver_all, cnn3d, trainer, tensorboard_loggers] + self._params.get_args_for_train_routine()))
do_training(*([sessionTf, saver_all, cnn3d, trainer, tensorboard_loggers] +\
self._params.get_args_for_train_routine() +\
[inp_shapes_per_path_train, inp_shapes_per_path_val, inp_shapes_per_path_test]))

ckpt_all.save(file_prefix = filename_to_save_with+".all.FINAL.ckpt2")
ckpt_net.save(file_prefix = filename_to_save_with+".net.FINAL.ckpt2")
Expand Down
30 changes: 9 additions & 21 deletions deepmedic/neuralnet/cnn3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,8 @@ def __init__(self):

self.num_classes = None

#=====================================
self.recFieldCnn = ""

#======= Input tensors X. Placeholders OR given tensors =======
self._inp_shapes_per_path = {'train': None,
'val': None,
'test': None} # TODO: For sampling. In eager, remove updating calc_inp_dims_of_paths_from_hr_inp

#======= Output tensors Y_GT ========
# For each targetLayer, I should be placing a y_gt placeholder/feed.
self._output_gt_tensor_feeds = {'train': {},
Expand All @@ -54,13 +48,6 @@ def __init__(self):
self._feeds_main = {'train': {} , 'val': {}, 'test': {}}


def get_inp_shapes_per_path(self):
return self._inp_shapes_per_path # TODO: This is for wrapper. Remove.

def get_inp_shape_of_path(self, path_idx, mode): # Called for sampling. TODO: Remove for eager.
assert mode in ['train', 'val', 'test']
return self._inp_shapes_per_path[mode][path_idx]

def getNumSubsPathways(self):
count = 0
for pathway in self.pathways:
Expand Down Expand Up @@ -196,18 +183,19 @@ def setup_ops_n_feeds_to_test(self, log, inp_plchldrs_test, indices_fms_per_path
log.print3("Done.")


def _setup_inp_plchldrs(self, train_val_test): # TODO: REMOVE for eager


def create_inp_plchldrs(self, inp_dims, train_val_test): # TODO: Remove for eager
inp_shapes_per_path = self.calc_inp_dims_of_paths_from_hr_inp(inp_dims)
return self._setup_inp_plchldrs(train_val_test, inp_shapes_per_path), inp_shapes_per_path

def _setup_inp_plchldrs(self, train_val_test, inp_shapes_per_path): # TODO: REMOVE for eager
assert train_val_test in ['train', 'val', 'test']
inp_plchldrs = {}
inp_plchldrs['x'] = tf.compat.v1.placeholder(dtype="float32", shape=[None, self.pathways[0].get_n_fms_in()]+self._inp_shapes_per_path[train_val_test][0], name='inp_x_'+train_val_test)
inp_plchldrs['x'] = tf.compat.v1.placeholder(dtype="float32", shape=[None, self.pathways[0].get_n_fms_in()]+inp_shapes_per_path[0], name='inp_x_'+train_val_test)
for subpath_i in range(self.numSubsPaths): # if there are subsampled paths...
inp_plchldrs['x_sub_'+str(subpath_i)] = tf.compat.v1.placeholder(dtype="float32", shape=[None, self.pathways[0].get_n_fms_in()]+self._inp_shapes_per_path[train_val_test][subpath_i+1], name="inp_x_sub_"+str(subpath_i)+'_' + train_val_test)
inp_plchldrs['x_sub_'+str(subpath_i)] = tf.compat.v1.placeholder(dtype="float32", shape=[None, self.pathways[0].get_n_fms_in()]+inp_shapes_per_path[subpath_i+1], name="inp_x_sub_"+str(subpath_i)+'_' + train_val_test)
return inp_plchldrs

def create_inp_plchldrs(self, inp_dims, train_val_test): # TODO: Remove for eager
self._inp_shapes_per_path[train_val_test] = self.calc_inp_dims_of_paths_from_hr_inp(inp_dims)
return self._setup_inp_plchldrs(train_val_test)


def make_cnn_model( self,
log,
Expand Down
4 changes: 0 additions & 4 deletions deepmedic/neuralnet/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, cnn3d) :
# Cnn
self.num_classes = cnn3d.num_classes
self.recFieldCnn = cnn3d.recFieldCnn
self._inp_shapes_per_path = cnn3d.get_inp_shapes_per_path()
self.finalTargetLayer_outputShape = {"train": cnn3d.finalTargetLayer.output["train"].shape,
"val": cnn3d.finalTargetLayer.output["val"].shape,
"test": cnn3d.finalTargetLayer.output["test"].shape}
Expand All @@ -43,7 +42,4 @@ def __init__(self, cnn3d) :
def getNumPathwaysThatRequireInput(self) :
return self._numPathwaysThatRequireInput

def get_inp_shape_of_path(self, path_idx, mode):
# mode: 'train', 'val', 'test'
return self._inp_shapes_per_path[mode][path_idx]

15 changes: 9 additions & 6 deletions deepmedic/routines/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def prepare_feeds_dict(feeds, channs_of_tiles_per_path):


def predict_whole_volume_by_tiling(log, sessionTf, cnn3d,
channels, roi_mask, batchsize,
channels, roi_mask, inp_shapes_per_path, batchsize,
save_fms_flag, idxs_fms_to_save ):
# One of the main routines. Segment whole volume tile-by-tile.

Expand All @@ -246,7 +246,7 @@ def predict_whole_volume_by_tiling(log, sessionTf, cnn3d,

# Tile the image and get all slices of the tiles that it fully breaks down to.
slice_coords_all_tiles = get_slice_coords_of_all_img_tiles(log,
cnn3d.get_inp_shape_of_path(0, 'test'),
inp_shapes_per_path[0],
stride_of_tiling,
batchsize,
inp_chan_dims,
Expand All @@ -271,7 +271,8 @@ def predict_whole_volume_by_tiling(log, sessionTf, cnn3d,
channs_of_tiles_per_path = extractSegmentsGivenSliceCoords(cnn3d,
slice_coords_of_tiles_batch,
channels,
cnn3d.recFieldCnn)
cnn3d.recFieldCnn,
inp_shapes_per_path)

# ============================== Perform forward pass ====================================
t_fwd_start = time.time()
Expand Down Expand Up @@ -506,7 +507,9 @@ def inference_on_whole_volumes(sessionTf,
# Saving feature maps
save_fms_flag,
idxs_fms_to_save,
namesForSavingFms):
namesForSavingFms,
# Sampling
inp_shapes_per_path):
# save_fms_flag: should contain an entry per pathwayType, even if just []...
# ... If not [], the list should contain one entry per layer of the pathway, even if just [].
# ... The layer entries, if not [], they should have to integers, lower and upper FM to visualise.
Expand All @@ -524,7 +527,7 @@ def inference_on_whole_volumes(sessionTf,
NA_PATTERN = AccuracyMonitorForEpSegm.NA_PATTERN
n_classes = cnn3d.num_classes
n_subjects = len(paths_per_chan_per_subj)
dims_hres_segment = cnn3d.get_inp_shape_of_path(0, 'test') # pathway [0] is the high-res path.
dims_hres_segment = inp_shapes_per_path[0] # pathway [0] is the high-res path.

# One dice score for whole foreground (0) AND one for each actual class
# Dice1 - AllpredictedLes/AllLesions
Expand Down Expand Up @@ -567,7 +570,7 @@ def inference_on_whole_volumes(sessionTf,
# array_fms_to_save will be None if not saving them.
(prob_maps_vols,
array_fms_to_save) = predict_whole_volume_by_tiling(log, sessionTf, cnn3d,
channels, roi_mask, batchsize,
channels, roi_mask, inp_shapes_per_path, batchsize,
save_fms_flag, idxs_fms_to_save )

# ========================== Post-Processing =========================
Expand Down

0 comments on commit 0cea7b5

Please sign in to comment.