diff --git a/internal/api/handlers/monitor.go b/internal/api/handlers/monitor.go index e421283f..f1c357fa 100644 --- a/internal/api/handlers/monitor.go +++ b/internal/api/handlers/monitor.go @@ -83,24 +83,53 @@ func monitorStream(mon net.Conn, ws *websocket.Conn) { }() } +func splitOrigin(origin string) (scheme, host, port string, err error) { + parts := strings.SplitN(origin, "://", 2) + if len(parts) != 2 { + return "", "", "", fmt.Errorf("invalid origin format: %s", origin) + } + scheme = parts[0] + hostPort := parts[1] + hostParts := strings.SplitN(hostPort, ":", 2) + host = hostParts[0] + if len(hostParts) == 2 { + port = hostParts[1] + } else { + port = "*" + } + return scheme, host, port, nil +} + func checkOrigin(origin string, allowedOrigins []string) bool { + scheme, host, port, err := splitOrigin(origin) + if err != nil { + slog.Error("WebSocket origin check failed", slog.String("origin", origin), slog.String("error", err.Error())) + return false + } for _, allowed := range allowedOrigins { - if strings.HasSuffix(allowed, "*") { - // String ends with *, match the prefix - if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) { - return true - } - } else { - // Exact match - if allowed == origin { - return true - } + allowedScheme, allowedHost, allowedPort, err := splitOrigin(allowed) + if err != nil { + panic(err) + } + if allowedScheme != scheme { + continue } + if allowedHost != host && allowedHost != "*" { + continue + } + if allowedPort != port && allowedPort != "*" { + continue + } + return true } + slog.Error("WebSocket origin check failed", slog.String("origin", origin)) return false } func HandleMonitorWS(allowedOrigins []string) http.HandlerFunc { + // Do a dry-run of checkorigin, so it can panic if misconfigured now, not on first request + _ = checkOrigin("http://example.com:8000", allowedOrigins) + upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, diff --git a/internal/api/handlers/monitor_test.go b/internal/api/handlers/monitor_test.go new file mode 100644 index 00000000..3f02e676 --- /dev/null +++ b/internal/api/handlers/monitor_test.go @@ -0,0 +1,51 @@ +// This file is part of arduino-app-cli. +// +// Copyright 2025 ARDUINO SA (http://www.arduino.cc/) +// +// This software is released under the GNU General Public License version 3, +// which covers the main part of arduino-app-cli. +// The terms of this license can be found at: +// https://www.gnu.org/licenses/gpl-3.0.en.html +// +// You can be released from the requirements of the above licenses by purchasing +// a commercial license. Buying such a license is mandatory if you want to +// modify or otherwise use the software for commercial activities involving the +// Arduino software without disclosing the source code of your own applications. +// To purchase a commercial license, send an email to license@arduino.cc. + +package handlers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckOrigin(t *testing.T) { + origins := []string{ + "wails://wails", + "wails://wails.localhost:*", + "http://wails.localhost:*", + "http://localhost:*", + "https://localhost:*", + "http://example.com:7000", + "https://*:443", + } + + allow := func(origin string) { + require.True(t, checkOrigin(origin, origins), "Expected origin %s to be allowed", origin) + } + deny := func(origin string) { + require.False(t, checkOrigin(origin, origins), "Expected origin %s to be denied", origin) + } + allow("wails://wails") + allow("wails://wails:8000") + allow("http://wails.localhost") + allow("http://example.com:7000") + allow("https://blah.com:443") + deny("wails://evil.com") + deny("https://wails.localhost:8000") + deny("http://example.com:8000") + deny("http://blah.com:443") + deny("https://blah.com:8080") +}