Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_missing_schema feature to capture_sample_data. #3586

Closed
wants to merge 7 commits into from
3 changes: 2 additions & 1 deletion cumulusci/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def extract(
extraction_definition: T.Optional[Path] = None,
opt_in_only: T.Sequence[str] = (),
loading_rules_file: T.Optional[Path] = None,
drop_missing_schema: bool = False,
):
options = options or {}
logger = logger or DEFAULT_LOGGER
Expand All @@ -177,6 +178,7 @@ def extract(
org_config=self.org_config,
sql_path=self.data_file,
mapping=str(extract_mapping),
drop_missing_schema=drop_missing_schema,
)
task()
loading_rules = self._parse_loading_rules_file(loading_rules_file)
Expand Down Expand Up @@ -233,7 +235,6 @@ def load(
)

def _sql_dataload(self, options: T.Dict):

task = _make_task(
LoadData,
project_config=self.project_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as T
from logging import Logger, getLogger

from cumulusci.salesforce_api.filterable_objects import NOT_EXTRACTABLE
from cumulusci.salesforce_api.org_schema import Schema
Expand All @@ -8,6 +9,8 @@
synthesize_declaration_for_sobject,
)

DEFAULT_LOGGER = getLogger(__file__)


class SObjDependency(T.NamedTuple):
table_name_from: str
Expand Down Expand Up @@ -39,34 +42,48 @@ def _collect_dependencies_for_sobject(
fields: T.List[str],
schema: Schema,
only_required_fields: bool,
logger: Logger = DEFAULT_LOGGER,
) -> T.Dict[str, T.List[SObjDependency]]:
"""Ensure that required lookups are fulfilled for a single SObject

Do this by adding its referent tables (in full) to the extract.
Module-internal helper function.
"""
dependencies = {}
obj_info = schema[source_sfobject]

for field_name in fields:
field_info = schema[source_sfobject].fields[field_name]
if not field_info.createable: # pragma: no cover
continue
references = field_info.referenceTo
if len(references) == 1 and not references[0] == "RecordType":
target = references[0]

target_disallowed = target in NOT_EXTRACTABLE
field_disallowed = target_disallowed or not field_info.createable
field_allowed = not (only_required_fields or field_disallowed)
if field_info.requiredOnCreate or field_allowed:
dependencies.setdefault(source_sfobject, []).append(
SObjDependency(
source_sfobject, target, field_name, field_info.requiredOnCreate
)
)
field_info = obj_info.fields.get(field_name)
if not field_info:
logger.warning(f"Cannot find field {field_name} in {obj_info.name}.")
if field_info and field_info.createable:
dep = dependency_for_field(
field_info, only_required_fields, source_sfobject, field_name
)
if dep:
sobjdeps = dependencies.setdefault(source_sfobject, [])
sobjdeps.append(dep)

return dependencies


def dependency_for_field(
field_info, only_required_fields, source_sfobject, field_name
) -> T.Optional[SObjDependency]:
references = field_info.referenceTo
if len(references) == 1 and not references[0] == "RecordType":
target = references[0]

target_disallowed = target in NOT_EXTRACTABLE
field_disallowed = target_disallowed or not field_info.createable
field_allowed = not (only_required_fields or field_disallowed)
if field_info.requiredOnCreate or field_allowed:
return SObjDependency(
source_sfobject, target, field_name, field_info.requiredOnCreate
)
return None


def extend_declarations_to_include_referenced_tables(
decl_list: T.Sequence[SimplifiedExtractDeclaration], schema: Schema
) -> T.Sequence[SimplifiedExtractDeclaration]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import re
import typing as T
from logging import Logger, getLogger

from pydantic import validator

Expand All @@ -10,6 +11,8 @@
from .extract_yml import ExtractDeclaration, SFFieldGroupTypes, SFObjectGroupTypes
from .hardcoded_default_declarations import DEFAULT_DECLARATIONS

DEFAULT_LOGGER = getLogger(__file__)


class SimplifiedExtractDeclaration(ExtractDeclaration):
# a model where sf_object references a single sf_object
Expand Down Expand Up @@ -70,14 +73,24 @@ def flatten_declarations(


def _simplify_sfobject_declarations(
declarations, schema: Schema, opt_in_only: T.Sequence[str]
declarations: T.Iterable[ExtractDeclaration],
schema: Schema,
opt_in_only: T.Sequence[str],
logger: T.Optional[Logger] = DEFAULT_LOGGER,
) -> T.List[SimplifiedExtractDeclaration]:
"""Generate a new list of declarations such that all sf_object patterns
(like OBJECTS(CUSTOM)) have been expanded into many declarations
with specific names and defaults have been merged in."""
atomic_declarations, group_declarations = partition(
lambda d: d.is_group, declarations
)
missing_objs, atomic_declarations = partition(
lambda d: d.sf_object in schema.keys(), declarations
)
if missing_objs:
logger.warning(
f"Cannot find objects: {','.join(o.sf_object for o in missing_objs)}"
)
atomic_declarations = list(atomic_declarations)
normalized_atomic_declarations = _normalize_user_supplied_simple_declarations(
atomic_declarations, DEFAULT_DECLARATIONS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_required_lookups__pulled_in(self, org_config):
)
)

def test_parse_real_file(self, cumulusci_test_repo_root, org_config):
def test_parse_real_file(self, cumulusci_test_repo_root, org_config, caplog):
declarations = ExtractRulesFile.parse_extract(
cumulusci_test_repo_root / "datasets/test_minimal.extract.yml"
)
Expand All @@ -338,11 +338,15 @@ def test_parse_real_file(self, cumulusci_test_repo_root, org_config):
include_counts=True,
) as schema:
decls = flatten_declarations(declarations.values(), schema)
logs = str(caplog.record_tuples)
assert "MissingFieldShouldWarn" in logs
assert "MissingObjectShouldWarn__c" in logs
decls = {decl.sf_object: decl for decl in decls}
assert decls["Opportunity"].fields == [
"Name",
"ContactId",
"AccountId",
"MissingFieldShouldWarn",
"CloseDate", # pull these in because they required
"StageName",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def is_lookup(field_name):
# Record types are not treated as lookup.
if field_name == "RecordTypeId":
return False
schema_info_for_field = sobject_schema_info.fields[field_name]
target = schema_info_for_field.referenceTo
schema_info_for_field = sobject_schema_info.fields.get(field_name)
target = schema_info_for_field.referenceTo if schema_info_for_field else None
return target

simple_fields, lookups = partition(is_lookup, decl.fields)
Expand Down
15 changes: 14 additions & 1 deletion cumulusci/tasks/sample_data/capture_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cumulusci.core.config.org_config import OrgConfig
from cumulusci.core.datasets import Dataset
from cumulusci.core.exceptions import TaskOptionsError
from cumulusci.core.utils import process_bool_arg
from cumulusci.salesforce_api.filterable_objects import OPT_IN_ONLY
from cumulusci.salesforce_api.org_schema import Filters, get_org_schema
from cumulusci.tasks.salesforce.BaseSalesforceApiTask import BaseSalesforceApiTask
Expand Down Expand Up @@ -30,16 +31,23 @@ class CaptureSampleData(BaseSalesforceApiTask):
"Multiple files can be comma separated."
)
},
"drop_missing_schema": {
"description": "Set to True to skip any missing objects or fields instead of stopping with an error."
},
}

org_config: OrgConfig

def _init_options(self, kwargs):
super()._init_options(kwargs)
self.options.setdefault("dataset", "default")
self.options["drop_missing_schema"] = process_bool_arg(
self.options.get("drop_missing_schema") or False
)

def _run_task(self):
name = self.options["dataset"]
drop_missing_schema = self.options["drop_missing_schema"]
if extraction_definition := self.options.get("extraction_definition"):
extraction_definition = Path(extraction_definition)
if not extraction_definition.exists():
Expand Down Expand Up @@ -71,7 +79,12 @@ def _run_task(self):
opt_in_only += OPT_IN_ONLY

self.return_values = dataset.extract(
{}, self.logger, extraction_definition, opt_in_only, loading_rules
{},
self.logger,
extraction_definition,
opt_in_only,
loading_rules,
drop_missing_schema,
)
self.logger.info(f"{verb} dataset '{name}' in 'datasets/{name}'")
return self.return_values
11 changes: 9 additions & 2 deletions cumulusci/tasks/sample_data/test_capture_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def test_simple_extract(
# default dataset should created
Dataset.assert_any_call("default", mock.ANY, mock.ANY, org_config, mock.ANY)
# and extracted
drop_missing_schema = False
Dataset().__enter__().extract.assert_called_with(
{}, task.logger, None, mock.ANY, None
{}, task.logger, None, mock.ANY, None, drop_missing_schema
)

@mock.patch("cumulusci.tasks.sample_data.capture_sample_data.Dataset")
Expand All @@ -86,6 +87,7 @@ def test_named_extract(
"dataset": "mydataset",
"extraction_definition": extraction_definition,
"loading_rules": loading_rules,
"drop_missing_schema": True,
},
)
task()
Expand All @@ -96,7 +98,12 @@ def test_named_extract(
Dataset().__enter__().create.assert_called_with()
# and extracted
Dataset().__enter__().extract.assert_called_with(
{}, task.logger, extraction_definition, mock.ANY, loading_rules
{},
task.logger,
extraction_definition,
mock.ANY,
loading_rules,
task.options["drop_missing_schema"],
)

@mock.patch("cumulusci.tasks.sample_data.capture_sample_data.Dataset")
Expand Down
3 changes: 3 additions & 0 deletions datasets/test_minimal.extract.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ extract:
- Name
- ContactId
- AccountId
- MissingFieldShouldWarn
Contact: # Extract required Contact fields
fields: FIELDS(REQUIRED)
Account: # Extract required Account fields
fields: FIELDS(REQUIRED)
OBJECTS(CUSTOM): # Extract every field on Custom objects
fields: FIELDS(ALL)
MissingObjectShouldWarn__c: # Filtered out because not in org
fields: FIELDS(ALL)