Skip to content

Commit

Permalink
cleaned compute-mask-near-edges.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamnitsask committed Mar 26, 2020
1 parent 9cf7bbd commit 201fb75
Showing 1 changed file with 48 additions and 59 deletions.
107 changes: 48 additions & 59 deletions deepmedic/dataManagement/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,11 @@ def load_subj_and_sample(job_idx,
continue # This should not be needed, the next func should also handle it. But whatever.

time_sample_idx0 = time.time()
(idxs_sampl_centers,
slice_idxs_sampl_segms) = sample_idxs_of_segments(log,
job_id,
n_samples_for_cat,
dims_hres_segment,
dims_of_scan,
sampling_map)
idxs_sampl_centers = sample_idxs_of_segments(log,
job_id,
n_samples_for_cat,
dims_hres_segment,
sampling_map)
time_sample_idxs += time.time() - time_sample_idx0
str_samples_per_cat += "[" + cat_str + ": " + str(len(idxs_sampl_centers[0])) + "/" + str(n_samples_for_cat) + "] "

Expand Down Expand Up @@ -492,6 +490,46 @@ def preproc_imgs_of_subj(log, job_id, channels, gt_lbl_img, roi_mask, wmaps_to_s
return channels, gt_lbl_img, roi_mask, wmaps_to_sample_per_cat, pad_left_right_per_axis


def comp_valid_sampling_mask_excluding_edges(dims_of_segment, shape):
# I look for lesions that are not closer to the image boundaries than the ImagePart dimensions allow.
# KernelDim is always odd. BUT ImagePart dimensions can be odd or even.
# If odd, ok, floor(dim/2) from central.
# If even, dim/2-1 voxels towards the begining of the axis and dim/2 towards the end.
# Ie, "central" imagePart voxel is 1 closer to begining.
# BTW imagePartDim takes kernel into account (ie if I want 9^3 voxels classified per imagePart with kernel 5x5,
# I want 13 dim ImagePart)

# number of voxels to exclude from edges of the image, left and right in each axis, when sampling...
# ...the center of a segment. So that the segment will be fully contained in the image. (half segm left & right)
# dim1: 1 row per r,c,z. Dim2: left/right width not to sample from (=half segment).
n_vox_excl_left_right = np.zeros((len(dims_of_segment), 2), dtype='int16')

# The below starts all zero. Will be Multiplied by other true-false arrays expressing if the relevant
# voxels are within boundaries.
# In the end, the final vector will be true only for the indices of lesions that are within all boundaries.
mask_excl_near_edges = np.zeros(shape, dtype='int8')

# This loop leads to mask_excl_near_edges to be true for the indices ...
# ...that allow getting an imagePart CENTERED on them, and be safely within image boundaries. Note that ...
# ... if the imagePart is of even dimension, the "central" voxel is one voxel to the left.
for rcz_i in range(len(dims_of_segment)):
if dims_of_segment[rcz_i] % 2 == 0: # even
dims_div_2 = dims_of_segment[rcz_i] // 2
# central of ImagePart is 1 vox closer to begining of axes.
n_vox_excl_left_right[rcz_i] = [dims_div_2 - 1, dims_div_2]
else: # odd
# If odd, middle voxel is the "central". Eg 5/2 = 2, with 3rd voxel being the central.
dims_div_2_floor = math.floor(dims_of_segment[rcz_i] // 2)
n_vox_excl_left_right[rcz_i] = [dims_div_2_floor, dims_div_2_floor]
# used to be [n_vox_excl_left_right[0][0]: -n_vox_excl_left_right[0][1]],
# but in 2D case n_vox_excl_left_right might be ==0, causes problem and you get a null slice.
mask_excl_near_edges[
n_vox_excl_left_right[0][0]: shape[0] - n_vox_excl_left_right[0][1],
n_vox_excl_left_right[1][0]: shape[1] - n_vox_excl_left_right[1][1],
n_vox_excl_left_right[2][0]: shape[2] - n_vox_excl_left_right[2][1]] = 1

return mask_excl_near_edges

def sampling_cumsum(p_1darr, n_samples):
p_cumsum = p_1darr.cumsum(dtype='float64') # This is dangerous, final elements go below or beyond 1.
p_cumsum = np.clip(p_cumsum, a_min=None, a_max=1., out=p_cumsum)
Expand Down Expand Up @@ -529,7 +567,6 @@ def sample_idxs_of_segments(log,
job_id,
n_samples,
dims_of_segment,
dims_of_scan,
sampling_map):
"""
sampling_map: np.array of shape (H,W,D), dtype="int16" or potentially float if weightmaps given by user.
Expand All @@ -551,48 +588,11 @@ def sample_idxs_of_segments(log,
log.print3(job_id + " WARN: Sampling map for category is just zeros! " +\
" No samples for category from subject!")
return [[[], [], []], [[], [], []]]

# Now out of these, I need to randomly select one, which will be an ImagePart's central voxel.
# But I need to be CAREFUL and get one that IS NOT closer to the image boundaries than the dimensions of the
# ImagePart permit.

# I look for lesions that are not closer to the image boundaries than the ImagePart dimensions allow.
# KernelDim is always odd. BUT ImagePart dimensions can be odd or even.
# If odd, ok, floor(dim/2) from central.
# If even, dim/2-1 voxels towards the begining of the axis and dim/2 towards the end.
# Ie, "central" imagePart voxel is 1 closer to begining.
# BTW imagePartDim takes kernel into account (ie if I want 9^3 voxels classified per imagePart with kernel 5x5,
# I want 13 dim ImagePart)

# number of voxels to exclude from edges of the image, left and right in each axis, when sampling...
# ...the center of a segment. So that the segment will be fully contained in the image. (half segm left & right)
# dim1: 1 row per r,c,z. Dim2: left/right width not to sample from (=half segment).
n_vox_excl_left_right = np.zeros((len(dims_of_segment), 2), dtype='int16')

# The below starts all zero. Will be Multiplied by other true-false arrays expressing if the relevant
# voxels are within boundaries.
# In the end, the final vector will be true only for the indices of lesions that are within all boundaries.
mask_excl_near_edges = np.zeros(sampling_map.shape, dtype='int8')

# This loop leads to mask_excl_near_edges to be true for the indices ...
# ...that allow getting an imagePart CENTERED on them, and be safely within image boundaries. Note that ...
# ... if the imagePart is of even dimension, the "central" voxel is one voxel to the left.
for rcz_i in range(len(dims_of_segment)):
if dims_of_segment[rcz_i] % 2 == 0: # even
dims_div_2 = dims_of_segment[rcz_i] // 2
# central of ImagePart is 1 vox closer to begining of axes.
n_vox_excl_left_right[rcz_i] = [dims_div_2 - 1, dims_div_2]
else: # odd
# If odd, middle voxel is the "central". Eg 5/2 = 2, with 3rd voxel being the central.
dims_div_2_floor = math.floor(dims_of_segment[rcz_i] // 2)
n_vox_excl_left_right[rcz_i] = [dims_div_2_floor, dims_div_2_floor]
# used to be [n_vox_excl_left_right[0][0]: -n_vox_excl_left_right[0][1]],
# but in 2D case n_vox_excl_left_right might be ==0, causes problem and you get a null slice.
mask_excl_near_edges[
n_vox_excl_left_right[0][0]: dims_of_scan[0] - n_vox_excl_left_right[0][1],
n_vox_excl_left_right[1][0]: dims_of_scan[1] - n_vox_excl_left_right[1][1],
n_vox_excl_left_right[2][0]: dims_of_scan[2] - n_vox_excl_left_right[2][1]] = 1

mask_excl_near_edges = comp_valid_sampling_mask_excluding_edges(dims_of_segment, sampling_map.shape)
sampling_map_excl_near_edges = np.multiply(sampling_map, mask_excl_near_edges, dtype=sampling_map.dtype)
# normalize the probabilities to sum to 1, cause the function needs it as so.
sum_sampl_map = np.sum(sampling_map_excl_near_edges)
Expand All @@ -604,18 +604,7 @@ def sample_idxs_of_segments(log,
# Sample indexes of pixels around which we should extract sample segments.
idxs_of_sampled_centers = sample_with_appropriate_algorithm(n_samples, sampling_map_excl_near_edges, sum_sampl_map)

# Array with shape: 3(rcz) x NumberOfImagePartSamples x 2.
# Last dimension has [0] for lowest boundary of slice, and [1] for highest boundary. INCLUSIVE BOTH SIDES.
slice_idxs_of_sampled_segms = np.zeros(list(idxs_of_sampled_centers.shape) + [2], dtype='int32')
# below, np.newaxis broadcasts. To broadcast the -+.
slice_idxs_of_sampled_segms[:, :, 0] = idxs_of_sampled_centers - n_vox_excl_left_right[:, np.newaxis, 0]
slice_idxs_of_sampled_segms[:, :, 1] = idxs_of_sampled_centers + n_vox_excl_left_right[:, np.newaxis, 1]

# idxs_of_sampled_centers: Array of dimensions 3(rcz) x NumberOfImagePartSamples.
# slice_idxs_of_sampled_segms: Array of dimensions 3(rcz) x NumberOfImagePartSamples x 2. ...
# ... The last dim has [0] for the lower boundary of the slice, and [1] for the higher boundary.
# ... The slice coordinates returned are INCLUSIVE BOTH sides.
return (idxs_of_sampled_centers, slice_idxs_of_sampled_segms)
return idxs_of_sampled_centers


def get_subsampl_segment(channels, segment_hr_slice_coords, subs_factor, segment_lr_dims):
Expand Down

0 comments on commit 201fb75

Please sign in to comment.