Skip to content

Commit

Permalink
Fix offset (#8)
Browse files Browse the repository at this point in the history
* Changed PatchSampler to take as base the patche size instead of the input image's chunk sizes

* Reverted change in the computation when masks elements are relative smaller than patch sizes

* Fixed spatial chunk size computation when patch sizes are grater than the chunk size

* Fixed missing patches from chunks smaller than the input image chunk size

* Padding and stride added to PatchSampler and ImageBase classes to allow overlapping patches extraction

* Added tests for stride and pad parameters of PatchSampler class

* Fixed patch slices generation in PatchSampler to always retrieve patches of the defined shape

* Standardized patch sampling method to handle smaller and bigger mask scales than image scale

* Added example notebook to documentation

* Fixed incorrect sampling of patches on masked regions
  • Loading branch information
fercer committed May 10, 2024
1 parent 4afb0bd commit e2bed14
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/advanced_example_pytorch_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pad = dict(Y=16, X=16)
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True)
```

Create a dataset from the list of filenames. All those files should be stored within their respective group "0".
Create a dataset from the list of filenames. All those files should be stored within their respective group "4", which in this case it correspond to a downsampled version of the full resolution image by a factor of 16.

Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly

Expand Down
58 changes: 38 additions & 20 deletions zarrdataset/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,32 @@ def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray

return corners

def _compute_reference_indices(self, reference_coordinates: np.ndarray
def _compute_reference_indices(self, reference_coordinates: np.ndarray,
reference_axes_sizes: np.ndarray
) -> Tuple[List[np.ndarray],
List[Tuple[int]]]:
reference_per_axis = list(map(
lambda coords: np.append(np.full((1, ), fill_value=-float("inf")),
np.unique(coords)),
reference_coordinates.T
lambda coords, axis_size: np.concatenate((
np.full((1, ), fill_value=-float("inf")),
np.unique(coords),
np.full((1, ), fill_value=np.max(coords) + axis_size))),
reference_coordinates.T,
reference_axes_sizes
))

reference_idx = map(
lambda coord_axis, ref_axis:
np.argmax(ref_axis[None, ...]
* (coord_axis[..., None] >= ref_axis[None, ...]),
axis=-1),
np.max(np.arange(ref_axis.size)
* (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]),
axis=1),
reference_coordinates.T,
reference_per_axis
)
reference_idx = np.stack(tuple(reference_idx), axis=-1)
reference_idx = reference_idx.reshape(reference_coordinates.T.shape)

reference_idx = [
tuple(tls_coord - 1)
tuple(tls_coord)
for tls_coord in reference_idx.reshape(-1, len(reference_per_axis))
]

Expand All @@ -172,13 +178,14 @@ def _compute_overlap(self, corners_coordinates: np.ndarray,
np.ndarray]:
tls_idx = map(
lambda coord_axis, ref_axis:
np.argmax(ref_axis[None, None, ...]
* (coord_axis[..., None] >= ref_axis[None, None, ...]),
axis=-1),
np.max(np.arange(ref_axis.size)
* (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]),
axis=1),
np.moveaxis(corners_coordinates, -1, 0),
reference_per_axis
)
tls_idx = np.stack(tuple(tls_idx), axis=-1)
tls_idx = tls_idx.reshape(corners_coordinates.shape)

tls_coordinates = map(
lambda tls_coord, ref_axis: ref_axis[tls_coord],
Expand All @@ -192,11 +199,12 @@ def _compute_overlap(self, corners_coordinates: np.ndarray,
dist2cut = np.fabs(corners_coordinates - corners_cut[None])
coverage = np.prod(dist2cut, axis=-1)

return coverage, tls_idx - 1
return coverage, tls_idx

def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
patch_size: dict,
image_size: dict,
min_area: float,
allow_incomplete_patches: bool = False):
mask_scale = np.array([mask.scale.get(ax, 1)
for ax in self.spatial_axes],
Expand All @@ -223,7 +231,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
]

if min(image_blocks) == 0:
return []
return []

image_scale = np.array([patch_size.get(ax, 1)
for ax in self.spatial_axes],
Expand Down Expand Up @@ -254,6 +262,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
for ax in self.spatial_axes],
dtype=np.float32
)

chunk_br_coordinates = np.array(
[chunk_tlbr[ax].stop
if chunk_tlbr[ax].stop is not None
Expand All @@ -271,32 +280,39 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
)
mask_coordinates = mask_coordinates[in_chunk]

# Translate the mask coordinates to the origin for comparison with
# image coordinates.
mask_coordinates -= chunk_tl_coordinates

if all(map(operator.ge, image_scale, mask_scale)):
mask_corners = self._compute_corners(mask_coordinates, mask_scale)

(reference_per_axis,
reference_idx) =\
self._compute_reference_indices(image_coordinates)
self._compute_reference_indices(image_coordinates, image_scale)

(coverage,
corners_idx) = self._compute_overlap(mask_corners,
reference_per_axis)

covered_indices = [
reference_idx.index(tuple(idx))
if tuple(idx) in reference_idx else len(reference_idx)
for idx in corners_idx.reshape(-1, len(self.spatial_axes))
]

patches_coverage = np.bincount(covered_indices,
weights=coverage.flatten(),
minlength=np.prod(image_blocks))
minlength=len(reference_idx) + 1)
patches_coverage = patches_coverage[:-1]

else:
image_corners = self._compute_corners(image_coordinates,
image_scale)

(reference_per_axis,
reference_idx) = self._compute_reference_indices(mask_coordinates)
reference_idx) = self._compute_reference_indices(mask_coordinates,
mask_scale)

(coverage,
corners_idx) = self._compute_overlap(image_corners,
Expand All @@ -309,10 +325,6 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,

patches_coverage = np.sum(covered_indices * coverage, axis=0)

min_area = self._min_area
if min_area < 1:
min_area *= np.prod(list(patch_size.values()))

minimum_covered_tls = image_coordinates[patches_coverage > min_area]
minimum_covered_tls = minimum_covered_tls.astype(np.int64)

Expand Down Expand Up @@ -396,6 +408,7 @@ def compute_chunks(self,
mask,
self._max_chunk_size,
image_size,
min_area=1,
allow_incomplete_patches=True
)

Expand Down Expand Up @@ -428,11 +441,16 @@ def compute_patches(self, image_collection: ImageCollection,
for ax in self.spatial_axes
}

min_area = self._min_area
if min_area < 1:
min_area *= np.prod(list(patch_size.values()))

valid_mask_toplefts = self._compute_valid_toplefts(
chunk_tlbr,
mask,
stride,
image_size=image_size,
min_area=min_area,
allow_incomplete_patches=self._allow_incomplete_patches
)

Expand Down

0 comments on commit e2bed14

Please sign in to comment.