Skip to content

Commit

Permalink
Perform context unrolling inside declarations
Browse files Browse the repository at this point in the history
Unrolling the extra context for a declaration might depend on the
declaration's internals; that's typically the case with factory.Maybe:
the inner declarations might depend on the actual declaration used.

This adds `evaluate_pre` and a `evaluate_post` entrypoints to
declarations, more readable with regard to which build phase they are
used in.

Each of those will perform unrolling before calling the semi-public
actual function entrypoint (self.evaluate() for evaluate_pre,
self.call() for evaluate_post).

As a side effect, this fixes the issues with factory.Faker() when called
inside a factory.Maybe().

Closes #785 #786 #787 #788 #790 #796.
  • Loading branch information
rbarrois committed Dec 23, 2020
1 parent 15e11e7 commit e19142c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
31 changes: 4 additions & 27 deletions factory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@
)


PostGenerationContext = collections.namedtuple(
'PostGenerationContext',
['value_provided', 'value', 'extra'],
)


class DeclarationSet:
"""A set of declarations, including the recursive parameters.
Expand Down Expand Up @@ -274,21 +268,10 @@ def build(self, parent_step=None, force_sequence=None):
postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
unrolled_context = declaration.declaration.unroll_context(
instance=instance,
step=step,
context=declaration.context,
)

postgen_context = PostGenerationContext(
value_provided='' in unrolled_context,
value=unrolled_context.get(''),
extra={k: v for k, v in unrolled_context.items() if k != ''},
)
postgen_results[declaration_name] = declaration.declaration.call(
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
instance=instance,
step=step,
context=postgen_context,
overrides=declaration.context,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
Expand Down Expand Up @@ -358,16 +341,10 @@ def __getattr__(self, name):
if enums.get_builder_phase(value) == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
self.__pending.append(name)
try:
context = value.unroll_context(
instance=self,
step=self.__step,
context=declaration.context,
)

value = value.evaluate(
value = value.evaluate_pre(
instance=self,
step=self.__step,
extra=context,
overrides=declaration.context,
)
finally:
last = self.__pending.pop()
Expand Down
39 changes: 31 additions & 8 deletions factory/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import itertools
import logging
import typing as T

from . import enums, errors, utils

Expand Down Expand Up @@ -34,6 +35,10 @@ def unroll_context(self, instance, step, context):
subfactory = factory.base.DictFactory
return step.recurse(subfactory, context, force_sequence=step.sequence)

def evaluate_pre(self, instance, step, overrides):
context = self.unroll_context(instance, step, overrides)
return self.evaluate(instance, step, context)

def evaluate(self, instance, step, extra):
"""Evaluate this declaration.
Expand Down Expand Up @@ -477,36 +482,39 @@ def __init__(self, decider, yes_declaration=SKIP, no_declaration=SKIP):

self.FACTORY_BUILDER_PHASE = used_phases.pop() if used_phases else enums.BuilderPhase.ATTRIBUTE_RESOLUTION

def call(self, instance, step, context):
def evaluate_post(self, instance, step, overrides):
"""Handle post-generation declarations"""
decider_phase = enums.get_builder_phase(self.decider)
if decider_phase == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
# Note: we work on the *builder stub*, not on the actual instance.
# This gives us access to all Params-level definitions.
choice = self.decider.evaluate(instance=step.stub, step=step, extra=context.extra)
choice = self.decider.evaluate_pre(
instance=step.stub, step=step, overrides=overrides)
else:
assert decider_phase == enums.BuilderPhase.POST_INSTANTIATION
choice = self.decider.call(instance, step, context)
choice = self.decider.evaluate_post(
instance=instance, step=step, overrides={})

target = self.yes if choice else self.no
if enums.get_builder_phase(target) == enums.BuilderPhase.POST_INSTANTIATION:
return target.call(
return target.evaluate_post(
instance=instance,
step=step,
context=context,
overrides=overrides,
)
else:
# Flat value (can't be ATTRIBUTE_RESOLUTION, checked in __init__)
return target

def evaluate(self, instance, step, extra):
def evaluate_pre(self, instance, step, overrides):
choice = self.decider.evaluate(instance=instance, step=step, extra={})
target = self.yes if choice else self.no

if isinstance(target, BaseDeclaration):
return target.evaluate(
return target.evaluate_pre(
instance=instance,
step=step,
extra=extra,
overrides=overrides,
)
else:
# Flat value (can't be POST_INSTANTIATION, checked in __init__)
Expand Down Expand Up @@ -596,11 +604,26 @@ def __repr__(self):
# ===============


class PostGenerationContext(T.NamedTuple):
value_provided: bool
value: T.Any
extra: T.Dict[str, T.Any]


class PostGenerationDeclaration(BaseDeclaration):
"""Declarations to be called once the model object has been generated."""

FACTORY_BUILDER_PHASE = enums.BuilderPhase.POST_INSTANTIATION

def evaluate_post(self, instance, step, overrides):
context = self.unroll_context(instance, step, overrides)
postgen_context = PostGenerationContext(
value_provided=bool('' in context),
value=context.get(''),
extra={k: v for k, v in context.items() if k != ''},
)
return self.call(instance, step, postgen_context)

def call(self, instance, step, context): # pragma: no cover
"""Call this hook; no return value is expected.
Expand Down

0 comments on commit e19142c

Please sign in to comment.