From 82b9179fc4ec10d3051058634d682c6f930c0e19 Mon Sep 17 00:00:00 2001 From: Milad Abbasi Date: Thu, 21 Jan 2021 11:41:14 +0330 Subject: [PATCH] Add tests for file provider --- env_test.go | 18 +++++++ file.go | 15 ++++-- file_test.go | 122 +++++++++++++++++++++++++++++++++++++++++++ testdata/config.json | 6 +++ testdata/config.toml | 4 ++ testdata/config.yaml | 4 ++ testdata/config.yml | 4 ++ utils.go | 10 +++- 8 files changed, 177 insertions(+), 6 deletions(-) create mode 100644 file_test.go create mode 100644 testdata/config.json create mode 100644 testdata/config.toml create mode 100644 testdata/config.yaml create mode 100644 testdata/config.yml diff --git a/env_test.go b/env_test.go index 526c1c0..148f7fa 100644 --- a/env_test.go +++ b/env_test.go @@ -100,6 +100,24 @@ func TestEnvProvider_Fill(t *testing.T) { } }) + t.Run("config key", func(t *testing.T) { + os.Clearenv() + err := os.Setenv("CUSTOM_KEY", "env") + require.NoError(t, err) + + s := struct { + Env string `config:"CUSTOM_KEY"` + }{} + in, err := NewInput(&s) + require.NoError(t, err) + require.NotNil(t, in) + ep := EnvProvider{} + + err = ep.Fill(in) + require.NoError(t, err) + assert.Equal(t, "env", s.Env) + }) + t.Run("env prefix", func(t *testing.T) { os.Clearenv() err := os.Setenv("APP_Env", "env") diff --git a/file.go b/file.go index faddd5a..f0baa86 100644 --- a/file.go +++ b/file.go @@ -2,7 +2,9 @@ package gonfig import ( "encoding/json" + "errors" "fmt" + "io" "os" "path/filepath" @@ -15,8 +17,8 @@ const ( JSON = ".json" YML = ".yml" YAML = ".yaml" - ENV = ".env" TOML = ".toml" + ENV = ".env" ) // FileProvider loads values from file to provided struct @@ -89,6 +91,12 @@ func (fp *FileProvider) Fill(in *Input) error { // decode opens specified file and loads its content to input argument func (fp *FileProvider) decode(i interface{}) (err error) { + switch fp.FileExt { + case JSON, YML, YAML, TOML: + default: + return fmt.Errorf(unsupportedFileExtErrFormat, ErrUnsupportedFileExt, fp.FileExt) + } + f, err := os.Open(fp.FilePath) if err != nil { if os.IsNotExist(err) && !fp.Required { @@ -112,12 +120,9 @@ func (fp *FileProvider) decode(i interface{}) (err error) { case TOML: _, err = toml.DecodeReader(f, i) - - default: - err = fmt.Errorf(unsupportedFileExtErrFormat, ErrUnsupportedFileExt, fp.FileExt) } - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { return fmt.Errorf(decodeFailedErrFormat, err) } diff --git a/file_test.go b/file_test.go new file mode 100644 index 0000000..7735d47 --- /dev/null +++ b/file_test.go @@ -0,0 +1,122 @@ +package gonfig + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewFileProvider(t *testing.T) { + fp := NewFileProvider("file.yml") + require.NotNil(t, fp) + assert.Equal(t, "file.yml", fp.FilePath) + assert.Equal(t, ".yml", fp.FileExt) + assert.False(t, fp.Required) +} + +func TestFileProvider_Name(t *testing.T) { + fp := FileProvider{ + FileExt: ".json", + } + + assert.Equal(t, "File provider (json)", fp.Name()) +} + +func TestFileProvider_UnmarshalStruct(t *testing.T) { + t.Run("file existence", func(t *testing.T) { + fp := FileProvider{ + FilePath: "NotExistingFile.toml", + FileExt: ".toml", + Required: false, + } + + var i interface{} + err := fp.UnmarshalStruct(i) + assert.NoError(t, err) + + fp.Required = true + err = fp.UnmarshalStruct(i) + assert.Error(t, err) + }) + + t.Run("unsupported file extension", func(t *testing.T) { + fp := FileProvider{ + FileExt: ".ini", + } + + var i interface{} + err := fp.UnmarshalStruct(i) + require.Error(t, err) + assert.Truef( + t, + errors.Is(err, ErrUnsupportedFileExt), + "Error must wrap ErrUnsupportedFileExt error", + ) + }) + + t.Run("supported file extensions", func(t *testing.T) { + for _, e := range []string{".json", ".yml", ".yaml", ".toml"} { + s := struct{}{} + fp := FileProvider{ + FilePath: "testdata/config" + e, + FileExt: e, + Required: true, + } + + err := fp.UnmarshalStruct(&s) + require.NoError(t, err) + } + }) +} + +func TestFileProvider_Fill(t *testing.T) { + t.Run("should be set", func(t *testing.T) { + for _, e := range []string{".json", ".yml", ".yaml", ".toml"} { + s := struct { + Config struct { + Host string + } + }{} + in, err := NewInput(&s) + require.NoError(t, err) + require.NotNil(t, in) + + fp := FileProvider{ + FilePath: "testdata/config" + e, + FileExt: e, + Required: true, + } + + err = fp.Fill(in) + require.NoError(t, err) + for _, f := range in.Fields { + assert.True(t, f.IsSet) + } + } + }) + + t.Run("config key", func(t *testing.T) { + for _, e := range []string{".json", ".yml", ".yaml", ".toml"} { + s := struct { + Custom string `json:"custom_key" yaml:"custom_key" toml:"custom_key"` + }{} + in, err := NewInput(&s) + require.NoError(t, err) + require.NotNil(t, in) + + fp := FileProvider{ + FilePath: "testdata/config" + e, + FileExt: e, + Required: true, + } + + err = fp.Fill(in) + require.NoError(t, err) + for _, f := range in.Fields { + assert.True(t, f.IsSet) + } + } + }) +} diff --git a/testdata/config.json b/testdata/config.json new file mode 100644 index 0000000..aedbac6 --- /dev/null +++ b/testdata/config.json @@ -0,0 +1,6 @@ +{ + "config": { + "host": "golang.org" + }, + "custom_key": "custom" +} diff --git a/testdata/config.toml b/testdata/config.toml new file mode 100644 index 0000000..b97161b --- /dev/null +++ b/testdata/config.toml @@ -0,0 +1,4 @@ +custom_key = "custom" + +[config] +host = "golang.org" diff --git a/testdata/config.yaml b/testdata/config.yaml new file mode 100644 index 0000000..e01786d --- /dev/null +++ b/testdata/config.yaml @@ -0,0 +1,4 @@ +config: + host: golang.org + +custom_key: custom diff --git a/testdata/config.yml b/testdata/config.yml new file mode 100644 index 0000000..e01786d --- /dev/null +++ b/testdata/config.yml @@ -0,0 +1,4 @@ +config: + host: golang.org + +custom_key: custom diff --git a/utils.go b/utils.go index 5ed2850..af52642 100644 --- a/utils.go +++ b/utils.go @@ -58,7 +58,15 @@ func traverseMap(m map[string]interface{}, path []string) (string, bool) { nestedMap, ok := value.(map[string]interface{}) if !ok { - return "", false + nestedInterfaceMap, ok := value.(map[interface{}]interface{}) + if !ok { + return "", false + } + + nestedMap = make(map[string]interface{}) + for k, v := range nestedInterfaceMap { + nestedMap[fmt.Sprint(k)] = v + } } return traverseMap(nestedMap, path)