From 4c2eb0d27d003e16e9b4754c9489d34a6496f220 Mon Sep 17 00:00:00 2001 From: Ramiro <64089641+ramiro-gamarra@users.noreply.github.com> Date: Tue, 7 Mar 2023 13:42:03 -0800 Subject: [PATCH] Improving port forwarding error handling (#1839) * adding error handling hook to portforwarder. documenting exported symbols. removing unnecessary build constraints * improving the port forward wrapper api --- test/integration/k8s_test.go | 142 ++++++++++++++++-------------- test/integration/label.go | 2 - test/integration/portforward.go | 151 +++++++++++++++++++++++--------- 3 files changed, 186 insertions(+), 109 deletions(-) diff --git a/test/integration/k8s_test.go b/test/integration/k8s_test.go index 7c5d30ac5b..448f232991 100644 --- a/test/integration/k8s_test.go +++ b/test/integration/k8s_test.go @@ -4,8 +4,6 @@ package k8s import ( "context" - "log" - //"dnc/test/integration/goldpinger" "errors" "flag" @@ -43,7 +41,6 @@ var ( kubeconfig = flag.String("test-kubeconfig", filepath.Join(homedir.HomeDir(), ".kube", "config"), "(optional) absolute path to the kubeconfig file") delegatedSubnetID = flag.String("delegated-subnet-id", "", "delegated subnet id for node labeling") delegatedSubnetName = flag.String("subnet-name", "", "subnet name for node labeling") - gpPodScaleCounts = []int{2, 10, 100, 2} ) func shouldLabelNodes() bool { @@ -141,8 +138,12 @@ func TestPodScaling(t *testing.T) { } }) + podsClient := clientset.CoreV1().Pods(deployment.Namespace) + + gpPodScaleCounts := []int{2, 10, 100, 2} for _, c := range gpPodScaleCounts { count := c + t.Run(fmt.Sprintf("replica count %d", count), func(t *testing.T) { replicaCtx, cancel := context.WithTimeout(ctx, (retryAttempts+1)*retryDelaySec) defer cancel() @@ -151,93 +152,98 @@ func TestPodScaling(t *testing.T) { t.Fatalf("could not scale deployment: %v", err) } - if !t.Run("all pods have IPs assigned", func(t *testing.T) { - podsClient := clientset.CoreV1().Pods(deployment.Namespace) + t.Log("checking that all pods have IPs assigned") - checkPodIPsFn := func() error { - podList, err := podsClient.List(ctx, metav1.ListOptions{LabelSelector: "app=goldpinger"}) - if err != nil { - return err - } + checkPodIPsFn := func() error { + podList, err := podsClient.List(ctx, metav1.ListOptions{LabelSelector: "app=goldpinger"}) + if err != nil { + return err + } - if len(podList.Items) == 0 { - return errors.New("no pods scheduled") - } + if len(podList.Items) == 0 { + return errors.New("no pods scheduled") + } - for _, pod := range podList.Items { - if pod.Status.Phase == apiv1.PodPending { - return errors.New("some pods still pending") - } + for _, pod := range podList.Items { + if pod.Status.Phase == apiv1.PodPending { + return errors.New("some pods still pending") } + } - for _, pod := range podList.Items { - if pod.Status.PodIP == "" { - return errors.New("a pod has not been allocated an IP") - } + for _, pod := range podList.Items { + if pod.Status.PodIP == "" { + return errors.New("a pod has not been allocated an IP") } - - return nil } - err := defaultRetrier.Do(ctx, checkPodIPsFn) + + return nil + } + + if err := defaultRetrier.Do(ctx, checkPodIPsFn); err != nil { + t.Fatalf("not all pods were allocated IPs: %v", err) + } + + t.Log("all pods have been allocated IPs") + t.Log("checking that all pods can ping each other") + + clusterCheckCtx, cancel := context.WithTimeout(ctx, 20*time.Minute) + defer cancel() + + pfOpts := PortForwardingOpts{ + Namespace: "default", + LabelSelector: "type=goldpinger-pod", + LocalPort: 9090, + DestPort: 8080, + } + + pingCheckFn := func() error { + pf, err := NewPortForwarder(restConfig, t, pfOpts) if err != nil { - t.Fatalf("not all pods were allocated IPs: %v", err) + t.Fatalf("could not build port forwarder: %v", err) } - t.Log("all pods have been allocated IPs") - }) { - errors.New("Pods don't have IP's") - return - } - t.Run("all pods can ping each other", func(t *testing.T) { - clusterCheckCtx, cancel := context.WithTimeout(ctx, 20*time.Minute) + portForwardCtx, cancel := context.WithTimeout(ctx, (retryAttempts+1)*retryDelaySec) defer cancel() - clusterCheckFn := func() error { - pf, err := NewPortForwarder(restConfig) - if err != nil { - t.Fatal(err) - } - - portForwardCtx, cancel := context.WithTimeout(ctx, (retryAttempts+1)*retryDelaySec) - defer cancel() - var streamHandle PortForwardStreamHandle - portForwardFn := func() error { - log.Printf("attempting port forward") - handle, err := pf.Forward(ctx, "default", "type=goldpinger-pod", 9090, 8080) - if err != nil { - return err - } + portForwardFn := func() error { + t.Log("attempting port forward") - streamHandle = handle - return nil + if err := pf.Forward(portForwardCtx); err != nil { + return fmt.Errorf("could not start port forward: %w", err) } - if err := defaultRetrier.Do(portForwardCtx, portForwardFn); err != nil { - t.Fatalf("could not start port forward within %v: %v", retryDelaySec.String(), err) - } - defer streamHandle.Stop() - gpClient := goldpinger.Client{Host: streamHandle.Url()} + return nil + } - clusterState, err := gpClient.CheckAll(clusterCheckCtx) - if err != nil { - return err - } + if err := defaultRetrier.Do(portForwardCtx, portForwardFn); err != nil { + t.Fatalf("could not start port forward within %v: %v", retryDelaySec.String(), err) + } - stats := goldpinger.ClusterStats(clusterState) - stats.PrintStats() - if stats.AllPingsHealthy() { - return nil - } + go pf.KeepAlive(clusterCheckCtx) + + defer pf.Stop() + + gpClient := goldpinger.Client{Host: pf.Address()} - return errors.New("not all pings are healthy") + clusterState, err := gpClient.CheckAll(clusterCheckCtx) + if err != nil { + return fmt.Errorf("could not check all goldpinger pods: %w", err) } - if err := defaultRetrier.Do(clusterCheckCtx, clusterCheckFn); err != nil { - t.Fatalf("cluster could not reach healthy state: %v", err) + stats := goldpinger.ClusterStats(clusterState) + stats.PrintStats() + if stats.AllPingsHealthy() { + return nil } - t.Log("all pings successful!") - }) + return errors.New("not all pings are healthy") + } + + if err := defaultRetrier.Do(clusterCheckCtx, pingCheckFn); err != nil { + t.Fatalf("cluster could not reach healthy state: %v", err) + } + + t.Log("all pings successful!") }) } } diff --git a/test/integration/label.go b/test/integration/label.go index 314ca954fc..a8f439f0c2 100644 --- a/test/integration/label.go +++ b/test/integration/label.go @@ -1,5 +1,3 @@ -//go:build integration - package k8s import ( diff --git a/test/integration/portforward.go b/test/integration/portforward.go index d3f601fb5b..605bd70602 100644 --- a/test/integration/portforward.go +++ b/test/integration/portforward.go @@ -1,5 +1,3 @@ -//go:build integration - package k8s import ( @@ -8,7 +6,10 @@ import ( "io" "math/rand" "net/http" + "sync" + "time" + "github.com/pkg/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -16,95 +17,167 @@ import ( "k8s.io/client-go/transport/spdy" ) +type logger interface { + Logf(format string, args ...any) +} + +// PortForwarder can manage a port forwarding session. type PortForwarder struct { clientset *kubernetes.Clientset transport http.RoundTripper upgrader spdy.Upgrader -} + logger logger -type PortForwardStreamHandle struct { - url string - stopChan chan struct{} -} + opts PortForwardingOpts -func (p *PortForwardStreamHandle) Stop() { - p.stopChan <- struct{}{} + stopChan chan struct{} + errChan chan error + address string + lazyAddress sync.Once } -func (p *PortForwardStreamHandle) Url() string { - return p.url +type PortForwardingOpts struct { + Namespace string + LabelSelector string + LocalPort int + DestPort int } -func NewPortForwarder(restConfig *rest.Config) (*PortForwarder, error) { +// NewPortForwarder creates a PortForwarder. +func NewPortForwarder(restConfig *rest.Config, logger logger, opts PortForwardingOpts) (*PortForwarder, error) { clientset, err := kubernetes.NewForConfig(restConfig) if err != nil { - return nil, fmt.Errorf("could not create clientset: %v", err) + return nil, fmt.Errorf("could not create clientset: %w", err) } + transport, upgrader, err := spdy.RoundTripperFor(restConfig) if err != nil { - return nil, fmt.Errorf("could not create spdy roundtripper: %v", err) + return nil, fmt.Errorf("could not create spdy roundtripper: %w", err) } + return &PortForwarder{ clientset: clientset, transport: transport, upgrader: upgrader, + logger: logger, + opts: opts, + stopChan: make(chan struct{}, 1), }, nil } // todo: can be made more flexible to allow a service to be specified -func (p *PortForwarder) Forward(ctx context.Context, namespace, labelSelector string, localPort, destPort int) (PortForwardStreamHandle, error) { - pods, err := p.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector, FieldSelector: "status.phase=Running"}) + +// Forward attempts to initiate port forwarding a pod and port using the configured namespace and labels. +// An error is returned if a port forwarding session could not be started. If no error is returned, the +// Address method can be used to communicate with the pod, and the Stop and KeepAlive methods can be used +// to manage the lifetime of the port forwarding session. +func (p *PortForwarder) Forward(ctx context.Context) error { + pods, err := p.clientset.CoreV1().Pods(p.opts.Namespace).List(ctx, metav1.ListOptions{LabelSelector: p.opts.LabelSelector, FieldSelector: "status.phase=Running"}) if err != nil { - return PortForwardStreamHandle{}, fmt.Errorf("could not list pods in %q with label %q: %v", namespace, labelSelector, err) + return fmt.Errorf("could not list pods in %q with label %q: %w", p.opts.Namespace, p.opts.LabelSelector, err) } + if len(pods.Items) < 1 { - return PortForwardStreamHandle{}, fmt.Errorf("no pods found in %q with label %q", namespace, labelSelector) + return fmt.Errorf("no pods found in %q with label %q", p.opts.Namespace, p.opts.LabelSelector) //nolint:goerr113 //no specific handling expected } + randomIndex := rand.Intn(len(pods.Items)) podName := pods.Items[randomIndex].Name portForwardURL := p.clientset.CoreV1().RESTClient().Post(). Resource("pods"). - Namespace(namespace). + Namespace(p.opts.Namespace). Name(podName). SubResource("portforward").URL() - stopChan := make(chan struct{}, 1) - errChan := make(chan error, 1) readyChan := make(chan struct{}, 1) - dialer := spdy.NewDialer(p.upgrader, &http.Client{Transport: p.transport}, http.MethodPost, portForwardURL) - ports := []string{fmt.Sprintf("%d:%d", localPort, destPort)} - pf, err := portforward.New(dialer, ports, stopChan, readyChan, io.Discard, io.Discard) + ports := []string{fmt.Sprintf("%d:%d", p.opts.LocalPort, p.opts.DestPort)} + pf, err := portforward.New(dialer, ports, p.stopChan, readyChan, io.Discard, io.Discard) if err != nil { - return PortForwardStreamHandle{}, fmt.Errorf("could not create portforwarder: %v", err) + return fmt.Errorf("could not create portforwarder: %w", err) } + errChan := make(chan error, 1) go func() { + // ForwardPorts is a blocking function thus it has to be invoked in a goroutine to allow callers to do + // other things, but it can return 2 kinds of errors: initial dial errors that will be caught in the select + // block below (Ready should not fire in these cases) and later errors if the connection is dropped. + // this is why we propagate the error channel to PortForwardStreamHandle: to allow callers to handle + // cases of eventual errors. errChan <- pf.ForwardPorts() }() var portForwardPort int select { case <-ctx.Done(): - return PortForwardStreamHandle{}, ctx.Err() + return fmt.Errorf("portforward cancelled: %w", ctx.Err()) case err := <-errChan: - return PortForwardStreamHandle{}, fmt.Errorf("portforward failed: %v", err) + return fmt.Errorf("portforward failed: %w", err) case <-pf.Ready: - ports, err := pf.GetPorts() + prts, err := pf.GetPorts() if err != nil { - return PortForwardStreamHandle{}, fmt.Errorf("get portforward port: %v", err) - } - for _, port := range ports { - portForwardPort = int(port.Local) - break + return fmt.Errorf("get portforward port: %w", err) } - if portForwardPort < 1 { - return PortForwardStreamHandle{}, fmt.Errorf("invalid port returned: %d", portForwardPort) + + if len(prts) < 1 { + return errors.New("no ports forwarded") } + + portForwardPort = int(prts[0].Local) } - return PortForwardStreamHandle{ - url: fmt.Sprintf("http://localhost:%d", portForwardPort), - stopChan: stopChan, - }, nil + // once successful, any subsequent port forwarding sessions from keep alive would yield the same address. + // since the address could be read at the same time as the session is renewed, it's appropriate to initialize + // lazily. + p.lazyAddress.Do(func() { + p.address = fmt.Sprintf("http://localhost:%d", portForwardPort) + }) + + p.errChan = errChan + + return nil +} + +// Address returns an address for communicating with a port-forwarded pod. +func (p *PortForwarder) Address() string { + return p.address +} + +// Stop terminates a port forwarding session. +func (p *PortForwarder) Stop() { + select { + case p.stopChan <- struct{}{}: + default: + } +} + +// KeepAlive can be used to restart the port forwarding session in the background. +func (p *PortForwarder) KeepAlive(ctx context.Context) { + for { + select { + case <-ctx.Done(): + p.logger.Logf("port forwarder: keep alive cancelled: %v", ctx.Err()) + return + case pfErr := <-p.errChan: + // as of client-go v0.26.1, if the connection is successful at first but then fails, + // an error is logged but only a nil error is sent to this channel. this will be fixed + // in v0.27.x, which at the time of writing has not been released. + // + // see https://github.com/kubernetes/client-go/commit/d0842249d3b92ea67c446fe273f84fe74ebaed9f + // for the relevant change. + p.logger.Logf("port forwarder: received error signal: %v. restarting session", pfErr) + p.Stop() + if err := p.Forward(ctx); err != nil { + p.logger.Logf("port forwarder: could not restart session: %v. retrying", err) + + select { + case <-ctx.Done(): + p.logger.Logf("port forwarder: keep alive cancelled: %v", ctx.Err()) + return + case <-time.After(time.Second): // todo: make configurable? + continue + } + } + } + } }