diff --git a/cmd/start.go b/cmd/start.go index 5c78be1..aad3a7b 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -147,7 +147,21 @@ var rootCmd = &cobra.Command{ if err != nil { return fmt.Errorf("failed to generate certificate: %w", err) } - newServer, err := stack.StartServer(ctx, logger, tlsConfig, urls) + var signingKey *ecdsa.PublicKey + if len(c.SigningPublicKey) > 0 { + block, _ := pem.Decode(c.SigningPublicKey) + if block == nil { + logger.Warn("received signing public key but PEM decode failed") + } else if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err != nil { + logger.Warn("received signing public key but PKIX parse failed: %v", err) + } else if ecKey, ok := pub.(*ecdsa.PublicKey); !ok { + logger.Warn("received signing public key but unexpected type: %T", pub) + } else { + signingKey = ecKey + logger.Info("using upstream signing key for proxy verification") + } + } + newServer, err := stack.StartServer(ctx, logger, tlsConfig, urls, agent, signingKey) if err != nil { return fmt.Errorf("failed to start server: %w", err) } diff --git a/go.mod b/go.mod index 11396c5..839ebbd 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/agentuity/gravity go 1.26.1 require ( - github.com/agentuity/go-common v1.0.165 + github.com/agentuity/go-common v1.0.181 github.com/spf13/cobra v1.10.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) diff --git a/go.sum b/go.sum index 50479d5..87d5d85 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/agentuity/go-common v1.0.165 h1:hxaGFRg05/ir4OqoIpCQHhgGhs9VKeR564EsLhEN8Gc= -github.com/agentuity/go-common v1.0.165/go.mod h1:uW1IsiE9ydoK6HRwr8jgEE8wVXSXoFzhm/AJ8Q4xlos= +github.com/agentuity/go-common v1.0.181 h1:+mJSQhZdPj++ZxSyIwM3BtG7GcCmLAahWlfRcfaI2Lc= +github.com/agentuity/go-common v1.0.181/go.mod h1:YuiBVsz9WZ5K1vW2cjHvjUnuA65t7YZTHj4Nq11b0UQ= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/errors v1.12.0 h1:d7oCs6vuIMUQRVbi6jWWWEJZahLCfJpnJSVobd1/sUo= diff --git a/internal/stack/stack.go b/internal/stack/stack.go index 6413add..e268052 100644 --- a/internal/stack/stack.go +++ b/internal/stack/stack.go @@ -1,6 +1,7 @@ package stack import ( + "bytes" "context" "crypto/ecdsa" "crypto/tls" @@ -17,6 +18,7 @@ import ( "sync" "time" + agcrypto "github.com/agentuity/go-common/crypto" "github.com/agentuity/go-common/gravity" "github.com/agentuity/go-common/gravity/proto" "github.com/agentuity/go-common/gravity/provider" @@ -131,7 +133,7 @@ func GenerateCertificate(_ context.Context, logger _logger.Logger, bundle string return tlsConfig, nil } -func StartServer(ctx context.Context, logger _logger.Logger, tlsConfig *tls.Config, urls UrlsMetadata) (*http.Server, error) { +func StartServer(ctx context.Context, logger _logger.Logger, tlsConfig *tls.Config, urls UrlsMetadata, agent AgentMetadata, signingKey *ecdsa.PublicKey) (*http.Server, error) { // Set up reverse proxy to the agent server agentURL := fmt.Sprintf("http://127.0.0.1:%d", urls.LocalPort) @@ -140,31 +142,36 @@ func StartServer(ctx context.Context, logger _logger.Logger, tlsConfig *tls.Conf return nil, fmt.Errorf("failed to parse agent URL: %w", err) } - proxy := httputil.NewSingleHostReverseProxy(upstreamURL) - // Override the Director to restore the original public hostname from - // X-Forwarded-Host. Without this, the Host header may contain internal - // routing names (e.g., "*.agentuity-us.live.internal") that leak through - // to the local dev server. Vite and other dev servers check the Host - // header against their allowedHosts list and reject unrecognized hostnames. - defaultDirector := proxy.Director - proxy.Director = func(req *http.Request) { - defaultDirector(req) - if fwdHost := req.Header.Get("X-Forwarded-Host"); fwdHost != "" { - req.Host = fwdHost - } - } - proxy.FlushInterval = -1 - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - // Suppress expected context cancellation errors (client disconnect, WebSocket close) - if errors.Is(ctx.Err(), context.Canceled) || errors.Is(r.Context().Err(), context.Canceled) { - return - } - logger.Error("proxy error: %v", err) - http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + proxy := &httputil.ReverseProxy{ + // Use Rewrite instead of the deprecated Director. Restore the original + // public hostname from X-Forwarded-Host so that the Host header sent + // to the local dev server matches the public URL. Vite and other dev + // servers check the Host header against their allowedHosts list and + // reject unrecognized hostnames. + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(upstreamURL) + if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" { + r.Out.Host = fwdHost + } + }, + FlushInterval: -1, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + // Suppress expected context cancellation errors (client disconnect, WebSocket close) + if errors.Is(ctx.Err(), context.Canceled) || errors.Is(r.Context().Err(), context.Canceled) { + return + } + logger.Error("proxy error: %v", err) + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + }, + ModifyResponse: func(resp *http.Response) error { + logger.Trace("response %s: %d", resp.Request.URL.Path, resp.StatusCode) + return nil + }, } - proxy.ModifyResponse = func(resp *http.Response) error { - logger.Trace("response %s: %d", resp.Request.URL.Path, resp.StatusCode) - return nil + + // Log once if signing key is missing (avoid flooding logs on every request). + if signingKey == nil && agent.PrivateKey != nil { + logger.Warn("no upstream signing key available, skipping signature verification") } server := &http.Server{ @@ -186,6 +193,7 @@ func StartServer(ctx context.Context, logger _logger.Logger, tlsConfig *tls.Conf return default: } + proxy.ServeHTTP(w, r) }), } @@ -218,6 +226,60 @@ func StartServer(ctx context.Context, logger _logger.Logger, tlsConfig *tls.Conf return server, serverErr } +// verifyUpstreamSignature checks that the incoming request was signed by the +// Agentuity ion proxy. It reads the body so that HTTP trailers become +// available, verifies the signature, and restores the body and Content-Length +// for downstream proxying. +func verifyUpstreamSignature(logger _logger.Logger, publicKey *ecdsa.PublicKey, r *http.Request) error { + alg := r.Header.Get(agcrypto.HeaderSignatureAlg) + keyID := r.Header.Get(agcrypto.HeaderSignatureKeyID) + timestamp := r.Header.Get(agcrypto.HeaderSignatureTimestamp) + nonce := r.Header.Get(agcrypto.HeaderSignatureNonce) + + logger.Debug("upstream signature: alg=%s keyid=%s timestamp=%s nonce=%s via=%s", + alg, keyID, timestamp, nonce, r.Header.Get("Via")) + + if alg == "" { + return fmt.Errorf("no signature headers present") + } + + // Read the full body so HTTP trailers become available. Save the + // original Content-Length so we can restore it after verification — + // the reverse proxy needs it to avoid chunked encoding to localhost. + origContentLength := r.ContentLength + body, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("read body: %w", err) + } + + // The ion proxy sends the Signature as an HTTP trailer (for streaming + // requests) or as a header (for WebSocket requests). + sig := r.Trailer.Get(agcrypto.HeaderSignature) + source := "trailer" + if sig == "" { + sig = r.Header.Get(agcrypto.HeaderSignature) + source = "header" + if sig == "" { + source = "missing" + } + } + logger.Debug("upstream signature value: source=%s present=%v", source, sig != "") + + // Verify the cryptographic signature against the upstream signing key. + verifyErr := agcrypto.VerifyHTTPRequest(publicKey, r, body, nil) + if verifyErr != nil { + logger.Debug("upstream signature verification failed: %v", verifyErr) + } else { + logger.Debug("upstream signature verification succeeded") + } + + // Restore the body and Content-Length so the reverse proxy can forward it. + r.Body = io.NopCloser(bytes.NewReader(body)) + r.ContentLength = origContentLength + + return verifyErr +} + func CreateNetworkStack(logger _logger.Logger, urls UrlsMetadata) (*stack.Stack, *channel.Endpoint, error) { s := stack.New(stack.Options{