diff --git a/docs/source/data.rst b/docs/source/data.rst index 0910001783..c968d72945 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -152,11 +152,24 @@ PILReader .. autoclass:: PILReader :members: +Whole slide image reader +------------------------ + +BaseWSIReader +~~~~~~~~~~~~~ +.. autoclass:: BaseWSIReader + :members: + WSIReader ~~~~~~~~~ .. autoclass:: WSIReader :members: +CuCIMWSIReader +~~~~~~~~~~~~~~ +.. autoclass:: CuCIMWSIReader + :members: + Image writer ------------ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 19ca29eafa..ca4be87ef6 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,7 +34,7 @@ from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, @@ -87,3 +87,4 @@ worker_init_fn, zoom_affine, ) +from .wsi_reader import BaseWSIReader, CuCIMWSIReader, WSIReader diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py new file mode 100644 index 0000000000..4899fb8830 --- /dev/null +++ b/monai/data/wsi_reader.py @@ -0,0 +1,420 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from monai.config import DtypeLike, PathLike +from monai.data.image_reader import ImageReader, _stack_images +from monai.data.utils import is_supported_format +from monai.transforms.utility.array import EnsureChannelFirst +from monai.utils import ensure_tuple, optional_import, require_pkg + +CuImage, _ = optional_import("cucim", name="CuImage") + +__all__ = ["BaseWSIReader", "WSIReader", "CuCIMWSIReader"] + + +class BaseWSIReader(ImageReader): + """ + An abstract class that defines APIs to load patches from whole slide image files. + + Typical usage of a concrete implementation of this class is: + + .. code-block:: python + + image_reader = MyWSIReader() + wsi = image_reader.read(, **kwargs) + img_data, meta_data = image_reader.get_data(wsi) + + - The `read` call converts an image filename into whole slide image object, + - The `get_data` call fetches the image data, as well as meta data. + + The following methods needs to be implemented for any concrete implementation of this class: + + - `read` reads a whole slide image object from a given file + - `get_size` returns the size of the whole slide image of a given wsi object at a given level. + - `get_level_count` returns the number of levels in the whole slide image + - `get_patch` extracts and returns a patch image form the whole slide image + - `get_metadata` extracts and returns metadata for a whole slide image and a specific patch. + + + """ + + supported_suffixes: List[str] = [] + + def __init__(self, level: int, **kwargs): + super().__init__() + self.level = level + self.kwargs = kwargs + self.metadata: Dict[Any, Any] = {} + + @abstractmethod + def get_size(self, wsi, level: int) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def get_data( + self, + wsi, + location: Tuple[int, int] = (0, 0), + size: Optional[Tuple[int, int]] = None, + level: Optional[int] = None, + dtype: DtypeLike = np.uint8, + mode: str = "RGB", + ) -> Tuple[np.ndarray, Dict]: + """ + Verifies inputs, extracts patches from WSI image and generates metadata, and return them. + + Args: + wsi: a whole slide image object loaded from a file or a list of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + Returns: + a tuples, where the first element is an image patch [CxHxW] or stack of patches, + and second element is a dictionary of metadata + """ + patch_list: List = [] + metadata = {} + # CuImage object is iterable, so ensure_tuple won't work on single object + if not isinstance(wsi, List): + wsi = [wsi] + for each_wsi in ensure_tuple(wsi): + # Verify magnification level + if level is None: + level = self.level + max_level = self.get_level_count(each_wsi) - 1 + if level > max_level: + raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") + + # Verify location + if location is None: + location = (0, 0) + wsi_size = self.get_size(each_wsi, level) + if location[0] > wsi_size[0] or location[1] > wsi_size[1]: + raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}") + + # Verify size + if size is None: + if location != (0, 0): + raise ValueError("Patch size should be defined to exctract patches.") + size = self.get_size(each_wsi, level) + else: + if size[0] <= 0 or size[1] <= 0: + raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}") + + # Extract a patch or the entire image + patch = self.get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + # check if the image has three dimensions (2D + color) + if patch.ndim != 3: + raise ValueError( + f"The image dimension should be 3 but has {patch.ndim}. " + "`WSIReader` is designed to work only with 2D images with color channel." + ) + + # Create a list of patches + patch_list.append(patch) + + # Set patch-related metadata + each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) + metadata.update(each_meta) + + return _stack_images(patch_list, metadata), metadata + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + """ + Verify whether the specified file or files format is supported by WSI reader. + + The list of supported suffixes are read from `self.supported_suffixes`. + + Args: + filename: filename or a list of filenames to read. + + """ + return is_supported_format(filename, self.supported_suffixes) + + +class WSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches using different backend libraries + + Args: + backend: the name of backend whole slide image reader library, the default is cuCIM. + level: the level at which patches are extracted. + kwargs: additional arguments to be passed to the backend library + + """ + + def __init__(self, backend="cucim", level: int = 0, **kwargs): + super().__init__(level, **kwargs) + self.backend = backend.lower() + # Any new backend can be added below + if self.backend == "cucim": + self.reader = CuCIMWSIReader(level=level, **kwargs) + else: + raise ValueError("The supported backends are: cucim") + self.supported_suffixes = self.reader.supported_suffixes + + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return self.reader.get_level_count(wsi) + + def get_size(self, wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return self.reader.get_size(wsi, level) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + return self.reader.get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for the reader module (overrides `self.kwargs` for existing keys). + + Returns: + whole slide image object or list of such objects + + """ + return self.reader.read(data=data, **kwargs) + + +@require_pkg(pkg_name="cucim") +class CuCIMWSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches without loading the whole slide image into the memory. + + Args: + level: the whole slide image level at which the image is extracted. (default=0) + This is overridden if the level argument is provided in `get_data`. + kwargs: additional args for `cucim.CuImage` module: + https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + """ + + supported_suffixes = ["tif", "tiff", "svs"] + + def __init__(self, level: int = 0, **kwargs): + super().__init__(level, **kwargs) + + @staticmethod + def get_level_count(wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return wsi.resolutions["level_count"] # type: ignore + + @staticmethod + def get_size(wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + metadata: Dict = { + "backend": "cucim", + "spatial_shape": np.asarray(patch.shape[1:]), + "original_channel_dim": 0, + "location": location, + "size": size, + "level": level, + } + return metadata + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args that overrides `self.kwargs` for existing keys. + For more details look at https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + Returns: + whole slide image object or list of such objects + + """ + wsi_list: List = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for filename in filenames: + wsi = CuImage(filename, **kwargs_) + wsi_list.append(wsi) + + return wsi_list if len(filenames) > 1 else wsi_list[0] + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + # Extract a patch or the entire image + # (reverse the order of location and size to become WxH for cuCIM) + patch: np.ndarray = wsi.read_region(location=location[::-1], size=size[::-1], level=level) + + # Convert to numpy + patch = np.asarray(patch, dtype=dtype) + + # Make it channel first + patch = EnsureChannelFirst()(patch, {"original_channel_dim": -1}) # type: ignore + + # Check if the color channel is 3 (RGB) or 4 (RGBA) + if mode == "RGBA" and patch.shape[0] != 4: + raise ValueError( + f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." + ) + + if mode in "RGB": + if patch.shape[0] not in [3, 4]: + raise ValueError( + f"The image is expected to have three or four color channels in '{mode}' mode but has {patch.shape[0]}. " + ) + patch = patch[:3] + + return patch diff --git a/tests/min_tests.py b/tests/min_tests.py index 66b6c9ff3d..25acbccb41 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -157,6 +157,7 @@ def run_testsuit(): "test_vitautoenc", "test_write_metrics_reports", "test_wsireader", + "test_wsireader_new", "test_zoom", "test_zoom_affine", "test_zoomd", diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py new file mode 100644 index 0000000000..7b288f6040 --- /dev/null +++ b/tests/test_wsireader_new.py @@ -0,0 +1,218 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.data import DataLoader, Dataset +from monai.data.wsi_reader import WSIReader +from monai.transforms import Compose, LoadImaged, ToTensord +from monai.utils import first, optional_import +from monai.utils.enums import PostFix +from tests.utils import download_url_or_skip_test, testing_data_config + +cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(cucim, "CuImage") +openslide, has_osl = optional_import("openslide") +imwrite, has_tiff = optional_import("tifffile", name="imwrite") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec + +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) + +HEIGHT = 32914 +WIDTH = 46000 + +TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)] + +TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)] + +TEST_CASE_1 = [ + FILE_PATH, + {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, + np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), +] + +TEST_CASE_2 = [ + FILE_PATH, + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), +] + +TEST_CASE_3 = [ + [FILE_PATH, FILE_PATH], + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.concatenate( + [ + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + ], + axis=0, + ), +] + +TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW + +TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW + +TEST_CASE_ERROR_GRAY = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color + + +def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): + """ + Save numpy array into a TIFF RGB/RGBA file + + Args: + array: numpy ndarray with the shape of CxHxW and C==3 representing a RGB image + filename: the filename to be used for the tiff file. '_RGB.tiff' or '_RGBA.tiff' will be appended to this filename. + mode: RGB or RGBA + """ + if mode == "RGBA": + array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) + + img_rgb = array.transpose(1, 2, 0) + imwrite(filename, img_rgb, shape=img_rgb.shape, tile=(16, 16)) + + return filename + + +def save_gray_tiff(array: np.ndarray, filename: str): + """ + Save numpy array into a TIFF file + + Args: + array: numpy ndarray with any shape + filename: the filename to be used for the tiff file. + """ + img_gray = array + imwrite(filename, img_gray, shape=img_gray.shape, photometric="rgb") + + return filename + + +@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") +def setUpModule(): # noqa: N802 + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + +class WSIReaderTests: + class Tests(unittest.TestCase): + backend = None + + @parameterized.expand([TEST_CASE_0]) + def test_read_whole_image(self, file_path, level, expected_shape): + reader = WSIReader(self.backend, level=level) + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj)[0] + self.assertTupleEqual(img.shape, expected_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_read_region(self, file_path, patch_info, expected_img): + kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} + reader = WSIReader(self.backend, **kwargs) + with reader.read(file_path, **kwargs) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_3]) + def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} + reader = WSIReader(self.backend, **kwargs) + img_obj = reader.read(file_path, **kwargs) + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) + @skipUnless(has_tiff, "Requires tifffile.") + def test_read_rgba(self, img_expected): + # skip for OpenSlide since not working with images without tiles + if self.backend == "openslide": + return + image = {} + reader = WSIReader(self.backend) + for mode in ["RGB", "RGBA"]: + file_path = save_rgba_tiff( + img_expected, + os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"), + mode=mode, + ) + with reader.read(file_path) as img_obj: + image[mode], _ = reader.get_data(img_obj) + + self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) + self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + + @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @skipUnless(has_tiff, "Requires tifffile.") + def test_read_malformats(self, img_expected): + reader = WSIReader(self.backend) + file_path = save_gray_tiff( + img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + ) + with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): + with reader.read(file_path) as img_obj: + reader.get_data(img_obj) + + @parameterized.expand([TEST_CASE_TRANSFORM_0]) + def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape): + train_transform = Compose( + [ + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + ToTensord(keys=["image"]), + ] + ) + dataset = Dataset([{"image": file_path}], transform=train_transform) + data_loader = DataLoader(dataset) + data: dict = first(data_loader) + for s in data[PostFix.meta("image")]["spatial_shape"]: + torch.testing.assert_allclose(s, expected_spatial_shape) + self.assertTupleEqual(data["image"].shape, expected_shape) + + +@skipUnless(has_cucim, "Requires cucim") +class TestCuCIM(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "cucim" + + +if __name__ == "__main__": + unittest.main()