Skip to content

Commit

Permalink
constraining sample maps near edges, outside sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamnitsask committed Mar 26, 2020
1 parent 201fb75 commit a032685
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions deepmedic/dataManagement/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,19 @@ def get_n_samples_per_subj(n_samples, n_subjects):
return n_samples_per_subj


def constrain_sampling_maps_near_edges(sample_maps_per_cat, dims_sample):
# We wish to sample pixels around which the samples we will extract will be centered.
# But we need to be CAREFUL and get only pixels that are NOT closer to the image boundaries than the dimensions of the
# samples we wish to extract permit.
constrained_maps = []
mask_excl_edges = comp_valid_sampling_mask_excluding_edges(dims_sample, sample_maps_per_cat[0].shape)
for cat_i in range(len(sample_maps_per_cat)):
sampling_map = sample_maps_per_cat[cat_i]
sampling_map_excl_near_edges = np.multiply(sampling_map, mask_excl_edges, dtype=sampling_map.dtype)
constrained_maps.append(sampling_map_excl_near_edges)
return constrained_maps


def load_subj_and_sample(job_idx,
log,
train_val_or_test,
Expand Down Expand Up @@ -324,7 +337,9 @@ def load_subj_and_sample(job_idx,
gt_lbl_img,
roi_mask,
dims_of_scan)

sampling_maps_per_cat = constrain_sampling_maps_near_edges(sampling_maps_per_cat, dims_hres_segment)


# Get number of samples per sampling-category for the specific subject (class, foregr/backgr, etc)
(n_samples_per_cat, valid_cats) = sampling_type.distribute_n_samples_to_categs(n_samples_per_subj[job_idx],
sampling_maps_per_cat)
Expand All @@ -348,7 +363,6 @@ def load_subj_and_sample(job_idx,
idxs_sampl_centers = sample_idxs_of_segments(log,
job_id,
n_samples_for_cat,
dims_hres_segment,
sampling_map)
time_sample_idxs += time.time() - time_sample_idx0
str_samples_per_cat += "[" + cat_str + ": " + str(len(idxs_sampl_centers[0])) + "/" + str(n_samples_for_cat) + "] "
Expand Down Expand Up @@ -539,16 +553,16 @@ def sampling_cumsum(p_1darr, n_samples):
idxs_sampled = np.searchsorted(p_cumsum, random_ps)
return idxs_sampled

def sample_with_appropriate_algorithm(n_samples, sampling_map_excl_near_edges, sum_sampl_map):
def sample_with_appropriate_algorithm(n_samples, sampling_map, sum_sampl_map):

p_sampling_flat = sampling_map_excl_near_edges.ravel() # Unnormalized
p_sampling_flat = sampling_map.ravel() # Unnormalized
is_0 = p_sampling_flat == 0. # np.isclose(p_sampling_flat, 0.)
is_1 = p_sampling_flat == 1. # np.isclose(p_sampling_flat, 1.)
if np.all(np.logical_or(is_0, is_1)): # Whole sampling map is either 0 or 1. Do faster sampling only under valid idxs.
idxs = np.arange(p_sampling_flat.size, dtype='int32') # Unlikely large image size will need int64
valid_idxs = idxs[is_1>0]
idxs_sampled = np.random.choice(valid_idxs, size=n_samples, replace=True)
else: #if np.any( np.logical_and(sampling_map_excl_near_edges != 0., sampling_map_excl_near_edges != 1.)):
else: #if np.any( np.logical_and(sampling_map != 0., sampling_map != 1.)):
# Normalize, because sampling method np.random.choice requires it.
p_sampling_normed = np.multiply(p_sampling_flat, 1./sum_sampl_map, dtype='float64') # as before.
#idxs_sampled = sampling_cumsum(p_sampling_normed, n_samples)
Expand All @@ -558,51 +572,35 @@ def sample_with_appropriate_algorithm(n_samples, sampling_map_excl_near_edges, s
# where each of the array in the tuple has the same shape as the listOfIndices.
# They have the r/c/z coords that correspond to the index of the flattened version.
# So, idxs_sampled will be array of shape: 3(rcz) x n_samples.
idxs_sampled = np.asarray(np.unravel_index(idxs_sampled, sampling_map_excl_near_edges.shape))
idxs_sampled = np.asarray(np.unravel_index(idxs_sampled, sampling_map.shape))

return idxs_sampled

# made for 3d
def sample_idxs_of_segments(log,
job_id,
n_samples,
dims_of_segment,
sampling_map):
"""
sampling_map: np.array of shape (H,W,D), dtype="int16" or potentially float if weightmaps given by user.
Returns: [ idxs_of_sampled_centers, slice_idxs_of_sampled_segms ]
Returns: idxs_of_sampled_centers
Coordinates (xyz indices) of the "central" voxel of sampled segments (1 voxel to the left if dimension is even).
Also returns the indices of the image parts, left and right indices, INCLUSIVE BOTH SIDES.
> idxs_of_sampled_centers: array with shape: 3(xyz) x n_samples.
Example: [ xCoordsForCentralVoxelOfEachPart, yCoordsForCentralVoxelOfEachPart, zCoordsForCentralVoxelOfEachPart ]
>> x/y/z-CoordsForCentralVoxelOfEachPart: 1-dim array with n_samples, holding the x-indices of samples in image.
> slice_idxs_of_sampled_segms: 3(xyz) x NumberOfImagePartSamples x 2.
The last dimension has [0] for the lower boundary of the slice, and [1] for the higher boundary. INCLUSIVE BOTH SIDES.
Example: [ x-sliceCoordsOfImagePart, y-sliceCoordsOfImagePart, z-sliceCoordsOfImagePart ]
"""
# Check if the weight map is fully-zeros. In this case, return no element.
# Note: Currently, the caller function is checking this case already and does not let this being called.
# Which is still fine.
if np.isclose(np.sum(sampling_map), 0.):
log.print3(job_id + " WARN: Sampling map for category is just zeros! " +\
" No samples for category from subject!")
return [[[], [], []], [[], [], []]]

# Now out of these, I need to randomly select one, which will be an ImagePart's central voxel.
# But I need to be CAREFUL and get one that IS NOT closer to the image boundaries than the dimensions of the
# ImagePart permit.
mask_excl_near_edges = comp_valid_sampling_mask_excluding_edges(dims_of_segment, sampling_map.shape)
sampling_map_excl_near_edges = np.multiply(sampling_map, mask_excl_near_edges, dtype=sampling_map.dtype)
# normalize the probabilities to sum to 1, cause the function needs it as so.
sum_sampl_map = np.sum(sampling_map_excl_near_edges)
sum_sampl_map = np.sum(sampling_map)
if np.isclose(sum_sampl_map, 0.) : # is zero
log.print3(job_id + " WARN: AFTER EXCLUDING NEAR EDGES, sampling map for category is just zeros! " +\
log.print3(job_id + " WARN: Sampling map for category (after excluding near edges) is just zeros! " +\
" No samples for category from subject!")
return [ [[],[],[]], [[],[],[]] ]

# Sample indexes of pixels around which we should extract sample segments.
idxs_of_sampled_centers = sample_with_appropriate_algorithm(n_samples, sampling_map_excl_near_edges, sum_sampl_map)
idxs_of_sampled_centers = sample_with_appropriate_algorithm(n_samples, sampling_map, sum_sampl_map)

return idxs_of_sampled_centers

Expand Down

0 comments on commit a032685

Please sign in to comment.