Skip to content

Commit

Permalink
Sanitize argument-free object params before logging (#19771)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
V0XNIHILI and awaelchli committed Jun 6, 2024
1 parent a611de0 commit 4f96c83
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added sanitization for classes before logging them as hyperparameters ([#19771](https://github.com/Lightning-AI/pytorch-lightning/pull/19771))

- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))

- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/fabric/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
from argparse import Namespace
from dataclasses import asdict, is_dataclass
Expand Down Expand Up @@ -52,8 +54,11 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""

def _sanitize_callable(val: Any) -> Any:
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if inspect.isclass(val):
# If it's a class, don't try to instantiate it, just return the name
return val.__name__
if callable(val):
# Callables get a chance to return a name
try:
_val = val()
if callable(_val):
Expand Down
14 changes: 13 additions & 1 deletion tests/tests_fabric/utilities/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class B:


def test_sanitize_callable_params():
"""Callback function are not serializiable.
"""Callback functions are not serializable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
Expand All @@ -104,11 +104,21 @@ def return_something():
def wrapper_something():
return return_something

class ClassNoArgs:
def __init__(self):
pass

class ClassWithCall:
def __call__(self):
return "name"

params = Namespace(
foo="bar",
something=return_something,
wrapper_something_wo_name=(lambda: lambda: "1"),
wrapper_something=wrapper_something,
class_no_args=ClassNoArgs,
class_with_call=ClassWithCall,
)

params = _convert_params(params)
Expand All @@ -118,6 +128,8 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
assert params["class_no_args"] == "ClassNoArgs"
assert params["class_with_call"] == "ClassWithCall"


def test_sanitize_params():
Expand Down

0 comments on commit 4f96c83

Please sign in to comment.