diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index bf8fe39ec6..18a808c961 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -667,22 +667,25 @@ def _get_spatial_shape(self, img): class WSIReader(ImageReader): """ - Read whole slide imaging and extract patches. + Read whole slide images and extract patches. Args: - reader_lib: backend library to load the images, available options: "OpenSlide" or "cuCIM". + backend: backend library to load the images, available options: "OpenSlide" or "cuCIM". + level: the whole slide image level at which the image is extracted. (default=0) + Note that this is overridden by the level argument in `get_data`. """ - def __init__(self, reader_lib: str = "OpenSlide"): + def __init__(self, backend: str = "OpenSlide", level: int = 0): super().__init__() - self.reader_lib = reader_lib.lower() - if self.reader_lib == "openslide": + self.backend = backend.lower() + if self.backend == "openslide": self.wsi_reader, *_ = optional_import("openslide", name="OpenSlide") - elif self.reader_lib == "cucim": + elif self.backend == "cucim": self.wsi_reader, *_ = optional_import("cucim", name="CuImage") else: - raise ValueError('`reader_lib` should be either "cuCIM" or "OpenSlide"') + raise ValueError('`backend` should be either "cuCIM" or "OpenSlide"') + self.level = level def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ @@ -696,19 +699,21 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ - Read image data from specified file or files. - Note that the returned object is CuImage or list of CuImage objects. + Read image data from given file or list of files. Args: data: file name or a list of file names to read. + Returns: + image object or list of image objects + """ img_: List = [] filenames: Sequence[str] = ensure_tuple(data) for name in filenames: img = self.wsi_reader(name) - if self.reader_lib == "openslide": + if self.backend == "openslide": img.shape = (img.dimensions[1], img.dimensions[0], 3) img_.append(img) @@ -719,7 +724,7 @@ def get_data( img, location: Tuple[int, int] = (0, 0), size: Optional[Tuple[int, int]] = None, - level: int = 0, + level: Optional[int] = None, dtype: DtypeLike = np.uint8, grid_shape: Tuple[int, int] = (1, 1), patch_size: Optional[Union[int, Tuple[int, int]]] = None, @@ -738,9 +743,11 @@ def get_data( grid_shape: (row, columns) tuple define a grid to extract patches on that patch_size: (height, width) the size of extracted patches at the given level """ + if level is None: + level = self.level - if self.reader_lib == "openslide" and size is None: - # the maximum size is set to WxH + if self.backend == "openslide" and size is None: + # the maximum size is set to WxH at the specified level size = (img.shape[0] // (2 ** level) - location[0], img.shape[1] // (2 ** level) - location[1]) region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) @@ -780,7 +787,7 @@ def _extract_region( def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): """Convert to RGB mode and numpy array""" - if self.reader_lib == "openslide": + if self.backend == "openslide": # convert to RGB raw_region = raw_region.convert("RGB") # convert to numpy