From 864ab2eb570ea154a2c744b7e8d70e2092b00933 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 13 Nov 2025 15:26:01 -0800 Subject: [PATCH 1/3] Add `strip_logprobs` function --- src/art/utils/strip_logprobs.py | 48 +++++++ tests/unit/test_strip_logprobs.py | 218 ++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 src/art/utils/strip_logprobs.py create mode 100644 tests/unit/test_strip_logprobs.py diff --git a/src/art/utils/strip_logprobs.py b/src/art/utils/strip_logprobs.py new file mode 100644 index 00000000..2eaefa50 --- /dev/null +++ b/src/art/utils/strip_logprobs.py @@ -0,0 +1,48 @@ +import copy +import logging +import sys +from typing import Any + +logger = logging.getLogger(__name__) + + +def strip_logprobs(obj: Any) -> Any: + """ + Recursively remove 'logprobs' keys from nested data structures. + + Args: + obj: Any nested data structure + + Returns: + The same structure with 'logprobs' keys removed, or the original + object if deepcopy fails + """ + + try: + copied_obj = copy.deepcopy(obj) + except Exception as e: + logger.warning( + f"Failed to deepcopy object in strip_logprobs: {e}. " + "Returning original object unchanged." + ) + return obj + + result = _strip_logprobs(copied_obj) + + return result + + +def _strip_logprobs(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: _strip_logprobs(v) for k, v in obj.items() if k != "logprobs"} + elif isinstance(obj, (list, tuple)): + result = [_strip_logprobs(v) for v in obj] + return tuple(result) if isinstance(obj, tuple) else result + elif hasattr(obj, "__dict__"): + for k, v in obj.__dict__.items(): + if k == "logprobs": + setattr(obj, k, None) + else: + setattr(obj, k, _strip_logprobs(v)) + return obj + return obj diff --git a/tests/unit/test_strip_logprobs.py b/tests/unit/test_strip_logprobs.py new file mode 100644 index 00000000..f6da0196 --- /dev/null +++ b/tests/unit/test_strip_logprobs.py @@ -0,0 +1,218 @@ +"""Tests for strip_logprobs utility function.""" + +import copy +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from art.utils.strip_logprobs import strip_logprobs + + +class TestStripLogprobs: + """Test suite for strip_logprobs function.""" + + def test_strip_dict_with_logprobs(self): + """Test stripping logprobs from dictionary.""" + input_dict = { + "data": "value", + "logprobs": [0.1, 0.2, 0.3], + "nested": {"key": "val", "logprobs": {"nested_log": 0.5}}, + } + expected = {"data": "value", "nested": {"key": "val"}} + + with patch("builtins.print"): # Suppress debug prints + result = strip_logprobs(input_dict) + + assert result == expected + assert input_dict["logprobs"] == [0.1, 0.2, 0.3] # Original unchanged + + def test_strip_nested_dict(self): + """Test stripping logprobs from deeply nested dictionaries.""" + input_dict = { + "level1": { + "level2": { + "level3": {"data": 1, "logprobs": "remove_me"}, + "logprobs": [1, 2, 3], + } + }, + "logprobs": None, + } + expected = {"level1": {"level2": {"level3": {"data": 1}}}} + + with patch("builtins.print"): + result = strip_logprobs(input_dict) + + assert result == expected + + def test_strip_list_with_logprobs(self): + """Test stripping logprobs from lists.""" + input_list = [ + {"item": 1, "logprobs": 0.1}, + {"item": 2, "logprobs": 0.2}, + {"item": 3}, + ] + expected = [{"item": 1}, {"item": 2}, {"item": 3}] + + with patch("builtins.print"): + result = strip_logprobs(input_list) + + assert result == expected + + def test_strip_tuple_with_logprobs(self): + """Test stripping logprobs from tuples.""" + input_tuple = ( + {"item": 1, "logprobs": 0.1}, + {"item": 2}, + {"nested": {"logprobs": "remove"}}, + ) + expected = ({"item": 1}, {"item": 2}, {"nested": {}}) + + with patch("builtins.print"): + result = strip_logprobs(input_tuple) + + assert result == expected + assert isinstance(result, tuple) + + def test_strip_object_with_logprobs(self): + """Test stripping logprobs from objects with __dict__.""" + + class TestObject: + def __init__(self): + self.data = "value" + self.logprobs = [0.1, 0.2] + self.nested = {"key": "val", "logprobs": "remove"} + + obj = TestObject() + with patch("builtins.print"): + result = strip_logprobs(obj) + + assert result.data == "value" + assert result.logprobs is None # Set to None for objects + assert result.nested == {"key": "val"} + + def test_strip_mixed_nested_structure(self): + """Test stripping logprobs from mixed nested structures.""" + input_data = { + "list": [ + {"logprobs": 1}, + [{"nested_list": True, "logprobs": 2}], + ], + "tuple": ({"logprobs": 3}, {"keep": "me"}), + "dict": {"nested": {"logprobs": 4, "data": "keep"}}, + } + expected = { + "list": [{}, [{"nested_list": True}]], + "tuple": ({}, {"keep": "me"}), + "dict": {"nested": {"data": "keep"}}, + } + + with patch("builtins.print"): + result = strip_logprobs(input_data) + + assert result == expected + + def test_strip_empty_structures(self): + """Test stripping logprobs from empty structures.""" + assert strip_logprobs({}) == {} + assert strip_logprobs([]) == [] + assert strip_logprobs(()) == () + + def test_strip_none_and_primitives(self): + """Test stripping logprobs from None and primitive values.""" + with patch("builtins.print"): + assert strip_logprobs(None) is None + assert strip_logprobs(42) == 42 + assert strip_logprobs("string") == "string" + assert strip_logprobs(3.14) == 3.14 + assert strip_logprobs(True) is True + + def test_no_logprobs_unchanged(self): + """Test that structures without logprobs remain unchanged.""" + input_dict = { + "data": "value", + "nested": {"key": "val"}, + "list": [1, 2, 3], + } + + with patch("builtins.print"): + result = strip_logprobs(input_dict) + + assert result == input_dict + + def test_deepcopy_behavior(self): + """Test that the function creates a deep copy.""" + nested_list = [1, 2, 3] + input_dict = { + "data": nested_list, + "logprobs": "remove", + } + + with patch("builtins.print"): + result = strip_logprobs(input_dict) + + result["data"].append(4) + assert nested_list == [1, 2, 3] # Original unchanged + assert result["data"] == [1, 2, 3, 4] + + def test_debug_output(self, capsys): + """Test that debug output is printed.""" + input_dict = {"test": "data", "logprobs": "remove"} + + strip_logprobs(input_dict) + captured = capsys.readouterr() + + assert "====== CALLED STRIP_LOGPROBS ======" in captured.out + assert "====== INPUT ======" in captured.out + assert "====== OUTPUT ======" in captured.out + assert "====== END STRIP_LOGPROBS ======" in captured.out + + def test_deepcopy_failure_returns_original(self, caplog): + """Test that deepcopy failure returns original object and logs warning.""" + + class UnCopyableObject: + def __init__(self): + self.data = "value" + self.logprobs = "should_remain" + + def __deepcopy__(self, memo): + raise RuntimeError("Cannot deepcopy this object") + + obj = UnCopyableObject() + + with caplog.at_level(logging.WARNING): + result = strip_logprobs(obj) + + # Should return the original object unchanged + assert result is obj + assert result.logprobs == "should_remain" + + # Check that warning was logged + assert len(caplog.records) == 1 + assert "Failed to deepcopy object in strip_logprobs" in caplog.text + assert "Cannot deepcopy this object" in caplog.text + assert "Returning original object unchanged" in caplog.text + + def test_deepcopy_failure_with_recursion_error(self, caplog): + """Test handling of RecursionError during deepcopy.""" + + class RecursiveObject: + def __init__(self): + self.data = "value" + self.logprobs = [1, 2, 3] + + def __deepcopy__(self, memo): + raise RecursionError("maximum recursion depth exceeded") + + obj = RecursiveObject() + + with caplog.at_level(logging.WARNING): + result = strip_logprobs(obj) + + # Should return the original object unchanged + assert result is obj + assert result.logprobs == [1, 2, 3] + + # Check that warning was logged + assert "Failed to deepcopy object in strip_logprobs" in caplog.text + assert "maximum recursion depth exceeded" in caplog.text From b24066d9c49b0e7333d9e25d0789653c98e67ad5 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 13 Nov 2025 15:27:57 -0800 Subject: [PATCH 2/3] Justify strip_logprobs --- src/art/utils/strip_logprobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/utils/strip_logprobs.py b/src/art/utils/strip_logprobs.py index 2eaefa50..2d5c2693 100644 --- a/src/art/utils/strip_logprobs.py +++ b/src/art/utils/strip_logprobs.py @@ -8,7 +8,7 @@ def strip_logprobs(obj: Any) -> Any: """ - Recursively remove 'logprobs' keys from nested data structures. + Recursively remove 'logprobs' keys from nested data structures to reduce data storage size. Args: obj: Any nested data structure From fb2de71d0d85b75bd5ff8a4a435942c106e8714e Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 13 Nov 2025 15:46:13 -0800 Subject: [PATCH 3/3] Fix test_strip_logprobs --- tests/unit/test_strip_logprobs.py | 49 +++++++++---------------------- 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/tests/unit/test_strip_logprobs.py b/tests/unit/test_strip_logprobs.py index f6da0196..1879acb2 100644 --- a/tests/unit/test_strip_logprobs.py +++ b/tests/unit/test_strip_logprobs.py @@ -2,7 +2,7 @@ import copy import logging -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -21,8 +21,7 @@ def test_strip_dict_with_logprobs(self): } expected = {"data": "value", "nested": {"key": "val"}} - with patch("builtins.print"): # Suppress debug prints - result = strip_logprobs(input_dict) + result = strip_logprobs(input_dict) assert result == expected assert input_dict["logprobs"] == [0.1, 0.2, 0.3] # Original unchanged @@ -40,8 +39,7 @@ def test_strip_nested_dict(self): } expected = {"level1": {"level2": {"level3": {"data": 1}}}} - with patch("builtins.print"): - result = strip_logprobs(input_dict) + result = strip_logprobs(input_dict) assert result == expected @@ -54,8 +52,7 @@ def test_strip_list_with_logprobs(self): ] expected = [{"item": 1}, {"item": 2}, {"item": 3}] - with patch("builtins.print"): - result = strip_logprobs(input_list) + result = strip_logprobs(input_list) assert result == expected @@ -68,8 +65,7 @@ def test_strip_tuple_with_logprobs(self): ) expected = ({"item": 1}, {"item": 2}, {"nested": {}}) - with patch("builtins.print"): - result = strip_logprobs(input_tuple) + result = strip_logprobs(input_tuple) assert result == expected assert isinstance(result, tuple) @@ -84,8 +80,7 @@ def __init__(self): self.nested = {"key": "val", "logprobs": "remove"} obj = TestObject() - with patch("builtins.print"): - result = strip_logprobs(obj) + result = strip_logprobs(obj) assert result.data == "value" assert result.logprobs is None # Set to None for objects @@ -107,8 +102,7 @@ def test_strip_mixed_nested_structure(self): "dict": {"nested": {"data": "keep"}}, } - with patch("builtins.print"): - result = strip_logprobs(input_data) + result = strip_logprobs(input_data) assert result == expected @@ -120,12 +114,11 @@ def test_strip_empty_structures(self): def test_strip_none_and_primitives(self): """Test stripping logprobs from None and primitive values.""" - with patch("builtins.print"): - assert strip_logprobs(None) is None - assert strip_logprobs(42) == 42 - assert strip_logprobs("string") == "string" - assert strip_logprobs(3.14) == 3.14 - assert strip_logprobs(True) is True + assert strip_logprobs(None) is None + assert strip_logprobs(42) == 42 + assert strip_logprobs("string") == "string" + assert strip_logprobs(3.14) == 3.14 + assert strip_logprobs(True) is True def test_no_logprobs_unchanged(self): """Test that structures without logprobs remain unchanged.""" @@ -135,8 +128,7 @@ def test_no_logprobs_unchanged(self): "list": [1, 2, 3], } - with patch("builtins.print"): - result = strip_logprobs(input_dict) + result = strip_logprobs(input_dict) assert result == input_dict @@ -148,25 +140,12 @@ def test_deepcopy_behavior(self): "logprobs": "remove", } - with patch("builtins.print"): - result = strip_logprobs(input_dict) + result = strip_logprobs(input_dict) result["data"].append(4) assert nested_list == [1, 2, 3] # Original unchanged assert result["data"] == [1, 2, 3, 4] - def test_debug_output(self, capsys): - """Test that debug output is printed.""" - input_dict = {"test": "data", "logprobs": "remove"} - - strip_logprobs(input_dict) - captured = capsys.readouterr() - - assert "====== CALLED STRIP_LOGPROBS ======" in captured.out - assert "====== INPUT ======" in captured.out - assert "====== OUTPUT ======" in captured.out - assert "====== END STRIP_LOGPROBS ======" in captured.out - def test_deepcopy_failure_returns_original(self, caplog): """Test that deepcopy failure returns original object and logs warning."""