Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions internal/cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ func startManagedSSHForward(hostAlias string, sshSpec config.SSHSpec) error {
if err != nil {
return err
}
check := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias)
check := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias)
if err := check.Run(); err == nil {
return nil
}
Expand All @@ -915,18 +915,30 @@ func startManagedSSHForwardWithForwards(hostAlias string, forwards []config.Port
if err != nil {
return err
}
check := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias)
check := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "check", hostAlias)
if err := check.Run(); err == nil {
return nil
}
args := managedSSHForwardArgs(hostAlias, configPath, socketPath, forwards, sshSpec)
cmd := exec.Command(args[0], args[1:]...)
// `-O check` failed, so any existing socket file is from a dead master and
// would otherwise prevent `ssh -M` from binding a new one.
if err := removeStaleSSHControlSocket(socketPath); err != nil {
return err
}
cmd := execCommand(args[0], args[1:]...)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("start managed ssh forward: %w (%s)", err, strings.TrimSpace(string(out)))
}
return nil
}

func removeStaleSSHControlSocket(socketPath string) error {
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove stale ssh control socket %q: %w", socketPath, err)
}
return nil
}

func managedSSHForwardArgs(hostAlias, configPath, socketPath string, forwards []config.PortMapping, sshSpec config.SSHSpec) []string {
args := []string{
"ssh",
Expand Down Expand Up @@ -965,7 +977,7 @@ func stopManagedSSHForward(hostAlias string) error {
if err != nil {
return err
}
cmd := exec.Command("ssh", "-F", configPath, "-S", socketPath, "-O", "exit", hostAlias)
cmd := execCommand("ssh", "-F", configPath, "-S", socketPath, "-O", "exit", hostAlias)
if out, err := cmd.CombinedOutput(); err != nil {
msg := strings.ToLower(strings.TrimSpace(string(out)))
if strings.Contains(msg, "no such file") || strings.Contains(msg, "control socket connect") || strings.Contains(msg, "master running") {
Expand Down
40 changes: 40 additions & 0 deletions internal/cli/ssh_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"os"
"os/exec"
"path/filepath"
"reflect"
"strings"
Expand Down Expand Up @@ -128,3 +129,42 @@ func TestManagedSSHForwardArgs(t *testing.T) {
t.Fatalf("unexpected args:\n got: %#v\nwant: %#v", args, want)
}
}

func TestStartManagedSSHForwardWithForwardsRemovesStaleSocketAfterFailedCheck(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)

origExecCommand := execCommand
t.Cleanup(func() {
execCommand = origExecCommand
})

socketPath := filepath.Join(home, ".okdev", "ssh", "okdev-test.sock")
if err := os.MkdirAll(filepath.Dir(socketPath), 0o700); err != nil {
t.Fatalf("mkdir socket dir: %v", err)
}
if err := os.WriteFile(socketPath, []byte("stale"), 0o600); err != nil {
t.Fatalf("write stale socket file: %v", err)
}

var calls [][]string
execCommand = func(name string, args ...string) *exec.Cmd {
calls = append(calls, append([]string{name}, args...))
if len(args) >= 4 && args[2] == "-O" && args[3] == "check" {
return exec.Command("sh", "-c", "exit 1")
}
cmd := exec.Command("sh", "-c", `[ ! -e "$SOCKET_PATH" ]`)
cmd.Env = append(os.Environ(), "SOCKET_PATH="+socketPath)
return cmd
}

if err := startManagedSSHForwardWithForwards("okdev-test", nil, config.SSHSpec{}); err != nil {
t.Fatalf("startManagedSSHForwardWithForwards: %v", err)
}
if len(calls) != 2 {
t.Fatalf("expected check and start calls, got %d", len(calls))
}
if _, err := os.Stat(socketPath); !os.IsNotExist(err) {
t.Fatalf("expected stale socket to be removed, got err=%v", err)
}
}
28 changes: 22 additions & 6 deletions internal/connect/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ func Run(ctx context.Context, client ExecClient, namespace, pod string, command
}

func RunWithRetry(ctx context.Context, client ExecClient, namespace, pod string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error {
return runExecWithRetry(ctx, stderr, policy, func() error {
return client.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr)
})
}

func runExecWithRetry(ctx context.Context, stderr io.Writer, policy RetryPolicy, execFn func() error) error {
if policy.MaxAttempts <= 0 {
policy.MaxAttempts = 1
}
Expand All @@ -46,7 +52,7 @@ func RunWithRetry(ctx context.Context, client ExecClient, namespace, pod string,
if ctx.Err() != nil {
return ctx.Err()
}
err := client.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr)
err := execFn()
if err == nil {
return nil
}
Expand Down Expand Up @@ -82,17 +88,27 @@ type ExecContainerClient interface {
ExecInteractiveInContainer(ctx context.Context, namespace, pod, container string, tty bool, command []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error
}

// RunOnContainer executes a command in a specific container within a pod.
// If container is empty, it falls back to the default ExecInteractive behavior.
// RunOnContainer executes a command in a specific container within a pod using
// DefaultRetryPolicy. Transient connect errors are retried the same way as
// RunWithRetry; non-transient errors (including non-zero command exit codes)
// surface immediately. If container is empty, it falls back to the
// pod-level ExecInteractive path.
func RunOnContainer(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
return RunOnContainerWithRetry(ctx, client, namespace, pod, container, command, tty, stdin, stdout, stderr, DefaultRetryPolicy)
}

// RunOnContainerWithRetry is the explicit-policy variant of RunOnContainer.
func RunOnContainerWithRetry(ctx context.Context, client ExecClient, namespace, pod, container string, command []string, tty bool, stdin io.Reader, stdout io.Writer, stderr io.Writer, policy RetryPolicy) error {
if container == "" {
return Run(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr)
return RunWithRetry(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr, policy)
}
ec, ok := client.(ExecContainerClient)
if !ok {
return Run(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr)
return RunWithRetry(ctx, client, namespace, pod, command, tty, stdin, stdout, stderr, policy)
}
return ec.ExecInteractiveInContainer(ctx, namespace, pod, container, tty, command, stdin, stdout, stderr)
return runExecWithRetry(ctx, stderr, policy, func() error {
return ec.ExecInteractiveInContainer(ctx, namespace, pod, container, tty, command, stdin, stdout, stderr)
})
}

func isRetryableError(err error) bool {
Expand Down
46 changes: 46 additions & 0 deletions internal/connect/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ func (f *fakeExecClient) ExecInteractive(ctx context.Context, namespace, pod str
return err
}

type fakeExecContainerClient struct {
fakeExecClient
containerCalls int
}

func (f *fakeExecContainerClient) ExecInteractiveInContainer(ctx context.Context, namespace, pod, container string, tty bool, command []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
f.containerCalls++
return f.ExecInteractive(ctx, namespace, pod, tty, command, stdin, stdout, stderr)
}

func TestRunWithRetrySucceedsAfterTransient(t *testing.T) {
fc := &fakeExecClient{errs: []error{errors.New("EOF"), nil}}
var out bytes.Buffer
Expand Down Expand Up @@ -73,3 +83,39 @@ func TestRunWithRetryNoRetryOnNonTransient(t *testing.T) {
t.Fatalf("expected no retry, got calls=%d", fc.calls)
}
}

func TestRunOnContainerRetriesTransientContainerExec(t *testing.T) {
fc := &fakeExecContainerClient{
fakeExecClient: fakeExecClient{errs: []error{errors.New("EOF"), nil}},
}
var out bytes.Buffer
err := RunOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{
MaxAttempts: 3,
InitialBackoff: 1 * time.Millisecond,
MaxBackoff: 2 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
if fc.containerCalls < 2 {
t.Fatalf("expected container exec retry, got calls=%d", fc.containerCalls)
}
}

func TestRunOnContainerDoesNotRetryNonTransientContainerExec(t *testing.T) {
fc := &fakeExecContainerClient{
fakeExecClient: fakeExecClient{errs: []error{errors.New("permission denied")}},
}
var out bytes.Buffer
err := RunOnContainerWithRetry(context.Background(), fc, "ns", "pod", "dev", []string{"sh"}, true, &out, &out, &out, RetryPolicy{
MaxAttempts: 3,
InitialBackoff: 1 * time.Millisecond,
MaxBackoff: 2 * time.Millisecond,
})
if err == nil {
t.Fatal("expected error")
}
if fc.containerCalls != 1 {
t.Fatalf("expected no retry for non-transient container error, got calls=%d", fc.containerCalls)
}
}
Loading