Skip to content

Commit

Permalink
Enum typing doesn't require checks
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 19, 2024
1 parent 8a52b24 commit 0d81470
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,27 +825,26 @@ def __init__(
gaussian_sigma: float, optional
If non-zero, acs_image well be calculated
espirit_threshold: float, optional
Threshold for the calibration matrix when `type_of_map`=="espirit". Default: 0.05.
Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
Default: 0.05.
espirit_kernel_size: int, optional
Kernel size for the calibration matrix when `type_of_map`=="espirit". Default: 6.
Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
Default: 6.
espirit_crop: float, optional
Output eigenvalue cropping threshold when `type_of_map`=="espirit". Default: 0.95.
Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
Default: 0.95.
espirit_max_iters: int, optional
Power method iterations when `type_of_map`=="espirit". Default: 30.
Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
"""
super().__init__()
self.backward_operator = backward_operator
self.kspace_key = kspace_key
if type_of_map not in ["unit", "rss_estimate", "espirit"]:
raise ValueError(
f"Expected type of map to be either `unit`, `rss_estimate`, `espirit`. Got {type_of_map}."
)
self.type_of_map = type_of_map

# RSS estimate attributes
self.gaussian_sigma = gaussian_sigma
# Espirit attributes
if type_of_map == "espirit":
if type_of_map == SensitivityMapType.ESPIRIT:
self.espirit_calibrator = EspiritCalibration(
backward_operator,
espirit_threshold,
Expand Down Expand Up @@ -917,7 +916,7 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
sample: dict[str, Any]
Sample with key "sensitivity_map" with value the estimated sensitivity map.
"""
if self.type_of_map == "unit":
if self.type_of_map == SensitivityMapType.UNIT:
kspace = sample[self.kspace_key]
sensitivity_map = torch.zeros(kspace.shape).float()
# Assumes complex channel is last
Expand All @@ -926,7 +925,7 @@ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
# Shape (coil, height, width, complex=2)
sensitivity_map = sensitivity_map.to(kspace.device)

elif self.type_of_map == "rss_estimate":
elif self.type_of_map == SensitivityMapType.RSS_ESTIMATE:
# Shape (batch, coil, height, width, complex=2)
acs_image = self.estimate_acs_image(sample)
# Shape (batch, height, width)
Expand Down Expand Up @@ -1736,13 +1735,14 @@ def build_supervised_mri_transforms(
sensitivity_maps_gaussian : float
Optional sigma for gaussian weighting of sensitivity map.
sensitivity_maps_espirit_threshold : float, optional
Threshold for the calibration matrix when `type_of_map` is equal to "espirit". Default: 0.05.
Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
Default: 0.05.
sensitivity_maps_espirit_kernel_size : int, optional
Kernel size for the calibration matrix when `type_of_map` is equal to "espirit". Default: 6.
Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
sensitivity_maps_espirit_crop : float, optional
Output eigenvalue cropping threshold when `type_of_map` is equal to "espirit". Default: 0.95.
Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
sensitivity_maps_espirit_max_iters : int, optional
Power method iterations when `type_of_map` is equal to "espirit". Default: 30.
Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
delete_acs_mask : bool
If True will delete key `acs_mask`. Default: True.
delete_kspace : bool
Expand Down Expand Up @@ -1971,13 +1971,14 @@ def build_mri_transforms(
sensitivity_maps_gaussian : float
Optional sigma for gaussian weighting of sensitivity map.
sensitivity_maps_espirit_threshold : float, optional
Threshold for the calibration matrix when `type_of_map` is equal to "espirit". Default: 0.05.
Threshold for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`.
Default: 0.05.
sensitivity_maps_espirit_kernel_size : int, optional
Kernel size for the calibration matrix when `type_of_map` is equal to "espirit". Default: 6.
Kernel size for the calibration matrix when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 6.
sensitivity_maps_espirit_crop : float, optional
Output eigenvalue cropping threshold when `type_of_map` is equal to "espirit". Default: 0.95.
Output eigenvalue cropping threshold when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 0.95.
sensitivity_maps_espirit_max_iters : int, optional
Power method iterations when `type_of_map` is equal to "espirit". Default: 30.
Power method iterations when `type_of_map` is set to `SensitivityMapType.ESPIRIT`. Default: 30.
delete_acs_mask : bool
If True will delete key `acs_mask`. Default: True.
delete_kspace : bool
Expand Down

0 comments on commit 0d81470

Please sign in to comment.