Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/art/utils/strip_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import copy
import logging
import sys
from typing import Any

logger = logging.getLogger(__name__)


def strip_logprobs(obj: Any) -> Any:
"""
Recursively remove 'logprobs' keys from nested data structures to reduce data storage size.

Args:
obj: Any nested data structure

Returns:
The same structure with 'logprobs' keys removed, or the original
object if deepcopy fails
"""

try:
copied_obj = copy.deepcopy(obj)
except Exception as e:
logger.warning(
f"Failed to deepcopy object in strip_logprobs: {e}. "
"Returning original object unchanged."
)
return obj

result = _strip_logprobs(copied_obj)

return result


def _strip_logprobs(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _strip_logprobs(v) for k, v in obj.items() if k != "logprobs"}
elif isinstance(obj, (list, tuple)):
result = [_strip_logprobs(v) for v in obj]
return tuple(result) if isinstance(obj, tuple) else result
elif hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if k == "logprobs":
setattr(obj, k, None)
else:
setattr(obj, k, _strip_logprobs(v))
return obj
return obj
197 changes: 197 additions & 0 deletions tests/unit/test_strip_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Tests for strip_logprobs utility function."""

import copy
import logging
from unittest.mock import MagicMock

import pytest

from art.utils.strip_logprobs import strip_logprobs


class TestStripLogprobs:
"""Test suite for strip_logprobs function."""

def test_strip_dict_with_logprobs(self):
"""Test stripping logprobs from dictionary."""
input_dict = {
"data": "value",
"logprobs": [0.1, 0.2, 0.3],
"nested": {"key": "val", "logprobs": {"nested_log": 0.5}},
}
expected = {"data": "value", "nested": {"key": "val"}}

result = strip_logprobs(input_dict)

assert result == expected
assert input_dict["logprobs"] == [0.1, 0.2, 0.3] # Original unchanged

def test_strip_nested_dict(self):
"""Test stripping logprobs from deeply nested dictionaries."""
input_dict = {
"level1": {
"level2": {
"level3": {"data": 1, "logprobs": "remove_me"},
"logprobs": [1, 2, 3],
}
},
"logprobs": None,
}
expected = {"level1": {"level2": {"level3": {"data": 1}}}}

result = strip_logprobs(input_dict)

assert result == expected

def test_strip_list_with_logprobs(self):
"""Test stripping logprobs from lists."""
input_list = [
{"item": 1, "logprobs": 0.1},
{"item": 2, "logprobs": 0.2},
{"item": 3},
]
expected = [{"item": 1}, {"item": 2}, {"item": 3}]

result = strip_logprobs(input_list)

assert result == expected

def test_strip_tuple_with_logprobs(self):
"""Test stripping logprobs from tuples."""
input_tuple = (
{"item": 1, "logprobs": 0.1},
{"item": 2},
{"nested": {"logprobs": "remove"}},
)
expected = ({"item": 1}, {"item": 2}, {"nested": {}})

result = strip_logprobs(input_tuple)

assert result == expected
assert isinstance(result, tuple)

def test_strip_object_with_logprobs(self):
"""Test stripping logprobs from objects with __dict__."""

class TestObject:
def __init__(self):
self.data = "value"
self.logprobs = [0.1, 0.2]
self.nested = {"key": "val", "logprobs": "remove"}

obj = TestObject()
result = strip_logprobs(obj)

assert result.data == "value"
assert result.logprobs is None # Set to None for objects
assert result.nested == {"key": "val"}

def test_strip_mixed_nested_structure(self):
"""Test stripping logprobs from mixed nested structures."""
input_data = {
"list": [
{"logprobs": 1},
[{"nested_list": True, "logprobs": 2}],
],
"tuple": ({"logprobs": 3}, {"keep": "me"}),
"dict": {"nested": {"logprobs": 4, "data": "keep"}},
}
expected = {
"list": [{}, [{"nested_list": True}]],
"tuple": ({}, {"keep": "me"}),
"dict": {"nested": {"data": "keep"}},
}

result = strip_logprobs(input_data)

assert result == expected

def test_strip_empty_structures(self):
"""Test stripping logprobs from empty structures."""
assert strip_logprobs({}) == {}
assert strip_logprobs([]) == []
assert strip_logprobs(()) == ()

def test_strip_none_and_primitives(self):
"""Test stripping logprobs from None and primitive values."""
assert strip_logprobs(None) is None
assert strip_logprobs(42) == 42
assert strip_logprobs("string") == "string"
assert strip_logprobs(3.14) == 3.14
assert strip_logprobs(True) is True

def test_no_logprobs_unchanged(self):
"""Test that structures without logprobs remain unchanged."""
input_dict = {
"data": "value",
"nested": {"key": "val"},
"list": [1, 2, 3],
}

result = strip_logprobs(input_dict)

assert result == input_dict

def test_deepcopy_behavior(self):
"""Test that the function creates a deep copy."""
nested_list = [1, 2, 3]
input_dict = {
"data": nested_list,
"logprobs": "remove",
}

result = strip_logprobs(input_dict)

result["data"].append(4)
assert nested_list == [1, 2, 3] # Original unchanged
assert result["data"] == [1, 2, 3, 4]

def test_deepcopy_failure_returns_original(self, caplog):
"""Test that deepcopy failure returns original object and logs warning."""

class UnCopyableObject:
def __init__(self):
self.data = "value"
self.logprobs = "should_remain"

def __deepcopy__(self, memo):
raise RuntimeError("Cannot deepcopy this object")

obj = UnCopyableObject()

with caplog.at_level(logging.WARNING):
result = strip_logprobs(obj)

# Should return the original object unchanged
assert result is obj
assert result.logprobs == "should_remain"

# Check that warning was logged
assert len(caplog.records) == 1
assert "Failed to deepcopy object in strip_logprobs" in caplog.text
assert "Cannot deepcopy this object" in caplog.text
assert "Returning original object unchanged" in caplog.text

def test_deepcopy_failure_with_recursion_error(self, caplog):
"""Test handling of RecursionError during deepcopy."""

class RecursiveObject:
def __init__(self):
self.data = "value"
self.logprobs = [1, 2, 3]

def __deepcopy__(self, memo):
raise RecursionError("maximum recursion depth exceeded")

obj = RecursiveObject()

with caplog.at_level(logging.WARNING):
result = strip_logprobs(obj)

# Should return the original object unchanged
assert result is obj
assert result.logprobs == [1, 2, 3]

# Check that warning was logged
assert "Failed to deepcopy object in strip_logprobs" in caplog.text
assert "maximum recursion depth exceeded" in caplog.text