From 069883c15a3f78c89ad65c38268eea8745011008 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 01:26:55 +0800 Subject: [PATCH 01/31] feat: gateway v2 scaffolding --- packages/api/api.go | 63 +++++ packages/api/model.go | 40 +++ packages/cmd/network.go | 263 ++++++++++++++++++ packages/gateway-v2/gateway.go | 428 +++++++++++++++++++++++++++++ packages/proxy/proxy.go | 486 +++++++++++++++++++++++++++++++++ 5 files changed, 1280 insertions(+) create mode 100644 packages/cmd/network.go create mode 100644 packages/gateway-v2/gateway.go create mode 100644 packages/proxy/proxy.go diff --git a/packages/api/api.go b/packages/api/api.go index a9b204b6..7d61eb8a 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -40,6 +40,9 @@ const ( operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" operationCallBootstrapInstance = "CallBootstrapInstance" + operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" + operationCallRegisterOrgProxy = "CallRegisterOrgProxy" + operationCallRegisterGateway = "CallRegisterGateway" ) func CallGetEncryptedWorkspaceKey(httpClient *resty.Client, request GetEncryptedWorkspaceKeyRequest) (GetEncryptedWorkspaceKeyResponse, error) { @@ -671,3 +674,63 @@ func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRe return resBody, nil } + +func CallRegisterInstanceProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { + var resBody RegisterProxyResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/proxies/register-instance-proxy", config.INFISICAL_URL)) + + if err != nil { + return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterInstanceProxy, err) + } + + if response.IsError() { + return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceProxy, response, nil) + } + + return resBody, nil +} + +func CallRegisterProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { + var resBody RegisterProxyResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/proxies/register-org-proxy", config.INFISICAL_URL)) + + if err != nil { + return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterOrgProxy, err) + } + + if response.IsError() { + return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgProxy, response, nil) + } + + return resBody, nil +} + +func CallRegisterGateway(httpClient *resty.Client, request RegisterGatewayRequest) (RegisterGatewayResponse, error) { + var resBody RegisterGatewayResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v2/gateways", config.INFISICAL_URL)) + + if err != nil { + return RegisterGatewayResponse{}, NewGenericRequestError(operationCallRegisterGateway, err) + } + + if response.IsError() { + return RegisterGatewayResponse{}, NewAPIErrorWithResponse(operationCallRegisterGateway, response, nil) + } + + return resBody, nil +} diff --git a/packages/api/model.go b/packages/api/model.go index 3f10b4ca..ad172278 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -703,3 +703,43 @@ type BootstrapUser struct { Username string `json:"username"` SuperAdmin bool `json:"superAdmin"` } + +type RegisterProxyRequest struct { + IP string `json:"ip"` + Name string `json:"name"` +} + +type RegisterProxyResponse struct { + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerCertificateChain string `json:"serverCertificateChain"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCA string `json:"clientCA"` + } `json:"pki"` + SSH struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCAPublicKey string `json:"clientCAPublicKey"` + } `json:"ssh"` +} + +type RegisterGatewayRequest struct { + ProxyName string `json:"proxyName"` + Name string `json:"name"` +} + +type RegisterGatewayResponse struct { + GatewayID string `json:"gatewayId"` + ProxyIP string `json:"proxyIp"` + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerCertificateChain string `json:"serverCertificateChain"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCA string `json:"clientCA"` + } `json:"pki"` + SSH struct { + ClientCertificate string `json:"clientCertificate"` + ClientPrivateKey string `json:"clientPrivateKey"` + ServerCAPublicKey string `json:"serverCAPublicKey"` + } `json:"ssh"` +} diff --git a/packages/cmd/network.go b/packages/cmd/network.go new file mode 100644 index 00000000..8ec9c6b3 --- /dev/null +++ b/packages/cmd/network.go @@ -0,0 +1,263 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" + "github.com/Infisical/infisical-merge/packages/proxy" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var networkCmd = &cobra.Command{ + Use: "network", + Short: "Network-related commands", + Long: "Network-related commands for Infisical", +} + +var networkProxyCmd = &cobra.Command{ + Use: "proxy", + Short: "Run the Infisical proxy component", + Long: "Run the Infisical proxy component", + Run: func(cmd *cobra.Command, args []string) { + + proxyName, err := cmd.Flags().GetString("name") + if err != nil || proxyName == "" { + util.HandleError(err, "unable to get name flag") + } + + ip, err := cmd.Flags().GetString("ip") + if err != nil || ip == "" { + util.HandleError(err, "unable to get ip flag") + } + + instanceType, err := cmd.Flags().GetString("type") + if err != nil { + util.HandleError(err, "unable to get type flag") + } + + proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ + ProxyName: proxyName, + SSHPort: "2222", + TLSPort: "443", + StaticIP: ip, + Type: instanceType, + }) + + if err != nil { + util.HandleError(err, "unable to create proxy instance") + } + + if instanceType == "instance" { + proxyAuthSecret := os.Getenv("PROXY_AUTH_SECRET") + if proxyAuthSecret == "" { + util.HandleError(fmt.Errorf("PROXY_AUTH_SECRET is not set"), "unable to get proxy auth secret") + } + + proxyInstance.SetToken(proxyAuthSecret) + } else { + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + proxyInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down proxy...") + cancelCmd() + cancelSdk() + + // If we get a second signal, force exit + <-sigCh + log.Warn().Msgf("Force exit triggered") + os.Exit(1) + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + proxyInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + } + + // Use the same context for the proxy server + err = proxyInstance.Start(cmd.Context()) + if err != nil { + util.HandleError(err, "unable to start proxy instance") + } + }, +} + +var networkProxyInstallCmd = &cobra.Command{ + Use: "proxy install", + Short: "Install and enable systemd service for the proxy (requires sudo)", + Long: "Install and enable systemd service for the proxy. Must be run with sudo on Linux.", + Run: func(cmd *cobra.Command, args []string) { + // TODO: Implement this + }, +} + +var networkGatewayCmd = &cobra.Command{ + Use: "gateway", + Short: "Run the Infisical gateway component", + Long: "Run the Infisical gateway component", + Run: func(cmd *cobra.Command, args []string) { + + proxyName, err := cmd.Flags().GetString("proxy-name") + if err != nil || proxyName == "" { + util.HandleError(err, "unable to get proxy-name flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil || gatewayName == "" { + util.HandleError(err, "unable to get name flag") + } + + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ + Name: gatewayName, + ProxyName: proxyName, + ReconnectDelay: 10 * time.Second, + }) + + if err != nil { + util.HandleError(err, "unable to create gateway instance") + } + + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + gatewayInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down gateway...") + cancelCmd() + cancelSdk() + + // If we get a second signal, force exit + <-sigCh + log.Warn().Msgf("Force exit triggered") + os.Exit(1) + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + gatewayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + + err = gatewayInstance.Start(ctx) + if err != nil { + util.HandleError(err, "unable to start gateway instance") + } + + }, +} + +func init() { + networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") + networkGatewayCmd.Flags().String("name", "", "The name of the gateway") + networkGatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + networkGatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + networkGatewayCmd.Flags().String("client-id", "", "client id for universal auth") + networkGatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") + networkGatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + networkGatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + networkGatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + networkGatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + networkProxyCmd.Flags().String("type", "org", "The type of proxy to run. Must be either 'instance' or 'org'") + networkProxyCmd.Flags().String("ip", "", "The IP address of the proxy") + networkProxyCmd.Flags().String("name", "", "The name of the proxy") + networkProxyCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + networkProxyCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + networkProxyCmd.Flags().String("client-id", "", "client id for universal auth") + networkProxyCmd.Flags().String("client-secret", "", "client secret for universal auth") + networkProxyCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + networkProxyCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + networkProxyCmd.AddCommand(networkProxyInstallCmd) + + networkCmd.AddCommand(networkProxyCmd) + networkCmd.AddCommand(networkGatewayCmd) + + rootCmd.AddCommand(networkCmd) +} diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go new file mode 100644 index 00000000..2208e8cd --- /dev/null +++ b/packages/gateway-v2/gateway.go @@ -0,0 +1,428 @@ +package gatewayv2 + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "log" + "net" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "golang.org/x/crypto/ssh" +) + +type GatewayConfig struct { + Name string + ProxyName string + IdentityToken string + SSHPort int + ReconnectDelay time.Duration +} + +type Gateway struct { + GatewayID string + + httpClient *resty.Client + config *GatewayConfig + sshClient *ssh.Client + + // Certificate storage + certificates *api.RegisterGatewayResponse + + // mTLS server components + tlsConfig *tls.Config + tlsCACert []byte + tlsCAKey *rsa.PrivateKey + + // Connection management + mu sync.RWMutex + isConnected bool + ctx context.Context + cancel context.CancelFunc +} + +// NewGateway creates a new gateway instance +func NewGateway(config *GatewayConfig) (*Gateway, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.IdentityToken) + + ctx, cancel := context.WithCancel(context.Background()) + + // Set default SSH port if not specified + if config.SSHPort == 0 { + config.SSHPort = 2222 + } + + return &Gateway{ + httpClient: httpClient, + config: config, + ctx: ctx, + cancel: cancel, + }, nil +} + +// Change the Start method to accept a context +func (g *Gateway) Start(ctx context.Context) error { + log.Printf("Starting gateway") + for { + select { + case <-ctx.Done(): + log.Printf("Gateway stopped by context cancellation") + return nil + default: + if err := g.connectAndServe(); err != nil { + log.Printf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(g.config.ReconnectDelay): + continue + } + } + // If we get here, the connection was closed gracefully + log.Printf("Connection closed, reconnecting in 10 seconds...") + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): + continue + } + } + } +} + +func (g *Gateway) SetToken(token string) { + g.httpClient.SetAuthToken(token) +} + +func (g *Gateway) Stop() { + g.cancel() + + g.mu.Lock() + if g.sshClient != nil { + g.sshClient.Close() + g.sshClient = nil + } + g.isConnected = false + g.mu.Unlock() +} + +func (g *Gateway) connectAndServe() error { + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + // Create SSH client config + sshConfig, err := g.createSSHConfig() + if err != nil { + return fmt.Errorf("failed to create SSH config: %v", err) + } + + // Connect to Proxy server + log.Printf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) + if err != nil { + return fmt.Errorf("failed to connect to SSH server: %v", err) + } + + g.mu.Lock() + g.sshClient = client + g.isConnected = true + g.mu.Unlock() + + defer func() { + g.mu.Lock() + g.sshClient = nil + g.isConnected = false + g.mu.Unlock() + client.Close() + }() + + log.Printf("SSH connection established for gateway") + + // Handle incoming channels from the server + channels := client.HandleChannelOpen("direct-tcpip") + if channels == nil { + return fmt.Errorf("failed to handle channel open") + } + + // Process incoming channels + for newChannel := range channels { + go g.handleIncomingChannel(newChannel) + } + + return nil // Connection closed +} + +func (g *Gateway) registerGateway() error { + body := api.RegisterGatewayRequest{ + ProxyName: g.config.ProxyName, + Name: g.config.Name, + } + + certResp, err := api.CallRegisterGateway(g.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + g.GatewayID = certResp.GatewayID + g.certificates = &certResp + log.Printf("Successfully registered gateway and received certificates") + return nil +} + +func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { + privateKey, err := ssh.ParsePrivateKey([]byte(g.certificates.SSH.ClientPrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse SSH private key: %v", err) + } + + // Parse certificate + cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ClientCertificate)) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %v", err) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(cert.(*ssh.Certificate), privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate signer: %v", err) + } + + // Create SSH client config + config := &ssh.ClientConfig{ + User: g.GatewayID, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(certSigner), + }, + HostKeyCallback: g.createHostKeyCallback(), + Timeout: 30 * time.Second, + Config: ssh.Config{ + KeyExchanges: []string{ + "diffie-hellman-group14-sha256", + "diffie-hellman-group16-sha512", + "diffie-hellman-group18-sha512", + }, + Ciphers: []string{ + "aes128-ctr", + "aes192-ctr", + "aes256-ctr", + }, + MACs: []string{ + "hmac-sha2-256", + "hmac-sha2-512", + }, + }, + } + + return config, nil +} + +func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { + // Parse CA public key once when creating the callback + caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) + if err != nil { + // Return a callback that always fails since we can't parse the CA key + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return fmt.Errorf("failed to parse CA public key: %v", err) + } + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + cert, ok := key.(*ssh.Certificate) + if !ok { + return fmt.Errorf("host certificates required, raw host keys not allowed") + } + + return g.validateHostCertificate(cert, hostname, caKey) + } +} + +func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string, caKey ssh.PublicKey) error { + checker := &ssh.CertChecker{ + IsHostAuthority: func(auth ssh.PublicKey, address string) bool { + return bytes.Equal(auth.Marshal(), caKey.Marshal()) + }, + } + + if err := checker.CheckCert(hostname, cert); err != nil { + return fmt.Errorf("host certificate check failed: %v", err) + } + + log.Printf("Host certificate validated successfully for %s", hostname) + return nil +} + +func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { + var req struct { + Host string + Port uint32 + OriginHost string + OriginPort uint32 + } + + if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { + log.Printf("Failed to parse channel request: %v", err) + newChannel.Reject(ssh.Prohibited, "invalid request") + return + } + + log.Printf("Incoming connection request to %s:%d from %s:%d", + req.Host, req.Port, req.OriginHost, req.OriginPort) + + // Accept the channel + channel, requests, err := newChannel.Accept() + if err != nil { + log.Printf("Failed to accept channel: %v", err) + return + } + defer channel.Close() + + go ssh.DiscardRequests(requests) + + // Determine the target address + target := fmt.Sprintf("%s:%d", req.Host, req.Port) + log.Printf("Creating TCP tunnel to: %s", target) + + // Create mTLS server configuration + tlsConfig, err := g.createMTLSConfig() + if err != nil { + log.Printf("Failed to create mTLS config: %v", err) + return + } + + // Create a virtual connection that pipes data between SSH channel and TLS + virtualConn := &virtualConnection{ + channel: channel, + } + + // Wrap the virtual connection with TLS + tlsConn := tls.Server(virtualConn, tlsConfig) + + // Perform TLS handshake + if err := tlsConn.Handshake(); err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + log.Printf("mTLS connection established with client: %s", tlsConn.ConnectionState().ServerName) + + // Connect to local service + localConn, err := net.Dial("tcp", target) + if err != nil { + log.Printf("Failed to connect to local service %s: %v", target, err) + return + } + defer localConn.Close() + + log.Printf("TCP tunnel established to %s", target) + + // Create bidirectional tunnel with TLS + // Forward data from TLS connection to local service + go func() { + io.Copy(localConn, tlsConn) + localConn.Close() + log.Printf("TLS -> local service tunnel closed") + }() + + // Forward data from local service to TLS connection + io.Copy(tlsConn, localConn) + log.Printf("Local service -> TLS tunnel closed") +} + +func (g *Gateway) createMTLSConfig() (*tls.Config, error) { + // Parse server certificate + serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return nil, fmt.Errorf("failed to decode server certificate") + } + + // Parse server private key + serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return nil, fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse server private key: %v", err) + } + + // Parse client CA certificate + clientCABlock, _ := pem.Decode([]byte(g.certificates.PKI.ClientCA)) + if clientCABlock == nil { + return nil, fmt.Errorf("failed to decode client CA certificate") + } + + clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse client CA certificate: %v", err) + } + + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + clientCAPool.AddCert(clientCA) + + // Create TLS config + return &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + }, nil +} + +// virtualConnection implements net.Conn to bridge SSH channel and TLS +type virtualConnection struct { + channel ssh.Channel +} + +func (vc *virtualConnection) Read(b []byte) (n int, err error) { + return vc.channel.Read(b) +} + +func (vc *virtualConnection) Write(b []byte) (n int, err error) { + return vc.channel.Write(b) +} + +func (vc *virtualConnection) Close() error { + return vc.channel.Close() +} + +func (vc *virtualConnection) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) SetDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetReadDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go new file mode 100644 index 00000000..492dee61 --- /dev/null +++ b/packages/proxy/proxy.go @@ -0,0 +1,486 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "log" + "net" + + "strconv" + "strings" + "sync" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "golang.org/x/crypto/ssh" +) + +type ProxyConfig struct { + // API Configuration + Token string + ProxyName string + + Type string + + // Server Ports + SSHPort string + TLSPort string + + // Network Configuration + StaticIP string +} + +type Proxy struct { + httpClient *resty.Client + config *ProxyConfig + + // Certificate storage + certificates *api.RegisterProxyResponse + + // SSH server components + sshConfig *ssh.ServerConfig + sshCA ssh.Signer + + // TLS server components + tlsConfig *tls.Config + tlsCACert []byte + tlsCAKey *rsa.PrivateKey + + // Tunnel storage (Gateway ID -> SSH connection) + tunnels map[string]*ssh.ServerConn + mu sync.RWMutex + + // Server listeners + sshListener net.Listener + tlsListener net.Listener +} + +func NewProxy(config *ProxyConfig) (*Proxy, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.Token) + + return &Proxy{ + httpClient: httpClient, + config: config, + tunnels: make(map[string]*ssh.ServerConn), + }, nil +} + +func (p *Proxy) SetToken(token string) { + p.httpClient.SetAuthToken(token) +} + +func (p *Proxy) Start(ctx context.Context) error { + // Register proxy and get certificates from API + if err := p.registerProxy(); err != nil { + return fmt.Errorf("failed to register proxy: %v", err) + } + + // Setup SSH server + if err := p.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Setup TLS server + if err := p.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + // Start SSH server + go p.startSSHServer() + + // Start TLS server + go p.startTLSServer() + + log.Printf("Proxy server started successfully") + + // Wait for context cancellation + <-ctx.Done() + + // Cleanup + p.cleanup() + return nil +} + +func (p *Proxy) registerProxy() error { + body := api.RegisterProxyRequest{ + IP: p.config.StaticIP, + Name: p.config.ProxyName, + } + + if p.config.Type == "instance" { + certResp, err := api.CallRegisterInstanceProxy(p.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register instance proxy: %v", err) + } + p.certificates = &certResp + } else { + certResp, err := api.CallRegisterProxy(p.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register org proxy: %v", err) + } + p.certificates = &certResp + } + + log.Printf("Successfully registered proxy and received certificates from API") + return nil +} + +func (p *Proxy) setupSSHServer() error { + // Parse SSH CA public key + sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ClientCAPublicKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH CA public key: %v", err) + } + + // Parse SSH server private key + sshServerKey, err := ssh.ParsePrivateKey([]byte(p.certificates.SSH.ServerPrivateKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH server private key: %v", err) + } + + // Parse SSH server certificate + sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ServerCertificate)) + if err != nil { + return fmt.Errorf("failed to parse SSH server certificate: %v", err) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(sshServerCert.(*ssh.Certificate), sshServerKey) + if err != nil { + return fmt.Errorf("failed to create SSH certificate signer: %v", err) + } + + // Setup SSH server config + p.sshConfig = &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + // Check if this is an SSH certificate + cert, ok := key.(*ssh.Certificate) + if !ok { + log.Printf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + return nil, fmt.Errorf("certificates required, raw public keys not allowed") + } + + // Validate the certificate + if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { + log.Printf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + return nil, err + } + + gatewayId := "" + if len(cert.ValidPrincipals) > 0 { + gatewayId = cert.ValidPrincipals[0] + } + + if gatewayId == "" { + return nil, fmt.Errorf("gateway id is required") + } + + return &ssh.Permissions{ + Extensions: map[string]string{ + "gateway-id": gatewayId, + }, + }, nil + }, + } + + p.sshConfig.AddHostKey(certSigner) + return nil +} + +func (p *Proxy) setupTLSServer() error { + // Parse TLS server certificate + serverCertBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + // Note: serverCert is parsed for validation but not used in the TLS config + // since we use the raw bytes directly + _, err := x509.ParseCertificate(serverCertBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server certificate: %v", err) + } + + // Parse TLS server private key + serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + // Parse client CA certificate + clientCABlock, _ := pem.Decode([]byte(p.certificates.PKI.ClientCA)) + if clientCABlock == nil { + return fmt.Errorf("failed to decode client CA certificate") + } + + clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse client CA certificate: %v", err) + } + + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + clientCAPool.AddCert(clientCA) + + // Create TLS config + p.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + + return nil +} + +func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { + // Check certificate type + if cert.CertType != ssh.UserCert { + return fmt.Errorf("invalid certificate type: %d", cert.CertType) + } + + // Check if certificate is signed by our CA + checker := &ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) + }, + } + + // Validate the certificate + if err := checker.CheckCert(username, cert); err != nil { + return fmt.Errorf("certificate check failed: %v", err) + } + + log.Printf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) + return nil +} + +func (p *Proxy) startSSHServer() { + listener, err := net.Listen("tcp", ":"+p.config.SSHPort) + if err != nil { + log.Fatalf("Failed to start SSH server: %v", err) + } + p.sshListener = listener + + log.Printf("SSH server listening on :%s for gateways", p.config.SSHPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Failed to accept SSH connection: %v", err) + continue + } + go p.handleSSHAgent(conn) + } +} + +func (p *Proxy) handleSSHAgent(conn net.Conn) { + defer conn.Close() + + // SSH handshake + sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) + if err != nil { + log.Printf("SSH handshake failed: %v", err) + return + } + + gatewayId := sshConn.Permissions.Extensions["gateway-id"] + log.Printf("SSH handshake successful for gateway: %s", gatewayId) + + // Store the connection + p.mu.Lock() + p.tunnels[gatewayId] = sshConn + p.mu.Unlock() + + // Clean up when agent disconnects + defer func() { + p.mu.Lock() + delete(p.tunnels, gatewayId) + p.mu.Unlock() + log.Printf("Gateway %s disconnected", gatewayId) + }() + + for newChannel := range chans { + switch newChannel.ChannelType() { + case "session": + newChannel.Reject(ssh.Prohibited, "no shell access") + case "x11": + newChannel.Reject(ssh.Prohibited, "no X11 forwarding") + case "auth-agent": + newChannel.Reject(ssh.Prohibited, "no agent forwarding") + } + } +} + +func (p *Proxy) startTLSServer() { + listener, err := net.Listen("tcp", ":"+p.config.TLSPort) + if err != nil { + log.Fatalf("Failed to start TLS server: %v", err) + } + p.tlsListener = listener + + log.Printf("TLS server listening on :%s for clients", p.config.TLSPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Failed to accept TLS connection: %v", err) + continue + } + go p.handleTLSClient(conn) + } +} + +func (p *Proxy) handleTLSClient(conn net.Conn) { + defer conn.Close() + + log.Printf("Client connected from %s", conn.RemoteAddr()) + + // Wrap connection with TLS + tlsConn := tls.Server(conn, p.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + // Log client certificate info + if len(tlsConn.ConnectionState().PeerCertificates) > 0 { + cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + } + + p.handleClient(tlsConn) +} + +func (p *Proxy) handleClient(clientConn net.Conn) { + defer clientConn.Close() + + // Read the first few bytes to determine which agent to connect to + // Format: "agent1:host:port\n" or "agent1:host:port" followed by data + buffer := make([]byte, 1024) + n, err := clientConn.Read(buffer) + if err != nil { + log.Printf("Failed to read from client: %v", err) + return + } + + // Find the first newline to separate agent info from data + data := buffer[:n] + log.Printf("Received %d bytes from client: %q", n, string(data)) + newlineIndex := bytes.IndexByte(data, '\n') + + var gatewayId, targetHost string + var targetPort uint32 + var remainingData []byte + + if newlineIndex != -1 { + // Agent info is everything before the newline + agentInfo := string(data[:newlineIndex]) + remainingData = data[newlineIndex+1:] + + // Parse agent info in format "agent:host:port" + parts := strings.Split(agentInfo, ":") + if len(parts) != 3 { + log.Printf("Invalid client data format, expected 'agent:host:port', got: %s", agentInfo) + clientConn.Write([]byte("ERROR: Invalid format. Expected 'agent:host:port'\n")) + return + } + + gatewayId = parts[0] + targetHost = parts[1] + portStr := parts[2] + + // Parse port number + port, err := strconv.ParseUint(portStr, 10, 32) + if err != nil { + log.Printf("Invalid port number: %s", portStr) + clientConn.Write([]byte("ERROR: Invalid port number\n")) + return + } + targetPort = uint32(port) + + log.Printf("Extracted gateway: %s, target: %s:%d", gatewayId, targetHost, targetPort) + } else { + log.Printf("Invalid client data format - no newline found") + clientConn.Write([]byte("ERROR: Please use format 'gatewayId:host:port'\n")) + return + } + + // Get the SSH connection for this agent + p.mu.RLock() + conn, exists := p.tunnels[gatewayId] + p.mu.RUnlock() + + if !exists { + log.Printf("Gateway '%s' not connected", gatewayId) + clientConn.Write([]byte("ERROR: Gateway not connected\n")) + return + } + + log.Printf("Routing TCP connection to gateway: %s", gatewayId) + + // Open SSH channel to connect to agent's local service through the tunnel + payload := struct { + Host string + Port uint32 + _ string + _ uint32 + }{targetHost, targetPort, "", 0} + + channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) + if err != nil { + log.Printf("Failed to connect to agent: %v", err) + clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) + return + } + defer channel.Close() + + // If we have remaining data from the initial read, write it to the channel + if len(remainingData) > 0 { + channel.Write(remainingData) + } + + // Bidirectional forwarding + go func() { + io.Copy(channel, clientConn) + channel.CloseWrite() + }() + + io.Copy(clientConn, channel) + log.Printf("Client %s disconnected", clientConn.RemoteAddr()) +} + +func (p *Proxy) cleanup() { + log.Printf("Shutting down proxy server...") + + if p.sshListener != nil { + p.sshListener.Close() + } + if p.tlsListener != nil { + p.tlsListener.Close() + } + + log.Printf("Proxy server shutdown complete") +} From 1fb0a482b9b89f421e8982e6bf95e93afb6ed3eb Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 01:58:09 +0800 Subject: [PATCH 02/31] misc: updated proxy to start tls server instead of tcp --- packages/proxy/proxy.go | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 492dee61..e177e10c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -335,7 +335,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { } func (p *Proxy) startTLSServer() { - listener, err := net.Listen("tcp", ":"+p.config.TLSPort) + listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) if err != nil { log.Fatalf("Failed to start TLS server: %v", err) } @@ -349,34 +349,21 @@ func (p *Proxy) startTLSServer() { log.Printf("Failed to accept TLS connection: %v", err) continue } - go p.handleTLSClient(conn) + go p.handleClient(conn) } } -func (p *Proxy) handleTLSClient(conn net.Conn) { - defer conn.Close() - - log.Printf("Client connected from %s", conn.RemoteAddr()) - - // Wrap connection with TLS - tlsConn := tls.Server(conn, p.tlsConfig) - if err := tlsConn.Handshake(); err != nil { - log.Printf("TLS handshake failed: %v", err) - return - } - - // Log client certificate info - if len(tlsConn.ConnectionState().PeerCertificates) > 0 { - cert := tlsConn.ConnectionState().PeerCertificates[0] - log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - } - - p.handleClient(tlsConn) -} - func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() + // Log client certificate info if this is a TLS connection + if tlsConn, ok := clientConn.(*tls.Conn); ok { + if len(tlsConn.ConnectionState().PeerCertificates) > 0 { + cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + } + } + // Read the first few bytes to determine which agent to connect to // Format: "agent1:host:port\n" or "agent1:host:port" followed by data buffer := make([]byte, 1024) From cda3ac3e49a1d00ef8def690b9d009ede0e7a70f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 02:19:20 +0800 Subject: [PATCH 03/31] misc: added full server certificate chain to proxy tls --- packages/proxy/proxy.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index e177e10c..d2c61039 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -213,6 +213,18 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server certificate: %v", err) } + // Parse all certificates from the chain (intermediate + root CAs) + var chainCerts [][]byte + chainData := []byte(p.certificates.PKI.ServerCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + // Parse TLS server private key serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { @@ -239,11 +251,15 @@ func (p *Proxy) setupTLSServer() error { clientCAPool := x509.NewCertPool() clientCAPool.AddCert(clientCA) + // Create certificate chain: server cert + chain certs (intermediate + root) + certChain := [][]byte{serverCertBlock.Bytes} + certChain = append(certChain, chainCerts...) + // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { - Certificate: [][]byte{serverCertBlock.Bytes}, + Certificate: certChain, PrivateKey: serverKey, }, }, From 6f7eda5af231ac2c2d188726dc2ede3b47f74fc0 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 02:31:03 +0800 Subject: [PATCH 04/31] misc: added log --- packages/proxy/proxy.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index d2c61039..174e7ccd 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -255,6 +255,19 @@ func (p *Proxy) setupTLSServer() error { certChain := [][]byte{serverCertBlock.Bytes} certChain = append(certChain, chainCerts...) + // Debug: log the complete certificate chain as PEM + var chainPEM strings.Builder + for i, certBytes := range certChain { + chainPEM.WriteString(fmt.Sprintf("--- Certificate %d ---\n", i+1)) + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + chainPEM.Write(certPEM) + chainPEM.WriteString("\n") + } + log.Printf("Complete certificate chain PEM:\n%s", chainPEM.String()) + // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ From 33692075b2455e443973e94a503daa5134fcf53f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 03:29:47 +0800 Subject: [PATCH 05/31] misc: updated proxy to fetch client pem chain --- packages/api/model.go | 3 +-- packages/proxy/proxy.go | 58 ++++++++++++++--------------------------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index ad172278..f2128455 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -712,9 +712,8 @@ type RegisterProxyRequest struct { type RegisterProxyResponse struct { PKI struct { ServerCertificate string `json:"serverCertificate"` - ServerCertificateChain string `json:"serverCertificateChain"` ServerPrivateKey string `json:"serverPrivateKey"` - ClientCA string `json:"clientCA"` + ClientCertificateChain string `json:"clientCertificateChain"` } `json:"pki"` SSH struct { ServerCertificate string `json:"serverCertificate"` diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 174e7ccd..4d8794e3 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -213,18 +213,6 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server certificate: %v", err) } - // Parse all certificates from the chain (intermediate + root CAs) - var chainCerts [][]byte - chainData := []byte(p.certificates.PKI.ServerCertificateChain) - for { - block, rest := pem.Decode(chainData) - if block == nil { - break - } - chainCerts = append(chainCerts, block.Bytes) - chainData = rest - } - // Parse TLS server private key serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { @@ -236,43 +224,35 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server private key: %v", err) } - // Parse client CA certificate - clientCABlock, _ := pem.Decode([]byte(p.certificates.PKI.ClientCA)) - if clientCABlock == nil { - return fmt.Errorf("failed to decode client CA certificate") - } - - clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) - if err != nil { - return fmt.Errorf("failed to parse client CA certificate: %v", err) - } - // Create certificate pool for client CAs clientCAPool := x509.NewCertPool() - clientCAPool.AddCert(clientCA) - // Create certificate chain: server cert + chain certs (intermediate + root) - certChain := [][]byte{serverCertBlock.Bytes} - certChain = append(certChain, chainCerts...) + var chainCerts [][]byte + chainData := []byte(p.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } - // Debug: log the complete certificate chain as PEM - var chainPEM strings.Builder - for i, certBytes := range certChain { - chainPEM.WriteString(fmt.Sprintf("--- Certificate %d ---\n", i+1)) - certPEM := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - chainPEM.Write(certPEM) - chainPEM.WriteString("\n") + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - log.Printf("Complete certificate chain PEM:\n%s", chainPEM.String()) // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { - Certificate: certChain, + Certificate: [][]byte{serverCertBlock.Bytes}, PrivateKey: serverKey, }, }, From 97b9d174780c6f35a8f4d5b4d06f4a1ae0ed133f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 04:26:29 +0800 Subject: [PATCH 06/31] misc: added log point --- packages/proxy/proxy.go | 64 ++++++----------------------------------- 1 file changed, 8 insertions(+), 56 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 4d8794e3..56979266 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -12,8 +12,6 @@ import ( "log" "net" - "strconv" - "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" @@ -365,64 +363,23 @@ func (p *Proxy) startTLSServer() { func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() + var gatewayId string + // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { + fmt.Println(tlsConn.ConnectionState().PeerCertificates) if len(tlsConn.ConnectionState().PeerCertificates) > 0 { cert := tlsConn.ConnectionState().PeerCertificates[0] log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } } - // Read the first few bytes to determine which agent to connect to - // Format: "agent1:host:port\n" or "agent1:host:port" followed by data - buffer := make([]byte, 1024) - n, err := clientConn.Read(buffer) - if err != nil { - log.Printf("Failed to read from client: %v", err) - return - } - - // Find the first newline to separate agent info from data - data := buffer[:n] - log.Printf("Received %d bytes from client: %q", n, string(data)) - newlineIndex := bytes.IndexByte(data, '\n') - - var gatewayId, targetHost string - var targetPort uint32 - var remainingData []byte - - if newlineIndex != -1 { - // Agent info is everything before the newline - agentInfo := string(data[:newlineIndex]) - remainingData = data[newlineIndex+1:] - - // Parse agent info in format "agent:host:port" - parts := strings.Split(agentInfo, ":") - if len(parts) != 3 { - log.Printf("Invalid client data format, expected 'agent:host:port', got: %s", agentInfo) - clientConn.Write([]byte("ERROR: Invalid format. Expected 'agent:host:port'\n")) - return - } - - gatewayId = parts[0] - targetHost = parts[1] - portStr := parts[2] + fmt.Println("gatewayId", gatewayId) - // Parse port number - port, err := strconv.ParseUint(portStr, 10, 32) - if err != nil { - log.Printf("Invalid port number: %s", portStr) - clientConn.Write([]byte("ERROR: Invalid port number\n")) - return - } - targetPort = uint32(port) - - log.Printf("Extracted gateway: %s, target: %s:%d", gatewayId, targetHost, targetPort) - } else { - log.Printf("Invalid client data format - no newline found") - clientConn.Write([]byte("ERROR: Please use format 'gatewayId:host:port'\n")) - return - } + // TODO: extract these from the certificate + targetHost := "localhost" + targetPort := uint32(22) // Get the SSH connection for this agent p.mu.RLock() @@ -453,11 +410,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } defer channel.Close() - // If we have remaining data from the initial read, write it to the channel - if len(remainingData) > 0 { - channel.Write(remainingData) - } - // Bidirectional forwarding go func() { io.Copy(channel, clientConn) From ef24451d49827862a1d30c1733449b2e5fac59d9 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 04:34:58 +0800 Subject: [PATCH 07/31] misc: added handshake forcing --- packages/proxy/proxy.go | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 56979266..b99a8f31 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -11,7 +11,7 @@ import ( "io" "log" "net" - + "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" @@ -367,12 +367,33 @@ func (p *Proxy) handleClient(clientConn net.Conn) { // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { - fmt.Println(tlsConn.ConnectionState().PeerCertificates) - if len(tlsConn.ConnectionState().PeerCertificates) > 0 { - cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("TLS connection detected, forcing handshake...") + err := tlsConn.Handshake() + if err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + state := tlsConn.ConnectionState() + log.Printf("TLS handshake completed, peer certificates count: %d", len(state.PeerCertificates)) + + if len(state.PeerCertificates) > 0 { + cert := state.PeerCertificates[0] log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - gatewayId = cert.Subject.CommonName + parts := strings.Split(cert.Subject.CommonName, ":") + if len(parts) >= 2 { + gatewayId = parts[1] + } else { + log.Printf("Invalid CommonName format, expected 'part1:part2', got: %s", cert.Subject.CommonName) + return + } + } else { + log.Printf("No peer certificates found") + return } + } else { + log.Printf("Not a TLS connection, connection type: %T", clientConn) + return } fmt.Println("gatewayId", gatewayId) From a233a3f3cbf81639403d798302189260e20b8d34 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sun, 31 Aug 2025 23:05:39 +0800 Subject: [PATCH 08/31] misc: updated gateway to fetch client certificate chain --- packages/api/model.go | 3 +-- packages/gateway-v2/gateway.go | 30 +++++++++++++++++++----------- packages/proxy/proxy.go | 3 --- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index f2128455..c436d117 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -732,9 +732,8 @@ type RegisterGatewayResponse struct { ProxyIP string `json:"proxyIp"` PKI struct { ServerCertificate string `json:"serverCertificate"` - ServerCertificateChain string `json:"serverCertificateChain"` ServerPrivateKey string `json:"serverPrivateKey"` - ClientCA string `json:"clientCA"` + ClientCertificateChain string `json:"clientCertificateChain"` } `json:"pki"` SSH struct { ClientCertificate string `json:"clientCertificate"` diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 2208e8cd..12ad4acd 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -361,21 +361,29 @@ func (g *Gateway) createMTLSConfig() (*tls.Config, error) { return nil, fmt.Errorf("failed to parse server private key: %v", err) } - // Parse client CA certificate - clientCABlock, _ := pem.Decode([]byte(g.certificates.PKI.ClientCA)) - if clientCABlock == nil { - return nil, fmt.Errorf("failed to decode client CA certificate") + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + var chainCerts [][]byte + chainData := []byte(g.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest } - clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse client CA certificate: %v", err) + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - // Create certificate pool for client CAs - clientCAPool := x509.NewCertPool() - clientCAPool.AddCert(clientCA) - // Create TLS config return &tls.Config{ Certificates: []tls.Certificate{ diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index b99a8f31..f57aaea2 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -365,7 +365,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { var gatewayId string - // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { log.Printf("TLS connection detected, forcing handshake...") err := tlsConn.Handshake() @@ -396,8 +395,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { return } - fmt.Println("gatewayId", gatewayId) - // TODO: extract these from the certificate targetHost := "localhost" targetPort := uint32(22) From 74db2f340c5bea456baa51d0f16dc0e206ad3b11 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sun, 31 Aug 2025 23:25:42 +0800 Subject: [PATCH 09/31] misc: set target host of proxy to gateway --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index f57aaea2..279dc2c8 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -396,7 +396,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } // TODO: extract these from the certificate - targetHost := "localhost" + targetHost := "gateway" targetPort := uint32(22) // Get the SSH connection for this agent From 36c069dbeef68754f5438c85b0b8ab5622f4630d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Mon, 1 Sep 2025 22:20:36 +0800 Subject: [PATCH 10/31] feat: added TCP and HTTP forward handling to gateway --- packages/cmd/network.go | 16 +- packages/gateway-v2/connection.go | 143 ++++++++++++++++ packages/gateway-v2/constants.go | 22 +++ packages/gateway-v2/gateway.go | 271 +++++++++++++++++++++--------- packages/proxy/proxy.go | 1 - 5 files changed, 360 insertions(+), 93 deletions(-) create mode 100644 packages/gateway-v2/connection.go create mode 100644 packages/gateway-v2/constants.go diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 8ec9c6b3..2d6d67f4 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -57,9 +57,9 @@ var networkProxyCmd = &cobra.Command{ } if instanceType == "instance" { - proxyAuthSecret := os.Getenv("PROXY_AUTH_SECRET") + proxyAuthSecret := os.Getenv(gatewayv2.PROXY_AUTH_SECRET_ENV_NAME) if proxyAuthSecret == "" { - util.HandleError(fmt.Errorf("PROXY_AUTH_SECRET is not set"), "unable to get proxy auth secret") + util.HandleError(fmt.Errorf("%s is not set", gatewayv2.PROXY_AUTH_SECRET_ENV_NAME), "unable to get proxy auth secret") } proxyInstance.SetToken(proxyAuthSecret) @@ -145,14 +145,14 @@ var networkGatewayCmd = &cobra.Command{ Long: "Run the Infisical gateway component", Run: func(cmd *cobra.Command, args []string) { - proxyName, err := cmd.Flags().GetString("proxy-name") - if err != nil || proxyName == "" { - util.HandleError(err, "unable to get proxy-name flag") + proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get proxy-name flag or %s env", gatewayv2.PROXY_NAME_ENV_NAME)) } - gatewayName, err := cmd.Flags().GetString("name") - if err != nil || gatewayName == "" { - util.HandleError(err, "unable to get name flag") + gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) } gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go new file mode 100644 index 00000000..ad521392 --- /dev/null +++ b/packages/gateway-v2/connection.go @@ -0,0 +1,143 @@ +package gatewayv2 + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +func buildHttpInternalServerError(message string) string { + return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) +} + +func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caCert []byte, verifyTLS bool) error { + transport := &http.Transport{ + DisableKeepAlives: false, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + } + + if strings.HasPrefix(targetURL, "https://") { + tlsConfig := &tls.Config{} + + if len(caCert) > 0 { + caCertPool := x509.NewCertPool() + if caCertPool.AppendCertsFromPEM(caCert) { + tlsConfig.RootCAs = caCertPool + log.Info().Msg("Using provided CA certificate from gateway client") + } else { + log.Error().Msg("Failed to parse provided CA certificate") + } + } + + tlsConfig.InsecureSkipVerify = !verifyTLS + log.Info().Msgf("TLS verification set to: %v", verifyTLS) + + transport.TLSClientConfig = tlsConfig + } + + // Loop to handle multiple HTTP requests on the same connection + for { + log.Info().Msg("Attempting to read HTTP request...") + req, err := http.ReadRequest(reader) + + if err != nil { + if errors.Is(err, io.EOF) { + log.Info().Msg("Client closed HTTP connection") + return nil + } + + log.Error().Msgf("Failed to read HTTP request: %v", err) + return fmt.Errorf("failed to read HTTP request: %v", err) + } + log.Info().Msgf("Received HTTP request: %s", req.URL.Path) + + // Build full target URL + var targetFullURL string + if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { + baseURL := strings.TrimSuffix(targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } else { + baseURL := strings.TrimSuffix("http://"+targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } + + // create the request to the target + proxyReq, err := http.NewRequest(req.Method, targetFullURL, req.Body) + if err != nil { + log.Error().Msgf("Failed to create proxy request: %v", err) + conn.Write([]byte(buildHttpInternalServerError("failed to create proxy request"))) + continue // Continue to next request + } + proxyReq.Header = req.Header.Clone() + + log.Info().Msgf("Proxying %s %s to %s", req.Method, req.URL.Path, targetFullURL) + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + resp, err := client.Do(proxyReq) + if err != nil { + log.Error().Msgf("Failed to reach target: %v", err) + conn.Write([]byte(buildHttpInternalServerError(fmt.Sprintf("failed to reach target due to networking error: %s", err.Error())))) + continue // Continue to next request + } + + // Write the entire response (status line, headers, body) to the connection + resp.Header.Del("Connection") + + log.Info().Msgf("Writing response to connection: %s", resp.Status) + + if err := resp.Write(conn); err != nil { + log.Error().Err(err).Msg("Failed to write response to connection") + resp.Body.Close() + return fmt.Errorf("failed to write response to connection: %w", err) + } + + resp.Body.Close() + + // Check if client wants to close connection + if req.Header.Get("Connection") == "close" { + log.Info().Msg("Client requested connection close") + return nil + } + } +} + +func handleTCPProxy(conn *tls.Conn, target string) error { + localConn, err := net.Dial("tcp", target) + if err != nil { + log.Error().Msgf("Failed to connect to local service %s: %v", target, err) + return fmt.Errorf("failed to connect to local service %s: %v", target, err) + } + defer localConn.Close() + + // Create bidirectional tunnel with TLS + // Forward data from TLS connection to local service + go func() { + io.Copy(localConn, conn) + localConn.Close() + }() + + // Forward data from local service to TLS connection + io.Copy(conn, localConn) + + return nil +} diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go new file mode 100644 index 00000000..996068d6 --- /dev/null +++ b/packages/gateway-v2/constants.go @@ -0,0 +1,22 @@ +package gatewayv2 + +const ( + KUBERNETES_SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" + KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME = "KUBERNETES_SERVICE_PORT_HTTPS" + KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" + + PROXY_NAME_ENV_NAME = "PROXY_NAME" + GATEWAY_NAME_ENV_NAME = "GATEWAY_NAME" + + PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" + + INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" +) + +type HttpProxyAction string + +const ( + HttpProxyActionInjectGatewayK8sServiceAccountToken HttpProxyAction = "inject-k8s-sa-auth-token" + HttpProxyActionUseGatewayK8sServiceAccount HttpProxyAction = "use-k8s-sa" +) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 12ad4acd..85f7b4ab 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -1,25 +1,52 @@ package gatewayv2 import ( + "bufio" "bytes" "context" "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" "fmt" - "io" - "log" "net" + "strconv" + "strings" "sync" "time" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) +// ForwardMode represents the type of forwarding +type ForwardMode string + +const ( + ForwardModeHTTP ForwardMode = "HTTP" + ForwardModeTCP ForwardMode = "TCP" +) + +// ForwardConfig contains the configuration for forwarding +type ForwardConfig struct { + Mode ForwardMode + CACertificate []byte // Decoded CA certificate for HTTPS verification + VerifyTLS bool // Whether to verify TLS certificates + TargetHost string + TargetPort int +} + +// RoutingInfo represents the routing information embedded in client certificates +type RoutingInfo struct { + TargetHost string `json:"targetHost"` + TargetPort int `json:"targetPort"` +} + type GatewayConfig struct { Name string ProxyName string @@ -76,11 +103,11 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { // Change the Start method to accept a context func (g *Gateway) Start(ctx context.Context) error { - log.Printf("Starting gateway") + log.Info().Msgf("Starting gateway") for { select { case <-ctx.Done(): - log.Printf("Gateway stopped by context cancellation") + log.Info().Msgf("Gateway stopped by context cancellation") return nil default: if err := g.connectAndServe(); err != nil { @@ -93,7 +120,7 @@ func (g *Gateway) Start(ctx context.Context) error { } } // If we get here, the connection was closed gracefully - log.Printf("Connection closed, reconnecting in 10 seconds...") + log.Info().Msgf("Connection closed, reconnecting in 10 seconds...") select { case <-ctx.Done(): return ctx.Err() @@ -132,7 +159,7 @@ func (g *Gateway) connectAndServe() error { } // Connect to Proxy server - log.Printf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + log.Info().Msgf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) @@ -151,7 +178,7 @@ func (g *Gateway) connectAndServe() error { client.Close() }() - log.Printf("SSH connection established for gateway") + log.Info().Msgf("SSH connection established for gateway") // Handle incoming channels from the server channels := client.HandleChannelOpen("direct-tcpip") @@ -180,7 +207,57 @@ func (g *Gateway) registerGateway() error { g.GatewayID = certResp.GatewayID g.certificates = &certResp - log.Printf("Successfully registered gateway and received certificates") + log.Info().Msgf("Successfully registered gateway and received certificates") + + // Create mTLS config once during registration + serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + clientCAPool := x509.NewCertPool() + var chainCerts [][]byte + chainData := []byte(g.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Info().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + } + + g.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + return nil } @@ -232,10 +309,8 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { } func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { - // Parse CA public key once when creating the callback caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) if err != nil { - // Return a callback that always fails since we can't parse the CA key return func(hostname string, remote net.Addr, key ssh.PublicKey) error { return fmt.Errorf("failed to parse CA public key: %v", err) } @@ -262,7 +337,7 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string return fmt.Errorf("host certificate check failed: %v", err) } - log.Printf("Host certificate validated successfully for %s", hostname) + log.Info().Msgf("Host certificate validated successfully for %s", hostname) return nil } @@ -275,32 +350,24 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { - log.Printf("Failed to parse channel request: %v", err) + log.Info().Msgf("Failed to parse channel request: %v", err) newChannel.Reject(ssh.Prohibited, "invalid request") return } - log.Printf("Incoming connection request to %s:%d from %s:%d", - req.Host, req.Port, req.OriginHost, req.OriginPort) - - // Accept the channel channel, requests, err := newChannel.Accept() if err != nil { - log.Printf("Failed to accept channel: %v", err) + log.Info().Msgf("Failed to accept channel: %v", err) return } defer channel.Close() go ssh.DiscardRequests(requests) - // Determine the target address - target := fmt.Sprintf("%s:%d", req.Host, req.Port) - log.Printf("Creating TCP tunnel to: %s", target) - // Create mTLS server configuration - tlsConfig, err := g.createMTLSConfig() - if err != nil { - log.Printf("Failed to create mTLS config: %v", err) + tlsConfig := g.tlsConfig + if tlsConfig == nil { + log.Info().Msgf("TLS config not initialized, cannot create mTLS server") return } @@ -314,88 +381,124 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { // Perform TLS handshake if err := tlsConn.Handshake(); err != nil { - log.Printf("TLS handshake failed: %v", err) + log.Info().Msgf("TLS handshake failed: %v", err) return } - log.Printf("mTLS connection established with client: %s", tlsConn.ConnectionState().ServerName) + log.Info().Msgf("mTLS connection established with client") - // Connect to local service - localConn, err := net.Dial("tcp", target) + // Create reader for the TLS connection + reader := bufio.NewReader(tlsConn) + + // Get the forward mode here + forwardConfig, err := g.parseForwardConfig(tlsConn, reader) if err != nil { - log.Printf("Failed to connect to local service %s: %v", target, err) + log.Info().Msgf("Failed to parse forward command: %v", err) return } - defer localConn.Close() - log.Printf("TCP tunnel established to %s", target) + // Use target from certificate + target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + log.Info().Msgf("Using target from certificate: %s", target) - // Create bidirectional tunnel with TLS - // Forward data from TLS connection to local service - go func() { - io.Copy(localConn, tlsConn) - localConn.Close() - log.Printf("TLS -> local service tunnel closed") - }() - - // Forward data from local service to TLS connection - io.Copy(tlsConn, localConn) - log.Printf("Local service -> TLS tunnel closed") + if forwardConfig.Mode == ForwardModeHTTP { + handleHTTPProxy(tlsConn, reader, target, forwardConfig.CACertificate, forwardConfig.VerifyTLS) + return + } else if forwardConfig.Mode == ForwardModeTCP { + handleTCPProxy(tlsConn, target) + return + } } -func (g *Gateway) createMTLSConfig() (*tls.Config, error) { - // Parse server certificate - serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) - if serverCertBlock == nil { - return nil, fmt.Errorf("failed to decode server certificate") - } +func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { + config := &ForwardConfig{} - // Parse server private key - serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) - if serverKeyBlock == nil { - return nil, fmt.Errorf("failed to decode server private key") + if err := g.parseRoutingInfoFromCertificate(tlsConn, config); err != nil { + return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } - serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse server private key: %v", err) + for { + msg, err := reader.ReadBytes('\n') + if err != nil { + return nil, fmt.Errorf("failed to read command: %v", err) + } + + cmd := strings.ToUpper(strings.TrimSpace(string(strings.Split(string(msg), " ")[0]))) + args := strings.TrimSpace(strings.TrimPrefix(string(msg), strings.Split(string(msg), " ")[0])) + + switch cmd { + case "FORWARD-TCP": + config.Mode = ForwardModeTCP + return config, nil + + case "FORWARD-HTTP": + config.Mode = ForwardModeHTTP + if args != "" { + if err := g.parseForwardHTTPParams(args, config); err != nil { + return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) + } + } + + return config, nil + + default: + return nil, fmt.Errorf("invalid forward command: %s", cmd) + } } +} - // Create certificate pool for client CAs - clientCAPool := x509.NewCertPool() - var chainCerts [][]byte - chainData := []byte(g.certificates.PKI.ClientCertificateChain) - for { - block, rest := pem.Decode(chainData) - if block == nil { - break +func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) error { + parts := strings.Fields(params) + + for _, part := range parts { + if strings.HasPrefix(part, "ca=") { + caB64 := strings.TrimPrefix(part, "ca=") + caCert, err := base64.StdEncoding.DecodeString(caB64) + if err != nil { + return fmt.Errorf("invalid base64 CA certificate: %v", err) + } + config.CACertificate = caCert + } else if strings.HasPrefix(part, "verify=") { + verifyStr := strings.TrimPrefix(part, "verify=") + verify, err := strconv.ParseBool(verifyStr) + if err != nil { + return fmt.Errorf("invalid verify parameter: %s", verifyStr) + } + config.VerifyTLS = verify } - chainCerts = append(chainCerts, block.Bytes) - chainData = rest } - for i, certBytes := range chainCerts { - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) - continue + return nil +} + +// parseRoutingInfoFromCertificate extracts target host and port from client certificate custom extension +func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { + const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" + + // Get the peer certificates + state := tlsConn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return fmt.Errorf("no peer certificates found") + } + + clientCert := state.PeerCertificates[0] + + // Look for the routing extension + for _, ext := range clientCert.Extensions { + if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { + var routingInfo RoutingInfo + if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { + return fmt.Errorf("failed to parse routing info JSON: %v", err) + } + + config.TargetHost = routingInfo.TargetHost + config.TargetPort = routingInfo.TargetPort + + return nil } - clientCAPool.AddCert(cert) - log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - // Create TLS config - return &tls.Config{ - Certificates: []tls.Certificate{ - { - Certificate: [][]byte{serverCertBlock.Bytes}, - PrivateKey: serverKey, - }, - }, - ClientCAs: clientCAPool, - ClientAuth: tls.RequireAndVerifyClientCert, - MinVersion: tls.VersionTLS12, - }, nil + return fmt.Errorf("routing extension with OID %s not found in client certificate", GATEWAY_ROUTING_INFO_OID) } // virtualConnection implements net.Conn to bridge SSH channel and TLS diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 279dc2c8..7ae7589f 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -374,7 +374,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } state := tlsConn.ConnectionState() - log.Printf("TLS handshake completed, peer certificates count: %d", len(state.PeerCertificates)) if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] From f7ed054857d4f203818197718c426b06d792f01d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 01:01:48 +0800 Subject: [PATCH 11/31] feat: added auth injection for k8 and platform checks --- packages/gateway-v2/connection.go | 61 +++++++++++++++++++++++++++- packages/gateway-v2/gateway.go | 46 ++++++++++++++------- packages/proxy/proxy.go | 66 ++++++++++++++----------------- 3 files changed, 119 insertions(+), 54 deletions(-) diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go index ad521392..68517870 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/gateway-v2/connection.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "os" "strings" "time" @@ -19,7 +20,11 @@ func buildHttpInternalServerError(message string) string { return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) } -func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caCert []byte, verifyTLS bool) error { +func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { + targetURL := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + caCert := forwardConfig.CACertificate + verifyTLS := forwardConfig.VerifyTLS + transport := &http.Transport{ DisableKeepAlives: false, MaxIdleConns: 10, @@ -61,6 +66,57 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caC } log.Info().Msgf("Received HTTP request: %s", req.URL.Path) + actionHeader := HttpProxyAction(req.Header.Get(INFISICAL_HTTP_PROXY_ACTION_HEADER)) + + // Only platform actor can perform privileged actions + if actionHeader != "" && forwardConfig.ActorType == ActorTypePlatform { + if actionHeader == HttpProxyActionInjectGatewayK8sServiceAccountToken { + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue // Continue to next request instead of returning + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) + } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { // will work without a target URL set + // set the ca cert to the pod's k8s service account ca cert: + caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa ca cert"))) + continue + } + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(caCert); !ok { + conn.Write([]byte(buildHttpInternalServerError("failed to parse k8s sa ca cert"))) + continue + } + + transport.TLSClientConfig = &tls.Config{ + RootCAs: caCertPool, + } + + // set authorization header to the pod's k8s service account token: + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + + // update the target URL to point to the kubernetes API server: + kubernetesServiceHost := os.Getenv(KUBERNETES_SERVICE_HOST_ENV_NAME) + kubernetesServicePort := os.Getenv(KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME) + + fullBaseUrl := fmt.Sprintf("https://%s:%s", kubernetesServiceHost, kubernetesServicePort) + targetURL = fullBaseUrl + + log.Info().Msgf("Redirected request to Kubernetes API server: %s", targetURL) + } + + req.Header.Del(INFISICAL_HTTP_PROXY_ACTION_HEADER) + } + // Build full target URL var targetFullURL string if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { @@ -121,7 +177,8 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caC } } -func handleTCPProxy(conn *tls.Conn, target string) error { +func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { + target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) localConn, err := net.Dial("tcp", target) if err != nil { log.Error().Msgf("Failed to connect to local service %s: %v", target, err) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 85f7b4ab..881876b9 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -32,6 +32,16 @@ const ( ForwardModeTCP ForwardMode = "TCP" ) +type ActorType string + +const ( + ActorTypePlatform ActorType = "platform" + ActorTypeUser ActorType = "user" +) + +const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" +const GATEWAY_ACTOR_OID = "1.3.6.1.4.1.12345.100.2" + // ForwardConfig contains the configuration for forwarding type ForwardConfig struct { Mode ForwardMode @@ -39,6 +49,7 @@ type ForwardConfig struct { VerifyTLS bool // Whether to verify TLS certificates TargetHost string TargetPort int + ActorType ActorType } // RoutingInfo represents the routing information embedded in client certificates @@ -47,6 +58,10 @@ type RoutingInfo struct { TargetPort int `json:"targetPort"` } +type ActorDetails struct { + Type string `json:"type"` +} + type GatewayConfig struct { Name string ProxyName string @@ -111,7 +126,7 @@ func (g *Gateway) Start(ctx context.Context) error { return nil default: if err := g.connectAndServe(); err != nil { - log.Printf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + log.Error().Msgf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) select { case <-ctx.Done(): return ctx.Err() @@ -397,15 +412,13 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } - // Use target from certificate - target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) - log.Info().Msgf("Using target from certificate: %s", target) + log.Info().Msgf("Forward config: %+v", forwardConfig) if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(tlsConn, reader, target, forwardConfig.CACertificate, forwardConfig.VerifyTLS) + handleHTTPProxy(tlsConn, reader, forwardConfig) return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(tlsConn, target) + handleTCPProxy(tlsConn, forwardConfig) return } } @@ -413,7 +426,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { config := &ForwardConfig{} - if err := g.parseRoutingInfoFromCertificate(tlsConn, config); err != nil { + if err := g.parseDetailsFromCertificate(tlsConn, config); err != nil { return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } @@ -471,10 +484,7 @@ func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) e return nil } -// parseRoutingInfoFromCertificate extracts target host and port from client certificate custom extension -func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { - const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" - +func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { // Get the peer certificates state := tlsConn.ConnectionState() if len(state.PeerCertificates) == 0 { @@ -483,8 +493,8 @@ func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *For clientCert := state.PeerCertificates[0] - // Look for the routing extension for _, ext := range clientCert.Extensions { + // Extract target host and port from client certificate custom extension if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { var routingInfo RoutingInfo if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { @@ -493,12 +503,18 @@ func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *For config.TargetHost = routingInfo.TargetHost config.TargetPort = routingInfo.TargetPort - - return nil + } + // Extract actor type from client certificate custom extension + if ext.Id.String() == GATEWAY_ACTOR_OID { + var actorDetails ActorDetails + if err := json.Unmarshal(ext.Value, &actorDetails); err != nil { + return fmt.Errorf("failed to parse actor details JSON: %v", err) + } + config.ActorType = ActorType(actorDetails.Type) } } - return fmt.Errorf("routing extension with OID %s not found in client certificate", GATEWAY_ROUTING_INFO_OID) + return nil } // virtualConnection implements net.Conn to bridge SSH channel and TLS diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 7ae7589f..1b050a6d 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -9,14 +9,13 @@ import ( "encoding/pem" "fmt" "io" - "log" "net" - "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) @@ -101,7 +100,7 @@ func (p *Proxy) Start(ctx context.Context) error { // Start TLS server go p.startTLSServer() - log.Printf("Proxy server started successfully") + log.Info().Msg("Proxy server started successfully") // Wait for context cancellation <-ctx.Done() @@ -131,7 +130,7 @@ func (p *Proxy) registerProxy() error { p.certificates = &certResp } - log.Printf("Successfully registered proxy and received certificates from API") + log.Info().Msg("Successfully registered proxy and received certificates from API") return nil } @@ -166,13 +165,13 @@ func (p *Proxy) setupSSHServer() error { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) if !ok { - log.Printf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + log.Warn().Msgf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) return nil, fmt.Errorf("certificates required, raw public keys not allowed") } // Validate the certificate if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { - log.Printf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + log.Error().Msgf("Gateway '%s' certificate validation failed: %v", conn.User(), err) return nil, err } @@ -239,11 +238,10 @@ func (p *Proxy) setupTLSServer() error { for i, certBytes := range chainCerts { cert, err := x509.ParseCertificate(certBytes) if err != nil { - log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + log.Error().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) continue } clientCAPool.AddCert(cert) - log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } // Create TLS config @@ -268,7 +266,7 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("invalid certificate type: %d", cert.CertType) } - // Check if certificate is signed by our CA + // Check if certificate is signed expected CA checker := &ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) @@ -280,23 +278,23 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("certificate check failed: %v", err) } - log.Printf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) + log.Debug().Msgf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) return nil } func (p *Proxy) startSSHServer() { listener, err := net.Listen("tcp", ":"+p.config.SSHPort) if err != nil { - log.Fatalf("Failed to start SSH server: %v", err) + log.Fatal().Msgf("Failed to start SSH server: %v", err) } p.sshListener = listener - log.Printf("SSH server listening on :%s for gateways", p.config.SSHPort) + log.Info().Msgf("SSH server listening on :%s for gateways", p.config.SSHPort) for { conn, err := listener.Accept() if err != nil { - log.Printf("Failed to accept SSH connection: %v", err) + log.Error().Msgf("Failed to accept SSH connection: %v", err) continue } go p.handleSSHAgent(conn) @@ -309,12 +307,12 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // SSH handshake sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) if err != nil { - log.Printf("SSH handshake failed: %v", err) + log.Error().Msgf("SSH handshake failed: %v", err) return } gatewayId := sshConn.Permissions.Extensions["gateway-id"] - log.Printf("SSH handshake successful for gateway: %s", gatewayId) + log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) // Store the connection p.mu.Lock() @@ -326,7 +324,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { p.mu.Lock() delete(p.tunnels, gatewayId) p.mu.Unlock() - log.Printf("Gateway %s disconnected", gatewayId) + log.Info().Msgf("Gateway %s disconnected", gatewayId) }() for newChannel := range chans { @@ -344,16 +342,16 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { func (p *Proxy) startTLSServer() { listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) if err != nil { - log.Fatalf("Failed to start TLS server: %v", err) + log.Fatal().Msgf("Failed to start TLS server: %v", err) } p.tlsListener = listener - log.Printf("TLS server listening on :%s for clients", p.config.TLSPort) + log.Info().Msgf("TLS server listening on :%s for clients", p.config.TLSPort) for { conn, err := listener.Accept() if err != nil { - log.Printf("Failed to accept TLS connection: %v", err) + log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } go p.handleClient(conn) @@ -366,10 +364,10 @@ func (p *Proxy) handleClient(clientConn net.Conn) { var gatewayId string if tlsConn, ok := clientConn.(*tls.Conn); ok { - log.Printf("TLS connection detected, forcing handshake...") + log.Debug().Msg("TLS connection detected, forcing handshake...") err := tlsConn.Handshake() if err != nil { - log.Printf("TLS handshake failed: %v", err) + log.Error().Msgf("TLS handshake failed: %v", err) return } @@ -377,20 +375,14 @@ func (p *Proxy) handleClient(clientConn net.Conn) { if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] - log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - parts := strings.Split(cert.Subject.CommonName, ":") - if len(parts) >= 2 { - gatewayId = parts[1] - } else { - log.Printf("Invalid CommonName format, expected 'part1:part2', got: %s", cert.Subject.CommonName) - return - } + log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } else { - log.Printf("No peer certificates found") + log.Warn().Msg("No peer certificates found") return } } else { - log.Printf("Not a TLS connection, connection type: %T", clientConn) + log.Error().Msgf("Not a TLS connection, connection type: %T", clientConn) return } @@ -404,12 +396,12 @@ func (p *Proxy) handleClient(clientConn net.Conn) { p.mu.RUnlock() if !exists { - log.Printf("Gateway '%s' not connected", gatewayId) + log.Warn().Msgf("Gateway '%s' not connected", gatewayId) clientConn.Write([]byte("ERROR: Gateway not connected\n")) return } - log.Printf("Routing TCP connection to gateway: %s", gatewayId) + log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) // Open SSH channel to connect to agent's local service through the tunnel payload := struct { @@ -421,7 +413,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) if err != nil { - log.Printf("Failed to connect to agent: %v", err) + log.Error().Msgf("Failed to connect to agent: %v", err) clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) return } @@ -434,11 +426,11 @@ func (p *Proxy) handleClient(clientConn net.Conn) { }() io.Copy(clientConn, channel) - log.Printf("Client %s disconnected", clientConn.RemoteAddr()) + log.Info().Msgf("Client %s disconnected", clientConn.RemoteAddr()) } func (p *Proxy) cleanup() { - log.Printf("Shutting down proxy server...") + log.Info().Msg("Shutting down proxy server...") if p.sshListener != nil { p.sshListener.Close() @@ -447,5 +439,5 @@ func (p *Proxy) cleanup() { p.tlsListener.Close() } - log.Printf("Proxy server shutdown complete") + log.Info().Msg("Proxy server shutdown complete") } From dc7a438f40ca56ebcb9f0dcedb924b504499f4a4 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 21:49:11 +0800 Subject: [PATCH 12/31] feat: added heartbeat --- packages/api/api.go | 18 ++++++ packages/cmd/network.go | 26 ++++++--- packages/gateway-v2/connection.go | 89 +++++++++++++++++++++++++---- packages/gateway-v2/gateway.go | 95 ++++++++++++++++++++++++++----- 4 files changed, 197 insertions(+), 31 deletions(-) diff --git a/packages/api/api.go b/packages/api/api.go index 7d61eb8a..e20e6daf 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -39,6 +39,7 @@ const ( operationCallRegisterGatewayIdentityV1 = "CallRegisterGatewayIdentityV1" operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" + operationCallGatewayHeartBeatV2 = "CallGatewayHeartBeatV2" operationCallBootstrapInstance = "CallBootstrapInstance" operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" operationCallRegisterOrgProxy = "CallRegisterOrgProxy" @@ -655,6 +656,23 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } +func CallGatewayHeartBeatV2(httpClient *resty.Client) error { + response, err := httpClient. + R(). + SetHeader("User-Agent", USER_AGENT). + Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) + + if err != nil { + return NewGenericRequestError(operationCallGatewayHeartBeatV2, err) + } + + if response.IsError() { + return NewAPIErrorWithResponse(operationCallGatewayHeartBeatV2, response, nil) + } + + return nil +} + func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRequest) (BootstrapInstanceResponse, error) { var resBody BootstrapInstanceResponse response, err := httpClient. diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 2d6d67f4..2bf4d063 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -91,10 +91,15 @@ var networkProxyCmd = &cobra.Command{ cancelCmd() cancelSdk() - // If we get a second signal, force exit - <-sigCh - log.Warn().Msgf("Force exit triggered") - os.Exit(1) + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } }() // Token refresh goroutine - runs every 10 seconds @@ -192,10 +197,15 @@ var networkGatewayCmd = &cobra.Command{ cancelCmd() cancelSdk() - // If we get a second signal, force exit - <-sigCh - log.Warn().Msgf("Force exit triggered") - os.Exit(1) + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } }() // Token refresh goroutine - runs every 10 seconds diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go index 68517870..141681f8 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/gateway-v2/connection.go @@ -2,6 +2,7 @@ package gatewayv2 import ( "bufio" + "context" "crypto/tls" "crypto/x509" "errors" @@ -20,7 +21,7 @@ func buildHttpInternalServerError(message string) string { return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) } -func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { +func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { targetURL := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) caCert := forwardConfig.CACertificate verifyTLS := forwardConfig.VerifyTLS @@ -52,18 +53,45 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar // Loop to handle multiple HTTP requests on the same connection for { + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled, closing HTTP proxy connection") + return ctx.Err() + default: + } + log.Info().Msg("Attempting to read HTTP request...") - req, err := http.ReadRequest(reader) - if err != nil { + // Create a channel to receive the request or error + reqCh := make(chan *http.Request, 1) + errCh := make(chan error, 1) + + // Read request in a goroutine so we can cancel it + go func() { + req, err := http.ReadRequest(reader) + if err != nil { + errCh <- err + } else { + reqCh <- req + } + }() + + var req *http.Request + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled while reading HTTP request") + return ctx.Err() + case err := <-errCh: if errors.Is(err, io.EOF) { log.Info().Msg("Client closed HTTP connection") return nil } - log.Error().Msgf("Failed to read HTTP request: %v", err) return fmt.Errorf("failed to read HTTP request: %v", err) + case req = <-reqCh: + // Successfully received request } + log.Info().Msgf("Received HTTP request: %s", req.URL.Path) actionHeader := HttpProxyAction(req.Header.Get(INFISICAL_HTTP_PROXY_ACTION_HEADER)) @@ -78,7 +106,8 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) - } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { // will work without a target URL set + } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { + // will work without a target URL set // set the ca cert to the pod's k8s service account ca cert: caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) if err != nil { @@ -177,7 +206,7 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar } } -func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { +func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardConfig) error { target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) localConn, err := net.Dial("tcp", target) if err != nil { @@ -186,15 +215,55 @@ func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { } defer localConn.Close() - // Create bidirectional tunnel with TLS + // Create a context for this connection that gets cancelled when the parent context is cancelled + // or when either connection closes + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Error channel to collect errors from both copy goroutines + errCh := make(chan error, 2) + // Forward data from TLS connection to local service go func() { - io.Copy(localConn, conn) - localConn.Close() + defer cancel() + _, err := io.Copy(localConn, conn) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("TLS to local copy ended normally: %v", err) + } else { + log.Error().Msgf("TLS to local copy failed: %v", err) + } + } + errCh <- err }() // Forward data from local service to TLS connection - io.Copy(conn, localConn) + go func() { + defer cancel() + _, err := io.Copy(conn, localConn) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("Local to TLS copy ended normally: %v", err) + } else { + log.Error().Msgf("Local to TLS copy failed: %v", err) + } + } + errCh <- err + }() + + // Wait for either context cancellation or one of the copy operations to complete + select { + case <-connCtx.Done(): + log.Info().Msg("TCP proxy connection cancelled") + return connCtx.Err() + case err := <-errCh: + // One of the copy operations completed (or failed) + // The defer cancel() will stop the other goroutine + return err + } +} +func handlePing(ctx context.Context, conn *tls.Conn, reader *bufio.Reader) error { + conn.Write([]byte("PONG\n")) return nil } diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 881876b9..d7c2c61b 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -30,6 +30,7 @@ type ForwardMode string const ( ForwardModeHTTP ForwardMode = "HTTP" ForwardModeTCP ForwardMode = "TCP" + ForwardModePing ForwardMode = "PING" ) type ActorType string @@ -116,9 +117,59 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { }, nil } -// Change the Start method to accept a context +func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { + sendHeartbeat := func() { + if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + log.Warn().Msgf("Heartbeat failed: %v", err) + select { + case errCh <- err: + default: + log.Warn().Msg("Error channel full, skipping heartbeat error report") + } + } else { + log.Info().Msg("Gateway is reachable by Infisical") + } + } + + go func() { + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Second): + sendHeartbeat() + } + + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sendHeartbeat() + } + } + }() +} + func (g *Gateway) Start(ctx context.Context) error { log.Info().Msgf("Starting gateway") + + errCh := make(chan error, 1) + g.registerHeartBeat(ctx, errCh) + + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-errCh: + log.Warn().Msgf("Heartbeat error received: %v", err) + } + } + }() + for { select { case <-ctx.Done(): @@ -179,6 +230,7 @@ func (g *Gateway) connectAndServe() error { if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } + log.Info().Msgf("SSH connection established for gateway") g.mu.Lock() g.sshClient = client @@ -193,20 +245,33 @@ func (g *Gateway) connectAndServe() error { client.Close() }() - log.Info().Msgf("SSH connection established for gateway") - // Handle incoming channels from the server channels := client.HandleChannelOpen("direct-tcpip") if channels == nil { return fmt.Errorf("failed to handle channel open") } - // Process incoming channels - for newChannel := range channels { - go g.handleIncomingChannel(newChannel) - } + // Monitor for context cancellation and close SSH client + go func() { + <-g.ctx.Done() + log.Info().Msg("Context cancelled, closing SSH connection...") + client.Close() + }() - return nil // Connection closed + // Process incoming channels with context cancellation support + for { + select { + case <-g.ctx.Done(): + log.Info().Msg("Context cancelled, stopping channel processing") + return g.ctx.Err() + case newChannel, ok := <-channels: + if !ok { + log.Info().Msg("SSH channels closed") + return nil + } + go g.handleIncomingChannel(newChannel) + } + } } func (g *Gateway) registerGateway() error { @@ -352,7 +417,6 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string return fmt.Errorf("host certificate check failed: %v", err) } - log.Info().Msgf("Host certificate validated successfully for %s", hostname) return nil } @@ -400,8 +464,6 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } - log.Info().Msgf("mTLS connection established with client") - // Create reader for the TLS connection reader := bufio.NewReader(tlsConn) @@ -415,10 +477,13 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Info().Msgf("Forward config: %+v", forwardConfig) if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(tlsConn, reader, forwardConfig) + handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig) return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(tlsConn, forwardConfig) + handleTCPProxy(g.ctx, tlsConn, forwardConfig) + return + } else if forwardConfig.Mode == ForwardModePing { + handlePing(g.ctx, tlsConn, reader) return } } @@ -454,6 +519,10 @@ func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (* return config, nil + case "PING": + config.Mode = ForwardModePing + return config, nil + default: return nil, fmt.Errorf("invalid forward command: %s", cmd) } From 2dbb176e4a45e9b6ebfa755ff2fee7b6aff5c4ad Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 23:54:55 +0800 Subject: [PATCH 13/31] feat: added systemd support --- packages/cmd/network.go | 105 +++++++++++++++++++++---- packages/gateway-v2/constants.go | 4 +- packages/gateway-v2/systemd.go | 128 +++++++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 18 deletions(-) create mode 100644 packages/gateway-v2/systemd.go diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 2bf4d063..cbd43ca4 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/signal" + "runtime" "sync/atomic" "syscall" "time" @@ -24,9 +25,12 @@ var networkCmd = &cobra.Command{ } var networkProxyCmd = &cobra.Command{ - Use: "proxy", - Short: "Run the Infisical proxy component", - Long: "Run the Infisical proxy component", + Use: "proxy", + Short: "Run the Infisical proxy component", + Long: "Run the Infisical proxy component", + Example: "infisical network proxy --type=instance --ip= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { proxyName, err := cmd.Flags().GetString("name") @@ -135,19 +139,13 @@ var networkProxyCmd = &cobra.Command{ }, } -var networkProxyInstallCmd = &cobra.Command{ - Use: "proxy install", - Short: "Install and enable systemd service for the proxy (requires sudo)", - Long: "Install and enable systemd service for the proxy. Must be run with sudo on Linux.", - Run: func(cmd *cobra.Command, args []string) { - // TODO: Implement this - }, -} - var networkGatewayCmd = &cobra.Command{ - Use: "gateway", - Short: "Run the Infisical gateway component", - Long: "Run the Infisical gateway component", + Use: "gateway", + Short: "Run the Infisical gateway component", + Long: "Run the Infisical gateway component. Use 'network gateway install' to set up the systemd service.", + Example: "infisical network gateway --proxy-name= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) @@ -240,6 +238,75 @@ var networkGatewayCmd = &cobra.Command{ }, } +var networkGatewayInstallCmd = &cobra.Command{ + Use: "install", + Short: "Install and enable systemd service for the gateway (requires sudo)", + Long: "Install and enable systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical network gateway install --token= --domain= --name= --proxy-name=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + token, err := util.GetInfisicalToken(cmd) + if err != nil { + util.HandleError(err, "Unable to parse flag") + } + + if token == nil { + util.HandleError(errors.New("Token not found")) + } + + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil { + util.HandleError(err, "Unable to parse name flag") + } + + proxyName, err := cmd.Flags().GetString("proxy-name") + if err != nil { + util.HandleError(err, "Unable to parse proxy-name flag") + } + + err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) + if err != nil { + util.HandleError(err, "Unable to install systemd service") + } + }, +} + +var networkGatewayUninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall and remove systemd service for the gateway (requires sudo)", + Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical network gateway uninstall", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { + util.HandleError(err, "Failed to uninstall systemd service") + } + }, +} + func init() { networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") networkGatewayCmd.Flags().String("name", "", "The name of the gateway") @@ -264,7 +331,13 @@ func init() { networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - networkProxyCmd.AddCommand(networkProxyInstallCmd) + networkGatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") + networkGatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + networkGatewayInstallCmd.Flags().String("name", "", "The name of the gateway") + networkGatewayInstallCmd.Flags().String("proxy-name", "", "The name of the proxy") + + networkGatewayCmd.AddCommand(networkGatewayInstallCmd) + networkGatewayCmd.AddCommand(networkGatewayUninstallCmd) networkCmd.AddCommand(networkProxyCmd) networkCmd.AddCommand(networkGatewayCmd) diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index 996068d6..f746f558 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -6,8 +6,8 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - PROXY_NAME_ENV_NAME = "PROXY_NAME" - GATEWAY_NAME_ENV_NAME = "GATEWAY_NAME" + PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" + GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" diff --git a/packages/gateway-v2/systemd.go b/packages/gateway-v2/systemd.go new file mode 100644 index 00000000..794509ea --- /dev/null +++ b/packages/gateway-v2/systemd.go @@ -0,0 +1,128 @@ +package gatewayv2 + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/rs/zerolog/log" +) + +const systemdServiceTemplate = `[Unit] +Description=Infisical Gateway Service +After=network.target + +[Service] +Type=notify +NotifyAccess=all +EnvironmentFile=/etc/infisical/gateway.conf +ExecStart=infisical network gateway +Restart=on-failure +InaccessibleDirectories=/home +PrivateTmp=yes +LimitCORE=infinity +LimitNOFILE=1000000 +LimitNPROC=60000 +LimitRTPRIO=infinity +LimitRTTIME=7000000 + +[Install] +WantedBy=multi-user.target +` + +func InstallGatewaySystemdService(token string, domain string, name string, proxyName string) error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service installation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service installation - not running as root/sudo") + return nil + } + + configDir := "/etc/infisical" + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %v", err) + } + + configContent := fmt.Sprintf("INFISICAL_TOKEN=%s\n", token) + if domain != "" { + configContent += fmt.Sprintf("INFISICAL_API_URL=%s\n", domain) + } + + if name != "" { + configContent += fmt.Sprintf("%s=%s\n", GATEWAY_NAME_ENV_NAME, name) + } + if proxyName != "" { + configContent += fmt.Sprintf("%s=%s\n", PROXY_NAME_ENV_NAME, proxyName) + } + + configPath := filepath.Join(configDir, "gateway.conf") + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + return fmt.Errorf("failed to write config file: %v", err) + } + + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.WriteFile(servicePath, []byte(systemdServiceTemplate), 0644); err != nil { + return fmt.Errorf("failed to write systemd service file: %v", err) + } + + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully installed systemd service") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") + log.Info().Msg("To enable the service on boot, run: sudo systemctl enable infisical-gateway") + + return nil +} + +func UninstallGatewaySystemdService() error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service uninstallation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service uninstallation - not running as root/sudo") + return nil + } + + // Stop the service if it's running + stopCmd := exec.Command("systemctl", "stop", "infisical-gateway") + if err := stopCmd.Run(); err != nil { + log.Warn().Msgf("Failed to stop service: %v", err) + } + + // Disable the service + disableCmd := exec.Command("systemctl", "disable", "infisical-gateway") + if err := disableCmd.Run(); err != nil { + log.Warn().Msgf("Failed to disable service: %v", err) + } + + // Remove the service file + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.Remove(servicePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove systemd service file: %v", err) + } + + // Remove the configuration file + configPath := "/etc/infisical/gateway.conf" + if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove config file: %v", err) + } + + // Reload systemd to apply changes + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully uninstalled Infisical Gateway systemd service") + return nil +} From 085de6d98cca611a42de692350413704ea1ee563 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 01:17:58 +0800 Subject: [PATCH 14/31] misc: added proxy name validation --- packages/cmd/network.go | 7 ++++++- packages/proxy/proxy.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/packages/cmd/network.go b/packages/cmd/network.go index cbd43ca4..237aff9b 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -131,7 +131,6 @@ var networkProxyCmd = &cobra.Command{ }() } - // Use the same context for the proxy server err = proxyInstance.Start(cmd.Context()) if err != nil { util.HandleError(err, "unable to start proxy instance") @@ -272,11 +271,17 @@ var networkGatewayInstallCmd = &cobra.Command{ if err != nil { util.HandleError(err, "Unable to parse name flag") } + if gatewayName == "" { + util.HandleError(errors.New("Gateway name is required")) + } proxyName, err := cmd.Flags().GetString("proxy-name") if err != nil { util.HandleError(err, "Unable to parse proxy-name flag") } + if proxyName == "" { + util.HandleError(errors.New("Proxy name is required")) + } err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) if err != nil { diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 1b050a6d..dcd0f100 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -79,7 +79,6 @@ func (p *Proxy) SetToken(token string) { } func (p *Proxy) Start(ctx context.Context) error { - // Register proxy and get certificates from API if err := p.registerProxy(); err != nil { return fmt.Errorf("failed to register proxy: %v", err) } @@ -184,6 +183,13 @@ func (p *Proxy) setupSSHServer() error { return nil, fmt.Errorf("gateway id is required") } + // Validate that the user is authorized to connect to the current proxy + expectedKeyId := "client-" + p.config.ProxyName + if cert.KeyId != expectedKeyId { + log.Error().Msgf("Gateway '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) + return nil, fmt.Errorf("certificate Key ID does not match expected value") + } + return &ssh.Permissions{ Extensions: map[string]string{ "gateway-id": gatewayId, @@ -266,7 +272,7 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("invalid certificate type: %d", cert.CertType) } - // Check if certificate is signed expected CA + // Check if certificate is signed by expected CA checker := &ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) From 99091419edcfa6e67435d9a17a4096a1bac4f141 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 01:47:46 +0800 Subject: [PATCH 15/31] misc: added proxy cert auto-renewal --- packages/proxy/proxy.go | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index dcd0f100..b301b26b 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -11,6 +11,7 @@ import ( "io" "net" "sync" + "time" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" @@ -93,6 +94,9 @@ func (p *Proxy) Start(ctx context.Context) error { return fmt.Errorf("failed to setup TLS server: %v", err) } + // Start certificate renewal goroutine + go p.startCertificateRenewal(ctx) + // Start SSH server go p.startSSHServer() @@ -447,3 +451,45 @@ func (p *Proxy) cleanup() { log.Info().Msg("Proxy server shutdown complete") } + +// startCertificateRenewal runs a background process to renew certificates every 24 hours +func (p *Proxy) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting certificate renewal goroutine") + ticker := time.NewTicker(30 * time.Second) // TODO: update this to be every 10 days + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Checking certificates for renewal...") + if err := p.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew certificates: %v", err) + } else { + log.Info().Msg("Certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the server configurations +func (p *Proxy) renewCertificates() error { + // Re-register proxy to get fresh certificates + if err := p.registerProxy(); err != nil { + return fmt.Errorf("failed to register proxy: %v", err) + } + + // Update SSH server configuration + if err := p.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Update TLS server configuration + if err := p.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + return nil +} From 6d0a02105aec6e0ad4b97353b3ccd1723c511aa0 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:00:15 +0800 Subject: [PATCH 16/31] misc: updated proxy tls server handling for cert renewal --- packages/proxy/proxy.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index b301b26b..829e5dbe 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -350,7 +350,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { } func (p *Proxy) startTLSServer() { - listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) + listener, err := net.Listen("tcp", ":"+p.config.TLSPort) if err != nil { log.Fatal().Msgf("Failed to start TLS server: %v", err) } @@ -364,10 +364,27 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - go p.handleClient(conn) + go p.handleTLSClient(conn) } } +func (p *Proxy) handleTLSClient(conn net.Conn) { + defer conn.Close() + + // Perform TLS handshake using current TLS config + tlsConn := tls.Server(conn, p.tlsConfig) + defer tlsConn.Close() + + // Force TLS handshake + err := tlsConn.Handshake() + if err != nil { + log.Error().Msgf("TLS handshake failed: %v", err) + return + } + + p.handleClient(tlsConn) +} + func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() From b15233829bf7ec512a9d86f9fb99410fe4527b7b Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:08:23 +0800 Subject: [PATCH 17/31] misc: corrected client handling --- packages/proxy/proxy.go | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 829e5dbe..97489bfa 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -385,35 +385,19 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { p.handleClient(tlsConn) } -func (p *Proxy) handleClient(clientConn net.Conn) { - defer clientConn.Close() - +func (p *Proxy) handleClient(tlsConn *tls.Conn) { var gatewayId string + state := tlsConn.ConnectionState() - if tlsConn, ok := clientConn.(*tls.Conn); ok { - log.Debug().Msg("TLS connection detected, forcing handshake...") - err := tlsConn.Handshake() - if err != nil { - log.Error().Msgf("TLS handshake failed: %v", err) - return - } - - state := tlsConn.ConnectionState() - - if len(state.PeerCertificates) > 0 { - cert := state.PeerCertificates[0] - log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) - gatewayId = cert.Subject.CommonName - } else { - log.Warn().Msg("No peer certificates found") - return - } + if len(state.PeerCertificates) > 0 { + cert := state.PeerCertificates[0] + log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } else { - log.Error().Msgf("Not a TLS connection, connection type: %T", clientConn) + log.Warn().Msg("No peer certificates found") return } - // TODO: extract these from the certificate targetHost := "gateway" targetPort := uint32(22) @@ -424,7 +408,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { if !exists { log.Warn().Msgf("Gateway '%s' not connected", gatewayId) - clientConn.Write([]byte("ERROR: Gateway not connected\n")) + tlsConn.Write([]byte("ERROR: Gateway not connected\n")) return } @@ -441,19 +425,19 @@ func (p *Proxy) handleClient(clientConn net.Conn) { channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) if err != nil { log.Error().Msgf("Failed to connect to agent: %v", err) - clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) + tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) return } defer channel.Close() // Bidirectional forwarding go func() { - io.Copy(channel, clientConn) + io.Copy(channel, tlsConn) channel.CloseWrite() }() - io.Copy(clientConn, channel) - log.Info().Msgf("Client %s disconnected", clientConn.RemoteAddr()) + io.Copy(tlsConn, channel) + log.Info().Msgf("Client %s disconnected", tlsConn.RemoteAddr()) } func (p *Proxy) cleanup() { From 3bcf34c7ff0a8237e969fbec31f057f664ccb7aa Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:51:15 +0800 Subject: [PATCH 18/31] misc: addeed tls connection accept log --- packages/proxy/proxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 97489bfa..22b75fd2 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -364,6 +364,7 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } + log.Info().Msgf("TLS connection accepted from %s", conn.RemoteAddr()) go p.handleTLSClient(conn) } } From 9ccf30bfdb5aab88f008360029b42f0db24cee33 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:02:02 +0800 Subject: [PATCH 19/31] misc: add connection deadline for unauthenticated requests --- packages/proxy/proxy.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 22b75fd2..3fbeb9b3 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -364,7 +364,6 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - log.Info().Msgf("TLS connection accepted from %s", conn.RemoteAddr()) go p.handleTLSClient(conn) } } @@ -376,13 +375,19 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { tlsConn := tls.Server(conn, p.tlsConfig) defer tlsConn.Close() + // Set handshake timeout to avoid hanging on slow/malicious connections + tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + // Force TLS handshake err := tlsConn.Handshake() if err != nil { - log.Error().Msgf("TLS handshake failed: %v", err) + log.Debug().Msgf("TLS handshake failed from %s: %v", conn.RemoteAddr(), err) return } + // Clear deadline for actual data transfer + tlsConn.SetDeadline(time.Time{}) + p.handleClient(tlsConn) } From d39ef05297d27457fc9c8875724c1831548895cb Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:05:40 +0800 Subject: [PATCH 20/31] misc: finalized cert renewal interval to 10 days --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 3fbeb9b3..8d1a327c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -462,7 +462,7 @@ func (p *Proxy) cleanup() { // startCertificateRenewal runs a background process to renew certificates every 24 hours func (p *Proxy) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Starting certificate renewal goroutine") - ticker := time.NewTicker(30 * time.Second) // TODO: update this to be every 10 days + ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() for { From 60655841cd5338ee9dfd66038c9d330c358408b4 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:23:34 +0800 Subject: [PATCH 21/31] misc: add cert renewal to gateway server --- packages/gateway-v2/gateway.go | 45 +++++++++++++++++++++++++++++++++- packages/proxy/proxy.go | 2 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index d7c2c61b..98ac68f0 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -159,6 +159,9 @@ func (g *Gateway) Start(ctx context.Context) error { errCh := make(chan error, 1) g.registerHeartBeat(ctx, errCh) + // Start certificate renewal goroutine + go g.startCertificateRenewal(ctx) + go func() { for { select { @@ -289,7 +292,15 @@ func (g *Gateway) registerGateway() error { g.certificates = &certResp log.Info().Msgf("Successfully registered gateway and received certificates") - // Create mTLS config once during registration + // Setup mTLS config + if err := g.setupTLSConfig(); err != nil { + return fmt.Errorf("failed to setup TLS config: %v", err) + } + + return nil +} + +func (g *Gateway) setupTLSConfig() error { serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) if serverCertBlock == nil { return fmt.Errorf("failed to decode server certificate") @@ -622,3 +633,35 @@ func (vc *virtualConnection) SetReadDeadline(t time.Time) error { func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { return nil } + +// startCertificateRenewal runs a background process to renew certificates every 10 days +func (g *Gateway) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting gateway certificate renewal goroutine") + ticker := time.NewTicker(10 * 24 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Gateway certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Renewing gateway certificates...") + if err := g.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew gateway certificates: %v", err) + } else { + log.Info().Msg("Gateway certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the gateway configurations +func (g *Gateway) renewCertificates() error { + // Re-register gateway to get fresh certificates + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + return nil +} diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 8d1a327c..7d5d6251 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -471,7 +471,7 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Certificate renewal goroutine stopping...") return case <-ticker.C: - log.Info().Msg("Checking certificates for renewal...") + log.Info().Msg("Renewing certificates...") if err := p.renewCertificates(); err != nil { log.Error().Msgf("Failed to renew certificates: %v", err) } else { From 4e6ee387be5f00ca3c0d9cb59601b9682bd19d60 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 19:04:01 +0800 Subject: [PATCH 22/31] misc: used non-standard port for proxy TLS --- packages/cmd/network.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 237aff9b..753e7d22 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -51,7 +51,7 @@ var networkProxyCmd = &cobra.Command{ proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ ProxyName: proxyName, SSHPort: "2222", - TLSPort: "443", + TLSPort: "8443", StaticIP: ip, Type: instanceType, }) From 8eaf2a5ff18e5c530aa4d713c3c0d433cfdb5dbf Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 19:55:22 +0800 Subject: [PATCH 23/31] misc: improved security posture of proxy server --- packages/proxy/proxy.go | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 7d5d6251..c029bf88 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -164,6 +164,12 @@ func (p *Proxy) setupSSHServer() error { // Setup SSH server config p.sshConfig = &ssh.ServerConfig{ + MaxAuthTries: 3, + AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) { + if err != nil { + log.Warn().Msgf("Auth failed for %s@%s using %s: %v", conn.User(), conn.RemoteAddr(), method, err) + } + }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) @@ -315,7 +321,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { defer conn.Close() // SSH handshake - sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) + sshConn, chans, reqs, err := ssh.NewServerConn(conn, p.sshConfig) if err != nil { log.Error().Msgf("SSH handshake failed: %v", err) return @@ -324,8 +330,16 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { gatewayId := sshConn.Permissions.Extensions["gateway-id"] log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) - // Store the connection + // Store the connection (ensure only one connection per gateway) p.mu.Lock() + if existingConn, exists := p.tunnels[gatewayId]; exists { + p.mu.Unlock() + log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) + sshConn.Close() + existingConn.Close() // Also close the existing connection to force re-auth + return + } + p.tunnels[gatewayId] = sshConn p.mu.Unlock() @@ -337,14 +351,34 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { log.Info().Msgf("Gateway %s disconnected", gatewayId) }() + // Handle global requests (reject all for security) + go func() { + for req := range reqs { + log.Debug().Msgf("Rejecting global request: %s from gateway %s", req.Type, gatewayId) + if req.WantReply { + req.Reply(false, nil) + } + } + }() + + // Handle channel requests for newChannel := range chans { switch newChannel.ChannelType() { case "session": + log.Debug().Msgf("Rejecting session channel from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no shell access") case "x11": + log.Debug().Msgf("Rejecting X11 forwarding from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no X11 forwarding") case "auth-agent": + log.Debug().Msgf("Rejecting auth-agent forwarding from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no agent forwarding") + case "forwarded-tcpip": + log.Debug().Msgf("Rejecting forwarded-tcpip from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no port forwarding") + default: + log.Warn().Msgf("Rejecting unknown channel type '%s' from gateway %s", newChannel.ChannelType(), gatewayId) + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") } } } From ce41396d4a278749eef85691a1ebe2003a1bfb1d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:11:48 +0800 Subject: [PATCH 24/31] misc: added sending of error message when multiple gateway is detected --- packages/proxy/proxy.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index c029bf88..de59db16 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,11 +332,23 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if existingConn, exists := p.tunnels[gatewayId]; exists { + if _, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) - sshConn.Close() - existingConn.Close() // Also close the existing connection to force re-auth + + // Send error message to the new connection before closing + go func() { + // Send a global request with error information + _, _, err := sshConn.SendRequest("duplicate-connection-error", false, []byte(fmt.Sprintf("Gateway '%s' already has an active connection. Only one connection per gateway is allowed.", gatewayId))) + if err != nil { + log.Debug().Msgf("Failed to send duplicate connection error message to gateway '%s': %v", gatewayId, err) + } + + // Give a moment for the message to be sent before closing + time.Sleep(1000 * time.Millisecond) + sshConn.Close() + }() + return } From c51d31f02f0a87f9e8cc1300391d6f65122bcc00 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:15:14 +0800 Subject: [PATCH 25/31] Revert "misc: added sending of error message when multiple gateway is detected" This reverts commit ce41396d4a278749eef85691a1ebe2003a1bfb1d. --- packages/proxy/proxy.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index de59db16..c029bf88 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,23 +332,11 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if _, exists := p.tunnels[gatewayId]; exists { + if existingConn, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) - - // Send error message to the new connection before closing - go func() { - // Send a global request with error information - _, _, err := sshConn.SendRequest("duplicate-connection-error", false, []byte(fmt.Sprintf("Gateway '%s' already has an active connection. Only one connection per gateway is allowed.", gatewayId))) - if err != nil { - log.Debug().Msgf("Failed to send duplicate connection error message to gateway '%s': %v", gatewayId, err) - } - - // Give a moment for the message to be sent before closing - time.Sleep(1000 * time.Millisecond) - sshConn.Close() - }() - + sshConn.Close() + existingConn.Close() // Also close the existing connection to force re-auth return } From 21d61c1a1ec3714ff98e9418585da87485841904 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:16:11 +0800 Subject: [PATCH 26/31] misc: only close new connection for duplicate gateway --- packages/proxy/proxy.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index c029bf88..533028b9 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,11 +332,10 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if existingConn, exists := p.tunnels[gatewayId]; exists { + if _, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) sshConn.Close() - existingConn.Close() // Also close the existing connection to force re-auth return } From 7d2276fd6f834aa84ad74bed9d8415ae36f9be47 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:47:31 +0800 Subject: [PATCH 27/31] misc: decreased tls deadline --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 533028b9..bfbf0f5c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -409,7 +409,7 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { defer tlsConn.Close() // Set handshake timeout to avoid hanging on slow/malicious connections - tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + tlsConn.SetDeadline(time.Now().Add(5 * time.Second)) // Force TLS handshake err := tlsConn.Handshake() From e5a426d1680755c279c4b30b5de32191bf22566a Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 21:45:29 +0800 Subject: [PATCH 28/31] misc: addressed greptile --- packages/gateway-v2/constants.go | 2 +- packages/gateway-v2/gateway.go | 23 ++++++----------------- packages/proxy/proxy.go | 13 +------------ 3 files changed, 8 insertions(+), 30 deletions(-) diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index f746f558..de54cd6f 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -9,7 +9,7 @@ const ( PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" - PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" + PROXY_AUTH_SECRET_ENV_NAME = "INFISICAL_PROXY_AUTH_SECRET" INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" ) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 98ac68f0..bbad19f9 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" @@ -83,8 +82,6 @@ type Gateway struct { // mTLS server components tlsConfig *tls.Config - tlsCACert []byte - tlsCAKey *rsa.PrivateKey // Connection management mu sync.RWMutex @@ -364,8 +361,13 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { return nil, fmt.Errorf("failed to parse certificate: %v", err) } + sshCert, ok := cert.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("parsed key is not an SSH certificate, got type: %T", cert) + } + // Create certificate signer - certSigner, err := ssh.NewCertSigner(cert.(*ssh.Certificate), privateKey) + certSigner, err := ssh.NewCertSigner(sshCert, privateKey) if err != nil { return nil, fmt.Errorf("failed to create certificate signer: %v", err) } @@ -432,19 +434,6 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string } func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { - var req struct { - Host string - Port uint32 - OriginHost string - OriginPort uint32 - } - - if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { - log.Info().Msgf("Failed to parse channel request: %v", err) - newChannel.Reject(ssh.Prohibited, "invalid request") - return - } - channel, requests, err := newChannel.Accept() if err != nil { log.Info().Msgf("Failed to accept channel: %v", err) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index bfbf0f5c..26af7e15 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -437,9 +437,6 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { return } - targetHost := "gateway" - targetPort := uint32(22) - // Get the SSH connection for this agent p.mu.RLock() conn, exists := p.tunnels[gatewayId] @@ -453,15 +450,7 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) - // Open SSH channel to connect to agent's local service through the tunnel - payload := struct { - Host string - Port uint32 - _ string - _ uint32 - }{targetHost, targetPort, "", 0} - - channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) + channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { log.Error().Msgf("Failed to connect to agent: %v", err) tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) From fcdc1456df2c98d294bca56349cd3c46a05af950 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 21:48:55 +0800 Subject: [PATCH 29/31] misc: removed proxy auth logging --- packages/proxy/proxy.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 26af7e15..2a2a6301 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -165,11 +165,6 @@ func (p *Proxy) setupSSHServer() error { // Setup SSH server config p.sshConfig = &ssh.ServerConfig{ MaxAuthTries: 3, - AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) { - if err != nil { - log.Warn().Msgf("Auth failed for %s@%s using %s: %v", conn.User(), conn.RemoteAddr(), method, err) - } - }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) From fc62acd90cbdf6ac0da64cdbfdf93a5a7cc8e740 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 4 Sep 2025 04:05:30 +0800 Subject: [PATCH 30/31] misc: updated gateway logs --- packages/gateway-v2/gateway.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index bbad19f9..46553798 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -225,12 +225,12 @@ func (g *Gateway) connectAndServe() error { } // Connect to Proxy server - log.Info().Msgf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + log.Info().Msgf("Connecting to proxy server %s on %s:%d...", g.config.ProxyName, g.certificates.ProxyIP, g.config.SSHPort) client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } - log.Info().Msgf("SSH connection established for gateway") + log.Info().Msgf("Proxy connection established for gateway") g.mu.Lock() g.sshClient = client @@ -254,7 +254,7 @@ func (g *Gateway) connectAndServe() error { // Monitor for context cancellation and close SSH client go func() { <-g.ctx.Done() - log.Info().Msg("Context cancelled, closing SSH connection...") + log.Info().Msg("Context cancelled, closing proxy connection...") client.Close() }() From 144d4e7d3ed23c2a85ba120749949ae711aafe23 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 6 Sep 2025 21:54:01 +0800 Subject: [PATCH 31/31] misc: updated to use relay and connector terminologies --- packages/api/api.go | 44 +-- packages/api/model.go | 16 +- packages/cmd/connector.go | 222 +++++++++++ packages/cmd/network.go | 351 ------------------ packages/cmd/relay.go | 156 ++++++++ .../{gateway-v2 => connector}/connection.go | 12 +- .../gateway.go => connector/connector.go} | 206 +++++----- .../{gateway-v2 => connector}/constants.go | 12 +- packages/{gateway-v2 => connector}/systemd.go | 36 +- packages/{proxy/proxy.go => relay/relay.go} | 234 ++++++------ 10 files changed, 658 insertions(+), 631 deletions(-) create mode 100644 packages/cmd/connector.go delete mode 100644 packages/cmd/network.go create mode 100644 packages/cmd/relay.go rename packages/{gateway-v2 => connector}/connection.go (95%) rename packages/{gateway-v2/gateway.go => connector/connector.go} (70%) rename packages/{gateway-v2 => connector}/constants.go (55%) rename packages/{gateway-v2 => connector}/systemd.go (73%) rename packages/{proxy/proxy.go => relay/relay.go} (60%) diff --git a/packages/api/api.go b/packages/api/api.go index e20e6daf..16ede69c 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -39,10 +39,10 @@ const ( operationCallRegisterGatewayIdentityV1 = "CallRegisterGatewayIdentityV1" operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" - operationCallGatewayHeartBeatV2 = "CallGatewayHeartBeatV2" + operationCallConnectorHeartBeat = "CallConnectorHeartBeat" operationCallBootstrapInstance = "CallBootstrapInstance" - operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" - operationCallRegisterOrgProxy = "CallRegisterOrgProxy" + operationCallRegisterInstanceRelay = "CallRegisterInstanceRelay" + operationCallRegisterOrgRelay = "CallRegisterOrgRelay" operationCallRegisterGateway = "CallRegisterGateway" ) @@ -656,18 +656,18 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } -func CallGatewayHeartBeatV2(httpClient *resty.Client) error { +func CallConnectorHeartBeat(httpClient *resty.Client) error { response, err := httpClient. R(). SetHeader("User-Agent", USER_AGENT). - Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/connectors/heartbeat", config.INFISICAL_URL)) if err != nil { - return NewGenericRequestError(operationCallGatewayHeartBeatV2, err) + return NewGenericRequestError(operationCallConnectorHeartBeat, err) } if response.IsError() { - return NewAPIErrorWithResponse(operationCallGatewayHeartBeatV2, response, nil) + return NewAPIErrorWithResponse(operationCallConnectorHeartBeat, response, nil) } return nil @@ -693,61 +693,61 @@ func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRe return resBody, nil } -func CallRegisterInstanceProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { - var resBody RegisterProxyResponse +func CallRegisterInstanceRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse response, err := httpClient. R(). SetResult(&resBody). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v1/proxies/register-instance-proxy", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/relays/register-instance-relay", config.INFISICAL_URL)) if err != nil { - return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterInstanceProxy, err) + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterInstanceRelay, err) } if response.IsError() { - return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceProxy, response, nil) + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceRelay, response, nil) } return resBody, nil } -func CallRegisterProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { - var resBody RegisterProxyResponse +func CallRegisterRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse response, err := httpClient. R(). SetResult(&resBody). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v1/proxies/register-org-proxy", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/relays/register-org-relay", config.INFISICAL_URL)) if err != nil { - return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterOrgProxy, err) + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterOrgRelay, err) } if response.IsError() { - return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgProxy, response, nil) + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgRelay, response, nil) } return resBody, nil } -func CallRegisterGateway(httpClient *resty.Client, request RegisterGatewayRequest) (RegisterGatewayResponse, error) { - var resBody RegisterGatewayResponse +func CallRegisterConnector(httpClient *resty.Client, request RegisterConnectorRequest) (RegisterConnectorResponse, error) { + var resBody RegisterConnectorResponse response, err := httpClient. R(). SetResult(&resBody). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v2/gateways", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/connectors", config.INFISICAL_URL)) if err != nil { - return RegisterGatewayResponse{}, NewGenericRequestError(operationCallRegisterGateway, err) + return RegisterConnectorResponse{}, NewGenericRequestError(operationCallRegisterGateway, err) } if response.IsError() { - return RegisterGatewayResponse{}, NewAPIErrorWithResponse(operationCallRegisterGateway, response, nil) + return RegisterConnectorResponse{}, NewAPIErrorWithResponse(operationCallRegisterGateway, response, nil) } return resBody, nil diff --git a/packages/api/model.go b/packages/api/model.go index c436d117..8d85d453 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -704,12 +704,12 @@ type BootstrapUser struct { SuperAdmin bool `json:"superAdmin"` } -type RegisterProxyRequest struct { +type RegisterRelayRequest struct { IP string `json:"ip"` Name string `json:"name"` } -type RegisterProxyResponse struct { +type RegisterRelayResponse struct { PKI struct { ServerCertificate string `json:"serverCertificate"` ServerPrivateKey string `json:"serverPrivateKey"` @@ -722,15 +722,15 @@ type RegisterProxyResponse struct { } `json:"ssh"` } -type RegisterGatewayRequest struct { - ProxyName string `json:"proxyName"` +type RegisterConnectorRequest struct { + RelayName string `json:"relayName"` Name string `json:"name"` } -type RegisterGatewayResponse struct { - GatewayID string `json:"gatewayId"` - ProxyIP string `json:"proxyIp"` - PKI struct { +type RegisterConnectorResponse struct { + ConnectorID string `json:"connectorId"` + RelayIP string `json:"relayIp"` + PKI struct { ServerCertificate string `json:"serverCertificate"` ServerPrivateKey string `json:"serverPrivateKey"` ClientCertificateChain string `json:"clientCertificateChain"` diff --git a/packages/cmd/connector.go b/packages/cmd/connector.go new file mode 100644 index 00000000..6f2a211d --- /dev/null +++ b/packages/cmd/connector.go @@ -0,0 +1,222 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "runtime" + "sync/atomic" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/connector" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var connectorCmd = &cobra.Command{ + Use: "connector", + Short: "Connector-related commands", + Long: "Connector-related commands for Infisical", +} + +var connectorStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the Infisical connector component", + Long: "Start the Infisical connector component. Use 'connector install' to set up the systemd service.", + Example: "infisical connector start --relay=us-west-1 --name=my-connector --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + relayName, err := util.GetCmdFlagOrEnv(cmd, "relay", []string{connector.RELAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get relay flag or %s env", connector.RELAY_NAME_ENV_NAME)) + } + + connectorName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{connector.CONNECTOR_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", connector.CONNECTOR_NAME_ENV_NAME)) + } + + connectorInstance, err := connector.NewConnector(&connector.ConnectorConfig{ + Name: connectorName, + RelayName: relayName, + ReconnectDelay: 10 * time.Second, + }) + + if err != nil { + util.HandleError(err, "unable to create connector instance") + } + + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + connectorInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down connector...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + connectorInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + + err = connectorInstance.Start(ctx) + if err != nil { + util.HandleError(err, "unable to start connector instance") + } + + }, +} + +var connectorInstallCmd = &cobra.Command{ + Use: "install", + Short: "Install and enable systemd service for the connector (requires sudo)", + Long: "Install and enable systemd service for the connector. Must be run with sudo on Linux.", + Example: "sudo infisical connector install --token= --domain= --name= --relay=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + token, err := util.GetInfisicalToken(cmd) + if err != nil { + util.HandleError(err, "Unable to parse flag") + } + + if token == nil { + util.HandleError(errors.New("Token not found")) + } + + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + connectorName, err := cmd.Flags().GetString("name") + if err != nil { + util.HandleError(err, "Unable to parse name flag") + } + if connectorName == "" { + util.HandleError(errors.New("Connector name is required")) + } + + relayName, err := cmd.Flags().GetString("relay") + if err != nil { + util.HandleError(err, "Unable to parse relay flag") + } + if relayName == "" { + util.HandleError(errors.New("Relay name is required")) + } + + err = connector.InstallConnectorSystemdService(token.Token, domain, connectorName, relayName) + if err != nil { + util.HandleError(err, "Unable to install systemd service") + } + }, +} + +var connectorUninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall and remove systemd service for the connector (requires sudo)", + Long: "Uninstall and remove systemd service for the connector. Must be run with sudo on Linux.", + Example: "sudo infisical connector uninstall", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + if err := connector.UninstallConnectorSystemdService(); err != nil { + util.HandleError(err, "Failed to uninstall systemd service") + } + }, +} + +func init() { + connectorStartCmd.Flags().String("relay", "", "The name of the relay to connect to") + connectorStartCmd.Flags().String("name", "", "The name of the connector") + connectorStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + connectorStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + connectorStartCmd.Flags().String("client-id", "", "client id for universal auth") + connectorStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + connectorStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + connectorStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + connectorStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + connectorStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + connectorInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") + connectorInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + connectorInstallCmd.Flags().String("name", "", "The name of the connector") + connectorInstallCmd.Flags().String("relay", "", "The name of the relay") + + connectorCmd.AddCommand(connectorStartCmd) + connectorCmd.AddCommand(connectorInstallCmd) + connectorCmd.AddCommand(connectorUninstallCmd) + + rootCmd.AddCommand(connectorCmd) +} diff --git a/packages/cmd/network.go b/packages/cmd/network.go deleted file mode 100644 index 753e7d22..00000000 --- a/packages/cmd/network.go +++ /dev/null @@ -1,351 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "os" - "os/signal" - "runtime" - "sync/atomic" - "syscall" - "time" - - gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" - "github.com/Infisical/infisical-merge/packages/proxy" - "github.com/Infisical/infisical-merge/packages/util" - "github.com/rs/zerolog/log" - "github.com/spf13/cobra" -) - -var networkCmd = &cobra.Command{ - Use: "network", - Short: "Network-related commands", - Long: "Network-related commands for Infisical", -} - -var networkProxyCmd = &cobra.Command{ - Use: "proxy", - Short: "Run the Infisical proxy component", - Long: "Run the Infisical proxy component", - Example: "infisical network proxy --type=instance --ip= --name= --token=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - - proxyName, err := cmd.Flags().GetString("name") - if err != nil || proxyName == "" { - util.HandleError(err, "unable to get name flag") - } - - ip, err := cmd.Flags().GetString("ip") - if err != nil || ip == "" { - util.HandleError(err, "unable to get ip flag") - } - - instanceType, err := cmd.Flags().GetString("type") - if err != nil { - util.HandleError(err, "unable to get type flag") - } - - proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ - ProxyName: proxyName, - SSHPort: "2222", - TLSPort: "8443", - StaticIP: ip, - Type: instanceType, - }) - - if err != nil { - util.HandleError(err, "unable to create proxy instance") - } - - if instanceType == "instance" { - proxyAuthSecret := os.Getenv(gatewayv2.PROXY_AUTH_SECRET_ENV_NAME) - if proxyAuthSecret == "" { - util.HandleError(fmt.Errorf("%s is not set", gatewayv2.PROXY_AUTH_SECRET_ENV_NAME), "unable to get proxy auth secret") - } - - proxyInstance.SetToken(proxyAuthSecret) - } else { - infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) - if err != nil { - util.HandleError(err, "unable to get infisical client") - } - defer cancelSdk() - - var accessToken atomic.Value - accessToken.Store(infisicalClient.Auth().GetAccessToken()) - - if accessToken.Load().(string) == "" { - util.HandleError(errors.New("no access token found")) - } - - proxyInstance.SetToken(accessToken.Load().(string)) - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - ctx, cancelCmd := context.WithCancel(cmd.Context()) - defer cancelCmd() - - go func() { - <-sigCh - log.Info().Msg("Received shutdown signal, shutting down proxy...") - cancelCmd() - cancelSdk() - - // Give graceful shutdown 10 seconds, then force exit on second signal - select { - case <-sigCh: - log.Warn().Msg("Second signal received, force exit triggered") - os.Exit(1) - case <-time.After(10 * time.Second): - log.Info().Msg("Graceful shutdown completed") - os.Exit(0) - } - }() - - // Token refresh goroutine - runs every 10 seconds - go func() { - tokenRefreshTicker := time.NewTicker(10 * time.Second) - defer tokenRefreshTicker.Stop() - - for { - select { - case <-tokenRefreshTicker.C: - if ctx.Err() != nil { - return - } - - newToken := infisicalClient.Auth().GetAccessToken() - if newToken != "" && newToken != accessToken.Load().(string) { - accessToken.Store(newToken) - proxyInstance.SetToken(newToken) - } - - case <-ctx.Done(): - return - } - } - }() - } - - err = proxyInstance.Start(cmd.Context()) - if err != nil { - util.HandleError(err, "unable to start proxy instance") - } - }, -} - -var networkGatewayCmd = &cobra.Command{ - Use: "gateway", - Short: "Run the Infisical gateway component", - Long: "Run the Infisical gateway component. Use 'network gateway install' to set up the systemd service.", - Example: "infisical network gateway --proxy-name= --name= --token=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - - proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) - if err != nil { - util.HandleError(err, fmt.Sprintf("unable to get proxy-name flag or %s env", gatewayv2.PROXY_NAME_ENV_NAME)) - } - - gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) - if err != nil { - util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) - } - - gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ - Name: gatewayName, - ProxyName: proxyName, - ReconnectDelay: 10 * time.Second, - }) - - if err != nil { - util.HandleError(err, "unable to create gateway instance") - } - - infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) - if err != nil { - util.HandleError(err, "unable to get infisical client") - } - defer cancelSdk() - - var accessToken atomic.Value - accessToken.Store(infisicalClient.Auth().GetAccessToken()) - - if accessToken.Load().(string) == "" { - util.HandleError(errors.New("no access token found")) - } - - gatewayInstance.SetToken(accessToken.Load().(string)) - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - ctx, cancelCmd := context.WithCancel(cmd.Context()) - defer cancelCmd() - - go func() { - <-sigCh - log.Info().Msg("Received shutdown signal, shutting down gateway...") - cancelCmd() - cancelSdk() - - // Give graceful shutdown 10 seconds, then force exit on second signal - select { - case <-sigCh: - log.Warn().Msg("Second signal received, force exit triggered") - os.Exit(1) - case <-time.After(10 * time.Second): - log.Info().Msg("Graceful shutdown completed") - os.Exit(0) - } - }() - - // Token refresh goroutine - runs every 10 seconds - go func() { - tokenRefreshTicker := time.NewTicker(10 * time.Second) - defer tokenRefreshTicker.Stop() - - for { - select { - case <-tokenRefreshTicker.C: - if ctx.Err() != nil { - return - } - - newToken := infisicalClient.Auth().GetAccessToken() - if newToken != "" && newToken != accessToken.Load().(string) { - accessToken.Store(newToken) - gatewayInstance.SetToken(newToken) - } - - case <-ctx.Done(): - return - } - } - }() - - err = gatewayInstance.Start(ctx) - if err != nil { - util.HandleError(err, "unable to start gateway instance") - } - - }, -} - -var networkGatewayInstallCmd = &cobra.Command{ - Use: "install", - Short: "Install and enable systemd service for the gateway (requires sudo)", - Long: "Install and enable systemd service for the gateway. Must be run with sudo on Linux.", - Example: "sudo infisical network gateway install --token= --domain= --name= --proxy-name=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - if runtime.GOOS != "linux" { - util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) - } - - if os.Geteuid() != 0 { - util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) - } - - token, err := util.GetInfisicalToken(cmd) - if err != nil { - util.HandleError(err, "Unable to parse flag") - } - - if token == nil { - util.HandleError(errors.New("Token not found")) - } - - domain, err := cmd.Flags().GetString("domain") - if err != nil { - util.HandleError(err, "Unable to parse domain flag") - } - - gatewayName, err := cmd.Flags().GetString("name") - if err != nil { - util.HandleError(err, "Unable to parse name flag") - } - if gatewayName == "" { - util.HandleError(errors.New("Gateway name is required")) - } - - proxyName, err := cmd.Flags().GetString("proxy-name") - if err != nil { - util.HandleError(err, "Unable to parse proxy-name flag") - } - if proxyName == "" { - util.HandleError(errors.New("Proxy name is required")) - } - - err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) - if err != nil { - util.HandleError(err, "Unable to install systemd service") - } - }, -} - -var networkGatewayUninstallCmd = &cobra.Command{ - Use: "uninstall", - Short: "Uninstall and remove systemd service for the gateway (requires sudo)", - Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", - Example: "sudo infisical network gateway uninstall", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - if runtime.GOOS != "linux" { - util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) - } - - if os.Geteuid() != 0 { - util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) - } - - if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { - util.HandleError(err, "Failed to uninstall systemd service") - } - }, -} - -func init() { - networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") - networkGatewayCmd.Flags().String("name", "", "The name of the gateway") - networkGatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - networkGatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - networkGatewayCmd.Flags().String("client-id", "", "client id for universal auth") - networkGatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") - networkGatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") - networkGatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") - networkGatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - networkGatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - - networkProxyCmd.Flags().String("type", "org", "The type of proxy to run. Must be either 'instance' or 'org'") - networkProxyCmd.Flags().String("ip", "", "The IP address of the proxy") - networkProxyCmd.Flags().String("name", "", "The name of the proxy") - networkProxyCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - networkProxyCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - networkProxyCmd.Flags().String("client-id", "", "client id for universal auth") - networkProxyCmd.Flags().String("client-secret", "", "client secret for universal auth") - networkProxyCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") - networkProxyCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") - networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - - networkGatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") - networkGatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") - networkGatewayInstallCmd.Flags().String("name", "", "The name of the gateway") - networkGatewayInstallCmd.Flags().String("proxy-name", "", "The name of the proxy") - - networkGatewayCmd.AddCommand(networkGatewayInstallCmd) - networkGatewayCmd.AddCommand(networkGatewayUninstallCmd) - - networkCmd.AddCommand(networkProxyCmd) - networkCmd.AddCommand(networkGatewayCmd) - - rootCmd.AddCommand(networkCmd) -} diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go new file mode 100644 index 00000000..7c88fea9 --- /dev/null +++ b/packages/cmd/relay.go @@ -0,0 +1,156 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/connector" + "github.com/Infisical/infisical-merge/packages/relay" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var relayCmd = &cobra.Command{ + Use: "relay", + Short: "Relay-related commands", + Long: "Relay-related commands for Infisical", +} + +var relayStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the Infisical relay component", + Long: "Start the Infisical relay component", + Example: "infisical relay start --type=instance --ip= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + relayName, err := cmd.Flags().GetString("name") + if err != nil || relayName == "" { + util.HandleError(err, "unable to get name flag") + } + + ip, err := cmd.Flags().GetString("ip") + if err != nil || ip == "" { + util.HandleError(err, "unable to get ip flag") + } + + instanceType, err := cmd.Flags().GetString("type") + if err != nil { + util.HandleError(err, "unable to get type flag") + } + + relayInstance, err := relay.NewRelay(&relay.RelayConfig{ + RelayName: relayName, + SSHPort: "2222", + TLSPort: "8443", + StaticIP: ip, + Type: instanceType, + }) + + if err != nil { + util.HandleError(err, "unable to create relay instance") + } + + if instanceType == "instance" { + relayAuthSecret := os.Getenv(connector.RELAY_AUTH_SECRET_ENV_NAME) + if relayAuthSecret == "" { + util.HandleError(fmt.Errorf("%s is not set", connector.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret") + } + + relayInstance.SetToken(relayAuthSecret) + } else { + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + relayInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down relay...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + relayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + } + + err = relayInstance.Start(cmd.Context()) + if err != nil { + util.HandleError(err, "unable to start relay instance") + } + }, +} + +func init() { + relayStartCmd.Flags().String("type", "org", "The type of relay to run. Must be either 'instance' or 'org'") + relayStartCmd.Flags().String("ip", "", "The IP address of the relay") + relayStartCmd.Flags().String("name", "", "The name of the relay") + relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + relayStartCmd.Flags().String("client-id", "", "client id for universal auth") + relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + relayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + relayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + relayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + relayCmd.AddCommand(relayStartCmd) + + rootCmd.AddCommand(relayCmd) +} diff --git a/packages/gateway-v2/connection.go b/packages/connector/connection.go similarity index 95% rename from packages/gateway-v2/connection.go rename to packages/connector/connection.go index 141681f8..b662625e 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/connector/connection.go @@ -1,4 +1,4 @@ -package gatewayv2 +package connector import ( "bufio" @@ -18,7 +18,7 @@ import ( ) func buildHttpInternalServerError(message string) string { - return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) + return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"connector: %s\"}", message) } func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { @@ -39,7 +39,7 @@ func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, caCertPool := x509.NewCertPool() if caCertPool.AppendCertsFromPEM(caCert) { tlsConfig.RootCAs = caCertPool - log.Info().Msg("Using provided CA certificate from gateway client") + log.Info().Msg("Using provided CA certificate from connector client") } else { log.Error().Msg("Failed to parse provided CA certificate") } @@ -98,15 +98,15 @@ func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, // Only platform actor can perform privileged actions if actionHeader != "" && forwardConfig.ActorType == ActorTypePlatform { - if actionHeader == HttpProxyActionInjectGatewayK8sServiceAccountToken { + if actionHeader == HttpProxyActionInjectConnectorK8sServiceAccountToken { token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) if err != nil { conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) continue // Continue to next request instead of returning } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) - log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) - } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { + log.Info().Msgf("Injected connector k8s SA auth token in request to %s", targetURL) + } else if actionHeader == HttpProxyActionUseConnectorK8sServiceAccount { // will work without a target URL set // set the ca cert to the pod's k8s service account ca cert: caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) diff --git a/packages/gateway-v2/gateway.go b/packages/connector/connector.go similarity index 70% rename from packages/gateway-v2/gateway.go rename to packages/connector/connector.go index 46553798..a4181d97 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/connector/connector.go @@ -1,4 +1,4 @@ -package gatewayv2 +package connector import ( "bufio" @@ -39,8 +39,8 @@ const ( ActorTypeUser ActorType = "user" ) -const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" -const GATEWAY_ACTOR_OID = "1.3.6.1.4.1.12345.100.2" +const CONNECTOR_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" +const CONNECTOR_ACTOR_OID = "1.3.6.1.4.1.12345.100.2" // ForwardConfig contains the configuration for forwarding type ForwardConfig struct { @@ -62,23 +62,23 @@ type ActorDetails struct { Type string `json:"type"` } -type GatewayConfig struct { +type ConnectorConfig struct { Name string - ProxyName string + RelayName string IdentityToken string SSHPort int ReconnectDelay time.Duration } -type Gateway struct { - GatewayID string +type Connector struct { + ConnectorID string httpClient *resty.Client - config *GatewayConfig + config *ConnectorConfig sshClient *ssh.Client // Certificate storage - certificates *api.RegisterGatewayResponse + certificates *api.RegisterConnectorResponse // mTLS server components tlsConfig *tls.Config @@ -90,8 +90,8 @@ type Gateway struct { cancel context.CancelFunc } -// NewGateway creates a new gateway instance -func NewGateway(config *GatewayConfig) (*Gateway, error) { +// NewConnector creates a new connector instance +func NewConnector(config *ConnectorConfig) (*Connector, error) { httpClient, err := util.GetRestyClientWithCustomHeaders() if err != nil { return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) @@ -106,7 +106,7 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { config.SSHPort = 2222 } - return &Gateway{ + return &Connector{ httpClient: httpClient, config: config, ctx: ctx, @@ -114,9 +114,9 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { }, nil } -func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { +func (c *Connector) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() { - if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + if err := api.CallConnectorHeartBeat(c.httpClient); err != nil { log.Warn().Msgf("Heartbeat failed: %v", err) select { case errCh <- err: @@ -124,7 +124,7 @@ func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { log.Warn().Msg("Error channel full, skipping heartbeat error report") } } else { - log.Info().Msg("Gateway is reachable by Infisical") + log.Info().Msg("Connector is reachable by Infisical") } } @@ -150,14 +150,14 @@ func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { }() } -func (g *Gateway) Start(ctx context.Context) error { - log.Info().Msgf("Starting gateway") +func (c *Connector) Start(ctx context.Context) error { + log.Info().Msgf("Starting connector") errCh := make(chan error, 1) - g.registerHeartBeat(ctx, errCh) + c.registerHeartBeat(ctx, errCh) // Start certificate renewal goroutine - go g.startCertificateRenewal(ctx) + go c.startCertificateRenewal(ctx) go func() { for { @@ -173,15 +173,15 @@ func (g *Gateway) Start(ctx context.Context) error { for { select { case <-ctx.Done(): - log.Info().Msgf("Gateway stopped by context cancellation") + log.Info().Msgf("Connector stopped by context cancellation") return nil default: - if err := g.connectAndServe(); err != nil { - log.Error().Msgf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + if err := c.connectAndServe(); err != nil { + log.Error().Msgf("Connection failed: %v, retrying in %v...", err, c.config.ReconnectDelay) select { case <-ctx.Done(): return ctx.Err() - case <-time.After(g.config.ReconnectDelay): + case <-time.After(c.config.ReconnectDelay): continue } } @@ -197,51 +197,51 @@ func (g *Gateway) Start(ctx context.Context) error { } } -func (g *Gateway) SetToken(token string) { - g.httpClient.SetAuthToken(token) +func (c *Connector) SetToken(token string) { + c.httpClient.SetAuthToken(token) } -func (g *Gateway) Stop() { - g.cancel() +func (c *Connector) Stop() { + c.cancel() - g.mu.Lock() - if g.sshClient != nil { - g.sshClient.Close() - g.sshClient = nil + c.mu.Lock() + if c.sshClient != nil { + c.sshClient.Close() + c.sshClient = nil } - g.isConnected = false - g.mu.Unlock() + c.isConnected = false + c.mu.Unlock() } -func (g *Gateway) connectAndServe() error { - if err := g.registerGateway(); err != nil { - return fmt.Errorf("failed to register gateway: %v", err) +func (c *Connector) connectAndServe() error { + if err := c.registerConnector(); err != nil { + return fmt.Errorf("failed to register connector: %v", err) } // Create SSH client config - sshConfig, err := g.createSSHConfig() + sshConfig, err := c.createSSHConfig() if err != nil { return fmt.Errorf("failed to create SSH config: %v", err) } - // Connect to Proxy server - log.Info().Msgf("Connecting to proxy server %s on %s:%d...", g.config.ProxyName, g.certificates.ProxyIP, g.config.SSHPort) - client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) + // Connect to Relay server + log.Info().Msgf("Connecting to relay server %s on %s:%d...", c.config.RelayName, c.certificates.RelayIP, c.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", c.certificates.RelayIP, c.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } - log.Info().Msgf("Proxy connection established for gateway") + log.Info().Msgf("Relay connection established for connector") - g.mu.Lock() - g.sshClient = client - g.isConnected = true - g.mu.Unlock() + c.mu.Lock() + c.sshClient = client + c.isConnected = true + c.mu.Unlock() defer func() { - g.mu.Lock() - g.sshClient = nil - g.isConnected = false - g.mu.Unlock() + c.mu.Lock() + c.sshClient = nil + c.isConnected = false + c.mu.Unlock() client.Close() }() @@ -253,57 +253,57 @@ func (g *Gateway) connectAndServe() error { // Monitor for context cancellation and close SSH client go func() { - <-g.ctx.Done() - log.Info().Msg("Context cancelled, closing proxy connection...") + <-c.ctx.Done() + log.Info().Msg("Context cancelled, closing relay connection...") client.Close() }() // Process incoming channels with context cancellation support for { select { - case <-g.ctx.Done(): + case <-c.ctx.Done(): log.Info().Msg("Context cancelled, stopping channel processing") - return g.ctx.Err() + return c.ctx.Err() case newChannel, ok := <-channels: if !ok { log.Info().Msg("SSH channels closed") return nil } - go g.handleIncomingChannel(newChannel) + go c.handleIncomingChannel(newChannel) } } } -func (g *Gateway) registerGateway() error { - body := api.RegisterGatewayRequest{ - ProxyName: g.config.ProxyName, - Name: g.config.Name, +func (c *Connector) registerConnector() error { + body := api.RegisterConnectorRequest{ + RelayName: c.config.RelayName, + Name: c.config.Name, } - certResp, err := api.CallRegisterGateway(g.httpClient, body) + certResp, err := api.CallRegisterConnector(c.httpClient, body) if err != nil { - return fmt.Errorf("failed to register gateway: %v", err) + return fmt.Errorf("failed to register connector: %v", err) } - g.GatewayID = certResp.GatewayID - g.certificates = &certResp - log.Info().Msgf("Successfully registered gateway and received certificates") + c.ConnectorID = certResp.ConnectorID + c.certificates = &certResp + log.Info().Msgf("Successfully registered connector and received certificates") // Setup mTLS config - if err := g.setupTLSConfig(); err != nil { + if err := c.setupTLSConfig(); err != nil { return fmt.Errorf("failed to setup TLS config: %v", err) } return nil } -func (g *Gateway) setupTLSConfig() error { - serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) +func (c *Connector) setupTLSConfig() error { + serverCertBlock, _ := pem.Decode([]byte(c.certificates.PKI.ServerCertificate)) if serverCertBlock == nil { return fmt.Errorf("failed to decode server certificate") } - serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + serverKeyBlock, _ := pem.Decode([]byte(c.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { return fmt.Errorf("failed to decode server private key") } @@ -315,7 +315,7 @@ func (g *Gateway) setupTLSConfig() error { clientCAPool := x509.NewCertPool() var chainCerts [][]byte - chainData := []byte(g.certificates.PKI.ClientCertificateChain) + chainData := []byte(c.certificates.PKI.ClientCertificateChain) for { block, rest := pem.Decode(chainData) if block == nil { @@ -334,7 +334,7 @@ func (g *Gateway) setupTLSConfig() error { clientCAPool.AddCert(cert) } - g.tlsConfig = &tls.Config{ + c.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { Certificate: [][]byte{serverCertBlock.Bytes}, @@ -349,14 +349,14 @@ func (g *Gateway) setupTLSConfig() error { return nil } -func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { - privateKey, err := ssh.ParsePrivateKey([]byte(g.certificates.SSH.ClientPrivateKey)) +func (c *Connector) createSSHConfig() (*ssh.ClientConfig, error) { + privateKey, err := ssh.ParsePrivateKey([]byte(c.certificates.SSH.ClientPrivateKey)) if err != nil { return nil, fmt.Errorf("failed to parse SSH private key: %v", err) } // Parse certificate - cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ClientCertificate)) + cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(c.certificates.SSH.ClientCertificate)) if err != nil { return nil, fmt.Errorf("failed to parse certificate: %v", err) } @@ -374,11 +374,11 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { // Create SSH client config config := &ssh.ClientConfig{ - User: g.GatewayID, + User: c.ConnectorID, Auth: []ssh.AuthMethod{ ssh.PublicKeys(certSigner), }, - HostKeyCallback: g.createHostKeyCallback(), + HostKeyCallback: c.createHostKeyCallback(), Timeout: 30 * time.Second, Config: ssh.Config{ KeyExchanges: []string{ @@ -401,8 +401,8 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { return config, nil } -func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { - caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) +func (c *Connector) createHostKeyCallback() ssh.HostKeyCallback { + caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(c.certificates.SSH.ServerCAPublicKey)) if err != nil { return func(hostname string, remote net.Addr, key ssh.PublicKey) error { return fmt.Errorf("failed to parse CA public key: %v", err) @@ -415,11 +415,11 @@ func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { return fmt.Errorf("host certificates required, raw host keys not allowed") } - return g.validateHostCertificate(cert, hostname, caKey) + return c.validateHostCertificate(cert, hostname, caKey) } } -func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string, caKey ssh.PublicKey) error { +func (c *Connector) validateHostCertificate(cert *ssh.Certificate, hostname string, caKey ssh.PublicKey) error { checker := &ssh.CertChecker{ IsHostAuthority: func(auth ssh.PublicKey, address string) bool { return bytes.Equal(auth.Marshal(), caKey.Marshal()) @@ -433,7 +433,7 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string return nil } -func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { +func (c *Connector) handleIncomingChannel(newChannel ssh.NewChannel) { channel, requests, err := newChannel.Accept() if err != nil { log.Info().Msgf("Failed to accept channel: %v", err) @@ -444,7 +444,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { go ssh.DiscardRequests(requests) // Create mTLS server configuration - tlsConfig := g.tlsConfig + tlsConfig := c.tlsConfig if tlsConfig == nil { log.Info().Msgf("TLS config not initialized, cannot create mTLS server") return @@ -468,7 +468,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { reader := bufio.NewReader(tlsConn) // Get the forward mode here - forwardConfig, err := g.parseForwardConfig(tlsConn, reader) + forwardConfig, err := c.parseForwardConfig(tlsConn, reader) if err != nil { log.Info().Msgf("Failed to parse forward command: %v", err) return @@ -477,21 +477,21 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Info().Msgf("Forward config: %+v", forwardConfig) if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig) + handleHTTPProxy(c.ctx, tlsConn, reader, forwardConfig) return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(g.ctx, tlsConn, forwardConfig) + handleTCPProxy(c.ctx, tlsConn, forwardConfig) return } else if forwardConfig.Mode == ForwardModePing { - handlePing(g.ctx, tlsConn, reader) + handlePing(c.ctx, tlsConn, reader) return } } -func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { +func (c *Connector) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { config := &ForwardConfig{} - if err := g.parseDetailsFromCertificate(tlsConn, config); err != nil { + if err := c.parseDetailsFromCertificate(tlsConn, config); err != nil { return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } @@ -512,7 +512,7 @@ func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (* case "FORWARD-HTTP": config.Mode = ForwardModeHTTP if args != "" { - if err := g.parseForwardHTTPParams(args, config); err != nil { + if err := c.parseForwardHTTPParams(args, config); err != nil { return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) } } @@ -529,7 +529,7 @@ func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (* } } -func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) error { +func (c *Connector) parseForwardHTTPParams(params string, config *ForwardConfig) error { parts := strings.Fields(params) for _, part := range parts { @@ -553,7 +553,7 @@ func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) e return nil } -func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { +func (c *Connector) parseDetailsFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { // Get the peer certificates state := tlsConn.ConnectionState() if len(state.PeerCertificates) == 0 { @@ -564,7 +564,7 @@ func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *Forward for _, ext := range clientCert.Extensions { // Extract target host and port from client certificate custom extension - if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { + if ext.Id.String() == CONNECTOR_ROUTING_INFO_OID { var routingInfo RoutingInfo if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { return fmt.Errorf("failed to parse routing info JSON: %v", err) @@ -574,7 +574,7 @@ func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *Forward config.TargetPort = routingInfo.TargetPort } // Extract actor type from client certificate custom extension - if ext.Id.String() == GATEWAY_ACTOR_OID { + if ext.Id.String() == CONNECTOR_ACTOR_OID { var actorDetails ActorDetails if err := json.Unmarshal(ext.Value, &actorDetails); err != nil { return fmt.Errorf("failed to parse actor details JSON: %v", err) @@ -624,32 +624,32 @@ func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { } // startCertificateRenewal runs a background process to renew certificates every 10 days -func (g *Gateway) startCertificateRenewal(ctx context.Context) { - log.Info().Msg("Starting gateway certificate renewal goroutine") +func (c *Connector) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting connector certificate renewal goroutine") ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() for { select { case <-ctx.Done(): - log.Info().Msg("Gateway certificate renewal goroutine stopping...") + log.Info().Msg("Connector certificate renewal goroutine stopping...") return case <-ticker.C: - log.Info().Msg("Renewing gateway certificates...") - if err := g.renewCertificates(); err != nil { - log.Error().Msgf("Failed to renew gateway certificates: %v", err) + log.Info().Msg("Renewing connector certificates...") + if err := c.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew connector certificates: %v", err) } else { - log.Info().Msg("Gateway certificates renewed successfully") + log.Info().Msg("Connector certificates renewed successfully") } } } } -// renewCertificates fetches new certificates and updates the gateway configurations -func (g *Gateway) renewCertificates() error { - // Re-register gateway to get fresh certificates - if err := g.registerGateway(); err != nil { - return fmt.Errorf("failed to register gateway: %v", err) +// renewCertificates fetches new certificates and updates the connector configurations +func (c *Connector) renewCertificates() error { + // Re-register connector to get fresh certificates + if err := c.registerConnector(); err != nil { + return fmt.Errorf("failed to register connector: %v", err) } return nil diff --git a/packages/gateway-v2/constants.go b/packages/connector/constants.go similarity index 55% rename from packages/gateway-v2/constants.go rename to packages/connector/constants.go index de54cd6f..360b201f 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/connector/constants.go @@ -1,4 +1,4 @@ -package gatewayv2 +package connector const ( KUBERNETES_SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" @@ -6,10 +6,10 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" - GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" + RELAY_NAME_ENV_NAME = "INFISICAL_RELAY_NAME" + CONNECTOR_NAME_ENV_NAME = "INFISICAL_CONNECTOR_NAME" - PROXY_AUTH_SECRET_ENV_NAME = "INFISICAL_PROXY_AUTH_SECRET" + RELAY_AUTH_SECRET_ENV_NAME = "INFISICAL_RELAY_AUTH_SECRET" INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" ) @@ -17,6 +17,6 @@ const ( type HttpProxyAction string const ( - HttpProxyActionInjectGatewayK8sServiceAccountToken HttpProxyAction = "inject-k8s-sa-auth-token" - HttpProxyActionUseGatewayK8sServiceAccount HttpProxyAction = "use-k8s-sa" + HttpProxyActionInjectConnectorK8sServiceAccountToken HttpProxyAction = "inject-k8s-sa-auth-token" + HttpProxyActionUseConnectorK8sServiceAccount HttpProxyAction = "use-k8s-sa" ) diff --git a/packages/gateway-v2/systemd.go b/packages/connector/systemd.go similarity index 73% rename from packages/gateway-v2/systemd.go rename to packages/connector/systemd.go index 794509ea..a0c4d67e 100644 --- a/packages/gateway-v2/systemd.go +++ b/packages/connector/systemd.go @@ -1,4 +1,4 @@ -package gatewayv2 +package connector import ( "fmt" @@ -11,14 +11,14 @@ import ( ) const systemdServiceTemplate = `[Unit] -Description=Infisical Gateway Service +Description=Infisical Connector Service After=network.target [Service] Type=notify NotifyAccess=all -EnvironmentFile=/etc/infisical/gateway.conf -ExecStart=infisical network gateway +EnvironmentFile=/etc/infisical/connector.conf +ExecStart=infisical connector start Restart=on-failure InaccessibleDirectories=/home PrivateTmp=yes @@ -32,7 +32,7 @@ LimitRTTIME=7000000 WantedBy=multi-user.target ` -func InstallGatewaySystemdService(token string, domain string, name string, proxyName string) error { +func InstallConnectorSystemdService(token string, domain string, name string, relayName string) error { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service installation - not on Linux") return nil @@ -54,18 +54,18 @@ func InstallGatewaySystemdService(token string, domain string, name string, prox } if name != "" { - configContent += fmt.Sprintf("%s=%s\n", GATEWAY_NAME_ENV_NAME, name) + configContent += fmt.Sprintf("%s=%s\n", CONNECTOR_NAME_ENV_NAME, name) } - if proxyName != "" { - configContent += fmt.Sprintf("%s=%s\n", PROXY_NAME_ENV_NAME, proxyName) + if relayName != "" { + configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) } - configPath := filepath.Join(configDir, "gateway.conf") + configPath := filepath.Join(configDir, "connector.conf") if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { return fmt.Errorf("failed to write config file: %v", err) } - servicePath := "/etc/systemd/system/infisical-gateway.service" + servicePath := "/etc/systemd/system/infisical-connector.service" if err := os.WriteFile(servicePath, []byte(systemdServiceTemplate), 0644); err != nil { return fmt.Errorf("failed to write systemd service file: %v", err) } @@ -76,13 +76,13 @@ func InstallGatewaySystemdService(token string, domain string, name string, prox } log.Info().Msg("Successfully installed systemd service") - log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") - log.Info().Msg("To enable the service on boot, run: sudo systemctl enable infisical-gateway") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-connector") + log.Info().Msg("To enable the service on boot, run: sudo systemctl enable infisical-connector") return nil } -func UninstallGatewaySystemdService() error { +func UninstallConnectorSystemdService() error { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service uninstallation - not on Linux") return nil @@ -94,25 +94,25 @@ func UninstallGatewaySystemdService() error { } // Stop the service if it's running - stopCmd := exec.Command("systemctl", "stop", "infisical-gateway") + stopCmd := exec.Command("systemctl", "stop", "infisical-connector") if err := stopCmd.Run(); err != nil { log.Warn().Msgf("Failed to stop service: %v", err) } // Disable the service - disableCmd := exec.Command("systemctl", "disable", "infisical-gateway") + disableCmd := exec.Command("systemctl", "disable", "infisical-connector") if err := disableCmd.Run(); err != nil { log.Warn().Msgf("Failed to disable service: %v", err) } // Remove the service file - servicePath := "/etc/systemd/system/infisical-gateway.service" + servicePath := "/etc/systemd/system/infisical-connector.service" if err := os.Remove(servicePath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove systemd service file: %v", err) } // Remove the configuration file - configPath := "/etc/infisical/gateway.conf" + configPath := "/etc/infisical/connector.conf" if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove config file: %v", err) } @@ -123,6 +123,6 @@ func UninstallGatewaySystemdService() error { return fmt.Errorf("failed to reload systemd: %v", err) } - log.Info().Msg("Successfully uninstalled Infisical Gateway systemd service") + log.Info().Msg("Successfully uninstalled Infisical Connector systemd service") return nil } diff --git a/packages/proxy/proxy.go b/packages/relay/relay.go similarity index 60% rename from packages/proxy/proxy.go rename to packages/relay/relay.go index 2a2a6301..d2bb5da7 100644 --- a/packages/proxy/proxy.go +++ b/packages/relay/relay.go @@ -1,4 +1,4 @@ -package proxy +package relay import ( "bytes" @@ -20,10 +20,10 @@ import ( "golang.org/x/crypto/ssh" ) -type ProxyConfig struct { +type RelayConfig struct { // API Configuration Token string - ProxyName string + RelayName string Type string @@ -35,12 +35,12 @@ type ProxyConfig struct { StaticIP string } -type Proxy struct { +type Relay struct { httpClient *resty.Client - config *ProxyConfig + config *RelayConfig // Certificate storage - certificates *api.RegisterProxyResponse + certificates *api.RegisterRelayResponse // SSH server components sshConfig *ssh.ServerConfig @@ -51,7 +51,7 @@ type Proxy struct { tlsCACert []byte tlsCAKey *rsa.PrivateKey - // Tunnel storage (Gateway ID -> SSH connection) + // Tunnel storage (Connector ID -> SSH connection) tunnels map[string]*ssh.ServerConn mu sync.RWMutex @@ -60,7 +60,7 @@ type Proxy struct { tlsListener net.Listener } -func NewProxy(config *ProxyConfig) (*Proxy, error) { +func NewRelay(config *RelayConfig) (*Relay, error) { httpClient, err := util.GetRestyClientWithCustomHeaders() if err != nil { return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) @@ -68,90 +68,90 @@ func NewProxy(config *ProxyConfig) (*Proxy, error) { httpClient.SetAuthToken(config.Token) - return &Proxy{ + return &Relay{ httpClient: httpClient, config: config, tunnels: make(map[string]*ssh.ServerConn), }, nil } -func (p *Proxy) SetToken(token string) { - p.httpClient.SetAuthToken(token) +func (r *Relay) SetToken(token string) { + r.httpClient.SetAuthToken(token) } -func (p *Proxy) Start(ctx context.Context) error { - if err := p.registerProxy(); err != nil { - return fmt.Errorf("failed to register proxy: %v", err) +func (r *Relay) Start(ctx context.Context) error { + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) } // Setup SSH server - if err := p.setupSSHServer(); err != nil { + if err := r.setupSSHServer(); err != nil { return fmt.Errorf("failed to setup SSH server: %v", err) } // Setup TLS server - if err := p.setupTLSServer(); err != nil { + if err := r.setupTLSServer(); err != nil { return fmt.Errorf("failed to setup TLS server: %v", err) } // Start certificate renewal goroutine - go p.startCertificateRenewal(ctx) + go r.startCertificateRenewal(ctx) // Start SSH server - go p.startSSHServer() + go r.startSSHServer() // Start TLS server - go p.startTLSServer() + go r.startTLSServer() - log.Info().Msg("Proxy server started successfully") + log.Info().Msg("Relay server started successfully") // Wait for context cancellation <-ctx.Done() // Cleanup - p.cleanup() + r.cleanup() return nil } -func (p *Proxy) registerProxy() error { - body := api.RegisterProxyRequest{ - IP: p.config.StaticIP, - Name: p.config.ProxyName, +func (r *Relay) registerRelay() error { + body := api.RegisterRelayRequest{ + IP: r.config.StaticIP, + Name: r.config.RelayName, } - if p.config.Type == "instance" { - certResp, err := api.CallRegisterInstanceProxy(p.httpClient, body) + if r.config.Type == "instance" { + certResp, err := api.CallRegisterInstanceRelay(r.httpClient, body) if err != nil { - return fmt.Errorf("failed to register instance proxy: %v", err) + return fmt.Errorf("failed to register instance relay: %v", err) } - p.certificates = &certResp + r.certificates = &certResp } else { - certResp, err := api.CallRegisterProxy(p.httpClient, body) + certResp, err := api.CallRegisterRelay(r.httpClient, body) if err != nil { - return fmt.Errorf("failed to register org proxy: %v", err) + return fmt.Errorf("failed to register org relay: %v", err) } - p.certificates = &certResp + r.certificates = &certResp } - log.Info().Msg("Successfully registered proxy and received certificates from API") + log.Info().Msg("Successfully registered relay and received certificates from API") return nil } -func (p *Proxy) setupSSHServer() error { +func (r *Relay) setupSSHServer() error { // Parse SSH CA public key - sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ClientCAPublicKey)) + sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ClientCAPublicKey)) if err != nil { return fmt.Errorf("failed to parse SSH CA public key: %v", err) } // Parse SSH server private key - sshServerKey, err := ssh.ParsePrivateKey([]byte(p.certificates.SSH.ServerPrivateKey)) + sshServerKey, err := ssh.ParsePrivateKey([]byte(r.certificates.SSH.ServerPrivateKey)) if err != nil { return fmt.Errorf("failed to parse SSH server private key: %v", err) } // Parse SSH server certificate - sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ServerCertificate)) + sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ServerCertificate)) if err != nil { return fmt.Errorf("failed to parse SSH server certificate: %v", err) } @@ -163,53 +163,53 @@ func (p *Proxy) setupSSHServer() error { } // Setup SSH server config - p.sshConfig = &ssh.ServerConfig{ + r.sshConfig = &ssh.ServerConfig{ MaxAuthTries: 3, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) if !ok { - log.Warn().Msgf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + log.Warn().Msgf("Connector '%s' tried to authenticate with raw public key (rejected)", conn.User()) return nil, fmt.Errorf("certificates required, raw public keys not allowed") } // Validate the certificate - if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { - log.Error().Msgf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + if err := r.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { + log.Error().Msgf("Connector '%s' certificate validation failed: %v", conn.User(), err) return nil, err } - gatewayId := "" + connectorId := "" if len(cert.ValidPrincipals) > 0 { - gatewayId = cert.ValidPrincipals[0] + connectorId = cert.ValidPrincipals[0] } - if gatewayId == "" { - return nil, fmt.Errorf("gateway id is required") + if connectorId == "" { + return nil, fmt.Errorf("connector id is required") } - // Validate that the user is authorized to connect to the current proxy - expectedKeyId := "client-" + p.config.ProxyName + // Validate that the user is authorized to connect to the current relay + expectedKeyId := "client-" + r.config.RelayName if cert.KeyId != expectedKeyId { - log.Error().Msgf("Gateway '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) + log.Error().Msgf("Connector '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) return nil, fmt.Errorf("certificate Key ID does not match expected value") } return &ssh.Permissions{ Extensions: map[string]string{ - "gateway-id": gatewayId, + "connector-id": connectorId, }, }, nil }, } - p.sshConfig.AddHostKey(certSigner) + r.sshConfig.AddHostKey(certSigner) return nil } -func (p *Proxy) setupTLSServer() error { +func (r *Relay) setupTLSServer() error { // Parse TLS server certificate - serverCertBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerCertificate)) + serverCertBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerCertificate)) if serverCertBlock == nil { return fmt.Errorf("failed to decode server certificate") } @@ -222,7 +222,7 @@ func (p *Proxy) setupTLSServer() error { } // Parse TLS server private key - serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) + serverKeyBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { return fmt.Errorf("failed to decode server private key") } @@ -236,7 +236,7 @@ func (p *Proxy) setupTLSServer() error { clientCAPool := x509.NewCertPool() var chainCerts [][]byte - chainData := []byte(p.certificates.PKI.ClientCertificateChain) + chainData := []byte(r.certificates.PKI.ClientCertificateChain) for { block, rest := pem.Decode(chainData) if block == nil { @@ -256,7 +256,7 @@ func (p *Proxy) setupTLSServer() error { } // Create TLS config - p.tlsConfig = &tls.Config{ + r.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { Certificate: [][]byte{serverCertBlock.Bytes}, @@ -271,7 +271,7 @@ func (p *Proxy) setupTLSServer() error { return nil } -func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { +func (r *Relay) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { // Check certificate type if cert.CertType != ssh.UserCert { return fmt.Errorf("invalid certificate type: %d", cert.CertType) @@ -293,14 +293,14 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return nil } -func (p *Proxy) startSSHServer() { - listener, err := net.Listen("tcp", ":"+p.config.SSHPort) +func (r *Relay) startSSHServer() { + listener, err := net.Listen("tcp", ":"+r.config.SSHPort) if err != nil { log.Fatal().Msgf("Failed to start SSH server: %v", err) } - p.sshListener = listener + r.sshListener = listener - log.Info().Msgf("SSH server listening on :%s for gateways", p.config.SSHPort) + log.Info().Msgf("SSH server listening on :%s for connectors", r.config.SSHPort) for { conn, err := listener.Accept() @@ -308,47 +308,47 @@ func (p *Proxy) startSSHServer() { log.Error().Msgf("Failed to accept SSH connection: %v", err) continue } - go p.handleSSHAgent(conn) + go r.handleSSHAgent(conn) } } -func (p *Proxy) handleSSHAgent(conn net.Conn) { +func (r *Relay) handleSSHAgent(conn net.Conn) { defer conn.Close() // SSH handshake - sshConn, chans, reqs, err := ssh.NewServerConn(conn, p.sshConfig) + sshConn, chans, reqs, err := ssh.NewServerConn(conn, r.sshConfig) if err != nil { log.Error().Msgf("SSH handshake failed: %v", err) return } - gatewayId := sshConn.Permissions.Extensions["gateway-id"] - log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) + connectorId := sshConn.Permissions.Extensions["connector-id"] + log.Info().Msgf("SSH handshake successful for connector: %s", connectorId) - // Store the connection (ensure only one connection per gateway) - p.mu.Lock() - if _, exists := p.tunnels[gatewayId]; exists { - p.mu.Unlock() - log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) + // Store the connection (ensure only one connection per connector) + r.mu.Lock() + if _, exists := r.tunnels[connectorId]; exists { + r.mu.Unlock() + log.Warn().Msgf("Connector '%s' already has an active connection, rejecting new connection", connectorId) sshConn.Close() return } - p.tunnels[gatewayId] = sshConn - p.mu.Unlock() + r.tunnels[connectorId] = sshConn + r.mu.Unlock() // Clean up when agent disconnects defer func() { - p.mu.Lock() - delete(p.tunnels, gatewayId) - p.mu.Unlock() - log.Info().Msgf("Gateway %s disconnected", gatewayId) + r.mu.Lock() + delete(r.tunnels, connectorId) + r.mu.Unlock() + log.Info().Msgf("Connector %s disconnected", connectorId) }() // Handle global requests (reject all for security) go func() { for req := range reqs { - log.Debug().Msgf("Rejecting global request: %s from gateway %s", req.Type, gatewayId) + log.Debug().Msgf("Rejecting global request: %s from connector %s", req.Type, connectorId) if req.WantReply { req.Reply(false, nil) } @@ -359,32 +359,32 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { for newChannel := range chans { switch newChannel.ChannelType() { case "session": - log.Debug().Msgf("Rejecting session channel from gateway %s", gatewayId) + log.Debug().Msgf("Rejecting session channel from connector %s", connectorId) newChannel.Reject(ssh.Prohibited, "no shell access") case "x11": - log.Debug().Msgf("Rejecting X11 forwarding from gateway %s", gatewayId) + log.Debug().Msgf("Rejecting X11 forwarding from connector %s", connectorId) newChannel.Reject(ssh.Prohibited, "no X11 forwarding") case "auth-agent": - log.Debug().Msgf("Rejecting auth-agent forwarding from gateway %s", gatewayId) + log.Debug().Msgf("Rejecting auth-agent forwarding from connector %s", connectorId) newChannel.Reject(ssh.Prohibited, "no agent forwarding") case "forwarded-tcpip": - log.Debug().Msgf("Rejecting forwarded-tcpip from gateway %s", gatewayId) + log.Debug().Msgf("Rejecting forwarded-tcpip from connector %s", connectorId) newChannel.Reject(ssh.Prohibited, "no port forwarding") default: - log.Warn().Msgf("Rejecting unknown channel type '%s' from gateway %s", newChannel.ChannelType(), gatewayId) + log.Warn().Msgf("Rejecting unknown channel type '%s' from connector %s", newChannel.ChannelType(), connectorId) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") } } } -func (p *Proxy) startTLSServer() { - listener, err := net.Listen("tcp", ":"+p.config.TLSPort) +func (r *Relay) startTLSServer() { + listener, err := net.Listen("tcp", ":"+r.config.TLSPort) if err != nil { log.Fatal().Msgf("Failed to start TLS server: %v", err) } - p.tlsListener = listener + r.tlsListener = listener - log.Info().Msgf("TLS server listening on :%s for clients", p.config.TLSPort) + log.Info().Msgf("TLS server listening on :%s for clients", r.config.TLSPort) for { conn, err := listener.Accept() @@ -392,15 +392,15 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - go p.handleTLSClient(conn) + go r.handleTLSClient(conn) } } -func (p *Proxy) handleTLSClient(conn net.Conn) { +func (r *Relay) handleTLSClient(conn net.Conn) { defer conn.Close() // Perform TLS handshake using current TLS config - tlsConn := tls.Server(conn, p.tlsConfig) + tlsConn := tls.Server(conn, r.tlsConfig) defer tlsConn.Close() // Set handshake timeout to avoid hanging on slow/malicious connections @@ -416,39 +416,39 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { // Clear deadline for actual data transfer tlsConn.SetDeadline(time.Time{}) - p.handleClient(tlsConn) + r.handleClient(tlsConn) } -func (p *Proxy) handleClient(tlsConn *tls.Conn) { - var gatewayId string +func (r *Relay) handleClient(tlsConn *tls.Conn) { + var connectorId string state := tlsConn.ConnectionState() if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) - gatewayId = cert.Subject.CommonName + connectorId = cert.Subject.CommonName } else { log.Warn().Msg("No peer certificates found") return } - // Get the SSH connection for this agent - p.mu.RLock() - conn, exists := p.tunnels[gatewayId] - p.mu.RUnlock() + // Get the SSH connection for this connector + r.mu.RLock() + conn, exists := r.tunnels[connectorId] + r.mu.RUnlock() if !exists { - log.Warn().Msgf("Gateway '%s' not connected", gatewayId) - tlsConn.Write([]byte("ERROR: Gateway not connected\n")) + log.Warn().Msgf("Connector '%s' not connected", connectorId) + tlsConn.Write([]byte("ERROR: Connector not connected\n")) return } - log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) + log.Info().Msgf("Routing TCP connection to connector: %s", connectorId) channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { - log.Error().Msgf("Failed to connect to agent: %v", err) - tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) + log.Error().Msgf("Failed to connect to connector: %v", err) + tlsConn.Write([]byte("ERROR: Failed to connect to connector\n")) return } defer channel.Close() @@ -463,21 +463,21 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { log.Info().Msgf("Client %s disconnected", tlsConn.RemoteAddr()) } -func (p *Proxy) cleanup() { - log.Info().Msg("Shutting down proxy server...") +func (r *Relay) cleanup() { + log.Info().Msg("Shutting down relay server...") - if p.sshListener != nil { - p.sshListener.Close() + if r.sshListener != nil { + r.sshListener.Close() } - if p.tlsListener != nil { - p.tlsListener.Close() + if r.tlsListener != nil { + r.tlsListener.Close() } - log.Info().Msg("Proxy server shutdown complete") + log.Info().Msg("Relay server shutdown complete") } // startCertificateRenewal runs a background process to renew certificates every 24 hours -func (p *Proxy) startCertificateRenewal(ctx context.Context) { +func (r *Relay) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Starting certificate renewal goroutine") ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() @@ -489,7 +489,7 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { return case <-ticker.C: log.Info().Msg("Renewing certificates...") - if err := p.renewCertificates(); err != nil { + if err := r.renewCertificates(); err != nil { log.Error().Msgf("Failed to renew certificates: %v", err) } else { log.Info().Msg("Certificates renewed successfully") @@ -499,19 +499,19 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { } // renewCertificates fetches new certificates and updates the server configurations -func (p *Proxy) renewCertificates() error { - // Re-register proxy to get fresh certificates - if err := p.registerProxy(); err != nil { - return fmt.Errorf("failed to register proxy: %v", err) +func (r *Relay) renewCertificates() error { + // Re-register relay to get fresh certificates + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) } // Update SSH server configuration - if err := p.setupSSHServer(); err != nil { + if err := r.setupSSHServer(); err != nil { return fmt.Errorf("failed to setup SSH server: %v", err) } // Update TLS server configuration - if err := p.setupTLSServer(); err != nil { + if err := r.setupTLSServer(); err != nil { return fmt.Errorf("failed to setup TLS server: %v", err) }