Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def main():
pass

if __name__ == '__main__':
obj_watch = objwatch.watch(['distributed_module.py'], ranks=[0, 1, 2, 3], output='./dist.log, simple=False)
obj_watch = objwatch.watch(['distributed_module.py'], ranks=[0, 1, 2, 3], output='./dist.log', simple=False)
main()
obj_watch.stop()
```
Expand Down
43 changes: 43 additions & 0 deletions objwatch/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# MIT License
# Copyright (c) 2025 aeeeeeep


from enum import Enum
from types import FunctionType

try:
from types import NoneType # type: ignore
except ImportError:
NoneType = type(None) # type: ignore


class Constants:
"""
Constants class for managing magic values and configuration parameters in ObjWatch project.
"""

# Target processing related constants
MAX_TARGETS_DISPLAY = 8 # Maximum number of targets to display before truncation

# Sequence formatting related constants
MAX_SEQUENCE_ELEMENTS = 3 # Maximum number of elements to display when formatting sequences

# Logging related constants
LOG_INDENT_LEVEL = 2 # Default indentation level for JSON serialization

# Log element types
# Define types that are directly loggable
LOG_ELEMENT_TYPES = (
bool,
int,
float,
str,
NoneType,
FunctionType,
Enum,
)

# Log sequence types
# Define sequence types for logging
LOG_SEQUENCE_TYPES = (list, set, dict, tuple)
LOG_INDENT_LEVEL = 2
42 changes: 12 additions & 30 deletions objwatch/event_handls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,12 @@
import signal
import atexit
import xml.etree.ElementTree as ET
from enum import Enum
from types import FunctionType
from typing import Any, Optional

try:
from types import NoneType # type: ignore
except ImportError:
NoneType = type(None) # type: ignore

from typing import Any, Dict, Optional, Type
from .utils.logger import log_error, log_debug, log_warn, log_info
from .constants import Constants
from .events import EventType


# Define types that are directly loggable
log_element_types = (
bool,
int,
float,
str,
NoneType,
FunctionType,
Enum,
)

# Define sequence types for logging
log_sequence_types = (list, set, dict, tuple)
from .utils.logger import log_error, log_debug, log_warn, log_info


class EventHandls:
Expand Down Expand Up @@ -273,7 +253,9 @@ def determine_change_type(self, old_value_len: int, current_value_len: int) -> O
return None

@staticmethod
def format_sequence(seq: Any, max_elements: int = 3, func: Optional[FunctionType] = None) -> str:
def format_sequence(
seq: Any, max_elements: int = Constants.MAX_SEQUENCE_ELEMENTS, func: Optional[FunctionType] = None
) -> str:
"""
Format a sequence to display a limited number of elements.

Expand All @@ -290,21 +272,21 @@ def format_sequence(seq: Any, max_elements: int = 3, func: Optional[FunctionType
return f'({type(seq).__name__})[]'
display: Optional[list] = None
if isinstance(seq, list):
if all(isinstance(x, log_element_types) for x in seq[:max_elements]):
if all(isinstance(x, Constants.LOG_ELEMENT_TYPES) for x in seq[:max_elements]):
display = seq[:max_elements]
elif func is not None:
display = func(seq[:max_elements])
elif isinstance(seq, (set, tuple)):
seq_list = list(seq)[:max_elements]
if all(isinstance(x, log_element_types) for x in seq_list):
if all(isinstance(x, Constants.LOG_ELEMENT_TYPES) for x in seq_list):
display = seq_list
elif func is not None:
display = func(seq_list)
elif isinstance(seq, dict):
seq_keys = list(seq.keys())[:max_elements]
seq_values = list(seq.values())[:max_elements]
if all(isinstance(x, log_element_types) for x in seq_keys) and all(
isinstance(x, log_element_types) for x in seq_values
if all(isinstance(x, Constants.LOG_ELEMENT_TYPES) for x in seq_keys) and all(
isinstance(x, Constants.LOG_ELEMENT_TYPES) for x in seq_values
):
display = list(seq.items())[:max_elements]
elif func is not None:
Expand Down Expand Up @@ -333,9 +315,9 @@ def _format_value(value: Any) -> str:
Returns:
str: The formatted value string.
"""
if isinstance(value, log_element_types):
if isinstance(value, Constants.LOG_ELEMENT_TYPES):
return f"{value}"
elif isinstance(value, log_sequence_types):
elif isinstance(value, Constants.LOG_SEQUENCE_TYPES):
return EventHandls.format_sequence(value)
else:
try:
Expand Down
3 changes: 2 additions & 1 deletion objwatch/mp_handls.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# MIT License
# Copyright (c) 2025 aeeeeeep

from typing import Callable, Optional, Union
from types import FunctionType
from typing import Callable, Optional, Union

from .utils.logger import log_error, log_info


Expand Down
7 changes: 4 additions & 3 deletions objwatch/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import ast
import json
import inspect
import pkgutil
import importlib
import importlib.util
import pkgutil
from types import ModuleType, MethodType, FunctionType
from typing import Optional, Tuple, List, Union, Set

from .constants import Constants
from .utils.logger import log_error, log_warn

ClassType = type
Expand Down Expand Up @@ -534,7 +535,7 @@ def get_processed_targets(self) -> dict:
"""
return self.processed_targets

def serialize_targets(self, indent=2):
def serialize_targets(self, indent=Constants.LOG_INDENT_LEVEL):
"""Serialize objects that JSON cannot handle by default.

Converts sets to lists, and other objects to their __dict__ or string representation.
Expand All @@ -555,7 +556,7 @@ def target_handler(o):
return o.__dict__
return str(o)

if len(self.processed_targets) > 8:
if len(self.processed_targets) > Constants.MAX_TARGETS_DISPLAY:
truncated_obj = {key: "..." for key in self.processed_targets.keys()}
truncated_obj["Warning: too many top-level keys, only showing values like"] = "..."
return json.dumps(truncated_obj, indent=indent, default=target_handler)
Expand Down
17 changes: 9 additions & 8 deletions objwatch/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from types import FrameType
from typing import Optional, Any, Dict, Set

from .constants import Constants
from .config import ObjWatchConfig
from .targets import Targets
from .wrappers import ABCWrapper
from .events import EventType
from .event_handls import EventHandls, log_sequence_types
from .event_handls import EventHandls
from .mp_handls import MPHandls
from .utils.logger import log_info, log_debug, log_warn, log_error
from .utils.weak import WeakIdKeyDictionary
from .utils.logger import log_info, log_debug, log_warn, log_error


class Tracer:
Expand Down Expand Up @@ -402,7 +403,7 @@ def _update_objects_lens(self, frame: FrameType) -> None:
if obj not in self.tracked_objects_lens:
self.tracked_objects_lens[obj] = {}
for k, v in attrs.items():
if isinstance(v, log_sequence_types):
if isinstance(v, Constants.LOG_SEQUENCE_TYPES):
self.tracked_objects_lens[obj][k] = len(v)

def _get_function_info(self, frame: FrameType) -> dict:
Expand Down Expand Up @@ -530,7 +531,7 @@ def _track_object_change(self, frame: FrameType, lineno: int):

old_value = old_attrs.get(key, None)
old_value_len = old_attrs_lens.get(key, None)
is_current_seq = isinstance(current_value, log_sequence_types)
is_current_seq = isinstance(current_value, Constants.LOG_SEQUENCE_TYPES)
current_value_len = len(current_value) if old_value_len is not None and is_current_seq else None

self._handle_change_type(
Expand Down Expand Up @@ -577,15 +578,15 @@ def _track_locals_change(self, frame: FrameType, lineno: int):
abc_wrapper=self.abc_wrapper,
)

if isinstance(current_local, log_sequence_types):
if isinstance(current_local, Constants.LOG_SEQUENCE_TYPES):
self.tracked_locals_lens[frame][var] = len(current_local)

common_vars = set(old_locals.keys()) & set(current_locals.keys())
for var in common_vars:
old_local = old_locals[var]
old_local_len = old_locals_lens.get(var, None)
current_local = current_locals[var]
is_current_seq = isinstance(current_local, log_sequence_types)
is_current_seq = isinstance(current_local, Constants.LOG_SEQUENCE_TYPES)
current_local_len = len(current_local) if old_local_len is not None and is_current_seq else None

self._handle_change_type(lineno, "_", var, old_local, current_local, old_local_len, current_local_len)
Expand Down Expand Up @@ -621,7 +622,7 @@ def _track_globals_change(self, frame: FrameType, lineno: int):

old_value = self.tracked_globals[module_name].get(key, None)
old_value_len = self.tracked_globals_lens[module_name].get(key, None)
is_current_seq = isinstance(current_value, log_sequence_types)
is_current_seq = isinstance(current_value, Constants.LOG_SEQUENCE_TYPES)
current_value_len = len(current_value) if old_value_len is not None and is_current_seq else None

self._handle_change_type(lineno, "@", key, old_value, current_value, old_value_len, current_value_len)
Expand Down Expand Up @@ -679,7 +680,7 @@ def trace_func(frame: FrameType, event: str, arg: Any):
self.tracked_locals[frame] = local_vars
self.tracked_locals_lens[frame] = {}
for var, value in local_vars.items():
if isinstance(value, log_sequence_types):
if isinstance(value, Constants.LOG_SEQUENCE_TYPES):
self.tracked_locals_lens[frame][var] = len(value)

return trace_func
Expand Down
13 changes: 7 additions & 6 deletions objwatch/wrappers/abc_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# MIT License
# Copyright (c) 2025 aeeeeeep

from abc import ABC, abstractmethod
from types import FrameType
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from abc import ABC, abstractmethod

from ..event_handls import log_element_types, log_sequence_types, EventHandls
from ..constants import Constants
from ..event_handls import EventHandls


class ABCWrapper(ABC):
Expand Down Expand Up @@ -109,9 +110,9 @@ def _format_value(self, value: Any, is_return: bool = False) -> str:
Returns:
str: Formatted value string.
"""
if isinstance(value, log_element_types):
if isinstance(value, Constants.LOG_ELEMENT_TYPES):
formatted = f"{value}"
elif isinstance(value, log_sequence_types):
elif isinstance(value, Constants.LOG_SEQUENCE_TYPES):
formatted_sequence = EventHandls.format_sequence(value, func=self.format_sequence_func)
if formatted_sequence:
formatted = f"{formatted_sequence}"
Expand All @@ -124,7 +125,7 @@ def _format_value(self, value: Any, is_return: bool = False) -> str:
formatted = f"(type){type(value).__name__}"

if is_return:
if isinstance(value, log_sequence_types) and formatted:
if isinstance(value, Constants.LOG_SEQUENCE_TYPES) and formatted:
return f"[{formatted}]"
return f"{formatted}"
return formatted
Expand Down
3 changes: 2 additions & 1 deletion objwatch/wrappers/cpu_memory_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import psutil
from types import FrameType
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple

from .abc_wrapper import ABCWrapper


Expand Down
10 changes: 5 additions & 5 deletions objwatch/wrappers/tensor_shape_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from types import FrameType
from typing import Any, List, Optional, Tuple


from ..event_handls import log_element_types, log_sequence_types, EventHandls
from ..constants import Constants
from ..event_handls import EventHandls
from .abc_wrapper import ABCWrapper

try:
Expand Down Expand Up @@ -95,9 +95,9 @@ def _format_value(self, value: Any, is_return: bool = False) -> str:
"""
if torch is not None and isinstance(value, torch.Tensor):
formatted = f"{value.shape}"
elif isinstance(value, log_element_types):
elif isinstance(value, Constants.LOG_ELEMENT_TYPES):
formatted = f"{value}"
elif isinstance(value, log_sequence_types):
elif isinstance(value, Constants.LOG_SEQUENCE_TYPES):
formatted_sequence = EventHandls.format_sequence(value, func=self.format_sequence_func)
if formatted_sequence:
formatted = f"{formatted_sequence}"
Expand All @@ -112,7 +112,7 @@ def _format_value(self, value: Any, is_return: bool = False) -> str:
if is_return:
if isinstance(value, torch.Tensor):
return f"{value.shape}"
elif isinstance(value, log_sequence_types) and formatted:
elif isinstance(value, Constants.LOG_SEQUENCE_TYPES) and formatted:
return f"[{formatted}]"
return f"{formatted}"
return formatted
3 changes: 2 additions & 1 deletion objwatch/wrappers/torch_memory_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Copyright (c) 2025 aeeeeeep

from types import FrameType
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple

from .abc_wrapper import ABCWrapper

try:
Expand Down
2 changes: 0 additions & 2 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
-f https://download.pytorch.org/whl/torch_stable.html

psutil
pytest
mypy