diff --git a/easyssh.go b/easyssh.go index 15e0740..233313d 100644 --- a/easyssh.go +++ b/easyssh.go @@ -28,6 +28,11 @@ var ( defaultBufferSize = 4096 ) +var ( + // ErrProxyDialTimeout is returned when proxy dial connection times out + ErrProxyDialTimeout = errors.New("proxy dial timeout") +) + type Protocol string const ( @@ -253,7 +258,43 @@ func (ssh_conf *MakeConfig) Connect() (*ssh.Session, *ssh.Client, error) { return nil, nil, err } - conn, err := proxyClient.Dial(string(ssh_conf.Protocol), net.JoinHostPort(ssh_conf.Server, ssh_conf.Port)) + // Apply timeout to the connection from proxy to target server + timeout := ssh_conf.Timeout + if timeout == 0 { + timeout = defaultTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + type connResult struct { + conn net.Conn + err error + } + + connCh := make(chan connResult, 1) + go func() { + conn, err := proxyClient.Dial(string(ssh_conf.Protocol), net.JoinHostPort(ssh_conf.Server, ssh_conf.Port)) + select { + case connCh <- connResult{conn: conn, err: err}: + // Successfully sent result + case <-ctx.Done(): + // Context was cancelled, clean up the connection if it was established + if conn != nil { + conn.Close() + } + } + }() + + var conn net.Conn + select { + case result := <-connCh: + conn = result.conn + err = result.err + case <-ctx.Done(): + return nil, nil, fmt.Errorf("%w: %v", ErrProxyDialTimeout, ctx.Err()) + } + if err != nil { return nil, nil, err } @@ -413,6 +454,10 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<- func (ssh_conf *MakeConfig) Run(command string, timeout ...time.Duration) (outStr string, errStr string, isTimeout bool, err error) { stdoutChan, stderrChan, doneChan, errChan, err := ssh_conf.Stream(command, timeout...) if err != nil { + // Check if the error is from a proxy dial timeout + if errors.Is(err, ErrProxyDialTimeout) { + isTimeout = true + } return outStr, errStr, isTimeout, err } // read from the output channel until the done signal is passed diff --git a/easyssh_test.go b/easyssh_test.go index 5381d93..ec3edeb 100644 --- a/easyssh_test.go +++ b/easyssh_test.go @@ -2,9 +2,11 @@ package easyssh import ( "context" + "errors" "os" "os/user" "path" + "runtime" "testing" "time" @@ -512,3 +514,155 @@ func TestCommandTimeout(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error()) } + +// TestProxyTimeoutHandling tests that timeout is properly respected when using proxy connections +// This test uses a non-existent proxy server to force a timeout during proxy connection +func TestProxyTimeoutHandling(t *testing.T) { + ssh := &MakeConfig{ + Server: "example.com", + User: "testuser", + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 1 * time.Second, // Short timeout for testing + Proxy: DefaultConfig{ + User: "testuser", + Server: "10.255.255.1", // Non-routable IP that should timeout + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 1 * time.Second, + }, + } + + // Test Connect() method directly to test proxy connection timeout + start := time.Now() + session, client, err := ssh.Connect() + elapsed := time.Since(start) + + // Should timeout within reasonable bounds + assert.True(t, elapsed < 3*time.Second, "Connection should timeout within 3 seconds, took %v", elapsed) + assert.True(t, elapsed >= 1*time.Second, "Connection should take at least 1 second (timeout value), took %v", elapsed) + + // Should return nil session and client + assert.Nil(t, session) + assert.Nil(t, client) + + // Should have error + assert.NotNil(t, err) +} + +// TestProxyDialTimeout tests the specific scenario described in issue #93 +// where proxy dial timeout should be respected and properly detected +func TestProxyDialTimeout(t *testing.T) { + ssh := &MakeConfig{ + Server: "10.255.255.1", // Non-routable IP that should timeout + User: "testuser", + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 2 * time.Second, // Short timeout for testing + Proxy: DefaultConfig{ + User: "testuser", + Server: "10.255.255.2", // Another non-routable IP for proxy + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 2 * time.Second, + }, + } + + // Test Connect() method directly to avoid SSH server dependency + start := time.Now() + session, client, err := ssh.Connect() + elapsed := time.Since(start) + + // Should timeout within reasonable bounds + assert.True(t, elapsed < 5*time.Second, "Connection should timeout within 5 seconds, took %v", elapsed) + assert.True(t, elapsed >= 2*time.Second, "Connection should take at least 2 seconds (timeout value), took %v", elapsed) + + // Should return nil session and client + assert.Nil(t, session) + assert.Nil(t, client) + + // Should have error + assert.NotNil(t, err) + // Note: This will timeout at the proxy connection level, not at proxy dial level + // so it won't be ErrProxyDialTimeout, but we can still verify the timeout behavior +} + +// TestProxyDialTimeoutInRun tests timeout detection in Run method +func TestProxyDialTimeoutInRun(t *testing.T) { + ssh := &MakeConfig{ + Server: "example.com", + User: "testuser", + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 2 * time.Second, + Proxy: DefaultConfig{ + User: "testuser", + Server: "127.0.0.1", // Assume localhost SSH exists + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 2 * time.Second, + }, + } + + // Mock a scenario where Connect() returns ErrProxyDialTimeout + // by temporarily changing the target to a non-routable address + ssh.Server = "10.255.255.1" + + start := time.Now() + outStr, errStr, isTimeout, err := ssh.Run("whoami") + elapsed := time.Since(start) + + // Should timeout within reasonable bounds + assert.True(t, elapsed < 5*time.Second, "Should timeout within 5 seconds, took %v", elapsed) + + // Should return empty output + assert.Equal(t, "", outStr) + assert.Equal(t, "", errStr) + + // Should have error + assert.NotNil(t, err) + + // If it's specifically a proxy dial timeout, isTimeout should be true + if errors.Is(err, ErrProxyDialTimeout) { + assert.True(t, isTimeout, "isTimeout should be true for proxy dial timeout") + } +} + +// TestProxyGoroutineLeak tests that no goroutines are leaked when proxy dial times out +func TestProxyGoroutineLeak(t *testing.T) { + // Get initial goroutine count + initialGoroutines := runtime.NumGoroutine() + + ssh := &MakeConfig{ + Server: "10.255.255.1", // Non-routable IP that should timeout + User: "testuser", + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 1 * time.Second, // Short timeout + Proxy: DefaultConfig{ + User: "testuser", + Server: "10.255.255.2", // Another non-routable IP for proxy + Port: "22", + KeyPath: "./tests/.ssh/id_rsa", + Timeout: 1 * time.Second, + }, + } + + // Run multiple timeout operations + for i := 0; i < 5; i++ { + _, _, err := ssh.Connect() + assert.NotNil(t, err) // Should have error due to timeout + } + + // Give some time for goroutines to cleanup + time.Sleep(100 * time.Millisecond) + runtime.GC() // Force garbage collection + + // Check final goroutine count - should not have grown significantly + finalGoroutines := runtime.NumGoroutine() + + // Allow for some variance due to test framework overhead, but shouldn't grow by more than 2-3 goroutines + assert.True(t, finalGoroutines <= initialGoroutines+3, + "Goroutine leak detected: initial=%d, final=%d", initialGoroutines, finalGoroutines) +} +