Skip to content

Commit

Permalink
Merge branch 'master' into feature/reliability-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
prescod committed Apr 30, 2020
2 parents f9063ea + b7b39da commit afe55c2
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 22 deletions.
4 changes: 2 additions & 2 deletions snowfakery/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class ExecutionSummary:

def __init__(self, parse_results, runtime_results):
self.tables = parse_results.tables
self.dom = parse_results.templates
self.templates = parse_results.templates
self.intertable_dependencies = runtime_results.intertable_dependencies

def summarize_for_debugging(self):
return self.intertable_dependencies, self.dom
return self.intertable_dependencies, self.templates


def merge_options(option_definitions: List, user_options: Mapping) -> Tuple[Dict, set]:
Expand Down
3 changes: 0 additions & 3 deletions snowfakery/data_generator_runtime_dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def _generate_row(self, storage, context: RuntimeContext) -> ObjectRow:
self._generate_fields(context, row)

try:
# both of these lines loop over the fields so they could maybe
# be combined but it kind of messes with the modularity of the
# code.
self.register_row_intertable_references(row, context)
if not self.tablename.startswith("__"):
storage.write_row(self.tablename, row)
Expand Down
111 changes: 101 additions & 10 deletions snowfakery/generate_mapping_from_factory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,80 @@
from copy import deepcopy
from warnings import warn

from snowfakery.data_generator import ExecutionSummary
from snowfakery.data_gen_exceptions import DataGenNameError, DataGenError
from snowfakery.parse_factory_yaml import TableInfo
from snowfakery.data_generator_runtime import Dependency

def mapping_from_factory_templates(summary):

def mapping_from_factory_templates(summary: ExecutionSummary):
"""Use the outputs of the factory YAML and convert to Mapping.yml format"""
record_types = generate_record_type_pseudo_tables(summary)
dependencies, reference_fields = build_dependencies(summary.intertable_dependencies)
table_order = sort_dependencies(dependencies, summary.tables)
mappings = mappings_from_sorted_tables(
summary.tables, table_order, reference_fields
)
tables = {**summary.tables, **record_types}
table_order = sort_dependencies(dependencies, tables)
mappings = mappings_from_sorted_tables(tables, table_order, reference_fields,)
return mappings


def find_record_type_field(fields, context_name):
"""Verify that the RecordType field has the right capitalization and return it."""

# theoretically a custom object could have a field named record_type but more likely
# it would be a mistake, so warn on that too
record_type_fields = [
field
for field in fields
if field.name.lower().replace("d_t", "dt", 1) == "recordtype"
]
if not record_type_fields:
return None
elif len(record_type_fields) > 1:
raise DataGenError(f"Only one RecordType field allowed: {context_name}")
elif record_type_fields[0].name != "RecordType":
raise DataGenNameError(
"Recordtype field needs this capitalization: RecordType", None, None
)

return record_type_fields[0]


def generate_record_type_pseudo_tables(summary):
"""Generate fake table objects for purposes of dependency sorting, lookups and mapping generation"""
record_types = {}
for template in summary.templates:
real_table_name = template.tablename
record_type_field = find_record_type_field(template.fields, real_table_name)
if not record_type_field:
continue

real_table = summary.tables[real_table_name]
record_type_name = record_type_field.definition.definition

# generate a pretend table for purposes of dependency sorting,
# lookups and mapping generation
record_type_pseudo_table = record_types.setdefault(
record_type_name, TableInfo(template.tablename)
)
record_type_pseudo_table.register(template)
record_type_pseudo_table.record_type = record_type_name

# copy over the dependencies from the real table
for dependency in summary.intertable_dependencies.copy():
if dependency.table_name_from == real_table_name:
summary.intertable_dependencies.add(
Dependency(record_type_name, *dependency[1:])
)

# the record type field isn't helpful for the TableInfo of the real table anymore
# need a conditional here to ensure its only deleted once
if real_table.fields.get("RecordType"):
del real_table.fields["RecordType"]
real_table.has_record_types = True

return record_types


def build_dependencies(intertable_dependencies):
"""Figure out which tables depend on which other ones (through foreign keys)
Expand Down Expand Up @@ -64,7 +127,9 @@ def sort_dependencies(dependencies, tables):
return sorted_tables


def mappings_from_sorted_tables(tables, table_order, reference_fields):
def mappings_from_sorted_tables(
tables: dict, table_order: list, reference_fields: dict
):
"""Generate mapping.yml data structures. """
mappings = {}
for table_name in table_order:
Expand All @@ -73,6 +138,7 @@ def mappings_from_sorted_tables(tables, table_order, reference_fields):
fieldname: fieldname
for fieldname, fielddef in table.fields.items()
if (table_name, fieldname) not in reference_fields.keys()
and fieldname != "RecordType"
}
lookups = {
fieldname: {
Expand All @@ -82,11 +148,36 @@ def mappings_from_sorted_tables(tables, table_order, reference_fields):
for fieldname, fielddef in table.fields.items()
if (table_name, fieldname) in reference_fields.keys()
}
mappings[f"Insert {table_name}"] = {
"sf_object": table_name,
"table": table_name,

if "RecordType" in table.fields:
fielddef = table.fields["RecordType"].definition
if not getattr(fielddef, "definition"):
raise DataGenError(
"Record type definitions must be simple, not computed"
)
record_type = fielddef.definition
filters = [f"RecordType = '{record_type}'"]
else:
record_type = None
# add a filter to avoid returning rows associated with record types
filters = (
[f"RecordType is NULL"]
if getattr(table, "has_record_types", False)
else []
)

mapping = {
"sf_object": table.name,
"table": table.name,
"fields": fields,
"lookups": lookups,
}
if record_type:
mapping["record_type"] = record_type
if filters:
mapping["filters"] = filters
if lookups:
mapping["lookups"] = lookups

mappings[f"Insert {table_name}"] = mapping

return mappings
15 changes: 9 additions & 6 deletions snowfakery/parse_factory_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class TableInfo:
unifies what we know about it.
"""

def __init__(self):
def __init__(self, name):
self.name = name
self.fields = {}
self.friends = {}
self._templates = []
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self):
self.plugins = []
self.line_numbers = {}
self.options = []
self.object_infos = {}
self.table_infos = {}

def line_num(self, obj=None) -> Dict:
if not obj:
Expand Down Expand Up @@ -118,8 +119,10 @@ def register_template(self, template: ObjectTemplate) -> None:
We register templates so we can get a list of all fields that can
be generated. This can be used to create a dynamic schema.
"""
table_info = self.object_infos.get(template.tablename, None) or TableInfo()
self.object_infos[template.tablename] = table_info
table_info = self.table_infos.get(template.tablename, None) or TableInfo(
template.tablename
)
self.table_infos[template.tablename] = table_info
table_info.register(template)


Expand Down Expand Up @@ -490,10 +493,10 @@ def parse_factory(stream: IO[str]) -> ParseResult:
context = ParseContext()
objects = parse_file(stream, context)
templates = parse_object_template_list(objects, context)
tables = context.object_infos
tables = context.table_infos
tables = {
name: value
for name, value in context.object_infos.items()
for name, value in context.table_infos.items()
if not name.startswith("__")
}

Expand Down
2 changes: 2 additions & 0 deletions snowfakery/template_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from functools import lru_cache
from datetime import date, datetime
import dateutil.parser
from ast import literal_eval
Expand Down Expand Up @@ -108,6 +109,7 @@ def choice_wrapper(
return probability or when, pick


@lru_cache(maxsize=512)
def parse_date(d: Union[str, datetime, date]) -> Optional[Union[datetime, date]]:
if isinstance(d, (datetime, date)):
return d
Expand Down
30 changes: 30 additions & 0 deletions tests/cci/record_types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
- object: Account
fields:
name: Bluth Family
RecordType: Household
primary_contact:
- object: Contact
nickname: Michael
fields:
name: Michael Bluth
- object: Account
fields:
name: Bluth Corporation
primary_contact:
reference: Michael
RecordType: Organization
- object: Account
fields:
name: The Windors
primary_contact:
- object: Contact
nickname: Liz
fields:
name: The Queen
RecordType: Household
- object: Account
fields:
name: The Firm
primary_contact:
reference: Liz
RecordType: Organization
48 changes: 47 additions & 1 deletion tests/test_generate_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_cats_and_dogs(self):
assert mapping["Insert PetFood"]["sf_object"] == "PetFood"
assert mapping["Insert PetFood"]["table"] == "PetFood"
assert mapping["Insert PetFood"]["fields"] == {}
assert mapping["Insert PetFood"]["lookups"] == {}
assert not mapping["Insert PetFood"].get("lookups")
assert mapping["Insert Animal"]["sf_object"] == "Animal"
assert mapping["Insert Animal"]["table"] == "Animal"
assert mapping["Insert Animal"]["fields"] == {}
Expand Down Expand Up @@ -226,3 +226,49 @@ def test_table_is_free_simple(self):
},
["Grandparent", "Parent"],
)


class TestRecordTypes:
def test_basic_recordtypes(self):
yaml = """
- object: Obj
fields:
RecordType: Bar
"""
summary = generate(StringIO(yaml), {}, None)
mapping = mapping_from_factory_templates(summary)

assert mapping["Insert Bar"]["filters"][0] == "RecordType = 'Bar'"
assert mapping["Insert Obj"]["filters"][0] == "RecordType is NULL"

def test_recordtype_errors_on_wrong_capitalization(self):
yaml = """
- object: Obj
fields:
recordtype: Bar
"""
summary = generate(StringIO(yaml), {}, None)
with pytest.raises(DataGenError):
mapping_from_factory_templates(summary)

yaml = """
- object: Obj
fields:
record_type: Bar
"""
summary = generate(StringIO(yaml), {}, None)
with pytest.raises(DataGenError):
mapping_from_factory_templates(summary)

def test_recordtypes_and_lookups(self):
yaml = """
- object: Obj
fields:
RecordType: Bar
child:
- object: Child
"""
summary = generate(StringIO(yaml), {}, None)
mapping = mapping_from_factory_templates(summary)

assert mapping["Insert Bar"]["lookups"]["child"]["key_field"] == "child"

0 comments on commit afe55c2

Please sign in to comment.