Skip to content

Commit

Permalink
Code quality fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 18, 2024
1 parent daa28c5 commit 83ab7fb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
5 changes: 2 additions & 3 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,12 +2055,11 @@ def build_mri_transforms(
use_seed=use_seed,
).transforms

mri_transforms += [AddBooleanKeysModule(["is_ssl"], [not transforms_type == TranformsType.SUPERVISED])]
mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TranformsType.SUPERVISED])]

if transforms_type == TranformsType.SUPERVISED:
return Compose(mri_transforms)

if transforms_type == TranformsType.SSL_SSDU:
elif transforms_type == TranformsType.SSL_SSDU:
mask_splitter_kwargs = {
"ratio": mask_split_ratio,
"acs_region": mask_split_acs_region,
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/ssl/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def log_first_training_example_and_model(self, data: dict[str, Any]) -> None:
"""
storage = get_event_storage()

self.logger.info(f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}.")
self.logger.info("First case: slice_no: %s, filename: %s.", data["slice_no"][0], data["filename"][0])

if "input_sampling_mask" in data:
first_input_sampling_mask = data["input_sampling_mask"][0][0]
Expand Down
17 changes: 11 additions & 6 deletions direct/ssl/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ def _gaussian_split(
center_x = nrow // 2
center_y = ncol // 2

if self.keep_acs and acs_mask is None:
raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.")
mask = mask.clone() if not self.keep_acs else mask.clone() & (~acs_mask)
if self.keep_acs:
if acs_mask is None:
raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.")
else:
mask = mask & (~acs_mask)

with temp_seed(self.rng, seed):
if seed is None:
Expand Down Expand Up @@ -273,9 +275,12 @@ def _uniform_split(
center_x = nrow // 2
center_y = ncol // 2

if self.keep_acs and acs_mask is None:
raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.")
mask = mask.clone() if not self.keep_acs else mask.clone() & (~acs_mask)
if self.keep_acs:
if acs_mask is None:
raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.")
else:
mask = mask & (~acs_mask)

temp_mask = mask.cpu().clone()

if not self.keep_acs:
Expand Down

0 comments on commit 83ab7fb

Please sign in to comment.