diff --git a/pkg/vault/auth.go b/pkg/vault/auth.go index 464444e..e4c15d5 100644 --- a/pkg/vault/auth.go +++ b/pkg/vault/auth.go @@ -38,16 +38,27 @@ func (v *Vault) Login() error { // If not, prompt for login isTokenValid := v.isCurrentTokenValid() if isTokenValid == false { - log.Debug("No valid tokens found, need to login") + v.Debug("No valid tokens found, need to login") err := v.userLogin() if err != nil { - return err + return v.parseError(err) } } return nil } +// GetToken returns the raw token +func (v *Vault) GetToken() (string, error) { + v.tokenHelper = token.InternalTokenHelper{} + token, err := v.tokenHelper.Get() + if err != nil { + return "", v.parseError(err) + } + + return token, nil +} + // isCurrentTokenValid returns flase if user needs to relogin // I am not happy with this way of testing the token. // This function doesn't check for errors. @@ -87,21 +98,21 @@ func (v *Vault) userLogin() error { "password": password, }) if err != nil { - log.Debug("Do you have a bad username or password?") - return err + v.Debug("Do you have a bad username or password?") + return v.parseError(err) } v.client.SetToken(secret.Auth.ClientToken) // Write token to user's dot file err = v.tokenHelper.Store(secret.Auth.ClientToken) if err != nil { - return err + return v.parseError(err) } // Lookup the token to get the entity ID secret, err = v.client.Auth().Token().Lookup(v.client.Token()) if err != nil { - return err + return v.parseError(err) } // spew.Dump(secret) entityID := secret.Data["entity_id"].(string) @@ -132,7 +143,7 @@ func (v *Vault) getCredentials() (string, string, error) { if len(username) <= 0 { // If user just clicked enter if v.config.Username == "" { // If there also isn't default - return "", "", errors.New("No username given") + return "", "", v.newError("No username given") } username = v.config.Username } else { @@ -142,7 +153,7 @@ func (v *Vault) getCredentials() (string, string, error) { fmt.Print("Password: ") bytePassword, err := terminal.ReadPassword(int(syscall.Stdin)) if err != nil { - return "", "", err + return "", "", v.parseError(err) } fmt.Println("") password := string(bytePassword) diff --git a/pkg/vault/error.go b/pkg/vault/error.go new file mode 100644 index 0000000..bb87847 --- /dev/null +++ b/pkg/vault/error.go @@ -0,0 +1,60 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "os" + "strings" + "syscall" +) + +// Error is the custom error type for this package +type Error struct { + MessageParts []string + OriginalError error +} + +// Error returns the error string +func (verr Error) Error() string { + return fmt.Sprintf("Vault Error: %s", strings.Join(verr.MessageParts, "; ")) +} + +// parseError parses known errors into more user-friendly messages +func (v *Vault) parseError(err error) Error { + + var verr Error + verr.OriginalError = err + + // Catch some known HTTP errors + if uerr, ok := err.(*url.Error); ok { + if oerr, ok := uerr.Err.(*net.OpError); ok { + if addr, ok := oerr.Addr.(*net.TCPAddr); ok { + if addr.IP.String() == "127.0.0.1" { + verr.MessageParts = append(verr.MessageParts, "Vault appears to be connecting to localhost, ensure correct Vault address is set") + } + } + + if serr, ok := oerr.Err.(*os.SyscallError); ok { + if serr.Err == syscall.ECONNREFUSED { + verr.MessageParts = append(verr.MessageParts, "Connection Refused") + } + } + } + } + + if err == context.DeadlineExceeded { + verr.MessageParts = append(verr.MessageParts, fmt.Sprintf("Timeout connecting after %v seconds. Ensure connectivity to Vault.", v.config.Timeout)) + } + + verr.MessageParts = append(verr.MessageParts, fmt.Sprintf("%v", err)) + + return verr +} + +// newError returns a new error based on a given string +func (v *Vault) newError(msg string) Error { + return v.parseError(errors.New(msg)) +} diff --git a/pkg/vault/mounts.go b/pkg/vault/mounts.go index 52ef289..7a16f93 100644 --- a/pkg/vault/mounts.go +++ b/pkg/vault/mounts.go @@ -11,7 +11,7 @@ func (v *Vault) GetMounts(mountType string) ([]string, error) { mounts, err := v.client.Sys().ListMounts() if err != nil { - return nil, err + return nil, v.parseError(err) } var result []string diff --git a/pkg/vault/secrets.go b/pkg/vault/secrets.go index 1b17bc0..c9d94e8 100644 --- a/pkg/vault/secrets.go +++ b/pkg/vault/secrets.go @@ -2,7 +2,6 @@ package vault import ( "github.com/hashicorp/vault/api" - "errors" "path/filepath" ) @@ -16,17 +15,17 @@ func (v *Vault) GetSecretKey(path string, key string) (string, error) { secret, err := v.client.Logical().Read(path) if err != nil { - return "", err + return "", v.parseError(err) } // If we got back an empty response, fail if secret == nil { - return "", errors.New("Could not find secret `" + path + "`") + return "", v.newError("Could not find secret `" + path + "`") } // If the provided key doesn't exist, fail if secret.Data[key] == nil { - return "", errors.New("Vault: Could not find key `" + key + "` for secret `" + path + "`") + return "", v.newError("Vault: Could not find key `" + key + "` for secret `" + path + "`") } return secret.Data[key].(string), nil @@ -38,12 +37,12 @@ func (v *Vault) GetSecretKeys(path string) (map[string]string, error) { secret, err := v.client.Logical().Read(path) if err != nil { - return nil, err + return nil, v.parseError(err) } // If we got back an empty response, fail if secret == nil { - return nil, errors.New("Could not find secret `" + path + "`") + return nil, v.newError("Could not find secret `" + path + "`") } // Loop through and get all the keys @@ -62,12 +61,12 @@ func (v *Vault) ListSecrets(path string) ([]string, error) { secret, err := v.client.Logical().List(path) if err != nil { - return nil, err + return nil, v.parseError(err) } // If we got back an empty response, fail if secret == nil { - return nil, errors.New("Could not find secret `" + path + "`") + return nil, v.newError("Could not find secret `" + path + "`") } // Loop through and get all the keys diff --git a/pkg/vault/status.go b/pkg/vault/status.go index e2e4a61..e1f8716 100644 --- a/pkg/vault/status.go +++ b/pkg/vault/status.go @@ -8,9 +8,10 @@ import ( ) func (v *Vault) isVaultHealthy() (bool, error) { + result, err := v.client.Sys().Health() if err != nil { - return false, err + return false, v.parseError(err) } log.Debug("Vault server info from (" + v.client.Address() + ")") diff --git a/pkg/vault/vault.go b/pkg/vault/vault.go index 166ac12..c5708b0 100644 --- a/pkg/vault/vault.go +++ b/pkg/vault/vault.go @@ -1,10 +1,10 @@ package vault import ( + "fmt" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/token" "github.com/readytalk/stim/pkg/log" - "errors" "time" ) @@ -17,49 +17,65 @@ type Vault struct { } type Config struct { - Noprompt bool - Address string - Username string - Timeout time.Duration - Log log.Logger - InitialTokenDuration time.Duration + Noprompt bool + Address string + Username string + Timeout int + InitialTokenDuration time.Duration + Logger } -func New(config *Config) (*Vault, error) { - // Ensure that the Vault address is set - if config.Address == "" { - return nil, errors.New("Vault address not set") +type Logger interface { + Debug(args ...interface{}) + Info(args ...interface{}) +} + +func (v *Vault) Debug(message string) { + if v.config.Logger != nil { + v.config.Debug(message) + } +} + +func (v *Vault) Info(message string) { + if v.config.Logger != nil { + v.config.Info(message) + } else { + fmt.Println(message) } +} + +func New(config *Config) (*Vault, error) { v := &Vault{config: config} log.SetLogger(config.Log) - if v.config.Timeout == 0 { - v.config.Timeout = time.Second * 10 // No need to wait over a minite from default + // Ensure that the Vault address is set + if config.Address == "" { + return nil, v.newError("Vault address not set") } // Configure new Vault Client apiConfig := api.DefaultConfig() apiConfig.Address = v.config.Address // Since we read the env we can override - // apiConfig.HttpClient.Timeout = v.config.Timeout + apiConfig.Timeout = time.Duration(v.config.Timeout) * time.Second // Create our new API client var err error v.client, err = api.NewClient(apiConfig) if err != nil { - return nil, err + return nil, v.parseError(err) } // Ensure Vault is up and Healthy _, err = v.isVaultHealthy() if err != nil { - return nil, err + return nil, v.parseError(err) } // Run Login logic err = v.Login() if err != nil { - return nil, err + return nil, v.parseError(err) } // If user wants, extend the token timeout diff --git a/stim/rootcmd.go b/stim/rootcmd.go index d7235e0..e7bce34 100644 --- a/stim/rootcmd.go +++ b/stim/rootcmd.go @@ -36,5 +36,8 @@ func (stim *Stim) rootCommand(viper *viper.Viper) *cobra.Command { stim.config.SetDefault("homedir", homeDir) } + // Set some defaults + viper.SetDefault("vault-timeout", 15) + return cmd } diff --git a/stim/vault.go b/stim/vault.go index fdcfe2a..7b4e90c 100644 --- a/stim/vault.go +++ b/stim/vault.go @@ -35,12 +35,12 @@ func (stim *Stim) Vault() *vault.Vault { vault, err := vault.New(&vault.Config{ Address: stim.GetConfig("vault-address"), // Default is 127.0.0.1 Noprompt: stim.GetConfigBool("noprompt") == false && stim.IsAutomated(), - Log: stim.log, // Pass in the global logger object + Logger: stim.log, // Pass in the global logger object Username: username, // If set in the configs, pass in user InitialTokenDuration: timeInDuration, }) if err != nil { - stim.log.Fatal("Stim-Vault: Error Initializaing: ", err) + stim.log.Fatal(err) } stim.vault = vault