From b1e4a50d437ea96499aac709eaf50ce3f0e84232 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Wed, 26 Nov 2025 10:11:42 +0000 Subject: [PATCH 01/12] Perceptual loss changes. --- monai/losses/perceptual.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index ee653fac9d..3739e0e9f6 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -18,12 +18,16 @@ from monai.utils import optional_import from monai.utils.enums import StrEnum +from huggingface_hub import hf_hub_download LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") class PercetualNetworkType(StrEnum): + """Types of neural networks that are supported by perceptua loss. + """ + alex = "alex" vgg = "vgg" squeeze = "squeeze" @@ -108,9 +112,12 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module + + # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise + net=network_type, verbose=False, channel_wise=channel_wise, + cache_dir=cache_dir ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) @@ -122,7 +129,9 @@ def __init__( pretrained_state_dict_key=pretrained_state_dict_key, ) else: + # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d self.fake_3d_ratio = fake_3d_ratio self.channel_wise = channel_wise @@ -194,7 +203,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from - "Warvito/MedicalNet-models". + "Project-MONAI/perceptual-models". Args: net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} @@ -205,11 +214,12 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ def __init__( - self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False, + cache_dir: str | None = None, ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) + self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir) self.eval() self.channel_wise = channel_wise @@ -287,7 +297,7 @@ class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class - uses torch Hub to download the networks from "Warvito/radimagenet-models". + uses torch Hub to download the networks from "Project-MONAI/perceptual-models". Args: net: {``"radimagenet_resnet50"``} @@ -295,9 +305,12 @@ class RadImageNetPerceptualSimilarity(nn.Module): verbose: if false, mute messages from torch Hub load function. """ - def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + def __init__(self, net: str = "radimagenet_resnet50", + verbose: bool = False, + cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, + cache_dir=cache_dir) self.eval() for param in self.parameters(): From fa0639be83e9ebd97ed494174a7891500fadb067 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Wed, 26 Nov 2025 10:35:03 +0000 Subject: [PATCH 02/12] Fixes --- monai/losses/perceptual.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 3739e0e9f6..9cbf078271 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -15,18 +15,17 @@ import torch import torch.nn as nn +from huggingface_hub import hf_hub_download from monai.utils import optional_import from monai.utils.enums import StrEnum -from huggingface_hub import hf_hub_download LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") class PercetualNetworkType(StrEnum): - """Types of neural networks that are supported by perceptua loss. - """ + """Types of neural networks that are supported by perceptua loss.""" alex = "alex" vgg = "vgg" @@ -116,8 +115,7 @@ def __init__( # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise, - cache_dir=cache_dir + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) @@ -214,12 +212,17 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ def __init__( - self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False, + self, + net: str = "medicalnet_resnet_10_23datasets", + verbose: bool = False, + channel_wise: bool = False, cache_dir: str | None = None, ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir) + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir + ) self.eval() self.channel_wise = channel_wise @@ -305,12 +308,9 @@ class RadImageNetPerceptualSimilarity(nn.Module): verbose: if false, mute messages from torch Hub load function. """ - def __init__(self, net: str = "radimagenet_resnet50", - verbose: bool = False, - cache_dir: str | None = None) -> None: + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, - cache_dir=cache_dir) + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) self.eval() for param in self.parameters(): From 915de5fec45c87aafef2ef2ec36cea57a29900e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:59:08 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/perceptual.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 9cbf078271..6696847853 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn -from huggingface_hub import hf_hub_download from monai.utils import optional_import from monai.utils.enums import StrEnum From 5594bfe47f217134c35d29ca4e93f15f20ef7fe7 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Wed, 26 Nov 2025 10:35:03 +0000 Subject: [PATCH 04/12] Unnecessary import --- monai/losses/perceptual.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 3739e0e9f6..6696847853 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -18,15 +18,13 @@ from monai.utils import optional_import from monai.utils.enums import StrEnum -from huggingface_hub import hf_hub_download LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") class PercetualNetworkType(StrEnum): - """Types of neural networks that are supported by perceptua loss. - """ + """Types of neural networks that are supported by perceptua loss.""" alex = "alex" vgg = "vgg" @@ -116,8 +114,7 @@ def __init__( # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity( - net=network_type, verbose=False, channel_wise=channel_wise, - cache_dir=cache_dir + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) @@ -214,12 +211,17 @@ class MedicalNetPerceptualSimilarity(nn.Module): """ def __init__( - self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False, + self, + net: str = "medicalnet_resnet_10_23datasets", + verbose: bool = False, + channel_wise: bool = False, cache_dir: str | None = None, ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir) + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir + ) self.eval() self.channel_wise = channel_wise @@ -305,12 +307,9 @@ class RadImageNetPerceptualSimilarity(nn.Module): verbose: if false, mute messages from torch Hub load function. """ - def __init__(self, net: str = "radimagenet_resnet50", - verbose: bool = False, - cache_dir: str | None = None) -> None: + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, - cache_dir=cache_dir) + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) self.eval() for param in self.parameters(): From c99e16eb050848681d83e211205e8eaa998606c4 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 5 Dec 2025 17:01:50 +0000 Subject: [PATCH 05/12] Add check of network name --- monai/losses/perceptual.py | 41 ++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 6696847853..6de5d8bad4 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -17,12 +17,21 @@ import torch.nn as nn from monai.utils import optional_import + from monai.utils.enums import StrEnum +# Valid model name to download from the repository +HF_MONAI_MODELS = ( + "medicalnet_resnet10_23datasets", + "medicalnet_resnet50_23datasets", + "radimagenet_resnet50", +) + LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") + class PercetualNetworkType(StrEnum): """Types of neural networks that are supported by perceptua loss.""" @@ -86,13 +95,18 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: - raise ValueError( - "MedicalNet networks are only compatible with ``spatial_dims=3``." - "Argument is_fake_3d must be set to False." - ) - if channel_wise and "medicalnet_" not in network_type: + # Strict validation for MedicalNet + if "medicalnet_" in network_type: + if spatial_dims == 2 or is_fake_3d: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." + ) + if not channel_wise: + warnings.warn("MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.") + + # Channel-wise only for MedicalNet + elif channel_wise: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") if network_type.lower() not in list(PercetualNetworkType): @@ -219,8 +233,14 @@ def __init__( ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + if net not in HF_MONAI_MODELS: + raise ValueError( + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." + ) + self.model = torch.hub.load( - "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, + trust_repo=True, ) self.eval() @@ -309,7 +329,12 @@ class RadImageNetPerceptualSimilarity(nn.Module): def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) + if net not in HF_MONAI_MODELS: + raise ValueError( + f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." + ) + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir, + trust_repo=True) self.eval() for param in self.parameters(): From 717b99be246f4776f7bad8ba220738fa6d142d44 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:32:51 +0000 Subject: [PATCH 06/12] Update monai/losses/perceptual.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/losses/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 6de5d8bad4..9a04ffa776 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -33,7 +33,7 @@ class PercetualNetworkType(StrEnum): - """Types of neural networks that are supported by perceptua loss.""" + """Types of neural networks that are supported by perceptual loss.""" alex = "alex" vgg = "vgg" From b276f3ce309181540000e0ecfdeb3b54960b5066 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:34:17 +0000 Subject: [PATCH 07/12] Update monai/losses/perceptual.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/losses/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 9a04ffa776..efb3fa3c63 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -226,7 +226,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): def __init__( self, - net: str = "medicalnet_resnet_10_23datasets", + net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False, cache_dir: str | None = None, From 2156b847234954c498906ffe7b128442a119ed00 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:35:10 +0000 Subject: [PATCH 08/12] Update monai/losses/perceptual.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/losses/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index efb3fa3c63..d31788c633 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -333,7 +333,7 @@ def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cac raise ValueError( f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." ) - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir, + self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True) self.eval() From e2b982ea1b9514cba028e62a249cac7ba7910002 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:40:24 +0000 Subject: [PATCH 09/12] Update monai/losses/perceptual.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> --- monai/losses/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index d31788c633..847988ac11 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -131,7 +131,7 @@ def __init__( net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: - self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False, cache_dir=cache_dir) elif network_type == "resnet50": self.perceptual_function = TorchvisionModelPerceptualSimilarity( net=network_type, From e3be8de4cd7c5b643d856d676875f755c08c84d4 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 9 Dec 2025 16:04:35 +0000 Subject: [PATCH 10/12] Bug --- monai/losses/perceptual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 6de5d8bad4..1c6d3f82d4 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -290,7 +290,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: for i in range(input.shape[1]): l_idx = i * feats_per_ch r_idx = (i + 1) * feats_per_ch - results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) + results[:, i, ...] = feats_diff[:, l_idx : r_idx, ...].sum(dim=1) else: results = feats_diff.sum(dim=1, keepdim=True) From 6dfc2095985092b88750e7935d94d1e2e5857309 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 9 Dec 2025 16:10:53 +0000 Subject: [PATCH 11/12] DCO Remediation Commit for Virginia Fernandez I, Virginia Fernandez , hereby add my Signed-off-by to this commit: b1e4a50d437ea96499aac709eaf50ce3f0e84232 I, Virginia Fernandez , hereby add my Signed-off-by to this commit: fa0639be83e9ebd97ed494174a7891500fadb067 I, Virginia Fernandez , hereby add my Signed-off-by to this commit: 5594bfe47f217134c35d29ca4e93f15f20ef7fe7 I, Virginia Fernandez , hereby add my Signed-off-by to this commit: c99e16eb050848681d83e211205e8eaa998606c4 I, Virginia Fernandez , hereby add my Signed-off-by to this commit: e3be8de4cd7c5b643d856d676875f755c08c84d4 Signed-off-by: Virginia Fernandez --- monai/losses/perceptual.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index e2bd4d65ff..78cd3fc848 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -21,18 +21,18 @@ from monai.utils.enums import StrEnum # Valid model name to download from the repository -HF_MONAI_MODELS = ( +HF_MONAI_MODELS = frozenset(( "medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50", -) +)) LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") -class PercetualNetworkType(StrEnum): +class PerceptualNetworkType(StrEnum): """Types of neural networks that are supported by perceptual loss.""" alex = "alex" @@ -81,7 +81,7 @@ class PerceptualLoss(nn.Module): def __init__( self, spatial_dims: int, - network_type: str = PercetualNetworkType.alex, + network_type: str = PerceptualNetworkType.alex, is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, @@ -103,16 +103,16 @@ def __init__( "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." ) if not channel_wise: - warnings.warn("MedicalNet networks support channel-wise loss. Consider setting channel_wise=True.") + warnings.warn("MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2) # Channel-wise only for MedicalNet elif channel_wise: raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") - if network_type.lower() not in list(PercetualNetworkType): + if network_type.lower() not in list(PerceptualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join(PercetualNetworkType) + % ", ".join(PerceptualNetworkType) ) if cache_dir: @@ -232,7 +232,6 @@ def __init__( cache_dir: str | None = None, ) -> None: super().__init__() - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True if net not in HF_MONAI_MODELS: raise ValueError( f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." From d258390ffb6d65b711caea50ee665fa0132432a3 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 9 Dec 2025 16:18:11 +0000 Subject: [PATCH 12/12] Reformatting Signed-off-by: Virginia Fernandez --- monai/losses/perceptual.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 78cd3fc848..b2563aaf57 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -17,21 +17,17 @@ import torch.nn as nn from monai.utils import optional_import - from monai.utils.enums import StrEnum # Valid model name to download from the repository -HF_MONAI_MODELS = frozenset(( - "medicalnet_resnet10_23datasets", - "medicalnet_resnet50_23datasets", - "radimagenet_resnet50", -)) +HF_MONAI_MODELS = frozenset( + ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50") +) LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") - class PerceptualNetworkType(StrEnum): """Types of neural networks that are supported by perceptual loss.""" @@ -95,7 +91,6 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - # Strict validation for MedicalNet if "medicalnet_" in network_type: if spatial_dims == 2 or is_fake_3d: @@ -103,7 +98,9 @@ def __init__( "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." ) if not channel_wise: - warnings.warn("MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2) + warnings.warn( + "MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2 + ) # Channel-wise only for MedicalNet elif channel_wise: @@ -131,7 +128,9 @@ def __init__( net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir ) elif "radimagenet_" in network_type: - self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False, cache_dir=cache_dir) + self.perceptual_function = RadImageNetPerceptualSimilarity( + net=network_type, verbose=False, cache_dir=cache_dir + ) elif network_type == "resnet50": self.perceptual_function = TorchvisionModelPerceptualSimilarity( net=network_type, @@ -233,13 +232,10 @@ def __init__( ) -> None: super().__init__() if net not in HF_MONAI_MODELS: - raise ValueError( - f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." - ) + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") self.model = torch.hub.load( - "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, - trust_repo=True, + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True ) self.eval() @@ -289,7 +285,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: for i in range(input.shape[1]): l_idx = i * feats_per_ch r_idx = (i + 1) * feats_per_ch - results[:, i, ...] = feats_diff[:, l_idx : r_idx, ...].sum(dim=1) + results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1) else: results = feats_diff.sum(dim=1, keepdim=True) @@ -329,11 +325,10 @@ class RadImageNetPerceptualSimilarity(nn.Module): def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: super().__init__() if net not in HF_MONAI_MODELS: - raise ValueError( - f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}." + raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") + self.model = torch.hub.load( + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True ) - self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, - trust_repo=True) self.eval() for param in self.parameters():