From e579c76762c7fbac897d93a5fa98f931d65d853b Mon Sep 17 00:00:00 2001 From: acmore Date: Wed, 22 Apr 2026 23:45:24 +0800 Subject: [PATCH 1/2] fix: improve exec and ssh reliability --- internal/cli/ssh.go | 18 ++++++++++--- internal/cli/ssh_config_test.go | 46 +++++++++++++++++++++++++++++++++ internal/connect/runner.go | 20 +++++++++++--- internal/connect/runner_test.go | 46 +++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 8 deletions(-) diff --git a/internal/cli/ssh.go b/internal/cli/ssh.go index a590f96f..4d936f98 100644 --- a/internal/cli/ssh.go +++ b/internal/cli/ssh.go @@ -899,7 +899,7 @@ func startManagedSSHForward(hostAlias string, sshSpec config.SSHSpec) error { if err != nil { return err } - check := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias) + check := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias) if err := check.Run(); err == nil { return nil } @@ -915,18 +915,28 @@ func startManagedSSHForwardWithForwards(hostAlias string, forwards []config.Port if err != nil { return err } - check := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias) + check := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias) if err := check.Run(); err == nil { return nil } args := managedSSHForwardArgs(hostAlias, configPath, socketPath, forwards, sshSpec) - cmd := exec.Command(args[0], args[1:]...) + if err := removeStaleSSHControlSocket(socketPath); err != nil { + return err + } + cmd := execCommand(args[0], args[1:]...) if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("start managed ssh forward: %w (%s)", err, strings.TrimSpace(string(out))) } return nil } +func removeStaleSSHControlSocket(socketPath string) error { + if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("remove stale ssh control socket %q: %w", socketPath, err) + } + return nil +} + func managedSSHForwardArgs(hostAlias, configPath, socketPath string, forwards []config.PortMapping, sshSpec config.SSHSpec) []string { args := []string{ "ssh", @@ -965,7 +975,7 @@ func stopManagedSSHForward(hostAlias string) error { if err != nil { return err } - cmd := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "exit", hostAlias) + cmd := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "exit", hostAlias) if out, err := cmd.CombinedOutput(); err != nil { msg := strings.ToLower(strings.TrimSpace(string(out))) if strings.Contains(msg, "no such file") || strings.Contains(msg, "control socket connect") || strings.Contains(msg, "master running") { diff --git a/internal/cli/ssh_config_test.go b/internal/cli/ssh_config_test.go index 66b0989f..cb54adb3 100644 --- a/internal/cli/ssh_config_test.go +++ b/internal/cli/ssh_config_test.go @@ -2,6 +2,7 @@ package cli import ( "os" + "os/exec" "path/filepath" "reflect" "strings" @@ -128,3 +129,48 @@ func TestManagedSSHForwardArgs(t *testing.T) { t.Fatalf("unexpected args:\n got: %#v\nwant: %#v", args, want) } } + +func TestStartManagedSSHForwardWithForwardsRemovesStaleSocketAfterFailedCheck(t *testing.T) { + home := t.TempDir() + origHome := os.Getenv("HOME") + if err := os.Setenv("HOME", home); err != nil { + t.Fatalf("set HOME: %v", err) + } + defer func() { + _ = os.Setenv("HOME", origHome) + }() + + origExecCommand := execCommand + t.Cleanup(func() { + execCommand = origExecCommand + }) + + socketPath := filepath.Join(home, ".okdev", "ssh", "okdev-test.sock") + if err := os.MkdirAll(filepath.Dir(socketPath), 0o700); err != nil { + t.Fatalf("mkdir socket dir: %v", err) + } + if err := os.WriteFile(socketPath, []byte("stale"), 0o600); err != nil { + t.Fatalf("write stale socket file: %v", err) + } + + var calls [][]string + execCommand = func(name string, args ...string) *exec.Cmd { + calls = append(calls, append([]string{name}, args...)) + if len(args) >= 4 && args[2] == "-O" && args[3] == "check" { + return exec.Command("sh", "-c", "exit 1") + } + cmd := exec.Command("sh", "-c", `[ ! -e "$SOCKET_PATH" ]`) + cmd.Env = append(os.Environ(), "SOCKET_PATH="+socketPath) + return cmd + } + + if err := startManagedSSHForwardWithForwards("okdev-test", nil, config.SSHSpec{}); err != nil { + t.Fatalf("startManagedSSHForwardWithForwards: %v", err) + } + if len(calls) != 2 { + t.Fatalf("expected check and start calls, got %d", len(calls)) + } + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("expected stale socket to be removed, got err=%v", err) + } +} diff --git a/internal/connect/runner.go b/internal/connect/runner.go index eb648207..cf4e05ff 100644 --- a/internal/connect/runner.go +++ b/internal/connect/runner.go @@ -30,6 +30,12 @@ func Run(ctx context.Context, client ExecClient, namespace, pod string, command } func RunWithRetry(ctx context.Context, client ExecClient, namespace, pod string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error { + return runExecWithRetry(ctx, stderr, policy, func() error { + return client.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr) + }) +} + +func runExecWithRetry(ctx context.Context, stderr io.Writer, policy RetryPolicy, execFn func() error) error { if policy.MaxAttempts <= 0 { policy.MaxAttempts = 1 } @@ -46,7 +52,7 @@ func RunWithRetry(ctx context.Context, client ExecClient, namespace, pod string, if ctx.Err() != nil { return ctx.Err() } - err := client.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr) + err := execFn() if err == nil { return nil } @@ -85,14 +91,20 @@ type ExecContainerClient interface { // RunOnContainer executes a command in a specific container within a pod. // If container is empty, it falls back to the default ExecInteractive behavior. func RunOnContainer(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { + return runOnContainerWithRetry(ctx, client, namespace, pod, container, command, tty, stdin, stdout, stderr, DefaultRetryPolicy) +} + +func runOnContainerWithRetry(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error { if container == "" { - return Run(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr) + return RunWithRetry(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr, policy) } ec, ok := client.(ExecContainerClient) if !ok { - return Run(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr) + return RunWithRetry(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr, policy) } - return ec.ExecInteractiveInContainer(ctx, namespace, pod, container, tty, command, stdin, stdout, stderr) + return runExecWithRetry(ctx, stderr, policy, func() error { + return ec.ExecInteractiveInContainer(ctx, namespace, pod, container, tty, command, stdin, stdout, stderr) + }) } func isRetryableError(err error) bool { diff --git a/internal/connect/runner_test.go b/internal/connect/runner_test.go index d77cb3ec..f6c148af 100644 --- a/internal/connect/runner_test.go +++ b/internal/connect/runner_test.go @@ -23,6 +23,16 @@ func (f *fakeExecClient) ExecInteractive(ctx context.Context, namespace, pod str return err } +type fakeExecContainerClient struct { + fakeExecClient + containerCalls int +} + +func (f *fakeExecContainerClient) ExecInteractiveInContainer(ctx context.Context, namespace, pod, container string, tty bool, command []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { + f.containerCalls++ + return f.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr) +} + func TestRunWithRetrySucceedsAfterTransient(t *testing.T) { fc := &fakeExecClient{errs: []error{errors.New("EOF"), nil}} var out bytes.Buffer @@ -73,3 +83,39 @@ func TestRunWithRetryNoRetryOnNonTransient(t *testing.T) { t.Fatalf("expected no retry, got calls=%d", fc.calls) } } + +func TestRunOnContainerRetriesTransientContainerExec(t *testing.T) { + fc := &fakeExecContainerClient{ + fakeExecClient: fakeExecClient{errs: []error{errors.New("EOF"), nil}}, + } + var out bytes.Buffer + err := runOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ + MaxAttempts: 3, + InitialBackoff: 1 * time.Millisecond, + MaxBackoff: 2 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + if fc.containerCalls < 2 { + t.Fatalf("expected container exec retry, got calls=%d", fc.containerCalls) + } +} + +func TestRunOnContainerDoesNotRetryNonTransientContainerExec(t *testing.T) { + fc := &fakeExecContainerClient{ + fakeExecClient: fakeExecClient{errs: []error{errors.New("permission denied")}}, + } + var out bytes.Buffer + err := runOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ + MaxAttempts: 3, + InitialBackoff: 1 * time.Millisecond, + MaxBackoff: 2 * time.Millisecond, + }) + if err == nil { + t.Fatal("expected error") + } + if fc.containerCalls != 1 { + t.Fatalf("expected no retry for non-transient container error, got calls=%d", fc.containerCalls) + } +} From 446a7dab52f56b2e6b98ccb9cf09208f6bcaebbd Mon Sep 17 00:00:00 2001 From: acmore Date: Thu, 23 Apr 2026 08:51:53 +0800 Subject: [PATCH 2/2] Fix ssh stale socket recovery and container exec retry --- internal/cli/ssh.go | 2 ++ internal/cli/ssh_config_test.go | 8 +------- internal/connect/runner.go | 12 ++++++++---- internal/connect/runner_test.go | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/internal/cli/ssh.go b/internal/cli/ssh.go index 4d936f98..ed81f755 100644 --- a/internal/cli/ssh.go +++ b/internal/cli/ssh.go @@ -920,6 +920,8 @@ func startManagedSSHForwardWithForwards(hostAlias string, forwards []config.Port return nil } args := managedSSHForwardArgs(hostAlias, configPath, socketPath, forwards, sshSpec) + // `-O check` failed, so any existing socket file is from a dead master and + // would otherwise prevent `ssh -M` from binding a new one. if err := removeStaleSSHControlSocket(socketPath); err != nil { return err } diff --git a/internal/cli/ssh_config_test.go b/internal/cli/ssh_config_test.go index cb54adb3..d00caf25 100644 --- a/internal/cli/ssh_config_test.go +++ b/internal/cli/ssh_config_test.go @@ -132,13 +132,7 @@ func TestManagedSSHForwardArgs(t *testing.T) { func TestStartManagedSSHForwardWithForwardsRemovesStaleSocketAfterFailedCheck(t *testing.T) { home := t.TempDir() - origHome := os.Getenv("HOME") - if err := os.Setenv("HOME", home); err != nil { - t.Fatalf("set HOME: %v", err) - } - defer func() { - _ = os.Setenv("HOME", origHome) - }() + t.Setenv("HOME", home) origExecCommand := execCommand t.Cleanup(func() { diff --git a/internal/connect/runner.go b/internal/connect/runner.go index cf4e05ff..93408475 100644 --- a/internal/connect/runner.go +++ b/internal/connect/runner.go @@ -88,13 +88,17 @@ type ExecContainerClient interface { ExecInteractiveInContainer(ctx context.Context, namespace, pod, container string, tty bool, command []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error } -// RunOnContainer executes a command in a specific container within a pod. -// If container is empty, it falls back to the default ExecInteractive behavior. +// RunOnContainer executes a command in a specific container within a pod using +// DefaultRetryPolicy. Transient connect errors are retried the same way as +// RunWithRetry; non-transient errors (including non-zero command exit codes) +// surface immediately. If container is empty, it falls back to the +// pod-level ExecInteractive path. func RunOnContainer(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { - return runOnContainerWithRetry(ctx, client, namespace, pod, container, command, tty, stdin, stdout, stderr, DefaultRetryPolicy) + return RunOnContainerWithRetry(ctx, client, namespace, pod, container, command, tty, stdin, stdout, stderr, DefaultRetryPolicy) } -func runOnContainerWithRetry(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error { +// RunOnContainerWithRetry is the explicit-policy variant of RunOnContainer. +func RunOnContainerWithRetry(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error { if container == "" { return RunWithRetry(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr, policy) } diff --git a/internal/connect/runner_test.go b/internal/connect/runner_test.go index f6c148af..c2514e4d 100644 --- a/internal/connect/runner_test.go +++ b/internal/connect/runner_test.go @@ -89,7 +89,7 @@ func TestRunOnContainerRetriesTransientContainerExec(t *testing.T) { fakeExecClient: fakeExecClient{errs: []error{errors.New("EOF"), nil}}, } var out bytes.Buffer - err := runOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ + err := RunOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ MaxAttempts: 3, InitialBackoff: 1 * time.Millisecond, MaxBackoff: 2 * time.Millisecond, @@ -107,7 +107,7 @@ func TestRunOnContainerDoesNotRetryNonTransientContainerExec(t *testing.T) { fakeExecClient: fakeExecClient{errs: []error{errors.New("permission denied")}}, } var out bytes.Buffer - err := runOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ + err := RunOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{ MaxAttempts: 3, InitialBackoff: 1 * time.Millisecond, MaxBackoff: 2 * time.Millisecond,