diff --git a/_examples/tunnel.yaml b/_examples/tunnel.yaml new file mode 100644 index 0000000..b72aea0 --- /dev/null +++ b/_examples/tunnel.yaml @@ -0,0 +1,7 @@ +apiVersion: core.apoxy.dev/v1alpha2 +kind: Tunnel +metadata: + name: example +spec: + egressGateway: + enabled: true \ No newline at end of file diff --git a/pkg/apiserver/controllers/tunnel_agent_reconciler.go b/pkg/apiserver/controllers/tunnel_agent_reconciler.go index f2497a5..124c00a 100644 --- a/pkg/apiserver/controllers/tunnel_agent_reconciler.go +++ b/pkg/apiserver/controllers/tunnel_agent_reconciler.go @@ -7,14 +7,12 @@ import ( "time" "github.com/go-logr/logr" - "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/util/retry" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - "sigs.k8s.io/controller-runtime/pkg/event" controllerlog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" @@ -28,6 +26,8 @@ import ( // +kubebuilder:rbac:groups=core.apoxy.dev/v1alpha2,resources=tunnelagents/finalizers,verbs=update // +kubebuilder:rbac:groups=core.apoxy.dev/v1alpha2,resources=tunnels,verbs=get;list;watch +const indexControllerOwnerUID = ".metadata.controllerOwnerUID" + type TunnelAgentReconciler struct { client client.Client agentIPAM tunnet.IPAM @@ -45,6 +45,8 @@ func NewTunnelAgentReconciler(c client.Client, agentIPAM tunnet.IPAM, vniPool *v func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { log := controllerlog.FromContext(ctx, "name", req.Name) + log.Info("Reconciling TunnelAgent") + var agent corev1alpha2.TunnelAgent if err := r.client.Get(ctx, req.NamespacedName, &agent); err != nil { if apierrors.IsNotFound(err) { @@ -53,16 +55,19 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, err } - // handle deletion + // Handle deletion if !agent.DeletionTimestamp.IsZero() { + log.Info("Handling deletion of TunnelAgent") + if controllerutil.ContainsFinalizer(&agent, ApiServerFinalizer) { + log.Info("Releasing resources for TunnelAgent") + changed, err := r.releaseResourcesIfPresent(ctx, log, req.NamespacedName) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to release resources: %w", err) } - // releaseResourcesIfPresent potentially mutates the object, so we need - // to refetch it to avoid conflicts when we remove the finalizer. + // Refetch to avoid conflicts if we modified the object if changed { if err := r.client.Get(ctx, req.NamespacedName, &agent); err != nil { if apierrors.IsNotFound(err) { @@ -72,6 +77,8 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) } } + log.Info("Removing finalizer from TunnelAgent") + // Remove finalizer controllerutil.RemoveFinalizer(&agent, ApiServerFinalizer) if err := r.client.Update(ctx, &agent); err != nil { @@ -82,7 +89,7 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, nil } - // ensure finalizer + // Ensure finalizer if !controllerutil.ContainsFinalizer(&agent, ApiServerFinalizer) { controllerutil.AddFinalizer(&agent, ApiServerFinalizer) if err := r.client.Update(ctx, &agent); err != nil { @@ -90,7 +97,7 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) } } - // fetch owner Tunnel + // Fetch owner Tunnel tunnelName := agent.Spec.TunnelRef.Name if tunnelName == "" { // TODO: why would this happen? Should we mark the agent as failed. @@ -98,17 +105,18 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, nil } + log.Info("Fetching owner Tunnel", "tunnelName", tunnelName) + var tunnel corev1alpha2.Tunnel if err := r.client.Get(ctx, client.ObjectKey{Name: tunnelName}, &tunnel); err != nil { if apierrors.IsNotFound(err) { - // TODO: why would this happen? Should we mark the agent as failed. log.Info("Referenced Tunnel not found; skipping", "tunnelName", tunnelName) return ctrl.Result{RequeueAfter: 30 * time.Second}, nil } return ctrl.Result{}, err } - // ensure controller ownerRef agent -> tunnel + // Ensure controller ownerRef agent -> tunnel changed, err := r.ensureControllerOwner(&agent, &tunnel) if err != nil { return ctrl.Result{}, err @@ -128,24 +136,26 @@ func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) } func (r *TunnelAgentReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { - // Reconcile when spec generation changes OR when status (e.g., Connections) changes. - statusOrGenChanged := predicate.Funcs{ - CreateFunc: func(e event.CreateEvent) bool { return true }, - DeleteFunc: func(e event.DeleteEvent) bool { return true }, - UpdateFunc: func(e event.UpdateEvent) bool { - oldObj, ok1 := e.ObjectOld.(*corev1alpha2.TunnelAgent) - newObj, ok2 := e.ObjectNew.(*corev1alpha2.TunnelAgent) - if !ok1 || !ok2 { - return false + // Cache index to quickly look up TunnelAgents by their controller owner UID. + if err := mgr.GetFieldIndexer().IndexField( + ctx, + &corev1alpha2.TunnelAgent{}, + indexControllerOwnerUID, + func(obj client.Object) []string { + ta := obj.(*corev1alpha2.TunnelAgent) + for _, or := range ta.GetOwnerReferences() { + if or.Controller != nil && *or.Controller { + return []string{string(or.UID)} + } } - genChanged := oldObj.GetGeneration() != newObj.GetGeneration() - statusDiff := !equality.Semantic.DeepEqual(oldObj.Status, newObj.Status) - return genChanged || statusDiff + return nil }, + ); err != nil { + return err } return ctrl.NewControllerManagedBy(mgr). - For(&corev1alpha2.TunnelAgent{}, builder.WithPredicates(statusOrGenChanged)). + For(&corev1alpha2.TunnelAgent{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{})). Complete(r) } @@ -182,11 +192,10 @@ func (r *TunnelAgentReconciler) ensureConnectionAllocations( } conn.Address = pfx.String() newlyAllocatedPrefixes = append(newlyAllocatedPrefixes, pfx) - log.Info("Allocated overlay address", "connectionID", conn.ID, "address", conn.Address) } - // Allocate VNI if missing (nil means "unset"; zero can be valid but your pool won't return 0) + // Allocate VNI if missing if conn.VNI == nil { vni, err := r.vniPool.Allocate() if err != nil { @@ -201,13 +210,11 @@ func (r *TunnelAgentReconciler) ensureConnectionAllocations( } conn.VNI = &vni newlyAllocatedVNIs = append(newlyAllocatedVNIs, vni) - log.Info("Allocated VNI", "connectionID", conn.ID, "vni", *conn.VNI) } } if len(newlyAllocatedPrefixes) == 0 && len(newlyAllocatedVNIs) == 0 { - // nothing changed return nil } @@ -234,13 +241,11 @@ func (r *TunnelAgentReconciler) releaseResourcesIfPresent( ) (bool, error) { var changed bool err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { - // Always work on a fresh copy to avoid write conflicts. var cur corev1alpha2.TunnelAgent if err := r.client.Get(ctx, key, &cur); err != nil { if apierrors.IsNotFound(err) { return nil } - return err } @@ -249,7 +254,6 @@ func (r *TunnelAgentReconciler) releaseResourcesIfPresent( return nil } - // Free resources that are still recorded in status and clear the fields. for i := range cur.Status.Connections { conn := &cur.Status.Connections[i] @@ -263,7 +267,7 @@ func (r *TunnelAgentReconciler) releaseResourcesIfPresent( return fmt.Errorf("failed to release address %q: %w", conn.Address, err) } log.Info("Released overlay address", "connectionID", conn.ID, "address", conn.Address) - conn.Address = "" // clear in status + conn.Address = "" changed = true } @@ -272,18 +276,16 @@ func (r *TunnelAgentReconciler) releaseResourcesIfPresent( vni := *conn.VNI r.vniPool.Release(vni) log.Info("Released VNI", "connectionID", conn.ID, "vni", vni) - conn.VNI = nil // clear in status + conn.VNI = nil changed = true } } - // Commit to status. if changed { if err := r.client.Status().Update(ctx, &cur); err != nil { return err } } - return nil }) @@ -297,14 +299,8 @@ func (r *TunnelAgentReconciler) ensureControllerOwner(child client.Object, owner } } - // Set controller reference (overwrites any existing controller owner) - if err := controllerutil.SetControllerReference( - owner, - child, - r.client.Scheme(), - ); err != nil { + if err := controllerutil.SetControllerReference(owner, child, r.client.Scheme()); err != nil { return false, err } - return true, nil } diff --git a/pkg/apiserver/controllers/tunnel_reconciler.go b/pkg/apiserver/controllers/tunnel_reconciler.go index 6e16414..bdb41e1 100644 --- a/pkg/apiserver/controllers/tunnel_reconciler.go +++ b/pkg/apiserver/controllers/tunnel_reconciler.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "io" + "time" apierrors "k8s.io/apimachinery/pkg/api/errors" ctrl "sigs.k8s.io/controller-runtime" @@ -39,13 +40,49 @@ func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr if apierrors.IsNotFound(err) { return ctrl.Result{}, nil } - return ctrl.Result{}, err } - // handle deletion + // Handle deletion. if !tunnel.DeletionTimestamp.IsZero() { + log.Info("Handling deletion of Tunnel") + if controllerutil.ContainsFinalizer(&tunnel, ApiServerFinalizer) { + // Manually implement garbage collection of controller-owned TunnelAgents. + // This is due to us not using the built in gc controller from k8s.io/controller-manager. + + // List controller-owned TunnelAgents by indexed controller owner UID. + var agents corev1alpha2.TunnelAgentList + if err := r.client.List( + ctx, + &agents, + client.MatchingFields{indexControllerOwnerUID: string(tunnel.GetUID())}, + ); err != nil { + return ctrl.Result{}, err + } + + // Kick off deletion for any children that still exist. + stillPresent := false + for i := range agents.Items { + a := &agents.Items[i] + stillPresent = true + if a.DeletionTimestamp.IsZero() { + if err := r.client.Delete(ctx, a); err != nil && !apierrors.IsNotFound(err) { + return ctrl.Result{}, err + } + } + } + + // If any child remains (possibly terminating due to its own finalizers), + // requeue and keep the parent's finalizer to emulate foreground deletion. + if stillPresent { + log.Info("Waiting for controller-owned TunnelAgents to terminate", "remaining", len(agents.Items)) + return ctrl.Result{RequeueAfter: 2 * time.Second}, nil + } + + // No children remain → remove the parent's finalizer. + log.Info("All controller-owned TunnelAgents gone; removing Tunnel finalizer") + // Remove finalizer controllerutil.RemoveFinalizer(&tunnel, ApiServerFinalizer) if err := r.client.Update(ctx, &tunnel); err != nil { @@ -56,7 +93,7 @@ func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr return ctrl.Result{}, nil } - // ensure finalizer + // Ensure finalizer. if !controllerutil.ContainsFinalizer(&tunnel, ApiServerFinalizer) { controllerutil.AddFinalizer(&tunnel, ApiServerFinalizer) if err := r.client.Update(ctx, &tunnel); err != nil { @@ -64,7 +101,7 @@ func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr } } - // ensure bearer token + // Ensure bearer token in status. if tunnel.Status.Credentials == nil || tunnel.Status.Credentials.Token == "" { log.Info("Generating new bearer token for Tunnel") @@ -88,7 +125,7 @@ func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr func (r *TunnelReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). - For(&corev1alpha2.Tunnel{}, builder.WithPredicates(predicate.GenerationChangedPredicate{})). + For(&corev1alpha2.Tunnel{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{})). Complete(r) } diff --git a/pkg/apiserver/manager.go b/pkg/apiserver/manager.go index 84586a8..cc90323 100644 --- a/pkg/apiserver/manager.go +++ b/pkg/apiserver/manager.go @@ -49,6 +49,7 @@ import ( "github.com/apoxy-dev/apoxy/pkg/log" apoxynet "github.com/apoxy-dev/apoxy/pkg/tunnel/net" tunnet "github.com/apoxy-dev/apoxy/pkg/tunnel/net" + "github.com/apoxy-dev/apoxy/pkg/tunnel/vni" ctrlv1alpha1 "github.com/apoxy-dev/apoxy/api/controllers/v1alpha1" corev1alpha "github.com/apoxy-dev/apoxy/api/core/v1alpha" @@ -186,6 +187,7 @@ type options struct { resources []resource.Object proxyIPAM tunnet.IPAM agentIPAM tunnet.IPAM + vniPool *vni.VNIPool } // WithJWTKeys sets the JWT key pair. @@ -317,6 +319,13 @@ func WithAgentIPAM(ipam tunnet.IPAM) Option { } } +// WithVNIPool sets the VNI pool for tunnel agents. +func WithVNIPool(pool *vni.VNIPool) Option { + return func(o *options) { + o.vniPool = pool + } +} + func defaultResources() []resource.Object { // Higher versions need to be registered first as storage resources. return []resource.Object{ @@ -383,6 +392,8 @@ func defaultOptions(ctx context.Context) (*options, error) { return nil, fmt.Errorf("failed to create proxy IPAM: %w", err) } + vniPool := vni.NewVNIPool() + opts := &options{ clientConfig: NewClientConfig(), @@ -406,6 +417,7 @@ func defaultOptions(ctx context.Context) (*options, error) { proxyIPAM: proxyIPAM, agentIPAM: agentIPAM, + vniPool: vniPool, } // Generate default JWT key pair if not provided @@ -491,6 +503,7 @@ func (m *Manager) Start( return fmt.Errorf("failed to set up Proxy controller: %v", err) } + // Legacy v1alpha1 TunnelNode controller log.Infof("Registering TunnelNode controller") tunnelNodeReconciler := controllers.NewTunnelNodeReconciler( m.manager.GetClient(), @@ -513,6 +526,22 @@ func (m *Manager) Start( return nil }) + log.Infof("Registering Tunnel controller") + tunnelReconciler := controllers.NewTunnelReconciler(m.manager.GetClient()) + if err := tunnelReconciler.SetupWithManager(m.manager); err != nil { + return fmt.Errorf("failed to set up Tunnel controller: %v", err) + } + + log.Infof("Registering TunnelAgent controller") + tunnelAgentReconciler := controllers.NewTunnelAgentReconciler( + m.manager.GetClient(), + dOpts.agentIPAM, + dOpts.vniPool, + ) + if err := tunnelAgentReconciler.SetupWithManager(ctx, m.manager); err != nil { + return fmt.Errorf("failed to set up TunnelAgent controller: %v", err) + } + log.Infof("Registering Gateway controller") gwOpts := []gateway.Option{} if dOpts.enableKubeAPI { diff --git a/pkg/cmd/alpha/tunnel_relay.go b/pkg/cmd/alpha/tunnel_relay.go index 27b0e52..3d6ebf2 100644 --- a/pkg/cmd/alpha/tunnel_relay.go +++ b/pkg/cmd/alpha/tunnel_relay.go @@ -1,43 +1,59 @@ package alpha import ( + "bytes" "context" + "crypto/tls" "errors" "fmt" "log/slog" "net" - "net/netip" - "runtime" + "os" + goruntime "runtime" + "time" - "github.com/alphadose/haxmap" + "github.com/go-logr/logr" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/healthz" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" "github.com/apoxy-dev/icx" + corev1alpha2 "github.com/apoxy-dev/apoxy/api/core/v1alpha2" "github.com/apoxy-dev/apoxy/pkg/cryptoutils" "github.com/apoxy-dev/apoxy/pkg/tunnel" "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" - tunnet "github.com/apoxy-dev/apoxy/pkg/tunnel/net" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" - "github.com/apoxy-dev/apoxy/pkg/tunnel/vni" ) var ( + relayDevMode bool // whether to run in development mode (testing only) relayName string // the name for the relay - relayTunnel string // the name of the tunnel to serve extIfaceName string // the external interface name listenAddress string // the address to listen on for incoming connections userMode bool // whether to use user-mode routing (no special privileges required) relaySocksListenAddr string // when using user-mode routing, the address to listen on for SOCKS5 connections relayPcapPath string // optional pcap path + certFile string // path to TLS certificate (PEM) used when not in dev mode + keyFile string // path to TLS private key (PEM) used when not in dev mode + idSecretFile string // path to secret for the ID hasher used when not in dev mode + relayMetricsAddr string // bind address for the metrics endpoint + relayHealthAddr string // bind address for the health/ready probes + labelSelector string // label selector for controllers ) var tunnelRelayCmd = &cobra.Command{ Use: "relay", - Short: "Run a development mode tunnel relay", + Short: "Run a tunnel relay", Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { routerOpts := []router.Option{ @@ -85,92 +101,103 @@ var tunnelRelayCmd = &cobra.Command{ handler = r.Handler } - idHasher := hasher.NewHasher([]byte("C0rr3ct-Horse-Battery-Staple_But_Salty_1x9Q7p3Z")) + var ( + idHasher *hasher.Hasher + cert tls.Certificate + ) - _, cert, err := cryptoutils.GenerateSelfSignedTLSCert(relayName) - if err != nil { - return fmt.Errorf("failed to generate self-signed TLS cert: %w", err) + // Use a self-signed cert and a fixed hasher secret in dev mode. + if relayDevMode { + idHasher = hasher.NewHasher([]byte("C0rr3ct-Horse-Battery-Staple_But_Salty_1x9Q7p3Z")) + + _, c, err := cryptoutils.GenerateSelfSignedTLSCert(relayName) + if err != nil { + return fmt.Errorf("failed to generate self-signed TLS cert: %w", err) + } + cert = c + } else { + if idSecretFile == "" { + return fmt.Errorf("when not in development mode, --id-secret-file is required") + } + if certFile == "" || keyFile == "" { + return fmt.Errorf("when not in development mode, both --cert-file and --key-file are required") + } + + secret, err := os.ReadFile(idSecretFile) + if err != nil { + return fmt.Errorf("failed to read id hasher secret file: %w", err) + } + idHasher = hasher.NewHasher(bytes.TrimSpace(secret)) + + c, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("failed to load TLS certificate/key pair: %w", err) + } + cert = c } relay := tunnel.NewRelay(relayName, pcQuic, cert, handler, idHasher, rtr) - slog.Info("Configuring relay", slog.String("tunnelName", relayTunnel), slog.String("listenAddress", listenAddress), slog.String("externalInterface", extIfaceName)) + g, ctx := errgroup.WithContext(cmd.Context()) - relay.SetCredentials(relayTunnel, "letmein") - relay.SetRelayAddresses(relayTunnel, []string{pcQuic.LocalAddr().String()}) - relay.SetEgressGateway(true) + clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + clientcmd.NewDefaultClientConfigLoadingRules(), + &clientcmd.ConfigOverrides{}, + ) - systemULA := tunnet.NewULA(cmd.Context(), tunnet.SystemNetworkID) - agentIPAM, err := systemULA.IPAM(cmd.Context(), 96) + config, err := clientConfig.ClientConfig() if err != nil { - return fmt.Errorf("failed to create system ULA IPAM: %w", err) + return fmt.Errorf("loading kubeconfig: %w", err) } - vpool := vni.NewVNIPool() - - type connectionMetadata struct { - prefix netip.Prefix - vni uint + scheme := runtime.NewScheme() + if err := corev1alpha2.Install(scheme); err != nil { + return fmt.Errorf("installing corev1alpha2 scheme: %w", err) } - connections := haxmap.New[string, connectionMetadata]() - - relay.SetOnConnect(func(_ context.Context, agentName string, conn controllers.Connection) error { - slog.Info("Connected", slog.String("agent", agentName), slog.String("connID", conn.ID())) - - pfx, err := agentIPAM.Allocate() - if err != nil { - return fmt.Errorf("failed to allocate prefix: %w", err) - } - slog.Info("Allocated prefix for connection", - slog.String("agent", agentName), slog.String("connID", conn.ID()), - slog.String("prefix", pfx.String())) + ctrl.SetLogger(logr.FromSlogHandler(slog.Default().Handler())) - if err := conn.SetOverlayAddress(pfx.String()); err != nil { - agentIPAM.Release(pfx) - return fmt.Errorf("failed to set overlay address on connection: %w", err) - } + mgr, err := ctrl.NewManager(config, ctrl.Options{ + Cache: cache.Options{ + SyncPeriod: ptr.To(30 * time.Second), + }, + Scheme: scheme, + LeaderElection: false, + Metrics: metricsserver.Options{BindAddress: relayMetricsAddr}, + HealthProbeBindAddress: relayHealthAddr, + }) + if err != nil { + return fmt.Errorf("unable to start manager: %w", err) + } - vni, err := vpool.Allocate() - if err != nil { - return fmt.Errorf("failed to allocate VNI: %w", err) - } + if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { + return fmt.Errorf("failed to add healthz check: %w", err) + } - slog.Info("Allocated VNI for connection", - slog.String("agent", agentName), slog.String("connID", conn.ID()), - slog.Int("vni", int(vni))) + if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil { + return fmt.Errorf("failed to add readyz check: %w", err) + } - if err := conn.SetVNI(cmd.Context(), vni); err != nil { - return fmt.Errorf("failed to set VNI on connection: %w", err) - } + tunnelReconciler := controllers.NewTunnelReconciler(mgr.GetClient(), relay, labelSelector) + if err := tunnelReconciler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("failed to setup tunnel reconciler: %w", err) + } - connections.Set(conn.ID(), connectionMetadata{prefix: pfx, vni: vni}) + tunnelAgentReconciler := controllers.NewTunnelAgentReconciler(mgr.GetClient(), relay, labelSelector) + if err := tunnelAgentReconciler.SetupWithManager(mgr); err != nil { + return fmt.Errorf("failed to setup tunnel agent reconciler: %w", err) + } - return nil + g.Go(func() error { + return mgr.Start(ctx) }) - relay.SetOnDisconnect(func(_ context.Context, agentName, id string) error { - if cm, ok := connections.Get(id); ok { - if err := agentIPAM.Release(cm.prefix); err != nil { - slog.Error("Failed to release prefix", slog.Any("error", err), - slog.String("agent", agentName), slog.String("connID", id), - slog.String("prefix", cm.prefix.String())) - } - - vpool.Release(cm.vni) - - connections.Del(id) - - slog.Info("Disconnected", slog.String("agent", agentName), slog.String("connID", id)) - } else { - return fmt.Errorf("unknown connection ID: %s", id) - } - - return nil + g.Go(func() error { + return relay.Start(ctx) }) - if err := relay.Start(cmd.Context()); err != nil && !errors.Is(err, context.Canceled) { - return fmt.Errorf("failed to start relay: %w", err) + if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("failed to run relay: %w", err) } return nil @@ -179,12 +206,18 @@ var tunnelRelayCmd = &cobra.Command{ func init() { tunnelRelayCmd.Flags().StringVarP(&relayName, "name", "n", "dev", "The name of the relay.") - tunnelRelayCmd.Flags().StringVarP(&relayTunnel, "tunnel-name", "t", "dev", "The name of the tunnel to serve.") - tunnelRelayCmd.Flags().StringVar(&extIfaceName, "ext-iface", "eth0", "External interface name.") + tunnelRelayCmd.Flags().BoolVar(&relayDevMode, "dev", false, "Run the relay in development mode (insecure).") + tunnelRelayCmd.Flags().StringVar(&extIfaceName, "ext-iface", "eth0", "External interface name (when not using --user-mode).") tunnelRelayCmd.Flags().StringVar(&listenAddress, "listen-addr", "127.0.0.1:6081", "The address to listen on for incoming connections.") - tunnelRelayCmd.Flags().BoolVar(&userMode, "user-mode", runtime.GOOS != "linux", "Use user-mode routing (no special privileges required).") + tunnelRelayCmd.Flags().BoolVar(&userMode, "user-mode", goruntime.GOOS != "linux", "Use user-mode routing (no special privileges required).") tunnelRelayCmd.Flags().StringVar(&relaySocksListenAddr, "socks-addr", "localhost:1080", "When using user-mode routing, the address to listen on for SOCKS5 connections.") tunnelRelayCmd.Flags().StringVarP(&relayPcapPath, "pcap", "p", "", "Path to an optional packet capture file to write.") + tunnelRelayCmd.Flags().StringVar(&certFile, "cert-file", "", "Path to a TLS certificate (PEM). Required when not running with --dev.") + tunnelRelayCmd.Flags().StringVar(&keyFile, "key-file", "", "Path to a TLS private key (PEM). Required when not running with --dev.") + tunnelRelayCmd.Flags().StringVar(&idSecretFile, "id-secret-file", "", "Path to the secret used for stable ID hashing. Required when not running with --dev.") + tunnelRelayCmd.Flags().StringVar(&relayMetricsAddr, "metrics-addr", "127.0.0.1:8081", "Bind address for the metrics endpoint.") + tunnelRelayCmd.Flags().StringVar(&relayHealthAddr, "health-addr", "127.0.0.1:8080", "Bind address for the health and readiness probes.") + tunnelRelayCmd.Flags().StringVar(&labelSelector, "label-selector", "", "Label selector to filter Tunnel and TunnelAgent objects (e.g. 'app=apoxy').") tunnelCmd.AddCommand(tunnelRelayCmd) } diff --git a/pkg/cmd/alpha/tunnel_run.go b/pkg/cmd/alpha/tunnel_run.go index a951282..97d1f6d 100644 --- a/pkg/cmd/alpha/tunnel_run.go +++ b/pkg/cmd/alpha/tunnel_run.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "math/rand/v2" "net" "net/http" "net/netip" @@ -19,15 +20,17 @@ import ( "github.com/dpeckett/network" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/tools/clientcmd" + "github.com/apoxy-dev/apoxy/client/versioned" "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" - "github.com/apoxy-dev/apoxy/pkg/tunnel/randalloc" - "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/conntrackpc" + "github.com/apoxy-dev/apoxy/pkg/tunnel/randalloc" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" ) @@ -61,6 +64,40 @@ var tunnelRunCmd = &cobra.Command{ return fmt.Errorf("--min-conns must be at least 1") } + // Attempt kubernetes-based discovery if no relayAddr/token provided. + if seedRelayAddr == "" || token == "" { + clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + clientcmd.NewDefaultClientConfigLoadingRules(), + &clientcmd.ConfigOverrides{}, + ) + + config, err := clientConfig.ClientConfig() + if err != nil { + return fmt.Errorf("loading kubeconfig: %w", err) + } + + clientset, err := versioned.NewForConfig(config) + if err != nil { + return fmt.Errorf("creating clientset: %w", err) + } + + tunnel, err := clientset.CoreV1alpha2().Tunnels().Get(cmd.Context(), tunnelName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("fetching Tunnel %q: %w", tunnelName, err) + } + + if len(tunnel.Status.Addresses) == 0 { + return fmt.Errorf("tunnel %q has no relay addresses configured", tunnelName) + } + + if tunnel.Status.Credentials == nil { + return fmt.Errorf("tunnel %q has no credentials configured", tunnelName) + } + + seedRelayAddr = tunnel.Status.Addresses[rand.IntN(len(tunnel.Status.Addresses))] + token = tunnel.Status.Credentials.Token + } + g, ctx := errgroup.WithContext(cmd.Context()) // Start health endpoint server if configured. @@ -131,7 +168,7 @@ var tunnelRunCmd = &cobra.Command{ // - when the session ends, releases the relay for i := 0; i < minConns; i++ { g.Go(func() error { - return manageConnectionSlot(ctx, packetPlane.QuicMux, handler, relayAddressPool, tlsConf) + return manageConnectionSlot(ctx, packetPlane.QuicMux, handler, r, relayAddressPool, tlsConf) }) } @@ -142,9 +179,9 @@ var tunnelRunCmd = &cobra.Command{ func init() { tunnelRunCmd.Flags().StringVarP(&agentName, "agent", "a", "", "The name of this agent.") tunnelRunCmd.Flags().StringVarP(&tunnelName, "name", "n", "", "The logical name of the tunnel to connect to.") - tunnelRunCmd.Flags().StringVarP(&seedRelayAddr, "relay-addr", "r", "", "Seed relay address (host:port). The client bootstraps here, then uses the returned relay list.") + tunnelRunCmd.Flags().StringVarP(&seedRelayAddr, "relay-addr", "r", "", "Seed relay address (host:port), required if not using kubernetes-based discovery.") tunnelRunCmd.Flags().IntVar(&minConns, "min-conns", 1, "Minimum number of relays to maintain connections to (randomly selected from the server-provided list).") - tunnelRunCmd.Flags().StringVarP(&token, "token", "k", "", "The token to use for authenticating with the tunnel relays.") + tunnelRunCmd.Flags().StringVarP(&token, "token", "k", "", "The token to use for authenticating with the tunnel relays, required if not using kubernetes-based discovery.") tunnelRunCmd.Flags().BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification for relay connections.") tunnelRunCmd.Flags().StringVarP(&pcapPath, "pcap", "p", "", "Path to an optional packet capture file to write.") tunnelRunCmd.Flags().StringVar(&socksListenAddr, "socks-addr", "localhost:1080", "Listen address for SOCKS proxy.") @@ -152,8 +189,6 @@ func init() { cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("agent")) cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("name")) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("relay-addr")) - cobra.CheckErr(tunnelRunCmd.MarkFlagRequired("token")) tunnelCmd.AddCommand(tunnelRunCmd) } @@ -320,25 +355,6 @@ func initRouter( h := r.Handler - // Add assigned addresses. - for _, addrStr := range connectResp.Addresses { - slog.Info("Adding address", slog.String("address", addrStr)) - - addr, err := netip.ParsePrefix(addrStr) - if err != nil { - slog.Warn("Failed to parse address", - slog.String("address", addrStr), - slog.Any("error", err)) - continue - } - - if err := r.AddAddr(addr, nil); err != nil { - slog.Warn("Failed to add address", - slog.String("address", addrStr), - slog.Any("error", err)) - } - } - // Add routes. for _, rt := range connectResp.Routes { slog.Info("Adding route", slog.String("destination", rt.Destination)) @@ -462,6 +478,7 @@ func manageRelayConnectionOnce( ctx context.Context, pcQuic net.PacketConn, handler *icx.Handler, + r *router.ICXNetstackRouter, relayAddr string, tlsConf *tls.Config, onConnected func(*api.ConnectResponse), @@ -470,11 +487,22 @@ func manageRelayConnectionOnce( currentClient *api.Client currentConnID string connectResp *api.ConnectResponse + sessionAddrs []netip.Prefix ) // When this function returns, that relay session is down, so decrement // if we had actually marked it active. defer func() { + // Remove any addrs we attached for this session. + for _, a := range sessionAddrs { + if err := r.DelAddr(a); err != nil { + slog.Warn("Failed to remove address on disconnect", + slog.String("address", a.String()), + slog.Any("error", err)) + } else { + slog.Info("Removed address", slog.String("address", a.String())) + } + } if currentConnID != "" { connectionHealthCounter.Add(-1) } @@ -521,6 +549,25 @@ func manageRelayConnectionOnce( // Successful session establishment: mark this connection active. connectionHealthCounter.Add(1) + // Parse and attach the assigned addresses for this live session. + if connectResp != nil { + addrs, err := parsePrefixes(connectResp.Addresses) + if err != nil { + slog.Warn("Failed to parse assigned addresses", slog.Any("error", err)) + } else { + sessionAddrs = addrs + for _, a := range addrs { + if err := r.AddAddr(a, nil); err != nil { + slog.Warn("Failed to add address", + slog.String("address", a.String()), + slog.Any("error", err)) + } else { + slog.Info("Added address", slog.String("address", a.String())) + } + } + } + } + // Once connected, run key rotation and watchdog concurrently. sessionCtx, sessionCancel := context.WithCancel(ctx) defer sessionCancel() @@ -576,6 +623,7 @@ func manageConnectionSlot( ctx context.Context, pcQuicMux *conntrackpc.ConntrackPacketConn, handler *icx.Handler, + r *router.ICXNetstackRouter, relayAddressPool *randalloc.RandAllocator[string], tlsConf *tls.Config, ) error { @@ -643,7 +691,7 @@ func manageConnectionSlot( } // Run the actual session lifecycle (watchdog, key rotation, etc). - sessErr := manageRelayConnectionOnce(ctx, pcQuic, handler, relayAddr, tlsConf, onConnected) + sessErr := manageRelayConnectionOnce(ctx, pcQuic, handler, r, relayAddr, tlsConf, onConnected) if ctx.Err() != nil { return ctx.Err() diff --git a/pkg/cmd/alpha/tunnel_run_test.go b/pkg/cmd/alpha/tunnel_run_test.go index 9e8165c..93089d2 100644 --- a/pkg/cmd/alpha/tunnel_run_test.go +++ b/pkg/cmd/alpha/tunnel_run_test.go @@ -33,12 +33,12 @@ func TestTunnelRun(t *testing.T) { var connected bool // onConnect assigns VNI and overlay address so handleConnect can proceed. - onConnect := func(ctx context.Context, agent string, conn controllers.Connection) error { + onConnect := func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error { // Choose a deterministic VNI for the test. conn.SetVNI(ctx, 101) conn.SetOverlayAddress("10.0.0.2/32") - t.Logf("onConnect called, agent=%s", agent) - if agent == "test-agent" { + t.Logf("onConnect called, agent=%s", agentName) + if agentName == "test-agent" { connected = true } return nil @@ -60,7 +60,7 @@ func TestTunnelRun(t *testing.T) { "alpha", "tunnel", "run", "--agent", "test-agent", "--name", "test-tunnel", - "--relay-addr", r.Address(), + "--relay-addr", r.Address().String(), "--token", "letmein", "--insecure-skip-verify", }) @@ -76,7 +76,7 @@ func TestTunnelRun(t *testing.T) { // TODO: verify traffic routing through the tunnel } -func startRelay(t *testing.T, token string, onConnect func(context.Context, string, controllers.Connection) error, onDisconnect func(context.Context, string, string) error) (*tunnel.Relay, tls.Certificate, func()) { +func startRelay(t *testing.T, token string, onConnect func(context.Context, string, string, controllers.Connection) error, onDisconnect func(context.Context, string, string) error) (*tunnel.Relay, tls.Certificate, func()) { t.Helper() pc, err := net.ListenPacket("udp", "127.0.0.1:0") diff --git a/pkg/tunnel/controllers/relay.go b/pkg/tunnel/controllers/relay.go index 44d0298..31594ba 100644 --- a/pkg/tunnel/controllers/relay.go +++ b/pkg/tunnel/controllers/relay.go @@ -18,7 +18,7 @@ type Relay interface { // SetEgressGateway enables or disables internet egress for the tunnel agents. SetEgressGateway(enabled bool) // SetOnConnect sets a callback that is invoked when a new connection is established to the relay. - SetOnConnect(onConnect func(ctx context.Context, agentName string, conn Connection) error) + SetOnConnect(onConnect func(ctx context.Context, tunnelName, agentName string, conn Connection) error) // SetOnDisconnect sets a callback that is invoked when a connection is closed. SetOnDisconnect(onDisconnect func(ctx context.Context, agentName, id string) error) } diff --git a/pkg/tunnel/controllers/tunnel_agent_reconciler.go b/pkg/tunnel/controllers/tunnel_agent_reconciler.go index 7390bac..6a86722 100644 --- a/pkg/tunnel/controllers/tunnel_agent_reconciler.go +++ b/pkg/tunnel/controllers/tunnel_agent_reconciler.go @@ -6,7 +6,6 @@ import ( "log/slog" "github.com/alphadose/haxmap" - "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -15,13 +14,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - "sigs.k8s.io/controller-runtime/pkg/event" + controllerlog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" corev1alpha2 "github.com/apoxy-dev/apoxy/api/core/v1alpha2" ) -const tunnelRelayFinalizerTmpl = "tunnelrelay.apoxy.dev/%s/finalizer" +const tunnelRelayFinalizerTmpl = "tunnelrelay.apoxy.dev/%s-finalizer" type TunnelAgentReconciler struct { client client.Client @@ -45,6 +44,10 @@ func NewTunnelAgentReconciler(c client.Client, relay Relay, labelSelector string } func (r *TunnelAgentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := controllerlog.FromContext(ctx, "name", req.Name) + + log.Info("Reconciling TunnelAgent") + var agent corev1alpha2.TunnelAgent if err := r.client.Get(ctx, req.NamespacedName, &agent); err != nil { if apierrors.IsNotFound(err) { @@ -105,37 +108,47 @@ func (r *TunnelAgentReconciler) SetupWithManager(mgr ctrl.Manager) error { return fmt.Errorf("failed to create label selector predicate: %w", err) } - // Reconcile when spec generation changes OR when status (e.g., Connections) changes. - statusOrGenChanged := predicate.Funcs{ - CreateFunc: func(e event.CreateEvent) bool { return true }, - DeleteFunc: func(e event.DeleteEvent) bool { return true }, - UpdateFunc: func(e event.UpdateEvent) bool { - oldObj, ok1 := e.ObjectOld.(*corev1alpha2.TunnelAgent) - newObj, ok2 := e.ObjectNew.(*corev1alpha2.TunnelAgent) - if !ok1 || !ok2 { - return false - } - genChanged := oldObj.GetGeneration() != newObj.GetGeneration() - statusDiff := !equality.Semantic.DeepEqual(oldObj.Status, newObj.Status) - return genChanged || statusDiff - }, - } - return ctrl.NewControllerManagedBy(mgr). - For(&corev1alpha2.TunnelAgent{}, builder.WithPredicates(ls, statusOrGenChanged)). + For(&corev1alpha2.TunnelAgent{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{}, ls)). Complete(r) } // AddConnection registers a new active connection for the given agent. -func (r *TunnelAgentReconciler) AddConnection(ctx context.Context, agentName string, conn Connection) error { +func (r *TunnelAgentReconciler) AddConnection(ctx context.Context, tunnelName, agentName string, conn Connection) error { // Track the connection in-memory. r.conns.Set(conn.ID(), conn) + // Get the parent Tunnel object. + var tunnel corev1alpha2.Tunnel + if err := r.client.Get(ctx, types.NamespacedName{Name: tunnelName}, &tunnel); err != nil { + return fmt.Errorf("failed to get parent Tunnel %q for TunnelAgent %q: %w", tunnelName, agentName, err) + } + // Upsert connection in status (first), so we truly have a connection before adding the finalizer. if err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { - var cur corev1alpha2.TunnelAgent + cur := corev1alpha2.TunnelAgent{ + ObjectMeta: metav1.ObjectMeta{ + Name: agentName, + Labels: tunnel.ObjectMeta.Labels, + }, + Spec: corev1alpha2.TunnelAgentSpec{ + TunnelRef: corev1alpha2.TunnelRef{ + Name: tunnelName, + }, + }, + } + if err := r.client.Get(ctx, types.NamespacedName{Name: agentName}, &cur); err != nil { - return err + if apierrors.IsNotFound(err) { + // Create minimal object if missing. + slog.Info("Creating TunnelAgent object", slog.String("agent", agentName)) + + if err := r.client.Create(ctx, &cur); err != nil { + return fmt.Errorf("failed to create TunnelAgent %q: %w", agentName, err) + } + } else { + return fmt.Errorf("failed to get TunnelAgent %q: %w", agentName, err) + } } now := metav1.Now() @@ -193,8 +206,7 @@ func (r *TunnelAgentReconciler) AddConnection(ctx context.Context, agentName str // RemoveConnection deregisters a connection from the given agent by its ID. func (r *TunnelAgentReconciler) RemoveConnection(ctx context.Context, agentName, id string) error { // Drop from in-memory map. - conn, ok := r.conns.GetAndDel(id) - if ok { + if conn, ok := r.conns.GetAndDel(id); ok { if err := conn.Close(); err != nil { slog.Warn("Failed to close connection", slog.String("id", id), slog.Any("error", err)) } @@ -204,6 +216,9 @@ func (r *TunnelAgentReconciler) RemoveConnection(ctx context.Context, agentName, if err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { var cur corev1alpha2.TunnelAgent if err := r.client.Get(ctx, types.NamespacedName{Name: agentName}, &cur); err != nil { + if apierrors.IsNotFound(err) { + return nil // already gone + } return err } @@ -221,12 +236,30 @@ func (r *TunnelAgentReconciler) RemoveConnection(ctx context.Context, agentName, } // If no connections remain for THIS relay, remove our relay-scoped finalizer. + // Additionally, if there are NO connections remaining at all, delete the TunnelAgent. return retry.RetryOnConflict(retry.DefaultBackoff, func() error { var cur corev1alpha2.TunnelAgent if err := r.client.Get(ctx, types.NamespacedName{Name: agentName}, &cur); err != nil { + if apierrors.IsNotFound(err) { + return nil + } return err } + // Check if any connections remain at all. + if len(cur.Status.Connections) == 0 { + // Ensure our finalizer (if present) is removed to avoid blocking deletion. + if controllerutil.ContainsFinalizer(&cur, r.finalizer) { + controllerutil.RemoveFinalizer(&cur, r.finalizer) + if err := r.client.Update(ctx, &cur); err != nil { + return err + } + } + // Delete the TunnelAgent object (ignore if it disappears between calls). + return client.IgnoreNotFound(r.client.Delete(ctx, &cur)) + } + + // Otherwise, only consider removing our finalizer if *this relay* no longer has any live connections. hasRelayConn := false for _, c := range cur.Status.Connections { if _, ok := r.conns.Get(c.ID); ok { diff --git a/pkg/tunnel/controllers/tunnel_agent_reconciler_test.go b/pkg/tunnel/controllers/tunnel_agent_reconciler_test.go index c091e24..516c94f 100644 --- a/pkg/tunnel/controllers/tunnel_agent_reconciler_test.go +++ b/pkg/tunnel/controllers/tunnel_agent_reconciler_test.go @@ -28,12 +28,14 @@ func TestTunnelAgentReconciler_AddConnection(t *testing.T) { scheme := runtime.NewScheme() require.NoError(t, corev1alpha2.Install(scheme)) - agent := mkAgent("agent-1") + tunnel := mkTunnel("tunnel-1") + + agent := mkAgent("tunnel-1", "agent-1") c := fakeclient.NewClientBuilder(). WithScheme(scheme). - WithStatusSubresource(&corev1alpha2.TunnelAgent{}). - WithObjects(agent). + WithStatusSubresource(&corev1alpha2.Tunnel{}, &corev1alpha2.TunnelAgent{}). + WithObjects(tunnel, agent). Build() relay := &mockRelay{} @@ -47,7 +49,7 @@ func TestTunnelAgentReconciler_AddConnection(t *testing.T) { conn := &mockConn{} conn.On("ID").Return("conn-123") - require.NoError(t, r.AddConnection(ctx, agent.Name, conn)) + require.NoError(t, r.AddConnection(ctx, tunnel.Name, agent.Name, conn)) var got corev1alpha2.TunnelAgent require.NoError(t, c.Get(ctx, types.NamespacedName{Name: agent.Name}, &got)) @@ -60,7 +62,7 @@ func TestTunnelAgentReconciler_AddConnection(t *testing.T) { assert.Nil(t, entry.VNI) assert.Equal(t, relay.Address().String(), entry.RelayAddress) - finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "/finalizer" + finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "-finalizer" assert.True(t, controllerutil.ContainsFinalizer(&got, finalizer)) conn.AssertExpectations(t) @@ -73,11 +75,14 @@ func TestTunnelAgentReconciler_RemoveConnection(t *testing.T) { scheme := runtime.NewScheme() require.NoError(t, corev1alpha2.Install(scheme)) - agent := mkAgent("agent-2") + tunnel := mkTunnel("tunnel-1") + + agent := mkAgent("tunnel-1", "agent-2") + c := fakeclient.NewClientBuilder(). WithScheme(scheme). - WithStatusSubresource(&corev1alpha2.TunnelAgent{}). - WithObjects(agent). + WithStatusSubresource(&corev1alpha2.Tunnel{}, &corev1alpha2.TunnelAgent{}). + WithObjects(tunnel, agent). Build() relay := &mockRelay{} @@ -87,7 +92,7 @@ func TestTunnelAgentReconciler_RemoveConnection(t *testing.T) { relay.On("SetOnDisconnect", mock.Anything).Return().Once() r := controllers.NewTunnelAgentReconciler(c, relay, "") - finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "/finalizer" + finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "-finalizer" // Two mock conns conn1 := &mockConn{} @@ -98,8 +103,8 @@ func TestTunnelAgentReconciler_RemoveConnection(t *testing.T) { conn2.On("ID").Return("c2").Maybe() conn2.On("Close").Return(nil).Once() - require.NoError(t, r.AddConnection(ctx, agent.Name, conn1)) - require.NoError(t, r.AddConnection(ctx, agent.Name, conn2)) + require.NoError(t, r.AddConnection(ctx, tunnel.Name, agent.Name, conn1)) + require.NoError(t, r.AddConnection(ctx, tunnel.Name, agent.Name, conn2)) // Remove first require.NoError(t, r.RemoveConnection(ctx, agent.Name, "c1")) @@ -108,11 +113,11 @@ func TestTunnelAgentReconciler_RemoveConnection(t *testing.T) { assert.Len(t, got.Status.Connections, 1) assert.True(t, controllerutil.ContainsFinalizer(&got, finalizer)) - // Remove second + // Remove second — the CR should be deleted now that no connections remain require.NoError(t, r.RemoveConnection(ctx, agent.Name, "c2")) - require.NoError(t, c.Get(ctx, types.NamespacedName{Name: agent.Name}, &got)) - assert.Empty(t, got.Status.Connections) - assert.False(t, controllerutil.ContainsFinalizer(&got, finalizer)) + err := c.Get(ctx, types.NamespacedName{Name: agent.Name}, &got) + require.Error(t, err) + assert.True(t, apierrors.IsNotFound(err)) relay.AssertExpectations(t) } @@ -123,11 +128,14 @@ func TestTunnelAgentReconciler_ClosesConnections(t *testing.T) { scheme := runtime.NewScheme() require.NoError(t, corev1alpha2.Install(scheme)) - agent := mkAgent("agent-3") + tunnel := mkTunnel("tunnel-1") + + agent := mkAgent("tunnel-1", "agent-3") + c := fakeclient.NewClientBuilder(). WithScheme(scheme). - WithStatusSubresource(&corev1alpha2.TunnelAgent{}). - WithObjects(agent). + WithStatusSubresource(&corev1alpha2.Tunnel{}, &corev1alpha2.TunnelAgent{}). + WithObjects(tunnel, agent). Build() relay := &mockRelay{} @@ -137,7 +145,7 @@ func TestTunnelAgentReconciler_ClosesConnections(t *testing.T) { relay.On("SetOnDisconnect", mock.Anything).Return().Once() r := controllers.NewTunnelAgentReconciler(c, relay, "") - finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "/finalizer" + finalizer := "tunnelrelay.apoxy.dev/" + relay.Name() + "-finalizer" // Mock conn that should be closed conn := &mockConn{} @@ -145,7 +153,7 @@ func TestTunnelAgentReconciler_ClosesConnections(t *testing.T) { conn.On("Close").Return(nil).Once() // Add connection -> status + in-memory tracking + finalizer - require.NoError(t, r.AddConnection(ctx, agent.Name, conn)) + require.NoError(t, r.AddConnection(ctx, tunnel.Name, agent.Name, conn)) // Ensure finalizer exists before deletion var cur corev1alpha2.TunnelAgent @@ -178,11 +186,14 @@ func TestTunnelAgentReconcile_SetsAddressAndVNI(t *testing.T) { scheme := runtime.NewScheme() require.NoError(t, corev1alpha2.Install(scheme)) - agent := mkAgent("agent-4") + tunnel := mkTunnel("tunnel-1") + + agent := mkAgent("tunnel-1", "agent-4") + c := fakeclient.NewClientBuilder(). WithScheme(scheme). - WithStatusSubresource(&corev1alpha2.TunnelAgent{}). - WithObjects(agent). + WithStatusSubresource(&corev1alpha2.Tunnel{}, &corev1alpha2.TunnelAgent{}). + WithObjects(tunnel, agent). Build() relay := &mockRelay{} @@ -197,7 +208,7 @@ func TestTunnelAgentReconcile_SetsAddressAndVNI(t *testing.T) { conn := &mockConn{} conn.On("ID").Return("live-1") - require.NoError(t, r.AddConnection(ctx, agent.Name, conn)) + require.NoError(t, r.AddConnection(ctx, tunnel.Name, agent.Name, conn)) // Simulate the apiserver reconciler filling status.address & vni var cur corev1alpha2.TunnelAgent @@ -210,7 +221,7 @@ func TestTunnelAgentReconcile_SetsAddressAndVNI(t *testing.T) { // Expect our live connection to receive SetOverlayAddress + SetVNI on reconcile conn.On("SetOverlayAddress", "10.123.0.5/32").Return(nil).Once() - conn.On("SetVNI", uint(4242)).Return(nil).Once() + conn.On("SetVNI", mock.Anything, uint(4242)).Return(nil).Once() _, err := r.Reconcile(ctx, ctrl.Request{NamespacedName: types.NamespacedName{Name: agent.Name}}) require.NoError(t, err) @@ -219,14 +230,31 @@ func TestTunnelAgentReconcile_SetsAddressAndVNI(t *testing.T) { relay.AssertExpectations(t) } -func mkAgent(name string) *corev1alpha2.TunnelAgent { +func mkTunnel(tunnelName string) *corev1alpha2.Tunnel { + return &corev1alpha2.Tunnel{ + TypeMeta: metav1.TypeMeta{ + Kind: "Tunnel", + APIVersion: "core.apoxy.dev/v1alpha2", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: tunnelName, + }, + } +} + +func mkAgent(tunnelName, agentName string) *corev1alpha2.TunnelAgent { return &corev1alpha2.TunnelAgent{ TypeMeta: metav1.TypeMeta{ Kind: "TunnelAgent", APIVersion: "core.apoxy.dev/v1alpha2", }, ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: agentName, + }, + Spec: corev1alpha2.TunnelAgentSpec{ + TunnelRef: corev1alpha2.TunnelRef{ + Name: tunnelName, + }, }, } } diff --git a/pkg/tunnel/controllers/tunnel_reconciler.go b/pkg/tunnel/controllers/tunnel_reconciler.go index 265e991..ce05461 100644 --- a/pkg/tunnel/controllers/tunnel_reconciler.go +++ b/pkg/tunnel/controllers/tunnel_reconciler.go @@ -11,6 +11,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" + controllerlog "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -32,6 +33,10 @@ func NewTunnelReconciler(c client.Client, relay Relay, labelSelector string) *Tu } func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := controllerlog.FromContext(ctx, "name", req.Name) + + log.Info("Reconciling Tunnel") + var tunnel corev1alpha2.Tunnel if err := r.client.Get(ctx, req.NamespacedName, &tunnel); err != nil { if apierrors.IsNotFound(err) { @@ -43,10 +48,13 @@ func (r *TunnelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr // Update relay credentials if they have changed. if tunnel.Status.Credentials != nil { + log.Info("Updating credentials for tunnel") + r.relay.SetCredentials(tunnel.Name, tunnel.Status.Credentials.Token) } // Update relay addresses if they have changed. + log.Info("Updating relay addresses for tunnel") r.relay.SetRelayAddresses(tunnel.Name, tunnel.Status.Addresses) // Update egress gateway setting @@ -91,6 +99,6 @@ func (r *TunnelReconciler) SetupWithManager(mgr ctrl.Manager) error { } return ctrl.NewControllerManagedBy(mgr). - For(&corev1alpha2.Tunnel{}, builder.WithPredicates(predicate.GenerationChangedPredicate{}, ls)). + For(&corev1alpha2.Tunnel{}, builder.WithPredicates(&predicate.ResourceVersionChangedPredicate{}, ls)). Complete(r) } diff --git a/pkg/tunnel/controllers/tunnel_reconciler_test.go b/pkg/tunnel/controllers/tunnel_reconciler_test.go index 7fa647f..1c97b65 100644 --- a/pkg/tunnel/controllers/tunnel_reconciler_test.go +++ b/pkg/tunnel/controllers/tunnel_reconciler_test.go @@ -87,7 +87,7 @@ func (m *mockRelay) SetEgressGateway(enabled bool) { m.Called(enabled) } -func (m *mockRelay) SetOnConnect(onConnect func(ctx context.Context, agentName string, conn controllers.Connection) error) { +func (m *mockRelay) SetOnConnect(onConnect func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error) { m.Called(onConnect) } diff --git a/pkg/tunnel/relay.go b/pkg/tunnel/relay.go index 8e24e11..a236a4a 100644 --- a/pkg/tunnel/relay.go +++ b/pkg/tunnel/relay.go @@ -30,7 +30,9 @@ import ( ) const ( - keyLifespan = 24 * time.Hour + keyLifespan = 24 * time.Hour + gcMaxSilence = 120 * time.Second + gcCheckInterval = 5 * time.Second ) type Relay struct { @@ -45,7 +47,8 @@ type Relay struct { tokens *haxmap.Map[string, string] // map[tunnelName]token relayAddrs *haxmap.Map[string, []string] // map[tunnelName][]string conns *haxmap.Map[string, *connection] // map[connectionID]Connection - onConnect func(ctx context.Context, agentName string, conn controllers.Connection) error + agents *haxmap.Map[string, string] // map[connectionID]agentName + onConnect func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error onDisconnect func(ctx context.Context, agentName, id string) error } @@ -60,6 +63,7 @@ func NewRelay(name string, pc net.PacketConn, cert tls.Certificate, handler *icx tokens: haxmap.New[string, string](), relayAddrs: haxmap.New[string, []string](), conns: haxmap.New[string, *connection](), + agents: haxmap.New[string, string](), } } @@ -69,8 +73,9 @@ func (r *Relay) Name() string { } // Address is the underlay address of the relay. -func (r *Relay) Address() string { - return r.pc.LocalAddr().String() +func (r *Relay) Address() netip.AddrPort { + ua := r.pc.LocalAddr().(*net.UDPAddr) + return netip.AddrPortFrom(netip.MustParseAddr(ua.IP.String()), uint16(ua.Port)) } // SetCredentials sets the authentication token used by agents to authenticate with the relay. @@ -92,7 +97,7 @@ func (r *Relay) SetEgressGateway(enabled bool) { } // SetOnConnect sets a callback that is invoked when a new connection is established to the relay. -func (r *Relay) SetOnConnect(onConnect func(ctx context.Context, agentName string, conn controllers.Connection) error) { +func (r *Relay) SetOnConnect(onConnect func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error) { r.mu.Lock() defer r.mu.Unlock() @@ -130,6 +135,16 @@ func (r *Relay) Start(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) + // Start the router to handle network traffic. + g.Go(func() error { + return r.router.Start(ctx) + }) + + // Start the garbage collector. + g.Go(func() error { + return r.startGC(ctx, gcMaxSilence, gcCheckInterval) + }) + g.Go(func() error { <-ctx.Done() @@ -151,11 +166,6 @@ func (r *Relay) Start(ctx context.Context) error { return srv.Close() }) - // Start the router to handle network traffic. - g.Go(func() error { - return r.router.Start(ctx) - }) - g.Go(func() error { slog.Info("Starting relay", slog.String("addr", ln.Addr().String())) if err := srv.ServeListener(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -164,7 +174,11 @@ func (r *Relay) Start(ctx context.Context) error { return nil }) - return g.Wait() + if err := g.Wait(); err != nil && !errors.Is(err, context.Canceled) { + return err + } + + return nil } func (r *Relay) handleConnect(w http.ResponseWriter, req *http.Request, ps httprouter.Params) { @@ -202,12 +216,14 @@ func (r *Relay) handleConnect(w http.ResponseWriter, req *http.Request, ps httpr } r.conns.Set(conn.ID(), conn) + r.agents.Set(conn.ID(), request.Agent) r.mu.Lock() onConnect := r.onConnect r.mu.Unlock() - if err := onConnect(req.Context(), request.Agent, conn); err != nil { + tunnelName := ps.ByName("name") + if err := onConnect(req.Context(), tunnelName, request.Agent, conn); err != nil { slog.Error("onConnect callback failed", slog.Any("error", err)) http.Error(w, "Failed to handle connection", http.StatusInternalServerError) return @@ -265,13 +281,11 @@ func (r *Relay) handleConnect(w http.ResponseWriter, req *http.Request, ps httpr } } if r.egressGateway { - // Default route for all traffic. routes = append(routes, api.Route{Destination: "0.0.0.0/0"}, api.Route{Destination: "::/0"}) } - tunnelName := ps.ByName("name") relayAddrs, _ := r.relayAddrs.Get(tunnelName) resp := api.ConnectResponse{ @@ -310,6 +324,7 @@ func (r *Relay) handleDisconnect(w http.ResponseWriter, req *http.Request, ps ht http.Error(w, "Connection not found", http.StatusNotFound) return } + r.agents.Del(request.ID) if err := conn.Close(); err != nil { slog.Warn("Failed to close connection", slog.Any("error", err)) @@ -409,7 +424,6 @@ func (r *Relay) withAuth(next httprouter.Handle) httprouter.Handle { } if storedToken, ok := r.tokens.Get(tunnelName); !ok || storedToken != tokenStr { - slog.Warn("Invalid token for tunnel", slog.String("tunnel", tunnelName)) http.Error(w, "Unauthorized", http.StatusUnauthorized) r.closeConn(w, http3.ErrCodeRequestRejected, "unauthorized") return @@ -431,6 +445,70 @@ func (r *Relay) closeConn(w http.ResponseWriter, code http3.ErrCode, msg string) _ = h3c.CloseWithError(quic.ApplicationErrorCode(code), msg) } +func (r *Relay) startGC(ctx context.Context, maxSilence, checkInterval time.Duration) error { + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + now := time.Now() + r.conns.ForEach(func(id string, conn *connection) bool { + vni := conn.VNI() + if vni == nil { + return true + } + + vnet, ok := r.handler.GetVirtualNetwork(*vni) + if !ok { + return true + } + + lastRxNs := vnet.Stats.LastRXUnixNano.Load() + lastRx := now + if lastRxNs != 0 { + lastRx = time.Unix(0, lastRxNs) + } + + if since := now.Sub(lastRx); since > maxSilence { + // Connection has been silent for too long — clean it up. + if _, ok := r.conns.GetAndDel(id); ok { + agentName, _ := r.agents.Get(id) + r.agents.Del(id) + + slog.Warn("GC: dropping idle connection", + slog.String("id", id), + slog.Duration("silence", since), + slog.Duration("maxSilence", maxSilence), + ) + + if err := conn.Close(); err != nil { + slog.Warn("GC: failed to close connection", + slog.String("id", id), + slog.Any("error", err)) + } + + r.mu.Lock() + onDisconnect := r.onDisconnect + r.mu.Unlock() + if onDisconnect != nil { + if err := onDisconnect(ctx, agentName, id); err != nil { + slog.Warn("GC: onDisconnect callback failed", + slog.String("id", id), + slog.String("agent", agentName), + slog.Any("error", err)) + } + } + } + } + return true + }) + } + } +} + func randomKey() (api.Key, error) { var key api.Key _, err := rand.Read(key[:]) diff --git a/pkg/tunnel/relay_test.go b/pkg/tunnel/relay_test.go index 0713ec0..44b819d 100644 --- a/pkg/tunnel/relay_test.go +++ b/pkg/tunnel/relay_test.go @@ -31,9 +31,7 @@ import ( func TestRelay_Connect_UpdateKeys_Disconnect(t *testing.T) { const goodToken = "secret-token" - // onConnect assigns VNI and overlay address so handleConnect can proceed. - onConnect := func(ctx context.Context, agent string, conn controllers.Connection) error { - // Choose a deterministic VNI for the test. + onConnect := func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error { conn.SetVNI(ctx, 101) conn.SetOverlayAddress("10.0.0.2/32") return nil @@ -47,7 +45,7 @@ func TestRelay_Connect_UpdateKeys_Disconnect(t *testing.T) { return nil } - r, caCert, stop := startRelay(t, goodToken, onConnect, onDisconnect) + r, caCert, stop, _ := startRelay(t, goodToken, onConnect, onDisconnect) t.Cleanup(stop) c := clientForRelay(t, r, caCert, goodToken) @@ -58,7 +56,6 @@ func TestRelay_Connect_UpdateKeys_Disconnect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) t.Cleanup(cancel) - // Connect connectResp, err := c.Connect(ctx) require.NoError(t, err) require.NotEmpty(t, connectResp.ID) @@ -69,25 +66,130 @@ func TestRelay_Connect_UpdateKeys_Disconnect(t *testing.T) { require.WithinDuration(t, time.Now().Add(24*time.Hour), connectResp.Keys.ExpiresAt, time.Minute) firstEpoch := connectResp.Keys.Epoch - require.EqualValues(t, 0, firstEpoch, "initial key epoch should start at 0") + require.EqualValues(t, 0, firstEpoch) - // UpdateKeys upd, err := c.UpdateKeys(ctx, connectResp.ID) require.NoError(t, err) require.Equal(t, connectResp.ID, connectResp.ID) - require.GreaterOrEqual(t, int(upd.Keys.Epoch), int(firstEpoch+1), "epoch must increment") + require.GreaterOrEqual(t, int(upd.Keys.Epoch), int(firstEpoch+1)) require.WithinDuration(t, time.Now().Add(24*time.Hour), upd.Keys.ExpiresAt, time.Minute) - // Disconnect err = c.Disconnect(ctx, connectResp.ID) require.NoError(t, err) - // Verify callback observed the same info require.Equal(t, "it-agent", disc.Agent) require.Equal(t, connectResp.ID, disc.ID) } -func startRelay(t *testing.T, token string, onConnect func(context.Context, string, controllers.Connection) error, onDisconnect func(context.Context, string, string) error) (*tunnel.Relay, tls.Certificate, func()) { +func TestRelay_InvalidAuthClosesQUIC(t *testing.T) { + const goodToken = "correct-token" + const badToken = "wrong-token" + + onConnect := func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error { return nil } + onDisconnect := func(ctx context.Context, agentName, id string) error { return nil } + + r, caCert, stop, _ := startRelay(t, goodToken, onConnect, onDisconnect) + t.Cleanup(stop) + + tlsCfg := &tls.Config{ + RootCAs: cryptoutils.CertPoolForCertificate(caCert), + ServerName: "localhost", + } + + var captured quic.EarlyConnection + rt := &http3.Transport{ + TLSClientConfig: tlsCfg, + Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + c, err := quic.DialAddrEarly(ctx, addr, tlsConf, cfg) + if err == nil { + captured = c + } + return c, err + }, + } + t.Cleanup(func() { _ = rt.Close() }) + + h3Client := &http.Client{ + Transport: rt, + Timeout: 3 * time.Second, + } + + url := "https://" + r.Address().String() + "/v1/tunnel/test-tunnel" + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+badToken) + req.Header.Set("Content-Type", "application/json") + + _, err = h3Client.Do(req) + require.True(t, err != nil && strings.Contains(err.Error(), "H3_REQUEST_REJECTED"), "expected request to be rejected, got: %v", err) + + require.NotNil(t, captured, "should have captured the QUIC connection") + + select { + case <-captured.Context().Done(): + case <-time.After(750 * time.Millisecond): + t.Fatalf("expected QUIC connection to be closed after unauthorized response, but it remained open") + } +} + +func TestRelay_GarbageCollector_DropsIdleConnections(t *testing.T) { + const token = "gc-token" + + discCh := make(chan api.Request, 1) + onDisconnect := func(ctx context.Context, agent, id string) error { + discCh <- api.Request{Agent: agent, ID: id} + return nil + } + onConnect := func(ctx context.Context, tunnelName, agentName string, conn controllers.Connection) error { + conn.SetVNI(ctx, 202) + conn.SetOverlayAddress("10.0.0.3/32") + return nil + } + + r, caCert, stop, h := startRelay(t, token, onConnect, onDisconnect) + t.Cleanup(stop) + + c := clientForRelay(t, r, caCert, token) + t.Cleanup(func() { require.NoError(t, c.Close()) }) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + + connResp, err := c.Connect(ctx) + require.NoError(t, err) + require.NotEmpty(t, connResp.ID) + require.Equal(t, uint(202), connResp.VNI) + + vnet, ok := h.GetVirtualNetwork(connResp.VNI) + require.True(t, ok, "virtual network should exist") + + old := time.Now().Add(-10 * time.Minute).UnixNano() + vnet.Stats.LastRXUnixNano.Store(old) + + select { + case got := <-discCh: + require.Equal(t, connResp.ID, got.ID) + require.Equal(t, "it-agent", got.Agent) + default: + select { + case got := <-discCh: + require.Equal(t, connResp.ID, got.ID) + require.Equal(t, "it-agent", got.Agent) + case <-time.After(8 * time.Second): + t.Fatalf("expected GC to drop idle connection within a GC interval") + } + } + + err = c.Disconnect(ctx, connResp.ID) + require.Error(t, err) // should error as connection already dropped by GC +} + +func startRelay( + t *testing.T, + token string, + onConnect func(context.Context, string, string, controllers.Connection) error, + onDisconnect func(context.Context, string, string) error, +) (*tunnel.Relay, tls.Certificate, func(), *icx.Handler) { t.Helper() pc, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -96,8 +198,10 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri caCert, serverCert, err := cryptoutils.GenerateSelfSignedTLSCert("localhost") require.NoError(t, err) - h, err := icx.NewHandler(icx.WithLocalAddr(netstack.ToFullAddress(netip.MustParseAddrPort("127.0.0.1:6081"))), - icx.WithVirtMAC(tcpip.GetRandMacAddr())) + h, err := icx.NewHandler( + icx.WithLocalAddr(netstack.ToFullAddress(netip.MustParseAddrPort("127.0.0.1:6081"))), + icx.WithVirtMAC(tcpip.GetRandMacAddr()), + ) require.NoError(t, err) idKey := make([]byte, 32) @@ -107,7 +211,6 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri idHasher := hasher.NewHasher(idKey) rtr := &mockRouter{} - rtr.On("Start", mock.Anything).Return(nil) rtr.On("Close").Return(nil) rtr.On("AddAddr", mock.Anything, mock.Anything).Return(nil) @@ -121,7 +224,6 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri r.SetOnDisconnect(onDisconnect) ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) go func() { if err := r.Start(ctx); err != nil { @@ -130,7 +232,6 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri close(done) }() - // Give the server a brief moment to bind and start serving. time.Sleep(150 * time.Millisecond) stop := func() { @@ -138,76 +239,17 @@ func startRelay(t *testing.T, token string, onConnect func(context.Context, stri select { case <-done: case <-time.After(5 * time.Second): - // if shutdown hangs, tests will fail below anyway } _ = pc.Close() } - return r, caCert, stop -} - -func TestRelay_InvalidAuthClosesQUIC(t *testing.T) { - const goodToken = "correct-token" - const badToken = "wrong-token" - - // We don't expect to reach onConnect/onDisconnect for a bad token. - onConnect := func(ctx context.Context, agent string, conn controllers.Connection) error { return nil } - onDisconnect := func(ctx context.Context, agent, id string) error { return nil } - - r, caCert, stop := startRelay(t, goodToken, onConnect, onDisconnect) - t.Cleanup(stop) - - // Build a raw HTTP/3 client so we can inject a custom Dial that captures the QUIC connection. - tlsCfg := &tls.Config{ - RootCAs: cryptoutils.CertPoolForCertificate(caCert), - ServerName: "localhost", - } - - var captured quic.EarlyConnection - rt := &http3.Transport{ - TLSClientConfig: tlsCfg, - // Capture the QUIC connection used underneath, so we can verify it gets closed. - Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - c, err := quic.DialAddrEarly(ctx, addr, tlsConf, cfg) - if err == nil { - captured = c - } - return c, err - }, - } - t.Cleanup(func() { _ = rt.Close() }) - - h3Client := &http.Client{ - Transport: rt, - Timeout: 3 * time.Second, - } - - // Send a request with an invalid token. - url := "https://" + r.Address() + "/v1/tunnel/test-tunnel" - req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader([]byte(`{}`))) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+badToken) - req.Header.Set("Content-Type", "application/json") - - _, err = h3Client.Do(req) - require.True(t, err != nil && strings.Contains(err.Error(), "H3_REQUEST_REJECTED"), "expected request to be rejected, got: %v", err) - - // The relay calls CloseWithError on the underlying QUIC connection after writing 401. - // That should cause the client's connection context to be done very quickly. - require.NotNil(t, captured, "should have captured the QUIC connection") - - select { - case <-captured.Context().Done(): - // success: the QUIC connection was closed by the server - case <-time.After(750 * time.Millisecond): - t.Fatalf("expected QUIC connection to be closed after unauthorized response, but it remained open") - } + return r, caCert, stop, h } func clientForRelay(t *testing.T, r *tunnel.Relay, caCert tls.Certificate, token string) *api.Client { t.Helper() - baseURL := "https://" + r.Address() + baseURL := "https://" + r.Address().String() tlsCfg := &tls.Config{ RootCAs: cryptoutils.CertPoolForCertificate(caCert), ServerName: "localhost",