diff --git a/.gitignore b/.gitignore index 4b3fd91e..b6288a4c 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ e2e/server.srl e2e/server.key e2e/server.csr e2e/server.crt +api-testing diff --git a/cmd/server.go b/cmd/server.go index e70d8bd9..c3aced41 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -56,6 +56,7 @@ import ( fakeruntime "github.com/linuxsuren/go-fake-runtime" "github.com/linuxsuren/oauth-hub" + "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -379,6 +380,7 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) { ctx = context.WithValue(ctx, k, v) } + endpoint := pathParams["endpoint"] resp, err := extServer.GetPageOfServer(ctx, &server.SimpleName{Name: pathParams["extension"]}) if err != nil { fmt.Println(err) @@ -389,8 +391,18 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) { return } - fmt.Println("redirect to", resp.Message, "method", r.Method) - req, err := http.NewRequestWithContext(ctx, r.Method, resp.Message, r.Body) + api := resp.Message + "/" + endpoint + + // Check if this is a WebSocket request + if isWebSocketRequest(r) { + api = strings.ReplaceAll(api, "http://", "ws://") + fmt.Println("WebSocket request detected", api) + handleWebSocketProxy(w, r, api) + return + } + + fmt.Println("redirect to", api, "method", r.Method) + req, err := http.NewRequestWithContext(ctx, r.Method, api, r.Body) if err != nil { fmt.Println(err) return @@ -427,9 +439,10 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) { flusher.Flush() } } - mux.HandlePath(http.MethodPost, "/extensionProxy/{extension}", proxyHandler) - mux.HandlePath(http.MethodGet, "/extensionProxy/{extension}", proxyHandler) - mux.HandlePath(http.MethodDelete, "/extensionProxy/{extension}", proxyHandler) + mux.HandlePath(http.MethodPost, "/extensionProxy/{extension}/{endpoint}", proxyHandler) + mux.HandlePath(http.MethodGet, "/extensionProxy/{extension}/{endpoint}", proxyHandler) + mux.HandlePath(http.MethodDelete, "/extensionProxy/{extension}/{endpoint}", proxyHandler) + mux.HandlePath(http.MethodPost, "/extensionProxy/{extension}/{endpoint}", proxyHandler) mux.HandlePath(http.MethodGet, "/get", o.getAtestBinary) mux.HandlePath(http.MethodPost, "/runner/{suite}/{case}", service.WebRunnerHandler) mux.HandlePath(http.MethodGet, "/api/v1/sbom", service.SBomHandler) @@ -660,6 +673,84 @@ func (s *fakeGRPCServer) RegisterService(desc *grpc.ServiceDesc, impl interface{ // Do nothing due to this is a fake method } +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow connections from any origin + }, +} + +func isWebSocketRequest(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Connection")) == "upgrade" && strings.ToLower(r.Header.Get("Upgrade")) == "websocket" +} + +func handleWebSocketProxy(w http.ResponseWriter, r *http.Request, targetURL string) { + // Upgrade the connection + clientConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println("Failed to upgrade connection:", err) + return + } + defer clientConn.Close() + + // Create a WebSocket connection to the target server + // Clone headers to avoid duplicate headers issue + headers := make(http.Header) + for k, v := range r.Header { + // Skip headers that will be set by the WebSocket dialer + if k == "Upgrade" || k == "Connection" || k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || k == "Sec-Websocket-Extensions" || + k == "Sec-Websocket-Protocol" { + continue + } + headers[k] = v + } + + targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, headers) + if err != nil { + fmt.Println("Failed to connect to target:", err) + return + } + defer targetConn.Close() + + // Proxy messages between client and target + errChan := make(chan error, 2) + + // Client to target + go func() { + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + errChan <- err + return + } + + if err := targetConn.WriteMessage(messageType, message); err != nil { + errChan <- err + return + } + } + }() + + // Target to client + go func() { + for { + messageType, message, err := targetConn.ReadMessage() + if err != nil { + errChan <- err + return + } + + if err := clientConn.WriteMessage(messageType, message); err != nil { + errChan <- err + return + } + } + }() + + // Wait for an error to occur + <-errChan +} + //go:embed data/index.js var uiResourceJS []byte diff --git a/go.mod b/go.mod index 6e3b6fe2..55fdf380 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( require ( github.com/evanphx/json-patch v0.5.2 + github.com/gorilla/websocket v1.5.3 github.com/linuxsuren/http-downloader v0.0.99 golang.org/x/mod v0.28.0 ) diff --git a/go.sum b/go.sum index b8957aab..95d6e8ef 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGa github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 h1:Wqo399gCIufwto+VfwCSvsnfGpF/w5E9CNxSwbpD6No= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0/go.mod h1:qmOFXW2epJhM0qSnUUYpldc7gVz2KMQwJ/QYCDIa7XU= github.com/h2non/gock v1.2.0 h1:K6ol8rfrRkUOefooBC8elXoaNGYkpp7y2qcxGG6BzUE=