From 32cff663531ca332550293444c6981fffa785dbf Mon Sep 17 00:00:00 2001 From: Microck Date: Sun, 12 Apr 2026 14:51:00 +0000 Subject: [PATCH] fix(security): harden gui preset and auth handling Redact sensitive preset fields from the GUI presets API and reject invalid enrollment mode, channel, and duration input before reaching the workflow layer. Harden Tailscale operations by using a bounded HTTP client for device deletion and passing auth keys through a temporary file instead of process argv, with regression tests covering the API and transport behavior. --- internal/gui/server.go | 53 +++++++++++- internal/gui/server_test.go | 132 ++++++++++++++++++++++++++++++ internal/tailscale/client.go | 44 +++++++++- internal/tailscale/client_test.go | 89 +++++++++++++++++--- 4 files changed, 301 insertions(+), 17 deletions(-) create mode 100644 internal/gui/server_test.go diff --git a/internal/gui/server.go b/internal/gui/server.go index 737d494..a541aa0 100644 --- a/internal/gui/server.go +++ b/internal/gui/server.go @@ -39,6 +39,26 @@ type enrollRequest struct { Password string `json:"password"` } +type presetSummary struct { + ID string `json:"id"` + Description string `json:"description"` + Tags []string `json:"tags"` + AcceptRoutes bool `json:"acceptRoutes"` + AllowExitNodeSelection bool `json:"allowExitNodeSelection"` + ApprovedExitNodes []string `json:"approvedExitNodes"` +} + +var validModes = map[string]bool{ + string(model.LeaseModeSession): true, + string(model.LeaseModeTimed): true, + string(model.LeaseModePermanent): true, +} + +var validChannels = map[string]bool{ + string(model.ChannelStable): true, + string(model.ChannelLatest): true, +} + func Run(ctx context.Context, srv *Server, openBrowser bool, host string, port int) error { host = strings.TrimSpace(host) if host == "" { @@ -83,12 +103,27 @@ func Run(ctx context.Context, srv *Server, openBrowser bool, host string, port i } func (s *Server) presets(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } cfg, err := config.Load(s.ConfigPath) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - writeJSON(w, map[string]any{"defaultPreset": cfg.DefaultPreset, "presets": cfg.Presets}) + summaries := make([]presetSummary, len(cfg.Presets)) + for i, p := range cfg.Presets { + summaries[i] = presetSummary{ + ID: p.ID, + Description: p.Description, + Tags: p.Tags, + AcceptRoutes: p.AcceptRoutes, + AllowExitNodeSelection: p.AllowExitNodeSelection, + ApprovedExitNodes: p.ApprovedExitNodes, + } + } + writeJSON(w, map[string]any{"defaultPreset": cfg.DefaultPreset, "presets": summaries}) } func (s *Server) enroll(w http.ResponseWriter, r *http.Request) { @@ -101,6 +136,22 @@ func (s *Server) enroll(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid json body", http.StatusBadRequest) return } + if req.Mode != "" && !validModes[req.Mode] { + writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": fmt.Sprintf("invalid mode %q: must be session, timed, or permanent", req.Mode)}) + return + } + if req.Channel != "" && !validChannels[req.Channel] { + writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": fmt.Sprintf("invalid channel %q: must be stable or latest", req.Channel)}) + return + } + if req.Days < 0 { + writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": "days must be non-negative"}) + return + } + if req.CustomDays < 0 { + writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": "customDays must be non-negative"}) + return + } password := strings.TrimSpace(req.Password) if password == "" { password = strings.TrimSpace(os.Getenv("TAILSTICK_OPERATOR_PASSWORD")) diff --git a/internal/gui/server_test.go b/internal/gui/server_test.go new file mode 100644 index 0000000..8821403 --- /dev/null +++ b/internal/gui/server_test.go @@ -0,0 +1,132 @@ +package gui + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/tailstick/tailstick/internal/model" +) + +func TestPresetsRedactsSecretsAndOnlyAllowsGet(t *testing.T) { + root := t.TempDir() + configPath := filepath.Join(root, "tailstick.config.json") + configBody := `{ + "defaultPreset": "ops", + "presets": [ + { + "id": "ops", + "description": "Operations", + "authKey": "tskey-auth-secret", + "authKeyEnv": "TAILSTICK_AUTH_KEY", + "ephemeralAuthKey": "tskey-ephemeral-secret", + "ephemeralAuthKeyEnv": "TAILSTICK_EPHEMERAL_AUTH_KEY", + "tags": ["tag:ops"], + "acceptRoutes": true, + "allowExitNodeSelection": true, + "approvedExitNodes": ["100.64.0.1"], + "cleanup": { + "apiKey": "tskey-api-secret", + "apiKeyEnv": "TAILSTICK_API_KEY", + "deviceDeleteEnabled": true + } + } + ] +}` + if err := os.WriteFile(configPath, []byte(configBody), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + srv := &Server{ConfigPath: configPath} + + req := httptest.NewRequest(http.MethodGet, "/api/presets", nil) + rec := httptest.NewRecorder() + srv.presets(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("got status %d want 200", rec.Code) + } + body := rec.Body.String() + for _, forbidden := range []string{ + "authKey", + "authKeyEnv", + "ephemeralAuthKey", + "ephemeralAuthKeyEnv", + "apiKey", + "apiKeyEnv", + "tskey-auth-secret", + "tskey-api-secret", + } { + if strings.Contains(body, forbidden) { + t.Fatalf("response leaked %q: %s", forbidden, body) + } + } + if !strings.Contains(body, `"id":"ops"`) { + t.Fatalf("expected preset id in response, got %s", body) + } + + req = httptest.NewRequest(http.MethodPost, "/api/presets", nil) + rec = httptest.NewRecorder() + srv.presets(rec, req) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("got status %d want 405", rec.Code) + } +} + +func TestEnrollRejectsInvalidModeAndNegativeDurations(t *testing.T) { + srv := &Server{ + EnrollFn: func(context.Context, model.RuntimeOptions) (model.LeaseRecord, error) { + t.Fatal("enroll should not be called for invalid input") + return model.LeaseRecord{}, nil + }, + } + + for _, tc := range []struct { + name string + body string + want string + }{ + { + name: "invalid mode", + body: `{"mode":"bogus","channel":"stable"}`, + want: `invalid mode "bogus"`, + }, + { + name: "invalid channel", + body: `{"mode":"timed","channel":"bogus"}`, + want: `invalid channel "bogus"`, + }, + { + name: "negative days", + body: `{"mode":"timed","channel":"stable","days":-1}`, + want: `days must be non-negative`, + }, + { + name: "negative custom days", + body: `{"mode":"timed","channel":"stable","customDays":-1}`, + want: `customDays must be non-negative`, + }, + } { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/enroll", bytes.NewBufferString(tc.body)) + rec := httptest.NewRecorder() + srv.enroll(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("got status %d want 400", rec.Code) + } + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if got := payload["error"]; got == nil || !strings.Contains(got.(string), tc.want) { + t.Fatalf("got error %v want substring %q", got, tc.want) + } + }) + } +} diff --git a/internal/tailscale/client.go b/internal/tailscale/client.go index 1ea245d..5e2cb59 100644 --- a/internal/tailscale/client.go +++ b/internal/tailscale/client.go @@ -19,7 +19,7 @@ type Client struct { Runner platform.Runner } -var deleteDeviceHTTPClient = http.DefaultClient +var defaultDeleteDeviceHTTPClient = &http.Client{Timeout: 15 * time.Second} func (c Client) IsInstalled(ctx context.Context) bool { _, err := c.Runner.Run(ctx, []string{"tailscale", "version"}) @@ -61,9 +61,15 @@ func (c Client) Up(ctx context.Context, preset model.Preset, deviceName string, return fmt.Errorf("missing auth key") } + authArg, cleanupAuthKeyFile, err := authKeyArg(auth) + if err != nil { + return err + } + defer cleanupAuthKeyFile() + args := []string{ "tailscale", "up", - "--auth-key=" + auth, + authArg, "--hostname=" + deviceName, "--reset", } @@ -76,7 +82,7 @@ func (c Client) Up(ctx context.Context, preset model.Preset, deviceName string, if exitNode != "" { args = append(args, "--exit-node="+exitNode) } - _, err := c.Runner.Run(ctx, args) + _, err = c.Runner.Run(ctx, args) return err } @@ -132,6 +138,10 @@ func (c Client) Uninstall(ctx context.Context, preset model.Preset) error { } func DeleteDevice(ctx context.Context, apiKey, deviceID string) error { + return deleteDevice(ctx, defaultDeleteDeviceHTTPClient, apiKey, deviceID) +} + +func deleteDevice(ctx context.Context, client *http.Client, apiKey, deviceID string) error { if strings.TrimSpace(apiKey) == "" || strings.TrimSpace(deviceID) == "" { return nil } @@ -140,7 +150,7 @@ func DeleteDevice(ctx context.Context, apiKey, deviceID string) error { return err } req.SetBasicAuth(apiKey, "") - resp, err := deleteDeviceHTTPClient.Do(req) + resp, err := client.Do(req) if err != nil { return err } @@ -156,6 +166,32 @@ func DeleteDevice(ctx context.Context, apiKey, deviceID string) error { return fmt.Errorf("delete device failed: status=%d body=%s", resp.StatusCode, bodyText) } +func authKeyArg(auth string) (string, func(), error) { + f, err := os.CreateTemp("", "tailstick-auth-key-*") + if err != nil { + return "", func() {}, fmt.Errorf("create auth key temp file: %w", err) + } + path := f.Name() + cleanup := func() { + _ = os.Remove(path) + } + if _, err := f.WriteString(auth); err != nil { + cleanup() + _ = f.Close() + return "", func() {}, fmt.Errorf("write auth key temp file: %w", err) + } + if err := f.Chmod(0o600); err != nil && runtime.GOOS != "windows" { + cleanup() + _ = f.Close() + return "", func() {}, fmt.Errorf("chmod auth key temp file: %w", err) + } + if err := f.Close(); err != nil { + cleanup() + return "", func() {}, fmt.Errorf("close auth key temp file: %w", err) + } + return "--auth-key=file:" + path, cleanup, nil +} + func installCommand(preset model.Preset, channel model.Channel) []string { if runtime.GOOS == "windows" { if channel == model.ChannelLatest && len(preset.Install.WindowsLatest) > 0 { diff --git a/internal/tailscale/client_test.go b/internal/tailscale/client_test.go index 2c4b790..8e60522 100644 --- a/internal/tailscale/client_test.go +++ b/internal/tailscale/client_test.go @@ -4,13 +4,17 @@ import ( "context" "io" "net/http" + "os" + "path/filepath" "strings" "testing" + + "github.com/tailstick/tailstick/internal/model" + "github.com/tailstick/tailstick/internal/platform" ) func TestDeleteDeviceTreatsNotFoundAsAlreadyDeleted(t *testing.T) { - originalClient := deleteDeviceHTTPClient - deleteDeviceHTTPClient = &http.Client{ + client := &http.Client{ Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { if req.Method != http.MethodDelete { t.Fatalf("got method %s want DELETE", req.Method) @@ -29,18 +33,14 @@ func TestDeleteDeviceTreatsNotFoundAsAlreadyDeleted(t *testing.T) { }, nil }), } - t.Cleanup(func() { - deleteDeviceHTTPClient = originalClient - }) - if err := DeleteDevice(context.Background(), "tskey-api-example", "device-123"); err != nil { + if err := deleteDevice(context.Background(), client, "tskey-api-example", "device-123"); err != nil { t.Fatalf("expected 404 delete to be treated as success, got %v", err) } } func TestDeleteDeviceReturnsErrorForOtherFailures(t *testing.T) { - originalClient := deleteDeviceHTTPClient - deleteDeviceHTTPClient = &http.Client{ + client := &http.Client{ Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, @@ -49,11 +49,8 @@ func TestDeleteDeviceReturnsErrorForOtherFailures(t *testing.T) { }, nil }), } - t.Cleanup(func() { - deleteDeviceHTTPClient = originalClient - }) - err := DeleteDevice(context.Background(), "tskey-api-example", "device-123") + err := deleteDevice(context.Background(), client, "tskey-api-example", "device-123") if err == nil { t.Fatal("expected delete error") } @@ -62,6 +59,74 @@ func TestDeleteDeviceReturnsErrorForOtherFailures(t *testing.T) { } } +func TestDefaultDeleteDeviceClientHasTimeout(t *testing.T) { + if defaultDeleteDeviceHTTPClient.Timeout <= 0 { + t.Fatalf("expected default delete client timeout, got %s", defaultDeleteDeviceHTTPClient.Timeout) + } +} + +func TestClientUpUsesAuthKeyFileAndRemovesIt(t *testing.T) { + if testing.Short() { + t.Skip("skipping filesystem runner test in short mode") + } + + root := t.TempDir() + logPath := filepath.Join(root, "tailscale.log") + scriptPath := filepath.Join(root, "tailscale") + script := `#!/bin/sh +set -eu +for arg in "$@"; do + printf '%s\n' "$arg" >> "` + logPath + `" + case "$arg" in + --auth-key=file:*) + key_path=${arg#--auth-key=file:} + printf 'key=%s\n' "$(cat "$key_path")" >> "` + logPath + `" + ;; + esac +done +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake tailscale: %v", err) + } + t.Setenv("PATH", root+string(os.PathListSeparator)+os.Getenv("PATH")) + + client := Client{Runner: platform.Runner{}} + preset := model.Preset{AuthKey: "tskey-auth-secret"} + + if err := client.Up(context.Background(), preset, "device-name", model.LeaseModeTimed, ""); err != nil { + t.Fatalf("up: %v", err) + } + + bodyBytes, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read log: %v", err) + } + body := string(bodyBytes) + if strings.Contains(body, "--auth-key=tskey-auth-secret") { + t.Fatalf("raw auth key leaked into argv log: %q", body) + } + if !strings.Contains(body, "--auth-key=file:") { + t.Fatalf("expected file-based auth key flag, got %q", body) + } + if !strings.Contains(body, "key=tskey-auth-secret") { + t.Fatalf("expected helper script to read auth key file, got %q", body) + } + + var keyPath string + for _, line := range strings.Split(strings.TrimSpace(body), "\n") { + if strings.HasPrefix(line, "--auth-key=file:") { + keyPath = strings.TrimPrefix(line, "--auth-key=file:") + break + } + } + if keyPath == "" { + t.Fatal("failed to capture auth key temp path") + } + if _, err := os.Stat(keyPath); !os.IsNotExist(err) { + t.Fatalf("expected auth key temp file to be removed, stat err=%v", err) + } +} + type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {