Skip to content
12 changes: 5 additions & 7 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def weighted_patch_samples(


def correct_crop_centers(
centers: Union[NdarrayOrTensor, List[NdarrayOrTensor]],
centers: List[int],
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
allow_smaller: bool = False,
Expand Down Expand Up @@ -446,8 +446,7 @@ def correct_crop_centers(
valid_end[i] += 1
valid_centers = []
for c, v_s, v_e in zip(centers, valid_start, valid_end):
_c = int(convert_data_type(c, np.ndarray)[0]) # type: ignore
center_i = min(max(_c, v_s), v_e - 1)
center_i = min(max(c, v_s), v_e - 1)
valid_centers.append(int(center_i))
return valid_centers

Expand Down Expand Up @@ -503,7 +502,7 @@ def generate_pos_neg_label_crop_centers(
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
random_int = rand_state.randint(len(indices_to_use))
idx = indices_to_use[random_int]
center = unravel_index(idx, label_spatial_shape)
center = unravel_index(idx, label_spatial_shape).tolist()
# shift center to range of valid centers
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))

Expand Down Expand Up @@ -558,10 +557,9 @@ def generate_label_classes_crop_centers(
# randomly select the indices of a class based on the ratios
indices_to_use = indices[i]
random_int = rand_state.randint(len(indices_to_use))
center = unravel_index(indices_to_use[random_int], label_spatial_shape)
center = unravel_index(indices_to_use[random_int], label_spatial_shape).tolist()
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape, allow_smaller))
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))

return centers

Expand Down