Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 133 additions & 13 deletions internal/deploy/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
// Keymaster - SSH key management system
// This source code is licensed under the MIT license found in the LICENSE file.

// package deploy provides functionality for connecting to remote hosts via SSH
// Package deploy provides functionality for connecting to remote hosts via SSH
// and managing their authorized_keys files. This file contains the core SSH and
// SFTP client logic for connecting, authenticating, and transferring files.
//
// It includes configurable timeout support and enhanced error classification
// to provide better user feedback when connections fail.
package deploy // import "github.com/toeirei/keymaster/internal/deploy"

import (
Expand All @@ -13,33 +16,68 @@ import (
"io"
"net"
"path"
"strings"
"time"

"github.com/pkg/sftp"
"github.com/toeirei/keymaster/internal/db"
"golang.org/x/crypto/ssh"
)

// Default timeout values for SSH operations
const (
// DefaultConnectionTimeout is the default timeout for establishing SSH connections
DefaultConnectionTimeout = 10 * time.Second
// DefaultCommandTimeout is the default timeout for executing commands
DefaultCommandTimeout = 30 * time.Second
// DefaultHostKeyTimeout is the default timeout for host key retrieval
DefaultHostKeyTimeout = 5 * time.Second
// DefaultSFTPTimeout is the default timeout for SFTP operations
DefaultSFTPTimeout = 60 * time.Second
)

// ConnectionConfig holds timeout configuration for SSH connections
type ConnectionConfig struct {
ConnectionTimeout time.Duration
CommandTimeout time.Duration
SFTPTimeout time.Duration
}

// DefaultConnectionConfig returns a ConnectionConfig with default timeout values
func DefaultConnectionConfig() *ConnectionConfig {
return &ConnectionConfig{
ConnectionTimeout: DefaultConnectionTimeout,
CommandTimeout: DefaultCommandTimeout,
SFTPTimeout: DefaultSFTPTimeout,
}
}

// Deployer handles the connection and deployment to a remote host.
type Deployer struct {
client *ssh.Client
sftp *sftp.Client
config *ConnectionConfig
}

// NewDeployer creates a new SSH connection and returns a Deployer.
// For bootstrap connections, use NewBootstrapDeployer instead.
func NewDeployer(host, user, privateKey string) (*Deployer, error) {
return newDeployerInternal(host, user, privateKey, false)
return NewDeployerWithConfig(host, user, privateKey, DefaultConnectionConfig(), false)
}

// NewBootstrapDeployer creates a new SSH connection for bootstrap operations.
// It accepts any host key and saves it to the database for future connections.
func NewBootstrapDeployer(host, user, privateKey string) (*Deployer, error) {
return newDeployerInternal(host, user, privateKey, true)
return NewDeployerWithConfig(host, user, privateKey, DefaultConnectionConfig(), true)
}

// NewDeployerWithConfig creates a new SSH connection with custom timeout configuration.
func NewDeployerWithConfig(host, user, privateKey string, config *ConnectionConfig, isBootstrap bool) (*Deployer, error) {
return newDeployerInternal(host, user, privateKey, config, isBootstrap)
}

// newDeployerInternal is the internal implementation for creating deployers.
func newDeployerInternal(host, user, privateKey string, isBootstrap bool) (*Deployer, error) {
func newDeployerInternal(host, user, privateKey string, config *ConnectionConfig, isBootstrap bool) (*Deployer, error) {
// Define the host key callback based on bootstrap mode.
var hostKeyCallback ssh.HostKeyCallback

Expand Down Expand Up @@ -107,21 +145,24 @@ func newDeployerInternal(host, user, privateKey string, isBootstrap bool) (*Depl
if privateKey != "" {
signer, err := ssh.ParsePrivateKey([]byte(privateKey))
if err == nil {
config := &ssh.ClientConfig{
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: hostKeyCallback,
Timeout: 10 * time.Second,
Timeout: config.ConnectionTimeout,
}
client, err = ssh.Dial("tcp", addr, config)
client, err = ssh.Dial("tcp", addr, sshConfig)
if err == nil {
// Success! We connected with the system key.
sftpClient, sftpErr := sftp.NewClient(client)
if sftpErr != nil {
client.Close()
return nil, fmt.Errorf("failed to create sftp client: %w", sftpErr)
}
return &Deployer{client: client, sftp: sftpClient}, nil
return &Deployer{client: client, sftp: sftpClient, config: config}, nil
} else {
// Classify the error for better debugging
err = ClassifyConnectionError(host, err)
}
// If we provided a key and it failed, we will fall through to try the agent.
}
Expand All @@ -134,15 +175,16 @@ func newDeployerInternal(host, user, privateKey string, isBootstrap bool) (*Depl
return nil, fmt.Errorf("no authentication method available (system key failed and no ssh agent found)")
}

config := &ssh.ClientConfig{
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)},
HostKeyCallback: hostKeyCallback,
Timeout: 10 * time.Second,
Timeout: config.ConnectionTimeout,
}

client, err := ssh.Dial("tcp", addr, config)
client, err := ssh.Dial("tcp", addr, sshConfig)
if err != nil {
err = ClassifyConnectionError(host, err)
return nil, fmt.Errorf("connection with ssh agent failed: %w", err)
}

Expand All @@ -157,6 +199,7 @@ func newDeployerInternal(host, user, privateKey string, isBootstrap bool) (*Depl
return &Deployer{
client: client,
sftp: sftpClient,
config: config,
}, nil
}

Expand Down Expand Up @@ -253,8 +296,84 @@ func (d *Deployer) GetAuthorizedKeys() ([]byte, error) {
// in GetRemoteHostKey once the host key has been captured.
var ErrHostKeySuccessfullyRetrieved = errors.New("keymaster: successfully retrieved host key")

// Error classification functions for better error handling

// IsConnectionTimeoutError checks if the error is due to a connection timeout
func IsConnectionTimeoutError(err error) bool {
if err == nil {
return false
}

errStr := err.Error()
// Check for various timeout-related error messages
return strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "deadline exceeded") ||
strings.Contains(errStr, "i/o timeout")
}

// IsConnectionRefusedError checks if the error is due to connection being refused
func IsConnectionRefusedError(err error) bool {
if err == nil {
return false
}

errStr := err.Error()
return strings.Contains(errStr, "connection refused") ||
strings.Contains(errStr, "no route to host")
}

// IsAuthenticationError checks if the error is due to authentication failure
func IsAuthenticationError(err error) bool {
if err == nil {
return false
}

errStr := err.Error()
return strings.Contains(errStr, "authentication failed") ||
strings.Contains(errStr, "permission denied") ||
strings.Contains(errStr, "public key") ||
strings.Contains(errStr, "unable to authenticate")
}

// IsHostKeyError checks if the error is due to host key verification failure
func IsHostKeyError(err error) bool {
if err == nil {
return false
}

errStr := err.Error()
return strings.Contains(errStr, "HOST KEY MISMATCH") ||
strings.Contains(errStr, "unknown host key") ||
strings.Contains(errStr, "host key verification failed")
}

// ClassifyConnectionError provides a more descriptive error message based on the error type
func ClassifyConnectionError(host string, err error) error {
if err == nil {
return nil
}

switch {
case IsConnectionTimeoutError(err):
return fmt.Errorf("connection to %s timed out (host may be unreachable or firewall blocking connection): %w", host, err)
case IsConnectionRefusedError(err):
return fmt.Errorf("connection to %s refused (SSH daemon may not be running or wrong port): %w", host, err)
case IsAuthenticationError(err):
return fmt.Errorf("authentication failed for %s (check SSH keys or credentials): %w", host, err)
case IsHostKeyError(err):
return fmt.Errorf("host key verification failed for %s (run 'keymaster trust-host %s' to accept): %w", host, host, err)
default:
return fmt.Errorf("failed to connect to %s: %w", host, err)
}
}

// GetRemoteHostKey connects to a host just to retrieve its public key.
func GetRemoteHostKey(host string) (ssh.PublicKey, error) {
return GetRemoteHostKeyWithTimeout(host, DefaultHostKeyTimeout)
}

// GetRemoteHostKeyWithTimeout connects to a host with a custom timeout to retrieve its public key.
func GetRemoteHostKeyWithTimeout(host string, timeout time.Duration) (ssh.PublicKey, error) {
keyChan := make(chan ssh.PublicKey, 1)

config := &ssh.ClientConfig{
Expand All @@ -266,7 +385,7 @@ func GetRemoteHostKey(host string) (ssh.PublicKey, error) {
// Return a specific error to gracefully stop the handshake.
return ErrHostKeySuccessfullyRetrieved
},
Timeout: 5 * time.Second,
Timeout: timeout,
}

addr := host
Expand All @@ -283,7 +402,8 @@ func GetRemoteHostKey(host string) (ssh.PublicKey, error) {
return <-keyChan, nil
}
// It's a different, real error (e.g., connection refused).
return nil, fmt.Errorf("failed to connect to %s: %w", host, err)
err = ClassifyConnectionError(host, err)
return nil, err
}

// This case should ideally not be reached if the callback returns an error.
Expand Down
Loading