Skip to content

Commit

Permalink
Close plugins during Interpreter clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Aug 1, 2022
1 parent 5e5747b commit 79f9f74
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 5 deletions.
4 changes: 4 additions & 0 deletions snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def __exit__(self, *args):
plugin.close()
except Exception as e:
warn(f"Could not close {plugin} because {e}")
self.current_context = None
self.plugin_instances = None
self.plugin_function_libraries = None
self.instance_states = None

def get_contextual_state(
self,
Expand Down
24 changes: 20 additions & 4 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,25 @@ def __init__(self, result: Mapping):

def __getattr__(self, name):
# ensures that it won't recurse
return self.__dict__["result"][name]
res = self.__dict__.get("result", {}).get(name, ...)
if res == ...:
raise AttributeError(name)

def __reduce__(self):
return (self.__class__, (dict(self.result),))

def __repr__(self):
return f"<{self.__class__} {repr(self.result)}>"
try:
rep = repr(self.result)
except Exception:
rep = ""
return f"<{self.__class__} {rep}>"

def __str__(self):
return str(self.result)
try:
return str(self.result)
except Exception:
return repr(self)

@classmethod
def _from_continuation(cls, args):
Expand All @@ -320,6 +329,8 @@ def _register_for_continuation(cls):


class PluginResultIterator(PluginResult):
closed = False

def __init__(self, repeat):
self.repeat = repeat

Expand Down Expand Up @@ -353,7 +364,12 @@ def restart(self):

def close(self):
"Subclasses should implement this if they need to clean up resources"
pass # pragma: no cover
pass

def __del__(self):
if not self.closed:
self.close()
self.close = True


class PluginOption:
Expand Down
70 changes: 69 additions & 1 deletion tests/test_custom_plugins_and_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import operator
import time
from base64 import b64decode
import gc

from snowfakery import SnowfakeryPlugin, lazy
from snowfakery.plugins import PluginResult, PluginOption, memorable
from snowfakery.plugins import (
PluginResult,
PluginOption,
PluginResultIterator,
memorable,
)

from snowfakery.data_gen_exceptions import (
DataGenError,
Expand Down Expand Up @@ -40,6 +46,45 @@ class Functions:
def double(self, value):
return value * 2

@memorable
def fib(self, value):
return FibIterator()


class FibIterator(PluginResultIterator):
def __init__(self):
self.a = 1
self.b = 1

def close(self):
self.a = None
self.b = None

def next(self):
# import gc

# for ref in gc.get_referrers(self):
# print(ref, gc.get_referrers(ref))
# print(vars(self))
# print(self in gc.get_objects())
# print_referrers_bredth_first((self,), 0)
self.a, self.b = self.b, self.a + self.b
return self.b


def print_referrers_bredth_first(path, level):
if level > 3:
return
obj = path[0]
parents = tuple(gc.get_referrers(obj))
for parent in parents:
parts = (parent,) + path
print(" -> ".join(repr(part) for part in parts))

for parent in parents:
parts = (parent,) + path
print_referrers_bredth_first(parts, level + 1)


class DoubleVisionPlugin(SnowfakeryPlugin):
class Functions:
Expand Down Expand Up @@ -414,3 +459,26 @@ def test_plugin_does_not_close(self):
"""
with pytest.warns(UserWarning, match="close"):
generate_data(StringIO(yaml))


class TestPluginResultIterator:
@mock.patch(
"tests.test_custom_plugins_and_providers.FibIterator.close",
)
def test_plugin_result_iterator__closes(self, close, generated_rows):
yaml = """
- plugin: tests.test_custom_plugins_and_providers.SimpleTestPlugin
- object: OBJ
fields:
fibnum:
SimpleTestPlugin.fib:
"""
# it's better if closing of objects is triggered at a predictable
# time by the refcounter instead of by the cyclic GC
try:
gc.disable()
generate_data(StringIO(yaml), target_number=("OBJ", 3))
assert generated_rows.table_values("OBJ", field="fibnum") == [2, 3, 5]
assert len(close.mock_calls) == 1
finally:
gc.enable()

0 comments on commit 79f9f74

Please sign in to comment.