Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-3301] Refactor DoFn validation & allow specifying main inputs. #10991

Merged
merged 3 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 216 additions & 66 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,57 @@ func (f *DoFn) RestrictionT() *reflect.Type {
// a KV or not based on the other signatures (unless we're more loose about which
// sideinputs are present). Bind should respect that.

type mainInputs int

// The following constants prefixed with "Main" represent valid numbers of DoFn
// main inputs for DoFn construction and validation.
const (
MainUnknown mainInputs = -1 // Number of inputs is unknown for DoFn validation.
MainSingle mainInputs = 1 // Number of inputs for single value elements.
MainKv mainInputs = 2 // Number of inputs for KV elements.
)

// config stores the optional configuration parameters to NewDoFn.
type config struct {
numMainIn mainInputs
}

func defaultConfig() *config {
return &config{
numMainIn: MainUnknown,
}
}

// NumMainInputs is an optional config to NewDoFn which specifies the number
// of main inputs to the DoFn being created, allowing for more complete
// validation. Valid inputs are the package constants of type mainInputs.
//
// Example usage:
// graph.NewDoFn(fn, graph.NumMainInputs(graph.MainKv))
func NumMainInputs(num mainInputs) func(*config) {
return func(cfg *config) {
cfg.numMainIn = num
}
}

// NewDoFn constructs a DoFn from the given value, if possible.
func NewDoFn(fn interface{}) (*DoFn, error) {
func NewDoFn(fn interface{}, options ...func(*config)) (*DoFn, error) {
ret, err := NewFn(fn)
if err != nil {
return nil, errors.WithContext(errors.Wrapf(err, "invalid DoFn"), "constructing DoFn")
}
return AsDoFn(ret)
cfg := defaultConfig()
for _, opt := range options {
opt(cfg)
}
return AsDoFn(ret, cfg.numMainIn)
}

// AsDoFn converts a Fn to a DoFn, if possible.
func AsDoFn(fn *Fn) (*DoFn, error) {
// AsDoFn converts a Fn to a DoFn, if possible. numMainIn specifies how many
// main inputs are expected in the DoFn's method signatures. Valid inputs are
// the package constants of type mainInputs. If that number is MainUnknown then
// validation is done by best effort and may miss some edge cases.
func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
addContext := func(err error, fn *Fn) error {
return errors.WithContextf(err, "graph.AsDoFn: for Fn named %v", fn.Name())
}
Expand All @@ -239,52 +279,50 @@ func AsDoFn(fn *Fn) (*DoFn, error) {
return nil, addContext(err, fn)
}

// Start validating DoFn. First, check that ProcessElement has a main input.
// Validate ProcessElement has correct number of main inputs (as indicated by
// numMainIn), and that main inputs are before side inputs.
processFn := fn.methods[processElementName]
pos, num, ok := processFn.Inputs()
if ok {
first := processFn.Param[pos].Kind
if first != funcx.FnValue {
err := errors.New("side input parameters must follow main input parameter")
err = errors.SetTopLevelMsgf(err,
"Method %v of DoFns should always have a main input before side inputs, "+
"but it has side inputs (as Iters or ReIters) first in DoFn %v.",
processElementName, fn.Name())
err = errors.WithContextf(err, "method %v", processElementName)
return nil, addContext(err, fn)
}
if err := validateMainInputs(fn, processFn, processElementName, numMainIn); err != nil {
return nil, addContext(err, fn)
}

// If numMainIn is unknown, we can try inferring it from the second input in ProcessElement.
// If there is none, or it's not a FnValue type, then we can safely infer that there's only
// one main input.
pos, num, _ := processFn.Inputs()
if numMainIn == MainUnknown && (num == 1 || processFn.Param[pos+1].Kind != funcx.FnValue) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to infer the number of inputs before validateMainInputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validateMainInputs performs error checks we need to do before we can infer # of main inputs (stuff like making sure we have at least 1 input present). So moving this before validateMainInputs would just mean moving those error checks back above the inferring and nothing really changes.

numMainIn = MainSingle
}

// If the ProcessElement function includes side inputs or emit functions those must also be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR but why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's part of the API for start/finishBundle. I don't remember why it's done that way though. lostluck@ might be able to answer why when he gets back.

There might be room to make the side inputs/emits in start/finishBundle optional, but I believe right now it's mandatory (if we don't catch and throw an error here, it'll just break later on in translation or execution or something).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At most relaxed we'd be able to either not require them at all if none are used, or isolate them by their types. All instances of a given side input or emit with the same type would need to be listed at once, since otherwise we have no way to distinguish them except by position. Permitting Nothing to be set would be the most convenient, or permitting only the Side Inputs and not requireing the Emits.

For now though, it's better to be more strict now and relax later, since the inverse is impossible, and such variety is harder to maintain if unnecessary.

// present in the signatures of startBundle and finishBundle.
if ok && num > 1 {
if startFn, ok := fn.methods[startBundleName]; ok {
processFnInputs := processFn.Param[pos : pos+num]
if err := validateMethodInputs(processFnInputs, startFn, startBundleName); err != nil {
return nil, addContext(err, fn)
}
processFnInputs := processFn.Param[pos : pos+num]
if startFn, ok := fn.methods[startBundleName]; ok {
if err := validateSideInputs(processFnInputs, startFn, startBundleName, numMainIn); err != nil {
return nil, addContext(err, fn)
}
if finishFn, ok := fn.methods[finishBundleName]; ok {
processFnInputs := processFn.Param[pos : pos+num]
if err := validateMethodInputs(processFnInputs, finishFn, finishBundleName); err != nil {
return nil, addContext(err, fn)
}
}
if finishFn, ok := fn.methods[finishBundleName]; ok {
if err := validateSideInputs(processFnInputs, finishFn, finishBundleName, numMainIn); err != nil {
return nil, addContext(err, fn)
}
}

pos, num, ok = processFn.Emits()
pos, num, ok := processFn.Emits()
var processFnEmits []funcx.FnParam
if ok {
if startFn, ok := fn.methods[startBundleName]; ok {
processFnEmits := processFn.Param[pos : pos+num]
if err := validateMethodEmits(processFnEmits, startFn, startBundleName); err != nil {
return nil, addContext(err, fn)
}
processFnEmits = processFn.Param[pos : pos+num]
} else {
processFnEmits = processFn.Param[0:0]
}
if startFn, ok := fn.methods[startBundleName]; ok {
if err := validateEmits(processFnEmits, startFn, startBundleName); err != nil {
return nil, addContext(err, fn)
}
if finishFn, ok := fn.methods[finishBundleName]; ok {
processFnEmits := processFn.Param[pos : pos+num]
if err := validateMethodEmits(processFnEmits, finishFn, finishBundleName); err != nil {
return nil, addContext(err, fn)
}
}
if finishFn, ok := fn.methods[finishBundleName]; ok {
if err := validateEmits(processFnEmits, finishFn, finishBundleName); err != nil {
return nil, addContext(err, fn)
}
}

Expand Down Expand Up @@ -328,32 +366,86 @@ func AsDoFn(fn *Fn) (*DoFn, error) {
return (*DoFn)(fn), nil
}

// validateMethodEmits compares the emits found in a DoFn method signature with the emits found in
// validateMainInputs checks that a method has the given number of main inputs
// and that main inputs are before any side inputs.
func validateMainInputs(fn *Fn, method *funcx.Fn, methodName string, numMainIn mainInputs) error {
if numMainIn == MainUnknown {
numMainIn = MainSingle // If unknown, validate for minimum number of inputs.
}

// Make sure there are enough inputs (at least numMainIn)
pos, num, ok := method.Inputs()
if !ok {
err := errors.Errorf("%v method has no main inputs", methodName)
err = errors.SetTopLevelMsgf(err,
"Method %v in DoFn %v is missing all inputs. A main input is required.",
methodName, fn.Name())
return err
}
if num < int(numMainIn) {
err := errors.Errorf("%v method has too few main inputs", methodName)
err = errors.SetTopLevelMsgf(err,
"Method %v in DoFn %v does not have enough main inputs. "+
"%v main inputs were expected, but only %v inputs were found.",
methodName, fn.Name(), numMainIn, num)
return err
}

// Check that the first numMainIn inputs are not side inputs (Iters or
// ReIters). We aren't able to catch singleton side inputs here since
// they're indistinguishable from main inputs.
mainInputs := method.Param[pos : pos+int(numMainIn)]
for i, p := range mainInputs {
if p.Kind != funcx.FnValue {
err := errors.Errorf("expected main input parameter but found "+
"side input parameter in position %v",
pos+i)
err = errors.SetTopLevelMsgf(err,
"Method %v in DoFn %v should have all main inputs before side inputs, "+
"but a side input (as Iter or ReIter) appears as parameter %v when a "+
"main input was expected.",
methodName, fn.Name(), pos+i)
err = errors.WithContextf(err, "method %v", methodName)
return err
}
}
return nil
}

// validateEmits compares the emits found in a DoFn method signature with the emits found in
// the signature for ProcessElement, and performs validation that those match. This function
// should only be used to validate methods that are expected to have the same emit parameters as
// ProcessElement.
func validateMethodEmits(processFnEmits []funcx.FnParam, method *funcx.Fn, methodName string) error {
methodPos, methodNum, ok := method.Emits()
func validateEmits(processFnEmits []funcx.FnParam, method *funcx.Fn, methodName string) error {
posMethodEmits, numMethodEmits, ok := method.Emits()
numProcessEmits := len(processFnEmits)

// Handle cases where method has no emits.
if !ok {
if numProcessEmits == 0 { // We're good, expected no emits.
return nil
}
// Error, missing emits.
err := errors.Errorf("emit parameters expected in method %v", methodName)
return errors.SetTopLevelMsgf(err,
"Missing emit parameters in the %v method of a DoFn. "+
"If emit parameters are present in %v those parameters must also be present in %v.",
methodName, processElementName, methodName)
}

processFnNum := len(processFnEmits)
if methodNum != processFnNum {
// Error if number of emits doesn't match.
if numMethodEmits != numProcessEmits {
err := errors.Errorf("number of emits in method %v does not match method %v: got %d, expected %d",
methodName, processElementName, methodNum, processFnNum)
methodName, processElementName, numMethodEmits, numProcessEmits)
return errors.SetTopLevelMsgf(err,
"Incorrect number of emit parameters in the %v method of a DoFn. "+
"The emit parameters should match those of the %v method.",
methodName, processElementName)
}

methodEmits := method.Param[methodPos : methodPos+methodNum]
for i := 0; i < processFnNum; i++ {
// Error if there's a type mismatch.
methodEmits := method.Param[posMethodEmits : posMethodEmits+numMethodEmits]
for i := 0; i < numProcessEmits; i++ {
if processFnEmits[i].T != methodEmits[i].T {
var err error = &funcx.TypeMismatchError{Got: methodEmits[i].T, Want: processFnEmits[i].T}
err = errors.Wrapf(err, "emit parameter in method %v does not match emit parameter in %v",
Expand All @@ -368,19 +460,74 @@ func validateMethodEmits(processFnEmits []funcx.FnParam, method *funcx.Fn, metho
return nil
}

// validateMethodInputs compares the inputs found in a DoFn method signature with the inputs found
// validateSideInputs compares the inputs found in a DoFn method signature with the inputs found
// in the signature for ProcessElement, and performs validation to check that the side inputs
// match. This function should only be used to validate methods that are expected to have matching
// side inputs to ProcessElement.
func validateMethodInputs(processFnInputs []funcx.FnParam, method *funcx.Fn, methodName string) error {
methodPos, methodNum, ok := method.Inputs()
func validateSideInputs(processFnInputs []funcx.FnParam, method *funcx.Fn, methodName string, numMainIn mainInputs) error {
if numMainIn == MainUnknown {
return validateSideInputsNumUnknown(processFnInputs, method, methodName)
}

numProcessIn := len(processFnInputs)
numSideIn := numProcessIn - int(numMainIn)
posMethodIn, numMethodIn, ok := method.Inputs()

// Handle cases where method has no inputs.
if !ok {
if numSideIn == 0 { // We're good, expected no side inputs.
return nil
}
// Error, missing side inputs.
err := errors.Errorf("side inputs expected in method %v", methodName)
return errors.SetTopLevelMsgf(err,
"Missing side inputs in the %v method of a DoFn. "+
"If side inputs are present in %v those side inputs must also be present in %v.",
methodName, processElementName, methodName)
}

// Error if number of side inputs doesn't match.
if numMethodIn != numSideIn {
err := errors.Errorf("number of side inputs in method %v does not match method %v: got %d, expected %d",
methodName, processElementName, numMethodIn, numSideIn)
return errors.SetTopLevelMsgf(err,
"Incorrect number of side inputs in the %v method of a DoFn. "+
"The side inputs should match those of the %v method.",
methodName, processElementName)
}

// Error if there's a type mismatch.
methodInputs := method.Param[posMethodIn : posMethodIn+numMethodIn]
sideInputs := processFnInputs[numMainIn:] // Skip main inputs in ProcessFn
for i := 0; i < len(sideInputs); i++ {
if sideInputs[i].T != methodInputs[i].T {
var err error = &funcx.TypeMismatchError{Got: methodInputs[i].T, Want: sideInputs[i].T}
err = errors.Wrapf(err, "side input in method %v does not match side input in %v",
methodName, processElementName)
return errors.SetTopLevelMsgf(err,
"Incorrect side inputs in the %v method of a DoFn. "+
"The side inputs should match those of the %v method.",
methodName, processElementName)
}
}

return nil
}

// Note: The second input to ProcessElements is not guaranteed to be a side input (it could be
// the Value part of a KV main input). Since we can't know whether to interpret it as a main or
// side input, some of these checks have to work around it in specific ways.
// validateSideInputsNumUnknown does similar validation as validateSideInputs, but for an unknown
// number of main inputs.
func validateSideInputsNumUnknown(processFnInputs []funcx.FnParam, method *funcx.Fn, methodName string) error {
// Note: By the time this is called, we should have already know that ProcessElement has at
// least two inputs, and the second input is ambiguous (could be either a main input or side
// input). Since we don't know how to interpret the second input, these checks will be more
// permissive than they would be otherwise.
posMethodIn, numMethodIn, ok := method.Inputs()
numProcessIn := len(processFnInputs)

// Handle cases where method has no inputs.
if !ok {
if len(processFnInputs) <= 2 {
return nil // This case is fine, since both ProcessElement inputs may be main inputs.
if numProcessIn <= int(MainKv) {
return nil // We're good, possible for there to be no side inputs.
}
err := errors.Errorf("side inputs expected in method %v", methodName)
return errors.SetTopLevelMsgf(err,
Expand All @@ -389,24 +536,27 @@ func validateMethodInputs(processFnInputs []funcx.FnParam, method *funcx.Fn, met
methodName, processElementName, methodName)
}

processFnNum := len(processFnInputs)
// The number of side inputs is the number of inputs minus 1 or 2 depending on whether the second
// input is a main or side input, so that's what we expect in the method's parameters.
// Ex. if ProcessElement has 7 inputs, method must have either 5 or 6 inputs.
if (methodNum != processFnNum-1) && (methodNum != processFnNum-2) {
// Error if number of side inputs doesn't match any of the possible numbers of side inputs,
// defined below.
numSideInSingle := numProcessIn - int(MainSingle)
numSideInKv := numProcessIn - int(MainKv)
if numMethodIn != numSideInSingle && numMethodIn != numSideInKv {
err := errors.Errorf("number of side inputs in method %v does not match method %v: got %d, expected either %d or %d",
methodName, processElementName, methodNum, processFnNum-1, processFnNum-2)
methodName, processElementName, numMethodIn, numSideInSingle, numSideInKv)
return errors.SetTopLevelMsgf(err,
"Incorrect number of side inputs in the %v method of a DoFn. "+
"The side inputs should match those of the %v method.",
methodName, processElementName)
}

methodInputs := method.Param[methodPos : methodPos+methodNum]
offset := processFnNum - methodNum // We need an offset to skip the main inputs in ProcessFnInputs
for i := 0; i < methodNum; i++ {
if processFnInputs[i+offset].T != methodInputs[i].T {
var err error = &funcx.TypeMismatchError{Got: methodInputs[i].T, Want: processFnInputs[i+offset].T}
// Error if there's a type mismatch.
methodInputs := method.Param[posMethodIn : posMethodIn+numMethodIn]
// If there's N inputs in the method, then we compare with the last N inputs to processElement.
offset := numProcessIn - numMethodIn
sideInputs := processFnInputs[offset:]
for i := 0; i < numMethodIn; i++ {
if sideInputs[i].T != methodInputs[i].T {
var err error = &funcx.TypeMismatchError{Got: methodInputs[i].T, Want: sideInputs[i].T}
err = errors.Wrapf(err, "side input in method %v does not match side input in %v",
methodName, processElementName)
return errors.SetTopLevelMsgf(err,
Expand Down
Loading