diff --git a/.gitignore b/.gitignore index f838c5b6..fc746ac4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ .idea .vscode +# Local development +.claude/settings.local.json + # Build artifacts /ocap-webserver /ocap-webserver.exe diff --git a/README.md b/README.md index c7be865d..5e18c683 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,9 @@ The configuration file is called `setting.json`. All settings can also be set vi }, "streaming": { "enabled": true + }, + "cors": { + "allowedOrigins": [] } } ``` @@ -220,6 +223,22 @@ Live mission data can be streamed to the server via WebSocket. | `streaming.pingInterval` | `OCAP_STREAMING_PINGINTERVAL` | Interval between WebSocket keepalive pings | `30s` | | `streaming.pingTimeout` | `OCAP_STREAMING_PINGTIMEOUT` | Timeout waiting for pong response | `10s` | +### CORS + +All responses include CORS headers so external services and web apps can fetch from the API. + +| Setting | Env Var | Description | Default | +|---------|---------|-------------|---------| +| `cors.allowedOrigins` | `OCAP_CORS_ALLOWEDORIGINS` | Origins allowed to make cross-origin requests. Empty list permits all origins (`*`). Comma-separated in env var. | `[]` (all origins) | + +When `allowedOrigins` is empty the server responds with `Access-Control-Allow-Origin: *`, which is appropriate for public read APIs. Restrict to specific origins if you want to limit which external sites can call admin endpoints: + +```json +"cors": { + "allowedOrigins": ["https://admin.example.com", "https://replay.example.com"] +} +``` + ## Large Recording Support ### Overview diff --git a/internal/server/handler.go b/internal/server/handler.go index 983776fa..1cdf12d3 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -3,13 +3,13 @@ package server import ( "bufio" "bytes" - "errors" "compress/gzip" "encoding/json" + "errors" "fmt" - "log/slog" "io" "io/fs" + "log/slog" "net/http" "net/url" "os" @@ -59,8 +59,8 @@ type Handler struct { repoAmmo *RepoAmmo setting Setting jwt *JWTManager - conversionTrigger ConversionTrigger // optional, nil if conversion disabled - staticFS fs.FS // optional, nil disables static file serving + conversionTrigger ConversionTrigger // optional, nil if conversion disabled + staticFS fs.FS // optional, nil disables static file serving maptoolMgr *maptool.JobManager // optional, nil if maptool disabled maptoolCfg *maptoolConfig // optional, nil if maptool disabled openIDVerifier openIDVerifier @@ -137,6 +137,8 @@ func NewHandler( prefixURL := strings.TrimRight(hdlr.setting.PrefixURL, "/") g := fuego.Group(s, prefixURL) + fuego.Use(g, newCORSMiddleware(hdlr.setting.CORS.AllowedOrigins)) + bearerAuth := openapi3.SecurityRequirement{"bearerAuth": {}} // Health & info @@ -199,6 +201,41 @@ func NewHandler( } } +// newCORSMiddleware returns a CORS middleware. When allowedOrigins is empty, +// all origins are permitted via the wildcard (*). When specific origins are +// listed, Vary: Origin is always set (so caches key on it) and the +// Allow-Origin header is only set for matching origins. +func newCORSMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { + originSet := make(map[string]struct{}, len(allowedOrigins)) + for _, o := range allowedOrigins { + originSet[o] = struct{}{} + } + wildcard := len(allowedOrigins) == 0 + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wildcard { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + w.Header().Add("Vary", "Origin") + if origin := r.Header.Get("Origin"); origin != "" { + if _, ok := originSet[origin]; ok { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + } + } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Max-Age", "86400") + if r.Method == http.MethodOptions && r.Header.Get("Origin") != "" { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) + } +} + func (*Handler) cacheControl(duration time.Duration) func(http.Handler) http.Handler { var header string if duration < time.Second { diff --git a/internal/server/handler_test.go b/internal/server/handler_test.go index c3eaac4a..c99ec6da 100644 --- a/internal/server/handler_test.go +++ b/internal/server/handler_test.go @@ -627,6 +627,109 @@ func TestCacheControl(t *testing.T) { }) } +func TestCORSMiddleware(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("wildcard (no origins configured)", func(t *testing.T) { + mw := newCORSMiddleware(nil) + + t.Run("sets wildcard origin and max-age on GET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/operations", nil) + rec := httptest.NewRecorder() + mw(inner).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, rec.Header().Get("Access-Control-Allow-Headers"), "Authorization") + assert.Equal(t, "86400", rec.Header().Get("Access-Control-Max-Age")) + }) + + t.Run("preflight OPTIONS returns 204 without calling inner handler", func(t *testing.T) { + called := false + guarded := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + }) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/operations", nil) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + mw(guarded).ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "*", rec.Header().Get("Access-Control-Allow-Origin")) + assert.False(t, called, "inner handler must not be called for preflight") + }) + }) + + t.Run("wildcard OPTIONS without Origin passes through to inner handler", func(t *testing.T) { + mw := newCORSMiddleware(nil) + called := false + guarded := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/operations", nil) + rec := httptest.NewRecorder() + mw(guarded).ServeHTTP(rec, req) + + assert.True(t, called, "OPTIONS without Origin header must not be intercepted") + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("specific origins configured", func(t *testing.T) { + mw := newCORSMiddleware([]string{"https://allowed.example.com"}) + + t.Run("allows matching origin and sets Vary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/operations", nil) + req.Header.Set("Origin", "https://allowed.example.com") + rec := httptest.NewRecorder() + mw(inner).ServeHTTP(rec, req) + + assert.Equal(t, "https://allowed.example.com", rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", rec.Header().Get("Vary")) + assert.Equal(t, "86400", rec.Header().Get("Access-Control-Max-Age")) + }) + + t.Run("sets Vary even for non-matching origin", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/operations", nil) + req.Header.Set("Origin", "https://evil.com") + rec := httptest.NewRecorder() + mw(inner).ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", rec.Header().Get("Vary")) + }) + + t.Run("sets Vary even when no Origin header", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/operations", nil) + rec := httptest.NewRecorder() + mw(inner).ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", rec.Header().Get("Vary")) + }) + + t.Run("OPTIONS without Origin passes through to inner handler", func(t *testing.T) { + called := false + guarded := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodOptions, "/api/v1/operations", nil) + rec := httptest.NewRecorder() + mw(guarded).ServeHTTP(rec, req) + + assert.True(t, called, "OPTIONS without Origin header must not be intercepted") + }) + }) +} + func TestWithConversionTrigger(t *testing.T) { trigger := &mockConversionTrigger{} diff --git a/internal/server/setting.go b/internal/server/setting.go index 6b256276..b581cc14 100644 --- a/internal/server/setting.go +++ b/internal/server/setting.go @@ -28,6 +28,7 @@ type Setting struct { Streaming Streaming `json:"streaming" yaml:"streaming"` Auth Auth `json:"auth" yaml:"auth"` HttpServer HttpServer `json:"httpServer" yaml:"httpServer"` + CORS CORSConfig `json:"cors" yaml:"cors"` } type Conversion struct { @@ -61,6 +62,10 @@ type Streaming struct { PingTimeout time.Duration `json:"pingTimeout" yaml:"pingTimeout"` } +type CORSConfig struct { + AllowedOrigins []string `json:"allowedOrigins" yaml:"allowedOrigins"` +} + type HttpServer struct { ReadTimeout time.Duration `json:"readTimeout" yaml:"readTimeout"` ReadHeaderTimeout time.Duration `json:"readHeaderTimeout" yaml:"readHeaderTimeout"` @@ -113,6 +118,7 @@ func NewSetting() (setting Setting, err error) { viper.SetDefault("auth.adminSteamIds", []string{}) viper.SetDefault("auth.steamApiKey", "") + viper.SetDefault("cors.allowedOrigins", []string{}) viper.SetDefault("httpServer.readTimeout", "120s") viper.SetDefault("httpServer.readHeaderTimeout", "30s") viper.SetDefault("httpServer.writeTimeout", "120s") @@ -129,10 +135,6 @@ func NewSetting() (setting Setting, err error) { return } - // Viper doesn't split comma-separated env var strings into slices, - // so a value like "id1,id2" ends up as ["id1,id2"]. Expand it. - setting.Auth.AdminSteamIDs = splitCSV(setting.Auth.AdminSteamIDs) - // Viper can't unmarshal a JSON string env var into map[string]string, // so parse OCAP_CUSTOMIZE_CSSOVERRIDES manually if set. Env var takes // precedence over config file. @@ -160,17 +162,3 @@ func NewSetting() (setting Setting, err error) { return } - -// splitCSV expands a []string where one element may contain comma-separated -// values (from an env var) into individual trimmed entries. -func splitCSV(in []string) []string { - var out []string - for _, s := range in { - for _, part := range strings.Split(s, ",") { - if v := strings.TrimSpace(part); v != "" { - out = append(out, v) - } - } - } - return out -} diff --git a/internal/server/setting_test.go b/internal/server/setting_test.go index 077e6421..a3cfb6fb 100644 --- a/internal/server/setting_test.go +++ b/internal/server/setting_test.go @@ -500,30 +500,6 @@ func TestSetting_HttpServer(t *testing.T) { assert.Equal(t, 140*time.Second, setting.HttpServer.IdleTimeout) } -func TestSplitCSV(t *testing.T) { - tests := []struct { - name string - in []string - want []string - }{ - {"nil input", nil, nil}, - {"empty slice", []string{}, nil}, - {"single value", []string{"abc"}, []string{"abc"}}, - {"already split", []string{"a", "b"}, []string{"a", "b"}}, - {"comma-separated single element", []string{"a,b,c"}, []string{"a", "b", "c"}}, - {"mixed", []string{"a,b", "c"}, []string{"a", "b", "c"}}, - {"whitespace trimmed", []string{" a , b , c "}, []string{"a", "b", "c"}}, - {"empty parts skipped", []string{"a,,b,"}, []string{"a", "b"}}, - {"all empty", []string{",,"}, nil}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := splitCSV(tt.in) - assert.Equal(t, tt.want, got) - }) - } -} - func TestNewSetting_NoConfigFile(t *testing.T) { viper.Reset() // Use a directory with no config file diff --git a/setting.json.example b/setting.json.example index e80257b9..5198c35c 100644 --- a/setting.json.example +++ b/setting.json.example @@ -42,5 +42,8 @@ "readHeaderTimeout": "30s", "writeTimeout": "120s", "idleTimeout": "120s" + }, + "cors": { + "allowedOrigins": [] } }