From c686711110f3e730e06050ea1ec82e462e7edfc7 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Wed, 5 Jun 2024 23:59:22 +0200 Subject: [PATCH 01/37] Add new transforms --- direct/data/datasets_config.py | 42 ++-- direct/data/mri_transforms.py | 443 ++++++++++++++++++++++++++++++++- direct/data/transforms.py | 87 +++++++ 3 files changed, 549 insertions(+), 23 deletions(-) diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index 21335408..77356d7a 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -2,8 +2,10 @@ """Classes holding the typed configurations for the datasets.""" +from __future__ import annotations + from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional from omegaconf import MISSING @@ -14,6 +16,7 @@ MaskSplitterType, RandomFlipType, ReconstructionType, + RescaleMode, SensitivityMapType, TransformsType, ) @@ -37,9 +40,17 @@ class SensitivityMapEstimationTransformConfig(BaseConfig): sensitivity_maps_gaussian: Optional[float] = 0.7 +@dataclass +class AugmentationTransformConfig(BaseConfig): + rescale: Optional[tuple[int, ...]] = None + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST + rescale_2d_if_3d: Optional[bool] = False + pad: Optional[tuple[int, ...]] = None + + @dataclass class RandomAugmentationTransformsConfig(BaseConfig): - random_rotation_degrees: Tuple[int, ...] = (-90, 90) + random_rotation_degrees: tuple[int, ...] = (-90, 90) random_rotation_probability: float = 0.0 random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM random_flip_probability: float = 0.0 @@ -62,8 +73,10 @@ class TransformsConfig(BaseConfig): Configuration for the masking. cropping : CropTransformConfig Configuration for the cropping. + augmentation : AugmentationTransformConfig + Configuration for the augmentation. Currently only rescale and pad are supported. random_augmentations : RandomAugmentationTransformsConfig - Configuration for the random augmentations. + Configuration for the random augmentations. Currently only random rotation, flip and reverse are supported. padding_eps : float Padding epsilon. Default is 0.001. estimate_body_coil_image : bool @@ -89,10 +102,10 @@ class TransformsConfig(BaseConfig): To use SSL transforms, set transforms_type to `SSL_SSDU`. This will require additional parameters to be set: mask_split_ratio, mask_split_acs_region, mask_split_keep_acs, mask_split_type, mask_split_gaussian_std. Default is `TransformsType.SUPERVISED`. - mask_split_ratio : Tuple[float, ...] + mask_split_ratio : tuple[float, ...] Ratio of the mask to split into input and target mask. Ignored if transforms_type is not `SSL_SSDU`. Default is (0.4,). - mask_split_acs_region : Tuple[int, int] + mask_split_acs_region : tuple[int, int] Region of the ACS k-space to keep in the input mask. Ignored if transforms_type is not `SSL_SSDU`. Default is (0, 0). mask_split_keep_acs : bool, optional @@ -111,6 +124,7 @@ class TransformsConfig(BaseConfig): masking: Optional[MaskingConfig] = MaskingConfig() cropping: CropTransformConfig = CropTransformConfig() + augmentation: AugmentationTransformConfig = AugmentationTransformConfig() random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig() padding_eps: float = 0.001 estimate_body_coil_image: bool = False @@ -123,8 +137,8 @@ class TransformsConfig(BaseConfig): use_seed: bool = True transforms_type: TransformsType = TransformsType.SUPERVISED # Next attributes are for the mask splitter in case of transforms_type is set to SSL_SSDU - mask_split_ratio: Tuple[float, ...] = (0.4,) - mask_split_acs_region: Tuple[int, int] = (0, 0) + mask_split_ratio: tuple[float, ...] = (0.4,) + mask_split_acs_region: tuple[int, int] = (0, 0) mask_split_keep_acs: Optional[bool] = False mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN mask_split_gaussian_std: float = 3.0 @@ -146,8 +160,8 @@ class H5SliceConfig(DatasetConfig): kspace_context: int = 0 pass_mask: bool = False data_root: Optional[str] = None - filenames_filter: Optional[List[str]] = None - filenames_lists: Optional[List[str]] = None + filenames_filter: Optional[list[str]] = None + filenames_lists: Optional[list[str]] = None filenames_lists_root: Optional[str] = None @@ -155,12 +169,12 @@ class H5SliceConfig(DatasetConfig): class CMRxReconConfig(DatasetConfig): regex_filter: Optional[str] = None data_root: Optional[str] = None - filenames_filter: Optional[List[str]] = None - filenames_lists: Optional[List[str]] = None + filenames_filter: Optional[list[str]] = None + filenames_lists: Optional[list[str]] = None filenames_lists_root: Optional[str] = None kspace_key: str = "kspace_full" compute_mask: bool = False - extra_keys: Optional[List[str]] = None + extra_keys: Optional[list[str]] = None kspace_context: Optional[str] = None @@ -181,11 +195,11 @@ class FakeMRIBlobsConfig(DatasetConfig): @dataclass class SheppLoganDatasetConfig(DatasetConfig): - shape: Tuple[int, int, int] = (100, 100, 30) + shape: tuple[int, int, int] = (100, 100, 30) num_coils: int = 12 seed: Optional[int] = None B0: float = 3.0 - zlimits: Tuple[float, float] = (-0.929, 0.929) + zlimits: tuple[float, float] = (-0.929, 0.929) @dataclass diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index e8b25655..c15f2fbe 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -528,6 +528,226 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: return sample +class RescaleMode(DirectEnum): + AREA = "area" + BICUBIC = "bicubic" + BILINEAR = "bilinear" + NEAREST = "nearest" + NEAREST_EXACT = "nearest-exact" + TRILINEAR = "trilinear" + + +class RescaleKspace(DirectTransform): + """Rescale k-space (downsample/upsample) module. + + Rescales the k-space: + * It first projects the k-space to the image-domain via the backward operator, + * It rescales the back-projected k-space to specified shape, + * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. + + Parameters + ---------- + shape : tuple or list of ints + Shape to rescale the input. Must be correspond to (height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + rescale_2d_if_3d : bool, optional + If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions. + Default: False. + + Note + ---- + If the input k-space data is 3D, rescaling will be done only on the height and width dimensions if + `rescale_2d_if_3d` is set to True. + """ + + def __init__( + self, + shape: Union[tuple[int, int], list[int]], + forward_operator: Callable = T.fft2, + backward_operator: Callable = T.ifft2, + rescale_mode: RescaleMode = RescaleMode.NEAREST, + kspace_key: KspaceKey = KspaceKey.KSPACE, + rescale_2d_if_3d: Optional[bool] = None, + ) -> None: + """Inits :class:`RescaleKspace`. + + Parameters + ---------- + shape : tuple or list of ints + Shape to rescale the input. Must be correspond to (height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + rescale_2d_if_3d : bool, optional + If True and input k-space data is 3D, rescaling will be done only on the height and width dimensions, + by combining the slice/time dimension with the batch dimension. + Default: False. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + if len(shape) not in [2, 3]: + raise ValueError( + f"Shape should be a list or tuple of two integers if 2D or three integers if 3D. " + f"Received: {shape}." + ) + self.shape = shape + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.rescale_mode = rescale_mode + self.kspace_key = kspace_key + + self.rescale_2d_if_3d = rescale_2d_if_3d + if rescale_2d_if_3d and len(shape) == 3: + raise ValueError("Shape cannot have a length of 3 when rescale_2d_if_3d is set to True.") + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`RescaleKspace`. + + Parameters + ---------- + sample: Dict[str, Any] + Dict sample containing key `kspace`. + + Returns + ------- + Dict[str, Any] + Cropped and masked sample. + """ + kspace = sample[self.kspace_key] # shape (coil, [slice/time], height, width, complex=2) + + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + + backprojected_kspace = self.backward_operator(kspace, dim=dim) + + if kspace.ndim == 5 and self.rescale_2d_if_3d: + backprojected_kspace = backprojected_kspace.permute(1, 0, 2, 3, 4) + + if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): + backprojected_kspace = backprojected_kspace.unsqueeze(0) + + rescaled_backprojected_kspace = T.complex_image_resize(backprojected_kspace, self.shape, self.rescale_mode) + + if (kspace.ndim == 4) or (kspace.ndim == 5 and not self.rescale_2d_if_3d): + rescaled_backprojected_kspace = rescaled_backprojected_kspace.squeeze(0) + + if kspace.ndim == 5 and self.rescale_2d_if_3d: + rescaled_backprojected_kspace = rescaled_backprojected_kspace.permute(1, 0, 2, 3, 4) + + # Compute new k-space from rescaled_backprojected_kspace + # shape (coil, [slice/time if rescale_2d_if_3d else new_slc_or_time], new_height, new_width, complex=2) + sample[self.kspace_key] = self.forward_operator(rescaled_backprojected_kspace, dim=dim) # The rescaled kspace + + return sample + + +class PadKspace(DirectTransform): + """Pad k-space with zeros to desired shape module. + + Rescales the k-space by: + * It first projects the k-space to the image-domain via the backward operator, + * It pads the back-projected k-space to specified shape, + * It transforms the rescaled back-projected k-space to the k-space domain via the forward operator. + + Parameters + ---------- + pad_shape : tuple or list of ints + Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + """ + + def __init__( + self, + pad_shape: Union[tuple[int, int], list[int]], + forward_operator: Callable = T.fft2, + backward_operator: Callable = T.ifft2, + kspace_key: KspaceKey = KspaceKey.KSPACE, + ) -> None: + """Inits :class:`RescaleKspace`. + + Parameters + ---------- + pad_shape : tuple or list of ints + Shape to zero-pad the input. Must be correspond to (height, width) or (slice/time, height, width). + forward_operator : Callable + The forward operator, e.g. some form of FFT (centered or uncentered). + Default: :class:`direct.data.transforms.fft2`. + backward_operator : Callable + The backward operator, e.g. some form of inverse FFT (centered or uncentered). + Default: :class:`direct.data.transforms.ifft2`. + kspace_key : KspaceKey + K-space key. Default: KspaceKey.KSPACE. + """ + super().__init__() + self.logger = logging.getLogger(type(self).__name__) + + if len(pad_shape) not in [2, 3]: + raise ValueError(f"Shape should be a list or tuple of two or three integers. Received: {pad_shape}.") + + self.shape = pad_shape + self.forward_operator = forward_operator + self.backward_operator = backward_operator + self.kspace_key = kspace_key + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Calls :class:`PadKspace`. + + Parameters + ---------- + sample: dict[str, Any] + Dict sample containing key `kspace`. + + Returns + ------- + dict[str, Any] + Cropped and masked sample. + """ + kspace = sample[self.kspace_key] # shape (coil, [slice or time], height, width, complex=2) + shape = kspace.shape + + sample["original_size"] = shape[1:-1] + + dim = self.spatial_dims.TWO_D if kspace.ndim == 4 else self.spatial_dims.THREE_D + + backprojected_kspace = self.backward_operator(kspace, dim=dim) + backprojected_kspace = T.view_as_complex(backprojected_kspace) + + padded_backprojected_kspace = T.pad_tensor(backprojected_kspace, self.shape) + padded_backprojected_kspace = T.view_as_real(padded_backprojected_kspace) + + # shape (coil, [slice or time], height, width, complex=2) + sample[self.kspace_key] = self.forward_operator(padded_backprojected_kspace, dim=dim) # The padded kspace + + return sample + + class ComputeZeroPadding(DirectTransform): r"""Computes zero padding present in multi-coil kspace input. @@ -987,6 +1207,78 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]: return sample +class CompressCoilModule(DirectModule): + """Compresses k-space coils using SVD.""" + + def __init__(self, kspace_key: KspaceKey, num_coils: int) -> None: + """Inits :class:`CompressCoilModule`. + + Parameters + ---------- + kspace_key : KspaceKey + K-space key. + num_coils : int + Number of coils to compress. + """ + super().__init__() + self.kspace_key = kspace_key + self.num_coils = num_coils + + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: + """Performs coil compression to input k-space. + + Parameters + ---------- + sample : dict[str, Any] + Dict sample containing key `kspace_key`. Assumes coil dimension is first axis. + + Returns + ------- + sample : dict[str, Any] + Dict sample with `kspace_key` compressed to num_coils. + """ + k_space = sample[self.kspace_key].clone() # shape (batch, coil, [slice/time], height, width, complex=2) + + if k_space.shape[1] <= self.num_coils: + return sample + + ndim = k_space.ndim + + k_space = torch.view_as_complex(k_space) + + if ndim == 6: # If 3D sample reshape slice into batch dimension as sensitivities are computed 2D + num_slice_or_time = k_space.shape[2] + k_space = k_space.permute(0, 2, 1, 3, 4) + k_space = k_space.reshape(k_space.shape[0] * num_slice_or_time, *k_space.shape[2:]) + + shape = k_space.shape + + # Reshape the k-space data to combine spatial dimensions + k_space_reshaped = k_space.reshape(shape[0], shape[1], -1) + + # Compute the coil combination matrix using Singular Value Decomposition (SVD) + U, _, _ = torch.linalg.svd(k_space_reshaped, full_matrices=False) + + # Select the top ncoils_new singular vectors from the decomposition + U_new = U[:, :, : self.num_coils] + + # Perform coil compression + compressed_k_space = torch.matmul(U_new.transpose(1, 2), k_space_reshaped) + + # Reshape the compressed k-space back to its original shape + compressed_k_space = compressed_k_space.reshape(shape[0], self.num_coils, *shape[2:]) + + if ndim == 6: + compressed_k_space = compressed_k_space.reshape( + shape[0] // num_slice_or_time, num_slice_or_time, self.num_coils, *shape[2:] + ).permute(0, 2, 1, 3, 4) + + compressed_k_space = torch.view_as_real(compressed_k_space) + sample[self.kspace_key] = compressed_k_space # shape (batch, new coil, [slice/time], height, width, complex=2) + + return sample + + class DeleteKeysModule(DirectModule): """Remove keys from the sample if present.""" @@ -1364,6 +1656,7 @@ def __call__(self, *args, **kwargs) -> SubWrapper: EstimateSensitivityMap = ModuleWrapper(EstimateSensitivityMapModule, toggle_dims=True) DeleteKeys = ModuleWrapper(DeleteKeysModule, toggle_dims=False) RenameKeys = ModuleWrapper(RenameKeysModule, toggle_dims=False) +CompressCoil = ModuleWrapper(CompressCoilModule, toggle_dims=True) PadCoilDimension = ModuleWrapper(PadCoilDimensionModule, toggle_dims=True) ComputeScalingFactor = ModuleWrapper(ComputeScalingFactorModule, toggle_dims=True) Normalize = ModuleWrapper(NormalizeModule, toggle_dims=False) @@ -1431,6 +1724,10 @@ def build_pre_mri_transforms( mask_func: Optional[Callable], crop: Optional[Union[tuple[int, int], str]] = None, crop_type: Optional[str] = "uniform", + rescale: Optional[Union[tuple[int, int], list[int]]] = None, + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, + rescale_2d_if_3d: Optional[bool] = False, + pad: Optional[Union[tuple[int, int], list[int]]] = None, image_center_crop: bool = True, random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), random_rotation_probability: float = 0.0, @@ -1466,6 +1763,21 @@ def build_pre_mri_transforms( a key "reconstruction_size" must be present in the sample. Default: None. crop_type : Optional[str] Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + rescale : tuple or list, optional + If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. + Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. + It is not recommended to be used in combination with `crop`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + rescale_2d_if_3d : bool, optional + If True and k-space data is 3D, rescaling will be done only on the height + and width dimensions, by combining the slice/time dimension with the batch dimension. + This is ignored if `rescale` is None. Default: False. + pad : tuple or list, optional + If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) + or (slice/time, height, width). Default: None. image_center_crop : bool If True the backprojected kspace will be cropped around the center, otherwise randomly. This will be ignored if `crop` is None. Default: True. @@ -1493,6 +1805,8 @@ def build_pre_mri_transforms( An MRI transformation object. """ # pylint: disable=too-many-locals + logger = logging.getLogger(build_pre_mri_transforms.__name__) + mri_transforms: list[Callable] = [ToTensor()] if crop: mri_transforms += [ @@ -1505,6 +1819,28 @@ def build_pre_mri_transforms( random_crop_sampler_use_seed=use_seed, ) ] + if rescale: + if crop: + logger.warning( + "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended." + ) + mri_transforms += [ + RescaleKspace( + shape=rescale, + forward_operator=forward_operator, + backward_operator=backward_operator, + rescale_mode=rescale_mode, + rescale_2d_if_3d=rescale_2d_if_3d, + ) + ] + if pad: + mri_transforms += [ + PadKspace( + pad_shape=pad, + forward_operator=forward_operator, + backward_operator=backward_operator, + ) + ] if random_rotation_probability > 0.0: mri_transforms += [ RandomRotation( @@ -1521,10 +1857,13 @@ def build_pre_mri_transforms( keys_to_flip=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), ) ] - if mask_func: + if padding_eps > 0.0: mri_transforms += [ ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps), ApplyZeroPadding(KspaceKey.KSPACE, "padding"), + ] + if mask_func: + mri_transforms += [ CreateSamplingMask( mask_func, shape=(None if (isinstance(crop, str)) else crop), @@ -1654,6 +1993,10 @@ def build_supervised_mri_transforms( mask_func: Optional[Callable], crop: Optional[Union[tuple[int, int], str]] = None, crop_type: Optional[str] = "uniform", + rescale: Optional[Union[tuple[int, int], list[int]]] = None, + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, + rescale_2d_if_3d: Optional[bool] = False, + pad: Optional[Union[tuple[int, int], list[int]]] = None, image_center_crop: bool = True, random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), random_rotation_probability: float = 0.0, @@ -1672,6 +2015,7 @@ def build_supervised_mri_transforms( delete_acs_mask: bool = True, delete_kspace: bool = True, image_recon_type: ReconstructionType = ReconstructionType.RSS, + compress_coils: Optional[int] = None, pad_coils: Optional[int] = None, scaling_key: TransformKey = TransformKey.MASKED_KSPACE, scale_percentile: Optional[float] = 0.99, @@ -1683,8 +2027,11 @@ def build_supervised_mri_transforms( * Converts input to (complex-valued) tensor. * Applies k-space (center) crop if requested. + * Applies k-space rescaling if requested. + * Applies k-space padding if requested. * Applies random augmentations (rotation, flip, reverse) if requested. * Adds a sampling mask if `mask_func` is defined. + * Compreses the coil dimension if requested. * Pads the coil dimension if requested. * Adds coil sensitivities and / or the body coil_image * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. @@ -1707,6 +2054,21 @@ def build_supervised_mri_transforms( a key "reconstruction_size" must be present in the sample. Default: None. crop_type : Optional[str] Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + rescale : tuple or list, optional + If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. + Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. + It is not recommended to be used in combination with `crop`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + rescale_2d_if_3d : bool, optional + If True and k-space data is 3D, rescaling will be done only on the height + and width dimensions, by combining the slice/time dimension with the batch dimension. + This is ignored if `rescale` is None. Default: False. + pad : tuple or list, optional + If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) + or (slice/time, height, width). Default: None. image_center_crop : bool If True the backprojected kspace will be cropped around the center, otherwise randomly. This will be ignored if `crop` is None. Default: True. @@ -1749,6 +2111,9 @@ def build_supervised_mri_transforms( If True will delete key `kspace` (fully sampled k-space). Default: True. image_recon_type : ReconstructionType Type to reconstruct target image. Default: ReconstructionType.RSS. + compress_coils : int, optional + Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. + Default: None. pad_coils : int Number of coils to pad data to. scaling_key : TransformKey @@ -1776,6 +2141,26 @@ def build_supervised_mri_transforms( random_crop_sampler_use_seed=use_seed, ) ] + if rescale: + mri_transforms += [ + RescaleKspace( + shape=rescale, + forward_operator=forward_operator, + backward_operator=backward_operator, + rescale_mode=rescale_mode, + rescale_2d_if_3d=rescale_2d_if_3d, + kspace_key=KspaceKey.KSPACE, + ) + ] + if pad: + mri_transforms += [ + PadKspace( + pad_shape=pad, + forward_operator=forward_operator, + backward_operator=backward_operator, + kspace_key=KspaceKey.KSPACE, + ) + ] if random_rotation_probability > 0.0: mri_transforms += [ RandomRotation( @@ -1799,10 +2184,13 @@ def build_supervised_mri_transforms( keys_to_reverse=(TransformKey.KSPACE, TransformKey.SENSITIVITY_MAP), ) ] - if mask_func: + if padding_eps > 0.0: mri_transforms += [ ComputeZeroPadding(KspaceKey.KSPACE, "padding", padding_eps), ApplyZeroPadding(KspaceKey.KSPACE, "padding"), + ] + if mask_func: + mri_transforms += [ CreateSamplingMask( mask_func, shape=(None if (isinstance(crop, str)) else crop), @@ -1810,7 +2198,8 @@ def build_supervised_mri_transforms( return_acs=estimate_sensitivity_maps, ), ] - + if compress_coils: + mri_transforms += [CompressCoil(num_coils=compress_coils, kspace_key=KspaceKey.KSPACE)] if pad_coils: mri_transforms += [PadCoilDimension(pad_coils=pad_coils, key=KspaceKey.KSPACE)] @@ -1830,10 +2219,8 @@ def build_supervised_mri_transforms( espirit_max_iters=sensitivity_maps_espirit_max_iters, ) ] - if delete_acs_mask: mri_transforms += [DeleteKeys(keys=["acs_mask"])] - mri_transforms += [ ApplyMask( sampling_mask_key="sampling_mask", @@ -1841,7 +2228,6 @@ def build_supervised_mri_transforms( target_kspace_key=KspaceKey.MASKED_KSPACE, ), ] - mri_transforms += [ ComputeScalingFactor( normalize_key=scaling_key, percentile=scale_percentile, scaling_factor_key=TransformKey.SCALING_FACTOR @@ -1854,7 +2240,6 @@ def build_supervised_mri_transforms( ], # Only these two keys are in the sample here ), ] - mri_transforms += [ ComputeImage( kspace_key=KspaceKey.KSPACE, @@ -1863,7 +2248,6 @@ def build_supervised_mri_transforms( type_reconstruction=image_recon_type, ) ] - if delete_kspace: mri_transforms += [DeleteKeys(keys=[KspaceKey.KSPACE])] @@ -1882,6 +2266,10 @@ def build_mri_transforms( mask_func: Optional[Callable], crop: Optional[Union[tuple[int, int], str]] = None, crop_type: Optional[str] = "uniform", + rescale: Optional[Union[tuple[int, int], list[int]]] = None, + rescale_mode: Optional[RescaleMode] = RescaleMode.NEAREST, + rescale_2d_if_3d: Optional[bool] = False, + pad: Optional[Union[tuple[int, int], list[int]]] = None, image_center_crop: bool = True, random_rotation_degrees: Optional[Sequence[int]] = (-90, 90), random_rotation_probability: float = 0.0, @@ -1900,6 +2288,7 @@ def build_mri_transforms( delete_acs_mask: bool = True, delete_kspace: bool = True, image_recon_type: ReconstructionType = ReconstructionType.RSS, + compress_coils: Optional[int] = None, pad_coils: Optional[int] = None, scaling_key: TransformKey = TransformKey.MASKED_KSPACE, scale_percentile: Optional[float] = 0.99, @@ -1918,8 +2307,11 @@ def build_mri_transforms( * Converts input to (complex-valued) tensor. * Applies k-space (center) crop if requested. + * Applies k-space rescaling if requested. + * Applies k-space padding if requested. * Applies random augmentations (rotation, flip, reverse) if requested. * Adds a sampling mask if `mask_func` is defined. + * Compreses the coil dimension if requested. * Pads the coil dimension if requested. * Adds coil sensitivities and / or the body coil_image * Masks the fully sampled k-space, if there is a mask function or a mask in the sample. @@ -1943,6 +2335,21 @@ def build_mri_transforms( a key "reconstruction_size" must be present in the sample. Default: None. crop_type : Optional[str] Type of cropping, either "gaussian" or "uniform". This will be ignored if `crop` is None. Default: "uniform". + rescale : tuple or list, optional + If not None, this will transform the "kspace" to the image domain, rescale it, and transform it back. + Must correspond to (height, width). This is ignored if `rescale` is None. Default: None. + It is not recommended to be used in combination with `crop`. + rescale_mode : RescaleMode + Mode to be used for rescaling. Can be RescaleMode.AREA, RescaleMode.BICUBIC, RescaleMode.BILINEAR, + RescaleMode.NEAREST, RescaleMode.NEAREST_EXACT, or RescaleMode.TRILINEAR. Note that not all modes are + supported for 2D or 3D data. Default: RescaleMode.NEAREST. + rescale_2d_if_3d : bool, optional + If True and k-space data is 3D, rescaling will be done only on the height + and width dimensions, by combining the slice/time dimension with the batch dimension. + This is ignored if `rescale` is None. Default: False. + pad : tuple or list, optional + If not None, this will zero-pad the "kspace" to the given size. Must correspond to (height, width) + or (slice/time, height, width). Default: None. image_center_crop : bool If True the backprojected kspace will be cropped around the center, otherwise randomly. This will be ignored if `crop` is None. Default: True. @@ -1985,6 +2392,9 @@ def build_mri_transforms( If True will delete key `kspace` (fully sampled k-space). Default: True. image_recon_type : ReconstructionType Type to reconstruct target image. Default: ReconstructionType.RSS. + compress_coils : int, optional + Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. + Default: None. pad_coils : int Number of coils to pad data to. scaling_key : TransformKey @@ -2024,8 +2434,18 @@ def build_mri_transforms( DirectTransform An MRI transformation object. """ + logger = logging.getLogger(build_mri_transforms.__name__) + logger.info("Creating %s MRI transforms.", transforms_type) - logging.getLogger(build_mri_transforms.__name__).info("Creating %s MRI transforms.", transforms_type) + if crop and rescale: + logger.warning( + "Rescale and crop are both given. Rescale will be applied after cropping. This is not recommended." + ) + + if compress_coils and pad_coils: + logger.warning( + "Compress coils and pad coils are both given. Compress coils will be applied before padding. This is not recommended." + ) mri_transforms = build_supervised_mri_transforms( forward_operator=forward_operator, @@ -2033,6 +2453,10 @@ def build_mri_transforms( mask_func=mask_func, crop=crop, crop_type=crop_type, + rescale=rescale, + rescale_mode=rescale_mode, + rescale_2d_if_3d=rescale_2d_if_3d, + pad=pad, image_center_crop=image_center_crop, random_rotation_degrees=random_rotation_degrees, random_rotation_probability=random_rotation_probability, @@ -2051,6 +2475,7 @@ def build_mri_transforms( delete_acs_mask=delete_acs_mask if transforms_type == TransformsType.SUPERVISED else False, delete_kspace=delete_kspace if transforms_type == TransformsType.SUPERVISED else False, image_recon_type=image_recon_type, + compress_coils=compress_coils, pad_coils=pad_coils, scaling_key=scaling_key, scale_percentile=scale_percentile, diff --git a/direct/data/transforms.py b/direct/data/transforms.py index b589c5a8..746d5c59 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -13,6 +13,7 @@ from numpy.typing import ArrayLike from direct.data.bbox import crop_to_bbox +from direct.types import DirectEnum from direct.utils import ensure_list, is_complex_data, is_power_of_two from direct.utils.asserts import assert_complex, assert_same_shape @@ -941,3 +942,89 @@ def expand_operator( assert_complex(sensitivity_map, complex_last=True) return complex_multiplication(sensitivity_map, data.unsqueeze(dim)) + + +def complex_image_resize( + complex_image: torch.Tensor, + resize_shape: Union[tuple[int, int], list[int]], + mode: str = "nearest", +) -> torch.Tensor: + """Resize a complex tensor to a new size. + + Parameters + ---------- + complex_image : torch.Tensor + Complex image tensor with shape (B, C, [D], [H,] W, 2) representing real and imaginary parts + resize_shape : tuple or list of two integers + Shape to resize image to. + mode : str + Algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' | + 'nearest-exact'. Default: 'nearest' + + Returns + ------- + resized_image : torch.Tensor + Resized complex image tensor with shape (B, C, [new_depth,] [new_height,] new_width, 2) + """ + resize_shape = tuple(resize_shape) + if (complex_image.ndim - 3) != len(resize_shape): + raise ValueError( + f"Received resize shape {resize_shape} and {complex_image.ndim - 3}D tensor input with shape " + f"{complex_image.shape[2:-1]}. Dimensions of resize shape and input tensor should match." + ) + + # Extract the real and imaginary parts separately + real_part = complex_image[..., 0] + imag_part = complex_image[..., 1] + + interpolate_args = {"size": resize_shape, "mode": mode} + + if mode in ["bilinear", "bicubic", "trilinear"]: + interpolate_args.update({"align_corners": True}) + + # Reshape and resize the real and imaginary parts independently + real_resized = torch.nn.functional.interpolate(real_part, **interpolate_args) + imag_resized = torch.nn.functional.interpolate(imag_part, **interpolate_args) + + # Combine the resized real and imaginary parts into a complex tensor + resized_image = torch.stack((real_resized, imag_resized), dim=-1) + + return resized_image + + +def pad_tensor(input_image: torch.Tensor, target_shape: tuple[int, int], value: float = 0) -> torch.Tensor: + """Pads an input image tensor to a desired shape. + + Parameters + ---------- + input_image : torch.Tensor + The input image tensor of shape (..., x, y) or (..., z, x, y). + target_shape : tuple of integers + The desired shape (X, Y) or (Z, X, Y) for the padded image. + value : float + Padding value. Default: 0. + + Returns + ------- + torch.Tensor + The padded image tensor + """ + if len(target_shape) == 2: + input_shape = input_image.shape[-2:] + elif len(target_shape) == 3: + input_shape = input_image.shape[-3:] + else: + raise ValueError(f"Target shape not supported. Received `target_shape`={target_shape}.") + + # Calculate the required padding + pad = [] + for i in range(len(target_shape)): + diff = target_shape[i] - input_shape[i] + pad_before = max(0, diff // 2) + pad_after = max(0, diff - pad_before) + pad.extend([pad_before, pad_after]) + + pad = pad[::-1] + padded_image = torch.nn.functional.pad(input_image, pad, mode="constant", value=value) + + return padded_image From c1dd9a51403d1f290a59a93f90b0aa95685652ca Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 00:09:30 +0200 Subject: [PATCH 02/37] Typing fix --- direct/data/transforms.py | 69 +++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 746d5c59..26e81a9d 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -1,11 +1,19 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors # Code and comments can be shared with code of FastMRI under the same MIT license: # https://github.com/facebookresearch/fastMRI/ # The code can have been adjusted to our needs. -from typing import Callable, List, Optional, Tuple, Union +"""Direct transforms module. + +This module contains functions for complex-valued data manipulation in PyTorch. This includes functions for complex +multiplication, division, modulus, fft, ifft, fftshift, ifftshift, and more. The functions are designed to work with +complex-valued data where the last axis denotes the real and imaginary parts respectively. The functions are designed to +work with complex-valued data where the last axis denotes the real and imaginary parts respectively.""" + +from __future__ import annotations + +from typing import Callable, Optional, Union import numpy as np import torch @@ -13,7 +21,6 @@ from numpy.typing import ArrayLike from direct.data.bbox import crop_to_bbox -from direct.types import DirectEnum from direct.utils import ensure_list, is_complex_data, is_power_of_two from direct.utils.asserts import assert_complex, assert_same_shape @@ -35,7 +42,7 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: return torch.from_numpy(data) -def verify_fft_dtype_possible(data: torch.Tensor, dims: Tuple[int, ...]) -> bool: +def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool: """fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if this is the case. @@ -98,7 +105,7 @@ def view_as_real(data): def fft2( data: torch.Tensor, - dim: Tuple[int, ...] = (1, 2), + dim: tuple[int, ...] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, @@ -159,7 +166,7 @@ def fft2( def ifft2( data: torch.Tensor, - dim: Tuple[int, ...] = (1, 2), + dim: tuple[int, ...] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, @@ -300,8 +307,8 @@ def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor: def roll( data: torch.Tensor, - shift: List[int], - dim: Union[List[int], Tuple[int, ...]], + shift: list[int], + dim: Union[list[int], tuple[int, ...]], ) -> torch.Tensor: """Similar to numpy roll but applies to pytorch tensors. @@ -309,7 +316,7 @@ def roll( ---------- data: torch.Tensor shift: tuple, int - dim: List or tuple of ints + dim: list or tuple of ints Returns ------- @@ -325,14 +332,14 @@ def roll( return data -def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor: +def fftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor: """Similar to numpy fftshift but applies to pytorch tensors. Parameters ---------- data: torch.Tensor Input data. - dim: List or tuple of ints or None + dim: list or tuple of ints or None Default: None. Returns @@ -353,14 +360,14 @@ def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = return roll(data, shift, dim) -def ifftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor: +def ifftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor: """Similar to numpy ifftshift but applies to pytorch tensors. Parameters ---------- data: torch.Tensor Input data. - dim: List or tuple of ints or None + dim: list or tuple of ints or None Default: None. Returns @@ -413,7 +420,7 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso return multiplication -def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> torch.Tensor: +def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: list[int]) -> torch.Tensor: r"""Computes the dot product of the complex tensors :math:`a` and :math:`b`: :math:`a^{*}b = `. Parameters @@ -422,7 +429,7 @@ def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> tor Input :math:`a`. b : torch.Tensor Input :math:`b`. - dim : List[int] + dim : list[int] Dimensions which will be suppressed. Useful when inputs are batched. Returns @@ -583,7 +590,7 @@ def apply_mask( mask_func: Union[Callable, torch.Tensor], seed: Optional[int] = None, return_mask: bool = True, -) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: +) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Subsample kspace by setting kspace to zero as given by a binary mask. Parameters @@ -665,13 +672,13 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) return torch.sqrt((data**2).sum(dim)) -def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> torch.Tensor: +def center_crop(data: torch.Tensor, shape: Union[list[int], tuple[int, ...]]) -> torch.Tensor: """Apply a center crop along the last two dimensions. Parameters ---------- data: torch.Tensor - shape: List or tuple of ints + shape: list or tuple of ints The output shape, should be smaller than the corresponding data dimensions. Returns @@ -691,19 +698,19 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> def complex_center_crop( - data_list: Union[List[torch.Tensor], torch.Tensor], - crop_shape: Union[List[int], Tuple[int, ...]], + data_list: Union[list[torch.Tensor], torch.Tensor], + crop_shape: Union[list[int], tuple[int, ...]], offset: int = 1, contiguous: bool = False, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: """Apply a center crop to the input data, or to a list of complex images. Parameters ---------- - data_list: Union[List[torch.Tensor], torch.Tensor] + data_list: Union[list[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2. - crop_shape: List[int] or Tuple[int, ...] + crop_shape: list[int] or tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. If one value is None, this is filled in by the image shape. offset: int @@ -713,7 +720,7 @@ def complex_center_crop( Returns ------- - Union[List[torch.Tensor], torch.Tensor] + Union[list[torch.Tensor], torch.Tensor] The center cropped input_image(s). """ data_list = ensure_list(data_list) @@ -747,22 +754,22 @@ def complex_center_crop( def complex_random_crop( - data_list: Union[List[torch.Tensor], torch.Tensor], - crop_shape: Union[List[int], Tuple[int, ...]], + data_list: Union[list[torch.Tensor], torch.Tensor], + crop_shape: Union[list[int], tuple[int, ...]], offset: int = 1, contiguous: bool = False, sampler: str = "uniform", - sigma: Union[float, List[float], None] = None, + sigma: Union[float, list[float], None] = None, seed: Union[None, int, ArrayLike] = None, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: """Apply a random crop to the input data tensor or a list of complex. Parameters ---------- - data_list: Union[List[torch.Tensor], torch.Tensor] + data_list: Union[list[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions -3 and -2 and the last dimensions should have a size of 2. - crop_shape: List[int] or Tuple[int, ...] + crop_shape: list[int] or tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. offset: int Starting dimension for cropping. @@ -776,7 +783,7 @@ def complex_random_crop( Returns ------- - Union[List[torch.Tensor], torch.Tensor] + Union[list[torch.Tensor], torch.Tensor] The center cropped input tensor or list of tensors. """ if sampler == "uniform" and sigma is not None: From 357a68041de7a7131a5dca3e17d7fe1190b5410b Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:46:22 +0200 Subject: [PATCH 03/37] Remove python 3.8 checks cause they keep failing --- .github/workflows/coverage.yml | 2 +- .github/workflows/pylint.yml | 4 ++-- .github/workflows/tox.yml | 2 +- contributing.rst | 2 +- docker/Dockerfile | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 306bb592..1b89bfc4 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.9] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index ede2845a..b36a53aa 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 301611ca..7c4c9004 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9] + python-version: [3.9] steps: - uses: actions/checkout@v2 diff --git a/contributing.rst b/contributing.rst index c44ad74b..b8d9d57d 100644 --- a/contributing.rst +++ b/contributing.rst @@ -110,7 +110,7 @@ Before you submit a pull request, check that it meets these guidelines: #. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -#. The pull request should work for Python 3.8 and 3.9 and for PyPy. Check Github actions and see that all tests pass. +#. The pull request should work for Python 3.9 and for PyPy. Check Github actions and see that all tests pass. Tests ^^^^^ diff --git a/docker/Dockerfile b/docker/Dockerfile index 6fe58b15..3288b09b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,6 +1,6 @@ ARG CUDA="11.3.0" ARG PYTORCH="1.10" -ARG PYTHON="3.8" +ARG PYTHON="3.9" # TODO: conda installs its own version of cuda FROM nvidia/cuda:${CUDA}-devel-ubuntu18.04 From 7cd2fa0dbdf055e2256629cf304289fb08864851 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:51:19 +0200 Subject: [PATCH 04/37] Try 3.10 --- .github/workflows/tox.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 7c4c9004..1db273bc 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.9, 3.10] steps: - uses: actions/checkout@v2 From 6821dac6a2bb3abfb4c1514747c456abfac7da1e Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:53:34 +0200 Subject: [PATCH 05/37] Try 3.10 --- .github/workflows/tox.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 1db273bc..861f544b 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 3.10] + python-version: [3.9, 310] steps: - uses: actions/checkout@v2 From 309d51596cc47eee3403af0ae0998a7ad62a27c2 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:54:44 +0200 Subject: [PATCH 06/37] Try 3.10 --- .github/workflows/tox.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 861f544b..455c50aa 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 310] + python-version: [3.10.0] steps: - uses: actions/checkout@v2 From a9a7bb6bb1cfc2dd7d387bd9b5fa2974760d79ab Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:58:55 +0200 Subject: [PATCH 07/37] Try 3.10 --- .github/workflows/tox.yml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 455c50aa..11a6636f 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.10.0] + python-version: [3.10, 3.9] steps: - uses: actions/checkout@v2 diff --git a/pyproject.toml b/pyproject.toml index 42f3a7c2..11f16161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ skip_missing_interpreters=true python = 3.8: py38 3.9: py39 + 3.10: py310 [testenv] deps = pytest From 6d34de5a05b4f7f09a0655156eb0caf262e30546 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 01:59:30 +0200 Subject: [PATCH 08/37] Try 3.10 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 11f16161..6d48620c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ generated-members=['numpy.*', 'torch.*', 'np.*'] [tool.tox] legacy_tox_ini = """ [tox] -envlist = py38, py39 +envlist = py38, py39, py310 skip_missing_interpreters=true [gh-actions] From 15bc1bb2584c48f62d7e6682a09bb42b3d9e4110 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 02:04:26 +0200 Subject: [PATCH 09/37] Update tox.yml --- .github/workflows/tox.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 11a6636f..a0c7b52e 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,12 +9,12 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.10, 3.9] + python-version: [3.10] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies From fedef5515b3ce6c70069c1f51b4c02b5532a5b71 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 02:05:37 +0200 Subject: [PATCH 10/37] Update tox.yml --- .github/workflows/tox.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index a0c7b52e..1662c4c5 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.10] + python-version: ['3.10'] steps: - uses: actions/checkout@v2 From f59668dcb52b52bfce0485603fc314d4bdab0092 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 02:12:59 +0200 Subject: [PATCH 11/37] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f5f247d8..7c2c5984 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def finalize_options(self): "direct=direct.cli:main", ], }, - setup_requires=["numpy", "cython"], + setup_requires=["numpy", "cython>=3.0"], install_requires=[ "numpy>=1.21.2", "h5py==3.3.0", From 9a35a0fbcf80b8d69003ae77489e9eeab1cfe661 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 02:17:20 +0200 Subject: [PATCH 12/37] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7c2c5984..f420e87e 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def finalize_options(self): "direct=direct.cli:main", ], }, - setup_requires=["numpy", "cython>=3.0"], + setup_requires=["numpy>=1.21.2", "cython>=3.0"], install_requires=[ "numpy>=1.21.2", "h5py==3.3.0", From 9bbdaffc33b137700c19db0afe6e6f59ca20a2b5 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 13:06:58 +0200 Subject: [PATCH 13/37] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f420e87e..7cbb1a2b 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ def finalize_options(self): setup_requires=["numpy>=1.21.2", "cython>=3.0"], install_requires=[ "numpy>=1.21.2", - "h5py==3.3.0", + "h5py==3.11.0", "omegaconf==2.1.1", "torch>=1.10.2", "torchvision", From 2d7ef736efb9bff5287d31e9f71bb5c1734a8297 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 13:18:46 +0200 Subject: [PATCH 14/37] Change py version --- .github/workflows/black.yml | 2 +- .github/workflows/coverage.yml | 2 +- .github/workflows/pylint.yml | 4 ++-- .github/workflows/tox.yml | 2 +- pyproject.toml | 7 +++---- setup.py | 5 ++--- 6 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 31e04660..738c7a26 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -7,5 +7,5 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 - uses: psf/black@22.12.0 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1b89bfc4..1c039ebd 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: ['3.10', '3.11'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index b36a53aa..99c065d5 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v1 with: - python-version: 3.9 + python-version: 3.10 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 1662c4c5..6fd69bb9 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10'] + python-version: ['3.10', '3.11'] steps: - uses: actions/checkout@v2 diff --git a/pyproject.toml b/pyproject.toml index 6d48620c..4d2a5871 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ [tool.black] line-length = 119 # PyCharm line length -target-version = ['py38'] +target-version = ['py310'] include = '\.pyi?$' exclude = ''' /( @@ -43,15 +43,14 @@ generated-members=['numpy.*', 'torch.*', 'np.*'] [tool.tox] legacy_tox_ini = """ [tox] -envlist = py38, py39, py310 +envlist = py39, py310, py311 skip_missing_interpreters=true [gh-actions] python = - 3.8: py38 3.9: py39 3.10: py310 - + 3.11: py311 [testenv] deps = pytest extras = dev diff --git a/setup.py b/setup.py index 7cbb1a2b..00a4ec34 100644 --- a/setup.py +++ b/setup.py @@ -36,16 +36,15 @@ def finalize_options(self): setup( author="Jonas Teuwen, George Yiasemis", author_email="j.teuwen@nki.nl, g.yiasemis@nki.nl", - python_requires=">=3.8", + python_requires=">=3.10", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], description="DIRECT - Deep Image REConsTruction - is a deep learning" " framework for MRI reconstruction.", entry_points={ From c6b4c0aa3be1fb8bd8a42232bb9dcdabfdafc7bc Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 13:45:55 +0200 Subject: [PATCH 15/37] Change py version --- .github/workflows/coverage.yml | 2 +- .github/workflows/tox.yml | 2 +- pyproject.toml | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1c039ebd..ed7c35e6 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11'] + python-version: ['3.9'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 6fd69bb9..c5e7ed87 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11'] + python-version: ['3.9', '3.10'] steps: - uses: actions/checkout@v2 diff --git a/pyproject.toml b/pyproject.toml index 4d2a5871..c0cceada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,14 +43,13 @@ generated-members=['numpy.*', 'torch.*', 'np.*'] [tool.tox] legacy_tox_ini = """ [tox] -envlist = py39, py310, py311 +envlist = py39, py310 skip_missing_interpreters=true [gh-actions] python = 3.9: py39 3.10: py310 - 3.11: py311 [testenv] deps = pytest extras = dev From fbad9e135fb220913c723170a06b8bd23afe8b47 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 13:50:53 +0200 Subject: [PATCH 16/37] Add forgotten argument --- direct/data/datasets_config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index 77356d7a..d0defa43 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -92,6 +92,9 @@ class TransformsConfig(BaseConfig): Default is True. image_recon_type : ReconstructionType Image reconstruction type. Default is ReconstructionType.RSS. + compress_coils : int, optional + Number of coils to compress input k-space. It is not recommended to be used in combination with `pad_coils`. + Default is None. pad_coils : int, optional Pad coils. Default is None. use_seed : bool @@ -133,6 +136,7 @@ class TransformsConfig(BaseConfig): delete_acs_mask: bool = True delete_kspace: bool = True image_recon_type: ReconstructionType = ReconstructionType.RSS + compress_coils: Optional[int] = None pad_coils: Optional[int] = None use_seed: bool = True transforms_type: TransformsType = TransformsType.SUPERVISED From 28412155a57d00eaf684e29fbf54c086b2e17562 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 13:54:54 +0200 Subject: [PATCH 17/37] Minor fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 00a4ec34..8b0702f9 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ def finalize_options(self): setup( author="Jonas Teuwen, George Yiasemis", author_email="j.teuwen@nki.nl, g.yiasemis@nki.nl", - python_requires=">=3.10", + python_requires=">=3.9", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", From e824db28adc8763c79c6077b83c3df6494dcd42c Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 13:56:39 +0200 Subject: [PATCH 18/37] Minor fix --- .github/workflows/pylint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 99c065d5..4c2321c2 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -10,7 +10,7 @@ jobs: - name: Set up Python 3.10 uses: actions/setup-python@v1 with: - python-version: 3.10 + python-version: '3.10' - name: Install dependencies run: | python -m pip install --upgrade pip From 354a00894f35aa0d9d0b945a7fffd4492e88ced0 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 15:16:42 +0200 Subject: [PATCH 19/37] Update for codacy, add tests --- direct/data/mri_transforms.py | 2 +- direct/data/transforms.py | 4 +- tests/tests_data/test_mri_transforms.py | 61 ++++++++++++++++++++++++- tests/tests_data/test_transforms.py | 54 ++++++++++++++++++++++ 4 files changed, 116 insertions(+), 5 deletions(-) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index c15f2fbe..c5841981 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -685,7 +685,7 @@ class PadKspace(DirectTransform): def __init__( self, - pad_shape: Union[tuple[int, int], list[int]], + pad_shape: Union[tuple[int, ...], list[int]], forward_operator: Callable = T.fft2, backward_operator: Callable = T.ifft2, kspace_key: KspaceKey = KspaceKey.KSPACE, diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 26e81a9d..42d71b50 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -1025,8 +1025,8 @@ def pad_tensor(input_image: torch.Tensor, target_shape: tuple[int, int], value: # Calculate the required padding pad = [] - for i in range(len(target_shape)): - diff = target_shape[i] - input_shape[i] + for _, (target_dim, input_dim) in enumerate(zip(target_shape, input_shape)): + diff = target_dim - input_dim pad_before = max(0, diff // 2) pad_after = max(0, diff - pad_before) pad.extend([pad_before, pad_after]) diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index 118e7093..145a30eb 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -14,6 +14,7 @@ ApplyMask, ApplyZeroPadding, Compose, + CompressCoil, ComputeImage, ComputeScalingFactor, ComputeZeroPadding, @@ -24,19 +25,22 @@ EstimateSensitivityMap, Normalize, PadCoilDimension, + PadKspace, RandomFlip, RandomFlipType, RandomReverse, RandomRotation, ReconstructionType, + RescaleKspace, + RescaleMode, SensitivityMapType, ToTensor, WhitenData, build_mri_transforms, ) -from direct.data.transforms import fft2, ifft2, modulus +from direct.data.transforms import fft2, ifft2 from direct.exceptions import ItemNotFoundException -from direct.types import IntegerListOrTupleString +from direct.types import IntegerListOrTupleString, KspaceKey def create_sample(shape, **kwargs): @@ -262,6 +266,41 @@ def test_CropKspace( assert sample["kspace"].shape == (shape[0],) + crop_shape + (2,) +@pytest.mark.parametrize( + "shape, pad_shape, mode", + [ + [(3, 10, 16), (20, 26), RescaleMode.NEAREST], + [(3, 10, 16), (20, 26), RescaleMode.AREA], + [(3, 10, 16), (20, 26), RescaleMode.BILINEAR], + [(3, 10, 16), (20, 26), RescaleMode.BICUBIC], + [(3, 21, 10, 16), (30, 20, 26), RescaleMode.NEAREST], + [(3, 21, 10, 16), (30, 20, 26), RescaleMode.AREA], + [(3, 21, 10, 16), (30, 20, 26), RescaleMode.TRILINEAR], + ], +) +def test_RescaleKspace(shape, pad_shape, mode): + sample = create_sample(shape=shape + (2,)) + transform = RescaleKspace(pad_shape, rescale_mode=mode) + + sample = transform(sample) + assert sample["kspace"].shape == shape[: -len(pad_shape)] + pad_shape + (2,) + + +@pytest.mark.parametrize( + "shape, pad_shape", + [ + [(3, 10, 16), (20, 26)], + [(3, 21, 10, 16), (30, 20, 26)], + ], +) +def test_PadKspace(shape, pad_shape): + sample = create_sample(shape=shape + (2,)) + transform = PadKspace(pad_shape) + + sample = transform(sample) + assert sample["kspace"].shape == shape[: -len(pad_shape)] + pad_shape + (2,) + + @pytest.mark.parametrize( "shape", [(3, 21, 10, 16)], @@ -557,6 +596,24 @@ def test_DeleteKeys(shape, delete_keys): assert key not in sample +@pytest.mark.parametrize( + "shape, compress_coils", + [ + [(5, 7, 6), 4], + [(4, 5, 5), 4], + [(4, 5, 5), 3], + [(5, 4, 5, 5), 6], + [(5, 4, 5, 5), 4], + ], +) +def test_CompressCoil(shape, compress_coils): + sample = create_sample(shape=shape + (2,)) + transform = CompressCoil(kspace_key=KspaceKey.KSPACE, num_coils=compress_coils) + + sample = transform(sample) + assert sample["kspace"].shape == (compress_coils if compress_coils < shape[0] else shape[0],) + shape[1:] + (2,) + + @pytest.mark.parametrize( "shape, pad_coils", [[(3, 10, 16), 5], [(5, 7, 6), 5], [(4, 5, 5), 2], [(4, 5, 5), None], [(3, 4, 6, 4), 4], [(5, 3, 3, 4), 3]], diff --git a/tests/tests_data/test_transforms.py b/tests/tests_data/test_transforms.py index 3c84edc9..616b5b11 100644 --- a/tests/tests_data/test_transforms.py +++ b/tests/tests_data/test_transforms.py @@ -523,3 +523,57 @@ def test_apply_padding(shape): padded_data = transforms.apply_padding(data, padding) assert torch.allclose(data * (~padding), padded_data) + + +@pytest.mark.parametrize( + "input_shape, resize_shape, mode", + [ + ((1, 6, 3, 2), (5,), "nearest"), + ((1, 3, 6, 3, 2), (5, 5), "nearest"), + ((1, 7, 3, 6, 3, 2), (5, 5, 5), "nearest"), + ((1, 6, 3, 2), (5,), "area"), + ((1, 3, 6, 3, 2), (5, 5), "area"), + ((1, 7, 3, 6, 3, 2), (5, 5, 5), "area"), + ((1, 6, 3, 2), (5,), "linear"), + ((1, 3, 6, 3, 2), (5, 5), "bilinear"), + ((1, 3, 6, 3, 2), (5, 5), "bicubic"), + ((1, 7, 3, 6, 3, 2), (5, 5, 5), "trilinear"), + ], +) +def test_complex_image_resize(input_shape, resize_shape, mode): + # Create a random complex_image tensor with the specified input shape + complex_image = torch.randn(input_shape) + + # Perform the resize operation + resized_image = transforms.complex_image_resize(complex_image, resize_shape, mode) + + # Determine the expected shape based on the resize_shape + expected_shape = input_shape[: -(len(resize_shape) + 1)] + tuple(resize_shape) + (2,) + + # Assert that the shape of the resized image matches the expected shape + assert resized_image.shape == expected_shape + + +@pytest.mark.parametrize( + "input_shape, target_shape, value", + [ + [(30, 20, 25), (40, 40), 0], + [(30, 20, 25), (40, 40), 3], + [(30, 20, 25), (40, 39, 28), 0], + [(30, 20, 25), (40, 39, 28), 6], + [(11, 30, 20, 25), (40, 39, 28), 1], + [(11, 30, 20, 25), (40, 10, 20), 1], + ], +) +def test_pad_tensor(input_shape, target_shape, value): + data = torch.ones(input_shape) + padded_data = transforms.pad_tensor(data, target_shape, value) + + expected_shape = list(input_shape[: -len(target_shape)]) + list(target_shape) + for i in range(1, len(target_shape) + 1): + if target_shape[-i] < input_shape[-i]: + expected_shape[-i] = input_shape[-i] + + assert list(padded_data.shape) == expected_shape + + assert data.sum() + (value * (np.prod(expected_shape) - np.prod(input_shape))) == padded_data.sum() From fd08d069dfe867c2b3f9ce722ea6a853ac1333e3 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 6 Jun 2024 15:17:48 +0200 Subject: [PATCH 20/37] Update for codacy, add tests --- direct/data/mri_transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index c5841981..606c56e5 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -2444,7 +2444,8 @@ def build_mri_transforms( if compress_coils and pad_coils: logger.warning( - "Compress coils and pad coils are both given. Compress coils will be applied before padding. This is not recommended." + "Compress coils and pad coils are both given. Compress coils will be applied before padding. " + "This is not recommended." ) mri_transforms = build_supervised_mri_transforms( From 31c6cfa491e704c9eace332dea05ca0005d2cc64 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 6 Jun 2024 15:35:52 +0200 Subject: [PATCH 21/37] Update pylint.yml --- .github/workflows/pylint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 4c2321c2..d9381901 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.10 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: '3.10' + python-version: '3.9' - name: Install dependencies run: | python -m pip install --upgrade pip From 76d06f2081b40f98a09c1b5783ac44876fd200f4 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Thu, 13 Jun 2024 11:23:04 +0200 Subject: [PATCH 22/37] Update coverage.yml --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index ed7c35e6..5a216756 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9'] + python-version: ['3.10'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 531e35c53fb58d8ce50c22e58e4edc4a92486270 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 15:18:28 +0200 Subject: [PATCH 23/37] Update py version --- contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing.rst b/contributing.rst index b8d9d57d..962ed0b2 100644 --- a/contributing.rst +++ b/contributing.rst @@ -110,7 +110,7 @@ Before you submit a pull request, check that it meets these guidelines: #. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -#. The pull request should work for Python 3.9 and for PyPy. Check Github actions and see that all tests pass. +#. The pull request should work for Python 3.10 and for PyPy. Check Github actions and see that all tests pass. Tests ^^^^^ From bf813b922aeeed93a3dc4ee012a6975034440108 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 15:21:03 +0200 Subject: [PATCH 24/37] Update py version --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 3288b09b..fcf8b2ff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,9 +1,9 @@ ARG CUDA="11.3.0" ARG PYTORCH="1.10" -ARG PYTHON="3.9" +ARG PYTHON="3.10" # TODO: conda installs its own version of cuda -FROM nvidia/cuda:${CUDA}-devel-ubuntu18.04 +FROM nvidia/cuda:${CUDA}-devel-ubuntu20.04 ENV CUDA_PATH /usr/local/cuda ENV CUDA_ROOT /usr/local/cuda/bin From a8f70d486285531f3edc67ccb0ffbd831deb1808 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 15:25:44 +0200 Subject: [PATCH 25/37] Update versions --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 8b0702f9..ce9a2956 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def finalize_options(self): "OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], @@ -56,9 +57,9 @@ def finalize_options(self): install_requires=[ "numpy>=1.21.2", "h5py==3.11.0", - "omegaconf==2.1.1", - "torch>=1.10.2", - "torchvision", + "omegaconf==2.3.0", + "torch>=2.0.0", + "torchvision==0.17.0", "scikit-image>=0.19.0", "scikit-learn>=1.0.1", "tensorboard>=2.7.0", From 6fd3629a629ac1a3b028a1d94fa824ec45c642ac Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 16:09:34 +0200 Subject: [PATCH 26/37] Typing --- direct/data/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 42d71b50..22fa6b08 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -105,7 +105,7 @@ def view_as_real(data): def fft2( data: torch.Tensor, - dim: tuple[int, ...] = (1, 2), + dim: tuple[int, int] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, @@ -166,7 +166,7 @@ def fft2( def ifft2( data: torch.Tensor, - dim: tuple[int, ...] = (1, 2), + dim: tuple[int, int] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, From 6716770f202d8153317cd365f0c0cd5655eea5af Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 16:10:03 +0200 Subject: [PATCH 27/37] Update versions --- setup.py | 4 ++-- tests/test_train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ce9a2956..9c1c5f14 100644 --- a/setup.py +++ b/setup.py @@ -58,8 +58,8 @@ def finalize_options(self): "numpy>=1.21.2", "h5py==3.11.0", "omegaconf==2.3.0", - "torch>=2.0.0", - "torchvision==0.17.0", + "torch>=2.2.0", + "torchvision==0.18.0", "scikit-image>=0.19.0", "scikit-learn>=1.0.1", "tensorboard>=2.7.0", diff --git a/tests/test_train.py b/tests/test_train.py index 39630c0a..f64cbb40 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -31,7 +31,7 @@ def create_test_transform_cfg(transforms_type): transforms_config = TransformsConfig( normalization=NormalizationTransformConfig(scaling_key="masked_kspace"), masking=MaskingConfig(name="FastMRIRandom"), - cropping=CropTransformConfig(crop=(32, 32)), + cropping=CropTransformConfig(crop="(32, 32)"), sensitivity_map_estimation=SensitivityMapEstimationTransformConfig(estimate_sensitivity_maps=True), transforms_type=transforms_type, ) From 8b9c4c6ee503678714e7e665e29cfae62bb5d5a2 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 13 Jun 2024 17:35:05 +0200 Subject: [PATCH 28/37] Minor fix --- direct/data/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 22fa6b08..7aca292e 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -42,14 +42,14 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: return torch.from_numpy(data) -def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool: +def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, int] | tuple[int, int, int]) -> bool: """fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if this is the case. Parameters ---------- data: torch.Tensor - dims: tuple + dims: tuple of two or three ints Returns ------- From e41f3b301329b052fea8d530d490e7cbe0d8d8c8 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Sun, 23 Jun 2024 12:52:29 +0200 Subject: [PATCH 29/37] Revert back --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index fcf8b2ff..6fe58b15 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,9 +1,9 @@ ARG CUDA="11.3.0" ARG PYTORCH="1.10" -ARG PYTHON="3.10" +ARG PYTHON="3.8" # TODO: conda installs its own version of cuda -FROM nvidia/cuda:${CUDA}-devel-ubuntu20.04 +FROM nvidia/cuda:${CUDA}-devel-ubuntu18.04 ENV CUDA_PATH /usr/local/cuda ENV CUDA_ROOT /usr/local/cuda/bin From e1290bce441c8958d12b601f29691581d3742872 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 10:36:58 +0200 Subject: [PATCH 30/37] New version (pip?) causes problems with cnp.int_t --- direct/common/_gaussian.pyx | 4 ++-- direct/common/_poisson.pyx | 6 +++--- direct/ssl/_gaussian_fill.pyx | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/direct/common/_gaussian.pyx b/direct/common/_gaussian.pyx index 8c4db1ff..7a05682b 100644 --- a/direct/common/_gaussian.pyx +++ b/direct/common/_gaussian.pyx @@ -57,7 +57,7 @@ def gaussian_mask_1d( int n, int center, double std, - cnp.ndarray[cnp.int_t, ndim=1, mode='c'] mask, + cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] mask, int seed, ): cdef int count, ind @@ -82,7 +82,7 @@ def gaussian_mask_2d( int center_x, int center_y, cnp.ndarray[cnp.float_t, ndim=1, mode='c'] std, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, int seed, ): cdef int count, indx, indy diff --git a/direct/common/_poisson.pyx b/direct/common/_poisson.pyx index 5b1ad183..a93e3263 100644 --- a/direct/common/_poisson.pyx +++ b/direct/common/_poisson.pyx @@ -38,7 +38,7 @@ def poisson( int nx, int ny, int max_attempts, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_x, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_y, int seed @@ -62,8 +62,8 @@ def poisson( cdef Py_ssize_t startx, endx, starty, endy, px, py # initialize active list - cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) - cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) srand(seed) diff --git a/direct/ssl/_gaussian_fill.pyx b/direct/ssl/_gaussian_fill.pyx index d880916a..406d99c5 100644 --- a/direct/ssl/_gaussian_fill.pyx +++ b/direct/ssl/_gaussian_fill.pyx @@ -45,8 +45,8 @@ def gaussian_fill( int center_x, int center_y, double std_scale, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] output_mask, + cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] output_mask, int seed, ): cdef int count, indx, indy From 6d118d6ef1e5a40864c7f1812ebb721ec8229d86 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 10:44:41 +0200 Subject: [PATCH 31/37] Try int_t again --- direct/common/_gaussian.pyx | 6 +++--- direct/common/_poisson.pyx | 6 +++--- direct/ssl/_gaussian_fill.pyx | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/direct/common/_gaussian.pyx b/direct/common/_gaussian.pyx index 7a05682b..0e6e8e1a 100644 --- a/direct/common/_gaussian.pyx +++ b/direct/common/_gaussian.pyx @@ -10,7 +10,7 @@ import numpy as np -cimport numpy as cnp +z from libc.math cimport cos, log, pi, sin, sqrt from libc.stdlib cimport RAND_MAX, rand, srand @@ -57,7 +57,7 @@ def gaussian_mask_1d( int n, int center, double std, - cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] mask, + cnp.ndarray[cnp.int_t, ndim=1, mode='c'] mask, int seed, ): cdef int count, ind @@ -82,7 +82,7 @@ def gaussian_mask_2d( int center_x, int center_y, cnp.ndarray[cnp.float_t, ndim=1, mode='c'] std, - cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, int seed, ): cdef int count, indx, indy diff --git a/direct/common/_poisson.pyx b/direct/common/_poisson.pyx index a93e3263..5b1ad183 100644 --- a/direct/common/_poisson.pyx +++ b/direct/common/_poisson.pyx @@ -38,7 +38,7 @@ def poisson( int nx, int ny, int max_attempts, - cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_x, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_y, int seed @@ -62,8 +62,8 @@ def poisson( cdef Py_ssize_t startx, endx, starty, endy, px, py # initialize active list - cdef cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) - cdef cnp.ndarray[cnp.int32_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) srand(seed) diff --git a/direct/ssl/_gaussian_fill.pyx b/direct/ssl/_gaussian_fill.pyx index 406d99c5..d880916a 100644 --- a/direct/ssl/_gaussian_fill.pyx +++ b/direct/ssl/_gaussian_fill.pyx @@ -45,8 +45,8 @@ def gaussian_fill( int center_x, int center_y, double std_scale, - cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] mask, - cnp.ndarray[cnp.int32_t, ndim=2, mode='c'] output_mask, + cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int_t, ndim=2, mode='c'] output_mask, int seed, ): cdef int count, indx, indy From 175371b2d71283938297d241ba46f385bd667da0 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 10:46:32 +0200 Subject: [PATCH 32/37] Try int64_t --- direct/common/_gaussian.pyx | 4 ++-- direct/common/_poisson.pyx | 6 +++--- direct/ssl/_gaussian_fill.pyx | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/direct/common/_gaussian.pyx b/direct/common/_gaussian.pyx index 0e6e8e1a..7625c745 100644 --- a/direct/common/_gaussian.pyx +++ b/direct/common/_gaussian.pyx @@ -57,7 +57,7 @@ def gaussian_mask_1d( int n, int center, double std, - cnp.ndarray[cnp.int_t, ndim=1, mode='c'] mask, + cnp.ndarray[cnp.int64_t, ndim=1, mode='c'] mask, int seed, ): cdef int count, ind @@ -82,7 +82,7 @@ def gaussian_mask_2d( int center_x, int center_y, cnp.ndarray[cnp.float_t, ndim=1, mode='c'] std, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int64_t, ndim=2, mode='c'] mask, int seed, ): cdef int count, indx, indy diff --git a/direct/common/_poisson.pyx b/direct/common/_poisson.pyx index 5b1ad183..59d78b7e 100644 --- a/direct/common/_poisson.pyx +++ b/direct/common/_poisson.pyx @@ -38,7 +38,7 @@ def poisson( int nx, int ny, int max_attempts, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int64_t, ndim=2, mode='c'] mask, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_x, cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] radius_y, int seed @@ -62,8 +62,8 @@ def poisson( cdef Py_ssize_t startx, endx, starty, endy, px, py # initialize active list - cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) - cdef cnp.ndarray[cnp.int_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int64_t, ndim=1, mode='c'] pxs = np.empty(nx * ny, dtype=int) + cdef cnp.ndarray[cnp.int64_t, ndim=1, mode='c'] pys = np.empty(nx * ny, dtype=int) srand(seed) diff --git a/direct/ssl/_gaussian_fill.pyx b/direct/ssl/_gaussian_fill.pyx index d880916a..23d27c92 100644 --- a/direct/ssl/_gaussian_fill.pyx +++ b/direct/ssl/_gaussian_fill.pyx @@ -45,8 +45,8 @@ def gaussian_fill( int center_x, int center_y, double std_scale, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] mask, - cnp.ndarray[cnp.int_t, ndim=2, mode='c'] output_mask, + cnp.ndarray[cnp.int64_t, ndim=2, mode='c'] mask, + cnp.ndarray[cnp.int64_t, ndim=2, mode='c'] output_mask, int seed, ): cdef int count, indx, indy From 716ea39f814aadab4631eea5fb7ed760ddfc527d Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 10:48:23 +0200 Subject: [PATCH 33/37] Typo --- direct/common/_gaussian.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/direct/common/_gaussian.pyx b/direct/common/_gaussian.pyx index 7625c745..e319728b 100644 --- a/direct/common/_gaussian.pyx +++ b/direct/common/_gaussian.pyx @@ -10,7 +10,7 @@ import numpy as np -z +cimport numpy as cnp from libc.math cimport cos, log, pi, sin, sqrt from libc.stdlib cimport RAND_MAX, rand, srand From 202dbaf409bf0b1ef51c4d16a32cd77f1fd6196e Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 10:59:17 +0200 Subject: [PATCH 34/37] Deprecated np.product -> np.prod --- tests/tests_data/test_transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/tests_data/test_transforms.py b/tests/tests_data/test_transforms.py index 616b5b11..2f7ebf83 100644 --- a/tests/tests_data/test_transforms.py +++ b/tests/tests_data/test_transforms.py @@ -212,7 +212,7 @@ def test_complex_center_crop(shape, target_shape): ], ) def test_roll(shift, dims, shape): - data = np.arange(np.product(shape)).reshape(shape) + data = np.arange(np.prod(shape)).reshape(shape) torch_tensor = torch.from_numpy(data) if not isinstance(shift, int) and not isinstance(dims, int) and len(shift) != len(dims): with pytest.raises(ValueError): @@ -232,7 +232,7 @@ def test_roll(shift, dims, shape): ], ) def test_complex_multiplication(shape): - data_0 = np.arange(np.product(shape)).reshape(shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) + data_0 = np.arange(np.prod(shape)).reshape(shape) + 1j * (np.arange(np.prod(shape)).reshape(shape) + 1) data_1 = data_0 + 0.5 + 1j torch_tensor_0 = transforms.to_tensor(data_0) torch_tensor_1 = transforms.to_tensor(data_1) @@ -247,8 +247,8 @@ def test_complex_multiplication(shape): [[3, 7], [5, 6, 2], [3, 4, 5], [4, 20, 42], [3, 4, 20, 40]], ) def test_complex_division(shape): - data_0 = np.arange(np.product(shape)).reshape(shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) - data_1 = np.arange(np.product(shape)).reshape(shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) + data_0 = np.arange(np.prod(shape)).reshape(shape) + 1j * (np.arange(np.prod(shape)).reshape(shape) + 1) + data_1 = np.arange(np.prod(shape)).reshape(shape) + 1j * (np.arange(np.prod(shape)).reshape(shape) + 1) torch_tensor_0 = transforms.to_tensor(data_0) torch_tensor_1 = transforms.to_tensor(data_1) out_torch = tensor_to_complex_numpy(transforms.complex_division(torch_tensor_0, torch_tensor_1)) @@ -369,7 +369,7 @@ def test_complex_bmm(shapes, batch_size): ], ) def test_conjugate(shape): - data = np.arange(np.product(shape)).reshape(shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) + data = np.arange(np.prod(shape)).reshape(shape) + 1j * (np.arange(np.prod(shape)).reshape(shape) + 1) torch_tensor = transforms.to_tensor(data) out_torch = tensor_to_complex_numpy(transforms.conjugate(torch_tensor)) @@ -379,7 +379,7 @@ def test_conjugate(shape): @pytest.mark.parametrize("shape", [[5, 3], [2, 4, 6], [2, 11, 4, 7]]) def test_fftshift(shape): - data = np.arange(np.product(shape)).reshape(shape) + data = np.arange(np.prod(shape)).reshape(shape) torch_tensor = torch.from_numpy(data) out_torch = transforms.fftshift(torch_tensor).numpy() out_numpy = np.fft.fftshift(data) @@ -395,7 +395,7 @@ def test_fftshift(shape): ], ) def test_ifftshift(shape): - data = np.arange(np.product(shape)).reshape(shape) + data = np.arange(np.prod(shape)).reshape(shape) torch_tensor = torch.from_numpy(data) out_torch = transforms.ifftshift(torch_tensor).numpy() out_numpy = np.fft.ifftshift(data) From 341ceb78597562e1cc4d9b71130f34623f3104f7 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 11:17:09 +0200 Subject: [PATCH 35/37] Pylint failing --- direct/functionals/challenges.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/direct/functionals/challenges.py b/direct/functionals/challenges.py index 34105e18..70f600d2 100644 --- a/direct/functionals/challenges.py +++ b/direct/functionals/challenges.py @@ -1,8 +1,12 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors + +""" Direct metrics for the FastMRI and Calgary-Campinas challenges.""" + import numpy as np import torch +from skimage.metrics import structural_similarity, peak_signal_noise_ratio + __all__ = ( "fastmri_ssim", "fastmri_psnr", @@ -21,7 +25,6 @@ def _to_numpy(tensor): def fastmri_ssim(gt, target): """Compute Structural Similarity Index Measure (SSIM) compatible with the FastMRI challenge.""" - from skimage.metrics import structural_similarity gt = _to_numpy(gt)[:, 0, ...] target = _to_numpy(target)[:, 0, ...] @@ -70,14 +73,10 @@ def _calgary_campinas_metric(gt, pred, metric_func): def calgary_campinas_ssim(gt, pred): - from skimage.metrics import structural_similarity - return _calgary_campinas_metric(gt, pred, structural_similarity) def calgary_campinas_psnr(gt, pred): - from skimage.metrics import peak_signal_noise_ratio - return _calgary_campinas_metric(gt, pred, peak_signal_noise_ratio) From e50b77bfbd968b982767f87cafd14d68f21e9fb7 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 11:25:35 +0200 Subject: [PATCH 36/37] Pylint failing --- direct/functionals/challenges.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/direct/functionals/challenges.py b/direct/functionals/challenges.py index 70f600d2..03614cb3 100644 --- a/direct/functionals/challenges.py +++ b/direct/functionals/challenges.py @@ -3,10 +3,9 @@ """ Direct metrics for the FastMRI and Calgary-Campinas challenges.""" import numpy as np +import skimage.metrics import torch -from skimage.metrics import structural_similarity, peak_signal_noise_ratio - __all__ = ( "fastmri_ssim", "fastmri_psnr", @@ -28,7 +27,7 @@ def fastmri_ssim(gt, target): gt = _to_numpy(gt)[:, 0, ...] target = _to_numpy(target)[:, 0, ...] - out = structural_similarity( + out = skimage.metrics.structural_similarity( gt.transpose(1, 2, 0), target.transpose(1, 2, 0), channel_axis=-1, @@ -41,9 +40,8 @@ def fastmri_psnr(gt, pred): """Compute Peak Signal to Noise Ratio metric (PSNR) compatible with the FastMRI challenge.""" gt = _to_numpy(gt)[:, 0, ...] pred = _to_numpy(pred)[:, 0, ...] - from skimage.metrics import peak_signal_noise_ratio as psnr - out = psnr(image_true=gt, image_test=pred, data_range=gt.max()) + out = skimage.metrics.peak_signal_noise_ratio(image_true=gt, image_test=pred, data_range=gt.max()) return torch.from_numpy(np.array(out)).float() @@ -73,11 +71,11 @@ def _calgary_campinas_metric(gt, pred, metric_func): def calgary_campinas_ssim(gt, pred): - return _calgary_campinas_metric(gt, pred, structural_similarity) + return _calgary_campinas_metric(gt, pred, skimage.metrics.structural_similarity) def calgary_campinas_psnr(gt, pred): - return _calgary_campinas_metric(gt, pred, peak_signal_noise_ratio) + return _calgary_campinas_metric(gt, pred, skimage.metrics.peak_signal_noise_ratio) def calgary_campinas_vif(gt, pred): From f3aa68ce02ccf4904c46f12cfd67324429e7a6f8 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Mon, 24 Jun 2024 11:51:37 +0200 Subject: [PATCH 37/37] Try pylint 3.10 --- .github/workflows/pylint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index d9381901..4c2321c2 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v1 with: - python-version: '3.9' + python-version: '3.10' - name: Install dependencies run: | python -m pip install --upgrade pip