diff --git a/internal/dashboardsvc/server_test.go b/internal/dashboardsvc/server_test.go index 830f3f5..66af29c 100644 --- a/internal/dashboardsvc/server_test.go +++ b/internal/dashboardsvc/server_test.go @@ -74,9 +74,13 @@ func grpcAuthCtx(t *testing.T, teamID, userID uuid.UUID) context.Context { return metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Bearer "+tok)) } +// resourceSelectColumns mirrors models.resourceColumns. Keep in sync — +// out-of-date column list here surfaces as "sql: expected N destination +// arguments in Scan, not M" at runtime. func resourceSelectColumns() *sqlmock.Rows { return sqlmock.NewRows([]string{ "id", "team_id", "token", "resource_type", "name", "connection_url", "key_prefix", "tier", + "env", "fingerprint", "cloud_vendor", "country_code", "status", "migration_status", "expires_at", "storage_bytes", "provider_resource_id", "created_request_id", "created_at", }) @@ -232,6 +236,7 @@ func TestDeleteResource_Success(t *testing.T) { rows := resourceSelectColumns().AddRow( resID, teamID, tok, "webhook", nil, nil, nil, "hobby", + "production", // env (added by migration 009; column at position 9 in resourceSelectColumns) nil, nil, nil, "active", nil, nil, int64(0), nil, nil, created, ) @@ -294,6 +299,7 @@ func TestRotateCredentials_Success(t *testing.T) { rows := resourceSelectColumns().AddRow( resID, teamID, tok, "queue", nil, enc, nil, "hobby", + "production", // env nil, nil, nil, "active", nil, nil, int64(0), nil, nil, created, ) @@ -335,6 +341,7 @@ func TestRotateCredentials_NoConnectionURL(t *testing.T) { rows := resourceSelectColumns().AddRow( resID, teamID, tok, "queue", nil, nil, nil, "hobby", + "production", // env nil, nil, nil, "active", nil, nil, int64(0), nil, nil, created, ) diff --git a/internal/handlers/billing_test.go b/internal/handlers/billing_test.go index b9928b5..df3eee6 100644 --- a/internal/handlers/billing_test.go +++ b/internal/handlers/billing_test.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -36,6 +37,9 @@ func billingTestApp(t *testing.T) *fiber.App { app := fiber.New(fiber.Config{ ErrorHandler: func(c *fiber.Ctx, err error) error { + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } code := fiber.StatusInternalServerError if e, ok := err.(*fiber.Error); ok { code = e.Code diff --git a/internal/handlers/deploy_env_vars_test.go b/internal/handlers/deploy_env_vars_test.go index fd20ad0..fe9f129 100644 --- a/internal/handlers/deploy_env_vars_test.go +++ b/internal/handlers/deploy_env_vars_test.go @@ -67,10 +67,15 @@ func TestDeployNew_EnvVarsJSON_Parsed_Into_InitEnv(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() + // Read the body once — require.NotEqual's message arg is evaluated + // unconditionally, so passing readBody(t, resp) there would consume the + // body before the success-path Decode can read it. + bodyBytes, _ := io.ReadAll(resp.Body) + // 202 (noop compute provider succeeds), 503 (service disabled) — both prove the // parse path executed without 400. A 400 here is the regression we guard. require.NotEqual(t, http.StatusBadRequest, resp.StatusCode, - "valid env_vars JSON must NOT return 400; got body: %s", readBody(t, resp)) + "valid env_vars JSON must NOT return 400; got body: %s", string(bodyBytes)) if resp.StatusCode == http.StatusAccepted { var created struct { @@ -79,7 +84,7 @@ func TestDeployNew_EnvVarsJSON_Parsed_Into_InitEnv(t *testing.T) { Env map[string]string `json:"env"` } `json:"item"` } - require.NoError(t, json.NewDecoder(resp.Body).Decode(&created)) + require.NoError(t, json.Unmarshal(bodyBytes, &created)) assert.Contains(t, created.Item.Env, "DATABASE_URL", "env_vars key must land in the deployment's env") assert.Equal(t, "postgres://x/y", created.Item.Env["DATABASE_URL"]) assert.Contains(t, created.Item.Env, "CUSTOM") diff --git a/internal/handlers/helpers.go b/internal/handlers/helpers.go index 5ec179b..c7ce618 100644 --- a/internal/handlers/helpers.go +++ b/internal/handlers/helpers.go @@ -1,12 +1,41 @@ package handlers -import "github.com/gofiber/fiber/v2" +import ( + "errors" -// respondError returns a structured JSON error response. + "github.com/gofiber/fiber/v2" +) + +// ErrResponseWritten is the sentinel respondError returns to signal "I +// already wrote the response body — propagate me up but DO NOT let Fiber's +// generic ErrorHandler overwrite the response." +// +// Callers that do `return ..., respondError(...)` from a helper get a +// non-nil error and short-circuit correctly even when the underlying +// c.Status().JSON() returned nil (the normal success case for body write). +// +// The router and test ErrorHandlers both detect this sentinel and return +// nil without writing — preserving the 400/403/etc. response respondError +// already committed. See router/router.go and testhelpers/testhelpers.go. +var ErrResponseWritten = errors.New("response already written by respondError") + +// respondError writes a structured JSON error and returns ErrResponseWritten. +// +// Always returns a non-nil error so multi-return helpers compose safely: +// +// teamID, err := h.requireTeamMatch(c) +// if err != nil { return err } +// +// The caller's `if err != nil` branch fires correctly even when the +// underlying response-write succeeded. Before this change, respondError +// returned c.Status().JSON()'s result (nil on success), so the caller's +// check was false and execution continued past the validation gate — +// producing 500s and silent provisioning of invalid input. func respondError(c *fiber.Ctx, status int, code, message string) error { - return c.Status(status).JSON(fiber.Map{ + _ = c.Status(status).JSON(fiber.Map{ "ok": false, "error": code, "message": message, }) + return ErrResponseWritten } diff --git a/internal/handlers/provision_helper.go b/internal/handlers/provision_helper.go index ba93dc6..e580979 100644 --- a/internal/handlers/provision_helper.go +++ b/internal/handlers/provision_helper.go @@ -246,6 +246,16 @@ func sanitizeName(name string) string { // // Empty input is treated as "production" — this preserves backwards compatibility // for every caller that pre-dates the env feature. +// resolveEnv validates the env scope from the URL query (preferred) or +// request body (fallback). On success returns (env, nil). On failure it +// writes the 400 response via respondError and returns (\"\", ErrResponseWritten). +// Callers use the standard pattern: +// +// env, err := resolveEnv(c, body.Env) +// if err != nil { return err } +// +// The ErrResponseWritten sentinel propagates up; the ErrorHandler +// recognises it and does not overwrite the response. func resolveEnv(c *fiber.Ctx, bodyEnv string) (string, error) { raw := c.Query("env") if raw == "" { diff --git a/internal/handlers/stack_test.go b/internal/handlers/stack_test.go index d345d86..ef1e5fe 100644 --- a/internal/handlers/stack_test.go +++ b/internal/handlers/stack_test.go @@ -17,6 +17,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "io" "mime/multipart" "net/http" @@ -105,6 +106,9 @@ func newStackTestApp(t *testing.T, db *sql.DB) *fiber.App { app := fiber.New(fiber.Config{ ErrorHandler: func(c *fiber.Ctx, err error) error { + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } code := fiber.StatusInternalServerError if e, ok := err.(*fiber.Error); ok { code = e.Code @@ -455,7 +459,7 @@ func TestStackNew_Anonymous_Returns202(t *testing.T) { assert.NotEmpty(t, body.StackID) assert.Equal(t, "anonymous", body.Tier) assert.Equal(t, "24h", body.ExpiresIn) - assert.Contains(t, body.Note, "instant.dev/start", "upgrade URL must appear in note") + assert.Contains(t, body.Note, "instanode.dev/start", "upgrade URL must appear in note") // Verify DB: stack has nil team_id and non-nil expires_at. var teamIDNull sql.NullString diff --git a/internal/handlers/teams_test.go b/internal/handlers/teams_test.go index e7805e6..4791c85 100644 --- a/internal/handlers/teams_test.go +++ b/internal/handlers/teams_test.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -44,7 +45,20 @@ func teamsApp(t *testing.T, db *sql.DB, actorUserID, actorTeamID, actorRole stri } mail := email.New("") // noop client — never actually sends - app := fiber.New() + app := fiber.New(fiber.Config{ + // respondError already wrote the body — short-circuit so the + // generic ErrorHandler does not overwrite 4xx with 500. + ErrorHandler: func(c *fiber.Ctx, err error) error { + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } + code := fiber.StatusInternalServerError + if e, ok := err.(*fiber.Error); ok { + code = e.Code + } + return c.Status(code).JSON(fiber.Map{"ok": false, "error": "internal_error", "message": err.Error()}) + }, + }) // Fake auth: inject user/team/role into Locals so RequireRole can decide. fakeAuth := func(c *fiber.Ctx) error { diff --git a/internal/handlers/vault.go b/internal/handlers/vault.go index 9a8db21..e3de73e 100644 --- a/internal/handlers/vault.go +++ b/internal/handlers/vault.go @@ -239,11 +239,34 @@ func (h *VaultHandler) upsertSecret(c *fiber.Ctx, action string) error { "error", terr, "team_id", teamID, "request_id", middleware.GetRequestID(c)) } else if team != nil { - // On rotate, we already require an existing key (rotate of a missing - // key is rejected by upsertSecret semantics). For PUT/set we must - // allow updating an existing key without burning a quota slot. + // Tier checks run in this order (most-restrictive first) so the + // reported error tells the caller what to upgrade: + // 1. env allowlist (403 vault_env_not_allowed) + // 2. quota cap (402 vault_quota_exceeded) + // 3. availability (403 vault_not_available) — handled inside quota // - // Tier check 1: vault availability + quota (skip on rotate — count + // Pre-fix the env check ran second; a hobby-tier caller at quota + // who PUT to staging got 402 quota_exceeded instead of 403 + // env_not_allowed — misleading, since adding seats wouldn't help. + + // Tier check 1: env restriction (applies to both PUT and rotate). + allowed := h.plans.VaultEnvsAllowed(team.PlanTier) + if len(allowed) > 0 { + envOK := false + for _, a := range allowed { + if a == env { + envOK = true + break + } + } + if !envOK { + return respondError(c, fiber.StatusForbidden, vaultErrEnvNotAllowed, + fmt.Sprintf("Plan %q only allows vault env %v; got %q. Upgrade to Pro for multi-env vault.", + team.PlanTier, allowed, env)) + } + } + + // Tier check 2: vault availability + quota (skip on rotate — count // can only stay flat or shrink). if action != "rotate" { maxEntries := h.plans.VaultMaxEntries(team.PlanTier) @@ -269,23 +292,6 @@ func (h *VaultHandler) upsertSecret(c *fiber.Ctx, action string) error { } } } - - // Tier check 2: env restriction (applies to both PUT and rotate). - allowed := h.plans.VaultEnvsAllowed(team.PlanTier) - if len(allowed) > 0 { - envOK := false - for _, a := range allowed { - if a == env { - envOK = true - break - } - } - if !envOK { - return respondError(c, fiber.StatusForbidden, vaultErrEnvNotAllowed, - fmt.Sprintf("Plan %q only allows vault env %v; got %q. Upgrade to Pro for multi-env vault.", - team.PlanTier, allowed, env)) - } - } } } diff --git a/internal/handlers/vault_test.go b/internal/handlers/vault_test.go index 972b863..0b9a94a 100644 --- a/internal/handlers/vault_test.go +++ b/internal/handlers/vault_test.go @@ -18,6 +18,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -112,6 +113,9 @@ func vaultTestApp(t *testing.T, db *sql.DB) *fiber.App { } app := fiber.New(fiber.Config{ ErrorHandler: func(c *fiber.Ctx, err error) error { + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } code := fiber.StatusInternalServerError if e, ok := err.(*fiber.Error); ok { code = e.Code @@ -471,7 +475,11 @@ func TestVault_Validation(t *testing.T) { want int }{ // Path params can't be empty in fiber routes; use illegal characters instead. - {"bad-key-with-slash", "/api/v1/vault/production/foo bar", http.StatusBadRequest}, + // Pre-encode the space (%20) so httptest.NewRequest accepts the URL — + // Go 1.26+ panics on unescaped spaces (older Go silently encoded them). + // The fiber handler decodes back to "foo bar" and the validator rejects + // the space — exactly what we want to assert. + {"bad-key-with-space", "/api/v1/vault/production/foo%20bar", http.StatusBadRequest}, {"bad-key-too-long", "/api/v1/vault/production/" + longString(300), http.StatusBadRequest}, {"bad-env-with-special", "/api/v1/vault/prod!ction/X", http.StatusBadRequest}, } diff --git a/internal/router/router.go b/internal/router/router.go index 5263f4b..45c2178 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -2,6 +2,7 @@ package router import ( "database/sql" + "errors" "log/slog" "github.com/gofiber/contrib/otelfiber/v2" @@ -27,6 +28,12 @@ func New(cfg *config.Config, db *sql.DB, rdb *redis.Client, geoDbs *middleware.G app := fiber.New(fiber.Config{ // Disable default error handler — we write our own JSON errors ErrorHandler: func(c *fiber.Ctx, err error) error { + // respondError already wrote the body — must not overwrite, or + // every 400/403/etc. becomes a 500 "internal_error" via the + // generic path below. See handlers.ErrResponseWritten. + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } code := fiber.StatusInternalServerError if e, ok := err.(*fiber.Error); ok { code = e.Code diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go index ac2387b..d66944a 100644 --- a/internal/testhelpers/testhelpers.go +++ b/internal/testhelpers/testhelpers.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -247,6 +248,12 @@ func NewTestAppWithServices(t *testing.T, db *sql.DB, rdb *redis.Client, service app := fiber.New(fiber.Config{ ErrorHandler: func(c *fiber.Ctx, err error) error { + // respondError already wrote the body — short-circuit so we + // don't overwrite. Matches the production ErrorHandler in + // router/router.go. + if errors.Is(err, handlers.ErrResponseWritten) { + return nil + } code := fiber.StatusInternalServerError if e, ok := err.(*fiber.Error); ok { code = e.Code