diff --git a/CHANGELOG.md b/CHANGELOG.md index 7036ace1..194791e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - general: replace `${KEY}` in the TOML conf with the `$KEY` env var [#24](https://github.com/AdRoll/baker/pull/24) - input: add KCL input. [#36](https://github.com/AdRoll/baker/pull/36) - filter: add RegexMatch filter. [#37](https://github.com/AdRoll/baker/pull/37) -- filter: add Concatenate filter [#33](https://github.com/AdRoll/baker/pull/33) - filter: add NotNull filter [#43](https://github.com/AdRoll/baker/pull/43) +- filter: add Concatenate filter [#28](https://github.com/AdRoll/baker/pull/33) +- Required configuration fields are now handled by Baker rather than by each component. [#41](https://github.com/AdRoll/baker/pull/41) ### Changed diff --git a/config.go b/config.go index 489c7957..c2d1b029 100644 --- a/config.go +++ b/config.go @@ -211,6 +211,49 @@ func replaceEnvVars(f io.Reader, mapper func(string) string) (io.Reader, error) return strings.NewReader(os.Expand(buf.String(), mapper)), nil } +func decodeAndCheckConfig(md toml.MetaData, compCfg interface{}) error { + var ( + cfg *toml.Primitive // config + dcfg interface{} // decoded config + name string // component name + typ string // component type + ) + + switch t := compCfg.(type) { + case ConfigInput: + cfg, dcfg = t.Config, t.DecodedConfig + name, typ = t.Name, "input" + case ConfigFilter: + cfg, dcfg = t.Config, t.DecodedConfig + name, typ = t.Name, "filter" + case ConfigOutput: + cfg, dcfg = t.Config, t.DecodedConfig + name, typ = t.Name, "output" + case ConfigUpload: + cfg, dcfg = t.Config, t.DecodedConfig + name, typ = t.Name, "upload" + case ConfigMetrics: + cfg, dcfg = t.Config, t.DecodedConfig + name, typ = t.Name, "metrics" + default: + panic(fmt.Sprintf("unexpected type %#v", cfg)) + } + + if req := CheckRequiredFields(dcfg); req != "" { + return fmt.Errorf("%s %q: %w", typ, name, ErrorRequiredField{req}) + } + + if cfg == nil { + return nil + } + + if err := md.PrimitiveDecode(*cfg, dcfg); err != nil { + return fmt.Errorf("%s %q: error parsing config: %v", typ, name, err) + } + + return nil +} + // NewConfigFromToml creates a Config from a reader reading from a TOML // configuration. comp describes all the existing components. func NewConfigFromToml(f io.Reader, comp Components) (*Config, error) { @@ -287,44 +330,34 @@ func NewConfigFromToml(f io.Reader, comp Components) (*Config, error) { // Copy custom configuration structure, to prepare for re-reading cfg.Input.DecodedConfig = cfg.Input.desc.Config - if cfg.Input.Config != nil { - if err := md.PrimitiveDecode(*cfg.Input.Config, cfg.Input.DecodedConfig); err != nil { - return nil, fmt.Errorf("error parsing input config: %v", err) - } + if err := decodeAndCheckConfig(md, cfg.Input); err != nil { + return nil, err } for idx := range cfg.Filter { // Clone the configuration object to allow the use of multiple instances of the same filter cfg.Filter[idx].DecodedConfig = cloneConfig(cfg.Filter[idx].desc.Config) - if cfg.Filter[idx].Config != nil { - if err := md.PrimitiveDecode(*cfg.Filter[idx].Config, cfg.Filter[idx].DecodedConfig); err != nil { - return nil, fmt.Errorf("error parsing filter config: %v", err) - } + if err := decodeAndCheckConfig(md, cfg.Filter[idx]); err != nil { + return nil, err } } cfg.Output.DecodedConfig = cfg.Output.desc.Config - if cfg.Output.Config != nil { - if err := md.PrimitiveDecode(*cfg.Output.Config, cfg.Output.DecodedConfig); err != nil { - return nil, fmt.Errorf("error parsing output config: %v", err) - } + if err := decodeAndCheckConfig(md, cfg.Output); err != nil { + return nil, err } if cfg.Upload.Name != "" { cfg.Upload.DecodedConfig = cfg.Upload.desc.Config - if cfg.Upload.Config != nil { - if err := md.PrimitiveDecode(*cfg.Upload.Config, cfg.Upload.DecodedConfig); err != nil { - return nil, fmt.Errorf("error parsing upload config: %v", err) - } + if err := decodeAndCheckConfig(md, cfg.Upload); err != nil { + return nil, err } } if cfg.Metrics.Name != "" { cfg.Metrics.DecodedConfig = cfg.Metrics.desc.Config - if cfg.Metrics.Config != nil { - if err := md.PrimitiveDecode(*cfg.Metrics.Config, cfg.Metrics.DecodedConfig); err != nil { - return nil, fmt.Errorf("error parsing metrics config: %v", err) - } + if err := decodeAndCheckConfig(md, cfg.Metrics); err != nil { + return nil, err } } @@ -360,3 +393,63 @@ func NewConfigFromToml(f io.Reader, comp Components) (*Config, error) { // Fill-in with missing defaults return &cfg, cfg.fillDefaults() } + +// hasConfig returns true if the underlying structure has at least one field. +func hasConfig(cfg interface{}) bool { + tf := reflect.TypeOf(cfg).Elem() + return tf.NumField() != 0 +} + +// RequiredFields returns the names of the underlying configuration structure +// fields which are tagged as required. To tag a field as being required, a +// "required" struct struct tag must be present and set to true. +// +// RequiredFields doesn't support struct embedding other structs. +func RequiredFields(cfg interface{}) []string { + var fields []string + + tf := reflect.TypeOf(cfg).Elem() + for i := 0; i < tf.NumField(); i++ { + field := tf.Field(i) + + req := field.Tag.Get("required") + if req != "true" { + continue + } + + fields = append(fields, field.Name) + } + + return fields +} + +// CheckRequiredFields checks that all fields that are tagged as required in +// cfg's type have actually been set to a value other than the field type zero +// value. If not CheckRequiredFields returns the name of the first required +// field that is not set, or, it returns an empty string if all required fields +// are set of the struct doesn't have any required fields (or any fields at all). +// +// CheckRequiredFields doesn't support struct embedding other structs. +func CheckRequiredFields(cfg interface{}) string { + fields := RequiredFields(cfg) + + for _, name := range fields { + rv := reflect.ValueOf(cfg).Elem() + fv := rv.FieldByName(name) + if fv.IsZero() { + return name + } + } + + return "" +} + +// ErrorRequiredField describes the absence of a required field +// in a component configuration. +type ErrorRequiredField struct { + Field string // Field is the name of the missing field +} + +func (e ErrorRequiredField) Error() string { + return fmt.Sprintf("%q is a required field", e.Field) +} diff --git a/config_api_test.go b/config_api_test.go new file mode 100644 index 00000000..a139d06c --- /dev/null +++ b/config_api_test.go @@ -0,0 +1,170 @@ +package baker_test + +import ( + "errors" + "reflect" + "strings" + "testing" + + "github.com/AdRoll/baker" + "github.com/AdRoll/baker/filter/filtertest" + "github.com/AdRoll/baker/input" +) + +func TestRequiredFields(t *testing.T) { + type ( + test1 struct { + Name string + Value string `help:"field value" required:"false"` + } + + test2 struct { + Name string + Value string `help:"field value" required:"true"` + } + + test3 struct { + Name string `required:"true"` + Value string `help:"field value" required:"true"` + } + ) + + tests := []struct { + name string + cfg interface{} + want []string + }{ + { + name: "no required fields", + cfg: &test1{}, + want: nil, + }, + { + name: "one required field", + cfg: &test2{}, + want: []string{"Value"}, + }, + { + name: "all required fields", + cfg: &test3{}, + want: []string{"Name", "Value"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := baker.RequiredFields(tt.cfg); !reflect.DeepEqual(got, tt.want) { + t.Errorf("RequiredFields() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckRequiredFields(t *testing.T) { + type ( + test1 struct { + Name string + Value string `help:"field value" required:"false"` + } + + test2 struct { + Name string + Value string `help:"field value" required:"true"` + } + + test3 struct { + Name string `required:"true"` + Value string `help:"field value" required:"true"` + } + ) + + tests := []struct { + name string + val interface{} + want string + }{ + { + name: "no required fields", + val: &test1{}, + want: "", + }, + { + name: "one missing required field ", + val: &test2{Name: "name", Value: ""}, + want: "Value", + }, + { + name: "one present required field ", + val: &test2{Name: "name", Value: "value"}, + want: "", + }, + { + name: "all required fields and all are missing", + val: &test3{}, + want: "Name", + }, + { + name: "all required fields but the first missing", + val: &test3{Value: "value"}, + want: "Name", + }, + { + name: "all required fields and all are present", + val: &test3{Name: "name", Value: "value"}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := baker.CheckRequiredFields(tt.val); got != tt.want { + t.Errorf("CheckRequiredFields() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewConfigFromTOMLRequiredField(t *testing.T) { + type dummyConfig struct { + Param1 string + Param2 string `required:"true"` + } + var dummyDesc = baker.OutputDesc{ + Name: "Dummy", + New: func(baker.OutputParams) (baker.Output, error) { return nil, nil }, + Config: &dummyConfig{}, + } + + toml := ` +[input] +name = "List" + +[input.config] +files=["testdata/input.csv.zst"] + +[output] +name = "Dummy" +procs=1 + [output.config] + param1="this parameter is set" + #param2="this parameter is not set" +` + + components := baker.Components{ + Inputs: []baker.InputDesc{input.ListDesc}, + Filters: []baker.FilterDesc{filtertest.PassThroughDesc}, + Outputs: []baker.OutputDesc{dummyDesc}, + } + + _, err := baker.NewConfigFromToml(strings.NewReader(toml), components) + if err == nil { + t.Fatal("expected an error") + } + + var errReq baker.ErrorRequiredField + if !errors.As(err, &errReq) { + t.Fatalf("got %q, want a ErrorRequiredField", err) + } + + if errReq.Field != "Param2" { + t.Errorf("got field=%q, want field=%q", errReq.Field, "Param2") + } +} diff --git a/config_test.go b/config_test.go index 63dd6dcd..19e81a13 100644 --- a/config_test.go +++ b/config_test.go @@ -67,14 +67,14 @@ func TestFillCreateRecordDefault(t *testing.T) { } func TestEnvVarBaseReplace(t *testing.T) { - src_toml := ` + src := ` [general] dont_validate_fields = ${DNT_VAL_FIELDS} alt_form = "$ALT_FORM" unexisting_var = "${THIS_DOESNT_EXIST}" ` - want_toml := ` + want := ` [general] dont_validate_fields = true alt_form = "ok" @@ -91,13 +91,13 @@ func TestEnvVarBaseReplace(t *testing.T) { return "" } - s, err := replaceEnvVars(strings.NewReader(src_toml), mapper) + s, err := replaceEnvVars(strings.NewReader(src), mapper) if err != nil { t.Fatalf("replaceEnvVars err: %v", err) } buf, _ := ioutil.ReadAll(s) - if want_toml != string(buf) { + if want != string(buf) { t.Fatalf("wrong toml: %s", string(buf)) } } diff --git a/help.go b/help.go index 056476d9..e87b2e24 100644 --- a/help.go +++ b/help.go @@ -129,17 +129,12 @@ func PrintHelp(w io.Writer, name string, comp Components) { } } -func hasConfig(cfg interface{}) bool { - tf := reflect.TypeOf(cfg).Elem() - return tf.NumField() != 0 -} - func dumpConfigHelp(w io.Writer, cfg interface{}) { - const sfmt = "%-18s | %-18s | %-26s | " + const sfmt = "%-18s | %-18s | %-18s | %-8s | " const sep = "----------------------------------------------------------------------------------------------------" - hpad := fmt.Sprintf(sfmt, "", "", "") - fmt.Fprintf(w, sfmt, "Name", "Type", "Default") + hpad := fmt.Sprintf(sfmt, "", "", "", "") + fmt.Fprintf(w, sfmt, "Name", "Type", "Default", "Required") fmt.Fprintf(w, "Help\n%s\n", sep) tf := reflect.TypeOf(cfg).Elem() @@ -180,9 +175,15 @@ func dumpConfigHelp(w io.Writer, cfg interface{}) { help := field.Tag.Get("help") def := field.Tag.Get("default") + req := field.Tag.Get("required") + if req == "true" { + req = "yes" + } else { + req = "no" + } - fmt.Fprintf(w, sfmt, field.Name, typ, def) - helpLines := strings.Split(wrapString(help, 40), "\n") + fmt.Fprintf(w, sfmt, field.Name, typ, def, req) + helpLines := strings.Split(wrapString(help, 60), "\n") if len(helpLines) > 0 { fmt.Fprint(w, helpLines[0], "\n") for _, h := range helpLines[1:] {