diff --git a/monai/utils/nvtx.py b/monai/utils/nvtx.py index 2dfbd03529..1980ceef71 100644 --- a/monai/utils/nvtx.py +++ b/monai/utils/nvtx.py @@ -92,7 +92,7 @@ def _decorate_method(self, obj, method, append_method_name): name = self.name # Get the class for special functions - if method.startswith("_"): + if method.startswith("__"): owner = type(obj) else: owner = obj @@ -109,7 +109,16 @@ def range_wrapper(*args, **kwargs): return output # Replace the method with the wrapped version - setattr(owner, method, range_wrapper) + if method.startswith("__"): + # If it is a special method, it requires special attention + class NVTXRangeDecoratedClass(owner): + ... + + setattr(NVTXRangeDecoratedClass, method, range_wrapper) + obj.__class__ = NVTXRangeDecoratedClass + + else: + setattr(owner, method, range_wrapper) def _get_method(self, obj: Any) -> tuple: if isinstance(obj, Module):