Skip to content

Commit

Permalink
Refactor YAML loading to use add_representer
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Jun 6, 2022
1 parent 7844d2c commit 19944a0
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 48 deletions.
6 changes: 3 additions & 3 deletions snowfakery/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .data_gen_exceptions import DataGenError
from .plugins import SnowfakeryPlugin, PluginOption

from .utils.yaml_utils import SnowfakeryDumper, hydrate
from .utils.yaml_utils import SnowfakeryContinuationDumper, hydrate
from snowfakery.standard_plugins.UniqueId import UniqueId

# This tool is essentially a three stage interpreter.
Expand Down Expand Up @@ -95,9 +95,9 @@ def load_continuation_yaml(continuation_file: OpenFileLike):
def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike):
"""Save the global interpreter state from Globals into a continuation_file"""
yaml.dump(
continuation_data.__getstate__(),
continuation_data,
continuation_file,
Dumper=SnowfakeryDumper,
Dumper=SnowfakeryContinuationDumper,
)


Expand Down
31 changes: 13 additions & 18 deletions snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml

from .utils.template_utils import FakerTemplateLibrary
from .utils.yaml_utils import SnowfakeryDumper, hydrate
from .utils.yaml_utils import hydrate
from .row_history import RowHistory
from .template_funcs import StandardFuncs
from .data_gen_exceptions import DataGenSyntaxError, DataGenNameError
Expand All @@ -27,6 +27,7 @@
)
from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes
from snowfakery.utils.collections import OrderedSet
from snowfakery.utils.yaml_utils import register_for_continuation

OutputStream = "snowfakery.output_streams.OutputStream"
VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition"
Expand Down Expand Up @@ -60,17 +61,15 @@ def generate_id(self, table_name: str) -> int:
def __getitem__(self, table_name: str) -> int:
return self.last_used_ids[table_name]

def __getstate__(self):
# TODO: Fix this to use the new convention of get_continuation_data
def get_continuation_state(self):
return {"last_used_ids": dict(self.last_used_ids)}

def __setstate__(self, state):
def restore_from_continuation(self, state):
self.last_used_ids = defaultdict(lambda: 0, state["last_used_ids"])
self.start_ids = {name: val + 1 for name, val in self.last_used_ids.items()}


SnowfakeryDumper.add_representer(defaultdict, SnowfakeryDumper.represent_dict)


class Dependency(NamedTuple):
table_name_from: str
table_name_to: str
Expand Down Expand Up @@ -195,29 +194,22 @@ def check_slots_filled(self):
def first_new_id(self, tablename):
return self.transients.first_new_id(tablename)

def __getstate__(self):
def serialize_dict_of_object_rows(dct):
return {k: v.__getstate__() for k, v in dct.items()}

persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames)
persistent_objects_by_table = serialize_dict_of_object_rows(
self.persistent_objects_by_table
)
def get_continuation_state(self):
intertable_dependencies = [
dict(v._asdict()) for v in self.intertable_dependencies
] # converts ordered-dict to dict for Python 3.6 and 3.7

state = {
"persistent_nicknames": persistent_nicknames,
"persistent_objects_by_table": persistent_objects_by_table,
"id_manager": self.id_manager.__getstate__(),
"persistent_nicknames": self.persistent_nicknames,
"persistent_objects_by_table": self.persistent_objects_by_table,
"id_manager": self.id_manager.get_continuation_state(),
"today": self.today,
"nicknames_and_tables": self.nicknames_and_tables,
"intertable_dependencies": intertable_dependencies,
}
return state

def __setstate__(self, state):
def restore_from_continuation(self, state):
def deserialize_dict_of_object_rows(dct):
return {k: hydrate(ObjectRow, v) for k, v in dct.items()}

Expand All @@ -244,6 +236,9 @@ def deserialize_dict_of_object_rows(dct):
self.reset_slots()


register_for_continuation(Globals, Globals.get_continuation_state)


class JinjaTemplateEvaluatorFactory:
def __init__(self, native_types: bool):
if native_types:
Expand Down
19 changes: 12 additions & 7 deletions snowfakery/object_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import yaml
import snowfakery # noqa
from .utils.yaml_utils import SnowfakeryDumper
from .utils.yaml_utils import register_for_continuation
from contextvars import ContextVar

IdManager = "snowfakery.data_generator_runtime.IdManager"
Expand All @@ -14,10 +14,6 @@ class ObjectRow:
Uses __getattr__ so that the template evaluator can use dot-notation."""

yaml_loader = yaml.SafeLoader
yaml_dumper = SnowfakeryDumper
yaml_tag = "!snowfakery_objectrow"

# be careful changing these slots because these objects must be serializable
# to YAML and JSON
__slots__ = ["_tablename", "_values", "_child_index"]
Expand Down Expand Up @@ -49,19 +45,28 @@ def __repr__(self):
except Exception:
return super().__repr__()

def __getstate__(self):
def get_continuation_state(self):
"""Get the state of this ObjectRow for serialization.
Do not include related ObjectRows because circular
references in serialization formats cause problems."""

# If we decided to try to serialize hierarchies, we could
# do it like this:
# * keep track of if an object has already been serialized using a
# property of the SnowfakeryContinuationDumper
# * If so, output an ObjectReference instead of an ObjectRow
values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)}
return {"_tablename": self._tablename, "_values": values}

def __setstate__(self, state):
def restore_from_continuation(self, state):
for slot, value in state.items():
setattr(self, slot, value)


register_for_continuation(ObjectRow, ObjectRow.get_continuation_state)


class ObjectReference(yaml.YAMLObject):
def __init__(self, tablename: str, id: int):
self._tablename = tablename
Expand Down
18 changes: 3 additions & 15 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from functools import wraps
import typing as T

import yaml
from yaml.representer import Representer
from faker.providers import BaseProvider as FakerProvider
from dateutil.relativedelta import relativedelta

import snowfakery.data_gen_exceptions as exc
from .utils.yaml_utils import SnowfakeryDumper
from snowfakery.utils.yaml_utils import register_for_continuation
from .utils.collections import CaseInsensitiveDict

from numbers import Number
Expand Down Expand Up @@ -306,17 +304,7 @@ def _from_continuation(cls, args):

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_register_for_continuation(cls)


def _register_for_continuation(cls):
SnowfakeryDumper.add_representer(cls, Representer.represent_object)
yaml.SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
loader.construct_mapping(node.value[0])
),
)
register_for_continuation(cls)


class PluginResultIterator(PluginResult):
Expand Down Expand Up @@ -372,4 +360,4 @@ def convert(self, value):


# round-trip PluginResult objects through continuation YAML if needed.
_register_for_continuation(PluginResult)
register_for_continuation(PluginResult)
4 changes: 2 additions & 2 deletions snowfakery/standard_plugins/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
memorable,
)
from snowfakery.utils.files import FileLike, open_file_like
from snowfakery.utils.yaml_utils import SnowfakeryDumper
from snowfakery.utils.yaml_utils import SnowfakeryContinuationDumper


def _open_db(db_url):
Expand Down Expand Up @@ -258,4 +258,4 @@ def chdir(path):
os.chdir(cwd)


SnowfakeryDumper.add_representer(quoted_name, Representer.represent_str)
SnowfakeryContinuationDumper.add_representer(quoted_name, Representer.represent_str)
39 changes: 36 additions & 3 deletions snowfakery/utils/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,44 @@
from yaml import SafeDumper
from typing import Callable
from yaml import SafeDumper, SafeLoader
from yaml.representer import Representer
from collections import defaultdict


class SnowfakeryDumper(SafeDumper):
class SnowfakeryContinuationDumper(SafeDumper):
pass


SnowfakeryContinuationDumper.add_representer(
defaultdict, SnowfakeryContinuationDumper.represent_dict
)


def hydrate(cls, data):
obj = cls.__new__(cls)
obj.__setstate__(data)
obj.restore_from_continuation(data)
return obj


# Evaluate whether its cleaner for functions to bypass register_for_continuation
# and go directly to SnowfakeryContinuationDumper.add_representer.
#
#


def represent_continuation(dumper: SnowfakeryContinuationDumper, data):
if isinstance(data, dict):
return Representer.represent_dict(dumper, data)
else:
return Representer.represent_object(dumper, data)


def register_for_continuation(cls, dump_transformer: Callable = lambda x: x):
SnowfakeryContinuationDumper.add_representer(
cls, lambda self, data: represent_continuation(self, dump_transformer(data))
)
SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
loader.construct_mapping(node.value[0])
),
)

0 comments on commit 19944a0

Please sign in to comment.