Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,22 +667,25 @@ def _get_spatial_shape(self, img):

class WSIReader(ImageReader):
"""
Read whole slide imaging and extract patches.
Read whole slide images and extract patches.

Args:
reader_lib: backend library to load the images, available options: "OpenSlide" or "cuCIM".
backend: backend library to load the images, available options: "OpenSlide" or "cuCIM".
level: the whole slide image level at which the image is extracted. (default=0)
Note that this is overridden by the level argument in `get_data`.

"""

def __init__(self, reader_lib: str = "OpenSlide"):
def __init__(self, backend: str = "OpenSlide", level: int = 0):
super().__init__()
self.reader_lib = reader_lib.lower()
if self.reader_lib == "openslide":
self.backend = backend.lower()
if self.backend == "openslide":
self.wsi_reader, *_ = optional_import("openslide", name="OpenSlide")
elif self.reader_lib == "cucim":
elif self.backend == "cucim":
self.wsi_reader, *_ = optional_import("cucim", name="CuImage")
else:
raise ValueError('`reader_lib` should be either "cuCIM" or "OpenSlide"')
raise ValueError('`backend` should be either "cuCIM" or "OpenSlide"')
self.level = level

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
"""
Expand All @@ -696,19 +699,21 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:

def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
"""
Read image data from specified file or files.
Note that the returned object is CuImage or list of CuImage objects.
Read image data from given file or list of files.

Args:
data: file name or a list of file names to read.

Returns:
image object or list of image objects

"""
img_: List = []

filenames: Sequence[str] = ensure_tuple(data)
for name in filenames:
img = self.wsi_reader(name)
if self.reader_lib == "openslide":
if self.backend == "openslide":
img.shape = (img.dimensions[1], img.dimensions[0], 3)
img_.append(img)

Expand All @@ -719,7 +724,7 @@ def get_data(
img,
location: Tuple[int, int] = (0, 0),
size: Optional[Tuple[int, int]] = None,
level: int = 0,
level: Optional[int] = None,
dtype: DtypeLike = np.uint8,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
Expand All @@ -738,9 +743,11 @@ def get_data(
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (height, width) the size of extracted patches at the given level
"""
if level is None:
level = self.level

if self.reader_lib == "openslide" and size is None:
# the maximum size is set to WxH
if self.backend == "openslide" and size is None:
# the maximum size is set to WxH at the specified level
size = (img.shape[0] // (2 ** level) - location[0], img.shape[1] // (2 ** level) - location[1])

region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)
Expand Down Expand Up @@ -780,7 +787,7 @@ def _extract_region(

def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8):
"""Convert to RGB mode and numpy array"""
if self.reader_lib == "openslide":
if self.backend == "openslide":
# convert to RGB
raw_region = raw_region.convert("RGB")
# convert to numpy
Expand Down