Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion internal/gui/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ type enrollRequest struct {
Password string `json:"password"`
}

type presetSummary struct {
ID string `json:"id"`
Description string `json:"description"`
Tags []string `json:"tags"`
AcceptRoutes bool `json:"acceptRoutes"`
AllowExitNodeSelection bool `json:"allowExitNodeSelection"`
ApprovedExitNodes []string `json:"approvedExitNodes"`
}

var validModes = map[string]bool{
string(model.LeaseModeSession): true,
string(model.LeaseModeTimed): true,
string(model.LeaseModePermanent): true,
}

var validChannels = map[string]bool{
string(model.ChannelStable): true,
string(model.ChannelLatest): true,
}

func Run(ctx context.Context, srv *Server, openBrowser bool, host string, port int) error {
host = strings.TrimSpace(host)
if host == "" {
Expand Down Expand Up @@ -83,12 +103,27 @@ func Run(ctx context.Context, srv *Server, openBrowser bool, host string, port i
}

func (s *Server) presets(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
cfg, err := config.Load(s.ConfigPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]any{"defaultPreset": cfg.DefaultPreset, "presets": cfg.Presets})
summaries := make([]presetSummary, len(cfg.Presets))
for i, p := range cfg.Presets {
summaries[i] = presetSummary{
ID: p.ID,
Description: p.Description,
Tags: p.Tags,
AcceptRoutes: p.AcceptRoutes,
AllowExitNodeSelection: p.AllowExitNodeSelection,
ApprovedExitNodes: p.ApprovedExitNodes,
}
}
writeJSON(w, map[string]any{"defaultPreset": cfg.DefaultPreset, "presets": summaries})
}

func (s *Server) enroll(w http.ResponseWriter, r *http.Request) {
Expand All @@ -101,6 +136,22 @@ func (s *Server) enroll(w http.ResponseWriter, r *http.Request) {
http.Error(w, "invalid json body", http.StatusBadRequest)
return
}
if req.Mode != "" && !validModes[req.Mode] {
writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": fmt.Sprintf("invalid mode %q: must be session, timed, or permanent", req.Mode)})
return
}
if req.Channel != "" && !validChannels[req.Channel] {
writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": fmt.Sprintf("invalid channel %q: must be stable or latest", req.Channel)})
return
}
if req.Days < 0 {
writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": "days must be non-negative"})
return
}
if req.CustomDays < 0 {
writeJSONCode(w, http.StatusBadRequest, map[string]any{"ok": false, "error": "customDays must be non-negative"})
return
}
password := strings.TrimSpace(req.Password)
if password == "" {
password = strings.TrimSpace(os.Getenv("TAILSTICK_OPERATOR_PASSWORD"))
Expand Down
132 changes: 132 additions & 0 deletions internal/gui/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package gui

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"

"github.com/tailstick/tailstick/internal/model"
)

func TestPresetsRedactsSecretsAndOnlyAllowsGet(t *testing.T) {
root := t.TempDir()
configPath := filepath.Join(root, "tailstick.config.json")
configBody := `{
"defaultPreset": "ops",
"presets": [
{
"id": "ops",
"description": "Operations",
"authKey": "tskey-auth-secret",
"authKeyEnv": "TAILSTICK_AUTH_KEY",
"ephemeralAuthKey": "tskey-ephemeral-secret",
"ephemeralAuthKeyEnv": "TAILSTICK_EPHEMERAL_AUTH_KEY",
"tags": ["tag:ops"],
"acceptRoutes": true,
"allowExitNodeSelection": true,
"approvedExitNodes": ["100.64.0.1"],
"cleanup": {
"apiKey": "tskey-api-secret",
"apiKeyEnv": "TAILSTICK_API_KEY",
"deviceDeleteEnabled": true
}
}
]
}`
if err := os.WriteFile(configPath, []byte(configBody), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}

srv := &Server{ConfigPath: configPath}

req := httptest.NewRequest(http.MethodGet, "/api/presets", nil)
rec := httptest.NewRecorder()
srv.presets(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("got status %d want 200", rec.Code)
}
body := rec.Body.String()
for _, forbidden := range []string{
"authKey",
"authKeyEnv",
"ephemeralAuthKey",
"ephemeralAuthKeyEnv",
"apiKey",
"apiKeyEnv",
"tskey-auth-secret",
"tskey-api-secret",
} {
if strings.Contains(body, forbidden) {
t.Fatalf("response leaked %q: %s", forbidden, body)
}
}
if !strings.Contains(body, `"id":"ops"`) {
t.Fatalf("expected preset id in response, got %s", body)
}

req = httptest.NewRequest(http.MethodPost, "/api/presets", nil)
rec = httptest.NewRecorder()
srv.presets(rec, req)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("got status %d want 405", rec.Code)
}
}

func TestEnrollRejectsInvalidModeAndNegativeDurations(t *testing.T) {
srv := &Server{
EnrollFn: func(context.Context, model.RuntimeOptions) (model.LeaseRecord, error) {
t.Fatal("enroll should not be called for invalid input")
return model.LeaseRecord{}, nil
},
}

for _, tc := range []struct {
name string
body string
want string
}{
{
name: "invalid mode",
body: `{"mode":"bogus","channel":"stable"}`,
want: `invalid mode "bogus"`,
},
{
name: "invalid channel",
body: `{"mode":"timed","channel":"bogus"}`,
want: `invalid channel "bogus"`,
},
{
name: "negative days",
body: `{"mode":"timed","channel":"stable","days":-1}`,
want: `days must be non-negative`,
},
{
name: "negative custom days",
body: `{"mode":"timed","channel":"stable","customDays":-1}`,
want: `customDays must be non-negative`,
},
} {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/enroll", bytes.NewBufferString(tc.body))
rec := httptest.NewRecorder()
srv.enroll(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("got status %d want 400", rec.Code)
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if got := payload["error"]; got == nil || !strings.Contains(got.(string), tc.want) {
t.Fatalf("got error %v want substring %q", got, tc.want)
}
})
}
}
44 changes: 40 additions & 4 deletions internal/tailscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Client struct {
Runner platform.Runner
}

var deleteDeviceHTTPClient = http.DefaultClient
var defaultDeleteDeviceHTTPClient = &http.Client{Timeout: 15 * time.Second}

func (c Client) IsInstalled(ctx context.Context) bool {
_, err := c.Runner.Run(ctx, []string{"tailscale", "version"})
Expand Down Expand Up @@ -61,9 +61,15 @@ func (c Client) Up(ctx context.Context, preset model.Preset, deviceName string,
return fmt.Errorf("missing auth key")
}

authArg, cleanupAuthKeyFile, err := authKeyArg(auth)
if err != nil {
return err
}
defer cleanupAuthKeyFile()

args := []string{
"tailscale", "up",
"--auth-key=" + auth,
authArg,
"--hostname=" + deviceName,
"--reset",
}
Expand All @@ -76,7 +82,7 @@ func (c Client) Up(ctx context.Context, preset model.Preset, deviceName string,
if exitNode != "" {
args = append(args, "--exit-node="+exitNode)
}
_, err := c.Runner.Run(ctx, args)
_, err = c.Runner.Run(ctx, args)
return err
}

Expand Down Expand Up @@ -132,6 +138,10 @@ func (c Client) Uninstall(ctx context.Context, preset model.Preset) error {
}

func DeleteDevice(ctx context.Context, apiKey, deviceID string) error {
return deleteDevice(ctx, defaultDeleteDeviceHTTPClient, apiKey, deviceID)
}

func deleteDevice(ctx context.Context, client *http.Client, apiKey, deviceID string) error {
if strings.TrimSpace(apiKey) == "" || strings.TrimSpace(deviceID) == "" {
return nil
}
Expand All @@ -140,7 +150,7 @@ func DeleteDevice(ctx context.Context, apiKey, deviceID string) error {
return err
}
req.SetBasicAuth(apiKey, "")
resp, err := deleteDeviceHTTPClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return err
}
Expand All @@ -156,6 +166,32 @@ func DeleteDevice(ctx context.Context, apiKey, deviceID string) error {
return fmt.Errorf("delete device failed: status=%d body=%s", resp.StatusCode, bodyText)
}

func authKeyArg(auth string) (string, func(), error) {
f, err := os.CreateTemp("", "tailstick-auth-key-*")
if err != nil {
return "", func() {}, fmt.Errorf("create auth key temp file: %w", err)
}
path := f.Name()
cleanup := func() {
_ = os.Remove(path)
}
if _, err := f.WriteString(auth); err != nil {
cleanup()
_ = f.Close()
return "", func() {}, fmt.Errorf("write auth key temp file: %w", err)
}
if err := f.Chmod(0o600); err != nil && runtime.GOOS != "windows" {
cleanup()
_ = f.Close()
return "", func() {}, fmt.Errorf("chmod auth key temp file: %w", err)
}
if err := f.Close(); err != nil {
cleanup()
return "", func() {}, fmt.Errorf("close auth key temp file: %w", err)
}
return "--auth-key=file:" + path, cleanup, nil
}

func installCommand(preset model.Preset, channel model.Channel) []string {
if runtime.GOOS == "windows" {
if channel == model.ChannelLatest && len(preset.Install.WindowsLatest) > 0 {
Expand Down
Loading
Loading