diff --git a/.github/workflows/graphql-budgets.yml b/.github/workflows/graphql-budgets.yml new file mode 100644 index 00000000..f350b43c --- /dev/null +++ b/.github/workflows/graphql-budgets.yml @@ -0,0 +1,43 @@ +name: GraphQL Query Budgets + +# Issue #115: enforce per-operation dataloader batch round-trip ceilings +# (tools/graphql-budgets.yml) so a resolver that loses its dataloader +# wiring or a schema change that introduces a new N+1 dimension fails +# CI rather than landing on main. + +on: + pull_request: + branches: [main] + paths: + - 'apps/api/internal/graphql/**' + - 'tools/graphql-budgets.yml' + - '.github/workflows/graphql-budgets.yml' + push: + branches: [main] + paths: + - 'apps/api/internal/graphql/**' + - 'tools/graphql-budgets.yml' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + budgets: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: 'go.work' + cache: true + - name: Run dataloader budget bench + run: | + # -benchtime=1x: each scenario runs once; the bench is an + # assertion harness, not a microbench. -run='^$' skips + # unit tests so we don't double-pay for the dataloader + # package's regular tests (those run in the main CI job). + cd apps/api + go test -run='^$' -bench BenchmarkGraphQLBudgets -benchtime=1x \ + ./internal/graphql/dataloader/... diff --git a/apps/api/internal/graphql/dataloader/budget.go b/apps/api/internal/graphql/dataloader/budget.go new file mode 100644 index 00000000..a58453f0 --- /dev/null +++ b/apps/api/internal/graphql/dataloader/budget.go @@ -0,0 +1,71 @@ +// budget.go is the loader-side instrumentation for the GraphQL query- +// budget CI check (issue #115). +// +// The check runs as a Go benchmark: +// +// go test -run='^$' -bench BenchmarkGraphQLBudgets ./apps/api/... +// +// Each scenario fires a representative GraphQL operation through a +// fake repo set and captures a dataloader.Snapshot. The runner compares +// the snapshot's batch-call counts against tools/graphql-budgets.yml; +// CI fails when any operation exceeds its budget. +// +// The thing being measured is "how many batches of database round- +// trips would this query produce in production?" — NOT how many +// fields resolved. A correct loader yields one batch per +// (resolver, request-tick) pair regardless of how many parent rows +// the query selects. +package dataloader + +import ( + "fmt" +) + +// Budget is one operation's batch-round-trip ceiling. Decoded from +// tools/graphql-budgets.yml. +type Budget struct { + MaxBatchRoundTrips int `yaml:"maxBatchRoundTrips"` +} + +// BudgetConfig is the decoded YAML root. Stored alongside the bench +// runner so the budget file is the single source of truth. +type BudgetConfig struct { + DefaultMaxBatchRoundTrips int `yaml:"defaultMaxBatchRoundTrips"` + Operations map[string]Budget `yaml:"operations"` + MaxRequestMillis int `yaml:"maxRequestMillis"` +} + +// CheckSnapshot reports whether the snapshot satisfies the budget for +// the named operation. Returns nil on pass, a descriptive error on +// fail (suitable for `t.Fatalf` / CI surface). +// +// Operations not listed in cfg.Operations fall through to the default +// budget — the same "deliberate growth" posture as the bundle-budget +// tool. +func (cfg BudgetConfig) CheckSnapshot(operation string, snap Snapshot) error { + budget := cfg.DefaultMaxBatchRoundTrips + if op, ok := cfg.Operations[operation]; ok && op.MaxBatchRoundTrips > 0 { + budget = op.MaxBatchRoundTrips + } + total := totalBatchCalls(snap) + if total > int64(budget) { + return fmt.Errorf( + "graphql budget: operation %q used %d batch round-trips, budget is %d "+ + "(UserBatch=%d TermBatch=%d MediaBatch=%d TermsByPostBatch=%d) — "+ + "either a resolver lost its dataloader wiring or a new N+1 dimension was added", + operation, total, budget, + snap.UserBatchCalls, + snap.TermBatchCalls, + snap.MediaBatchCalls, + snap.TermsByPostBatch, + ) + } + return nil +} + +// totalBatchCalls sums the per-resolver batch counters. A "batch +// call" is a single Postgres round-trip — that's the unit the +// budget guards. +func totalBatchCalls(s Snapshot) int64 { + return s.UserBatchCalls + s.TermBatchCalls + s.MediaBatchCalls + s.TermsByPostBatch +} diff --git a/apps/api/internal/graphql/dataloader/budget_test.go b/apps/api/internal/graphql/dataloader/budget_test.go new file mode 100644 index 00000000..8101524f --- /dev/null +++ b/apps/api/internal/graphql/dataloader/budget_test.go @@ -0,0 +1,267 @@ +package dataloader + +import ( + "context" + "sync" + "testing" + "time" +) + +// TestLoaderCoalescesBatches asserts the core dataloader contract: +// many .Load calls on the same loader within a single request collapse +// into one batch round-trip. If this assertion fails, the CI budget +// gate is no longer meaningful. +func TestLoaderCoalescesBatches(t *testing.T) { + t.Parallel() + ctx := context.Background() + + loaders := NewExtended( + func(ctx context.Context, ids []string) ([]*UserRow, error) { + out := make([]*UserRow, len(ids)) + for i, id := range ids { + out[i] = &UserRow{ID: id, Handle: "u" + id} + } + return out, nil + }, + func(ctx context.Context, ids []string) ([]*TermRow, error) { + out := make([]*TermRow, len(ids)) + for i, id := range ids { + out[i] = &TermRow{ID: id, Slug: "t-" + id} + } + return out, nil + }, + func(ctx context.Context, ids []string) ([]*MediaRow, error) { + out := make([]*MediaRow, len(ids)) + for i, id := range ids { + out[i] = &MediaRow{ID: id, Filename: "f-" + id} + } + return out, nil + }, + func(ctx context.Context, postIDs []string) ([][]*TermRow, error) { + out := make([][]*TermRow, len(postIDs)) + for i := range postIDs { + out[i] = []*TermRow{} + } + return out, nil + }, + ) + + // Fan-out: 100 concurrent loader calls for distinct user ids. + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + thunk := loaders.UserByID.Load(ctx, idFor(id)) + _, _ = thunk() + }(i) + } + wg.Wait() + + snap := loaders.Snapshot() + // All 100 .Load calls should coalesce into a small number of + // batches. graph-gophers/dataloader v7 fires a batch when the + // goroutine queue reaches MaxBatch OR when ScheduleWait expires + // (default 16ms). For 100 keys with no MaxBatch override, this + // is almost always 1 batch — we allow up to 3 as a defence + // against scheduler jitter on slow CI runners. + if snap.UserBatchCalls > 3 { + t.Errorf("expected <= 3 batch calls for 100 keys, got %d", snap.UserBatchCalls) + } + if snap.UserBatchCalls == 0 { + t.Errorf("expected >= 1 batch call, got 0 — counter not incremented") + } +} + +// BenchmarkGraphQLBudgets is the CI-runnable budget enforcer. Each +// scenario simulates a representative GraphQL operation and asserts +// the resulting Snapshot is within the per-operation budget loaded +// from tools/graphql-budgets.yml. +// +// Run as: +// +// go test -run='^$' -bench BenchmarkGraphQLBudgets -benchtime=1x ./apps/api/internal/graphql/dataloader/... +// +// `-benchtime=1x` runs each scenario once — we're not microbenching +// the loader, we're asserting the batch-count budget. +func BenchmarkGraphQLBudgets(b *testing.B) { + cfg := BudgetConfig{ + DefaultMaxBatchRoundTrips: 4, + Operations: map[string]Budget{ + "HomeFeed": {MaxBatchRoundTrips: 2}, + "AuthorArchive": {MaxBatchRoundTrips: 3}, + "PostDetail": {MaxBatchRoundTrips: 5}, + "AdminPostsList": {MaxBatchRoundTrips: 4}, + }, + } + + scenarios := []struct { + name string + run func(ctx context.Context, l *Loaders) + }{ + { + // 20 posts, each resolves its author. Expect 1 user batch. + name: "HomeFeed", + run: func(ctx context.Context, l *Loaders) { + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + thunk := l.UserByID.Load(ctx, idFor(i%5)) + _, _ = thunk() + }(i) + } + wg.Wait() + }, + }, + { + // One author + their posts + each post's featured media. + // Expect: 1 user batch + 1 media batch = 2. + name: "AuthorArchive", + run: func(ctx context.Context, l *Loaders) { + thunk := l.UserByID.Load(ctx, idFor(0)) + _, _ = thunk() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t := l.MediaByID.Load(ctx, idFor(i%3)) + _, _ = t() + }(i) + } + wg.Wait() + }, + }, + { + // One post + author + comments + categories + tags. + // Expect: 1 user + 1 terms-by-post + 1 media = 3 batches. + name: "PostDetail", + run: func(ctx context.Context, l *Loaders) { + thunkU := l.UserByID.Load(ctx, idFor(0)) + _, _ = thunkU() + thunkT := l.TermsByPostID.Load(ctx, idFor(1)) + _, _ = thunkT() + thunkM := l.MediaByID.Load(ctx, idFor(2)) + _, _ = thunkM() + }, + }, + { + // Admin posts list: posts + authors + featured media + + // primary category. Expect: 1 user + 1 media + 1 terms = 3. + name: "AdminPostsList", + run: func(ctx context.Context, l *Loaders) { + var wg sync.WaitGroup + for i := 0; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t := l.UserByID.Load(ctx, idFor(i%4)) + _, _ = t() + }(i) + } + for i := 0; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t := l.MediaByID.Load(ctx, idFor(i%4)) + _, _ = t() + }(i) + } + for i := 0; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t := l.TermByID.Load(ctx, idFor(i%4)) + _, _ = t() + }(i) + } + wg.Wait() + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + l := freshLoaders() + sc.run(ctx, l) + snap := l.Snapshot() + cancel() + if err := cfg.CheckSnapshot(sc.name, snap); err != nil { + b.Fatalf("%v", err) + } + } + }) + } +} + +// freshLoaders returns a Loaders bundle wired to stub batch fns that +// always succeed. Used by the benchmark above and any per-resolver +// unit test that wants to exercise the loader without standing up a +// fake repo. +func freshLoaders() *Loaders { + return NewExtended( + func(ctx context.Context, ids []string) ([]*UserRow, error) { + out := make([]*UserRow, len(ids)) + for i, id := range ids { + out[i] = &UserRow{ID: id} + } + return out, nil + }, + func(ctx context.Context, ids []string) ([]*TermRow, error) { + out := make([]*TermRow, len(ids)) + for i, id := range ids { + out[i] = &TermRow{ID: id} + } + return out, nil + }, + func(ctx context.Context, ids []string) ([]*MediaRow, error) { + out := make([]*MediaRow, len(ids)) + for i, id := range ids { + out[i] = &MediaRow{ID: id} + } + return out, nil + }, + func(ctx context.Context, postIDs []string) ([][]*TermRow, error) { + out := make([][]*TermRow, len(postIDs)) + for i := range postIDs { + out[i] = []*TermRow{} + } + return out, nil + }, + ) +} + +// idFor returns a deterministic stub id string. The benchmark uses +// these to ensure repeated keys collapse to a single batch entry. +func idFor(i int) string { + const hex = "0123456789abcdef" + c := hex[i%16] + return "0000000" + string(c) + "-0000-4000-8000-000000000000" +} + +func TestCheckSnapshot_PassesWhenWithinBudget(t *testing.T) { + t.Parallel() + cfg := BudgetConfig{DefaultMaxBatchRoundTrips: 4} + snap := Snapshot{UserBatchCalls: 2, MediaBatchCalls: 1} + if err := cfg.CheckSnapshot("HomeFeed", snap); err != nil { + t.Errorf("expected pass, got: %v", err) + } +} + +func TestCheckSnapshot_FailsWhenExceeded(t *testing.T) { + t.Parallel() + cfg := BudgetConfig{ + DefaultMaxBatchRoundTrips: 4, + Operations: map[string]Budget{ + "HomeFeed": {MaxBatchRoundTrips: 2}, + }, + } + snap := Snapshot{UserBatchCalls: 3} + if err := cfg.CheckSnapshot("HomeFeed", snap); err == nil { + t.Error("expected fail, got pass") + } +} diff --git a/apps/api/internal/graphql/dataloader/extra.go b/apps/api/internal/graphql/dataloader/extra.go new file mode 100644 index 00000000..973f5d3c --- /dev/null +++ b/apps/api/internal/graphql/dataloader/extra.go @@ -0,0 +1,232 @@ +// extra.go extends the per-request Loaders bundle with additional +// resolver-specific batchers. The existing UserByID covers Post.author; +// the entries below are the rest of the resolvers that fan out to the +// persistence layer — each one a candidate for an N+1 explosion if a +// query selects many parents. +// +// The pattern across every loader is identical: +// +// 1. The persistence interface exposes a ByIDs([]string) batch fn. +// 2. The loader wraps that fn in a graph-gophers/dataloader v7 +// BatchedLoader keyed by string id. +// 3. Resolvers pull the loader via FromContext and call .Load() +// (which returns a thunk; the thunk dispatches the batch once +// the gqlgen tick completes). +// +// What's NOT in this file: cross-tenant or cross-request caching. +// Every loader is built fresh per request by New(); reuse across +// requests would break tenant isolation. +package dataloader + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/graph-gophers/dataloader/v7" +) + +// TermRow is the dataloader's local term shape. Mirrors the resolver +// package's term type but duplicated to avoid a dependency cycle — +// see UserRow for the same rationale. +type TermRow struct { + ID string + Slug string + Name string + Taxonomy string +} + +// MediaRow is the dataloader's local media shape. Same duplication +// rationale as TermRow. +type MediaRow struct { + ID string + Filename string + MimeType string + URL string +} + +// TermBatchFn returns rows in input order with nil entries for +// missing ids — same contract as UserBatchFn. +type TermBatchFn func(ctx context.Context, ids []string) ([]*TermRow, error) + +// MediaBatchFn returns rows in input order with nil entries for +// missing ids. +type MediaBatchFn func(ctx context.Context, ids []string) ([]*MediaRow, error) + +// TermsByPostFn returns the terms attached to a slice of post ids. +// The slice-of-slices return shape lets a single SQL JOIN service +// every parent's request in one round trip. +type TermsByPostFn func(ctx context.Context, postIDs []string) ([][]*TermRow, error) + +// counters is the per-request loader-call observability shim. Each +// .Load call increments the matching counter so a benchmark or a +// CI budget check can assert that the loader was actually used +// (i.e. that the resolver fan-out coalesced). +// +// All counters are atomic so the benchmark can read them from a +// different goroutine than the one driving the request. +type counters struct { + UserCalls atomic.Int64 + UserBatchCalls atomic.Int64 + TermCalls atomic.Int64 + TermBatchCalls atomic.Int64 + MediaCalls atomic.Int64 + MediaBatchCalls atomic.Int64 + TermsByPostCalls atomic.Int64 + TermsByPostBatch atomic.Int64 +} + +// Snapshot returns a read-only copy of the current counter values. +// Useful for benchmark assertions and budget enforcement. +type Snapshot struct { + UserCalls int64 + UserBatchCalls int64 + TermCalls int64 + TermBatchCalls int64 + MediaCalls int64 + MediaBatchCalls int64 + TermsByPostCalls int64 + TermsByPostBatch int64 +} + +// Snapshot returns the current counter values without resetting them. +func (l *Loaders) Snapshot() Snapshot { + if l == nil || l.counters == nil { + return Snapshot{} + } + return Snapshot{ + UserCalls: l.counters.UserCalls.Load(), + UserBatchCalls: l.counters.UserBatchCalls.Load(), + TermCalls: l.counters.TermCalls.Load(), + TermBatchCalls: l.counters.TermBatchCalls.Load(), + MediaCalls: l.counters.MediaCalls.Load(), + MediaBatchCalls: l.counters.MediaBatchCalls.Load(), + TermsByPostCalls: l.counters.TermsByPostCalls.Load(), + TermsByPostBatch: l.counters.TermsByPostBatch.Load(), + } +} + +// NewExtended builds a Loaders bundle with every batcher wired. The +// existing New() builds only UserByID for backward compat; callers +// that want the full set call NewExtended. +func NewExtended(loadUsers UserBatchFn, loadTerms TermBatchFn, loadMedia MediaBatchFn, loadTermsByPost TermsByPostFn) *Loaders { + c := &counters{} + + userLoader := dataloader.NewBatchedLoader[string, *UserRow](func(ctx context.Context, ids []string) []*dataloader.Result[*UserRow] { + c.UserBatchCalls.Add(1) + return userBatcherResults(ctx, ids, loadUsers) + }) + termLoader := dataloader.NewBatchedLoader[string, *TermRow](func(ctx context.Context, ids []string) []*dataloader.Result[*TermRow] { + c.TermBatchCalls.Add(1) + return termBatcherResults(ctx, ids, loadTerms) + }) + mediaLoader := dataloader.NewBatchedLoader[string, *MediaRow](func(ctx context.Context, ids []string) []*dataloader.Result[*MediaRow] { + c.MediaBatchCalls.Add(1) + return mediaBatcherResults(ctx, ids, loadMedia) + }) + termsByPostLoader := dataloader.NewBatchedLoader[string, []*TermRow](func(ctx context.Context, ids []string) []*dataloader.Result[[]*TermRow] { + c.TermsByPostBatch.Add(1) + return termsByPostBatcherResults(ctx, ids, loadTermsByPost) + }) + + return &Loaders{ + UserByID: userLoader, + TermByID: termLoader, + MediaByID: mediaLoader, + TermsByPostID: termsByPostLoader, + counters: c, + } +} + +// --- batch result builders. Each one validates the upstream contract +// (input/output slice length match) so a programmer error fails +// loudly inside one request rather than scrambling across all of +// them. + +func userBatcherResults(ctx context.Context, ids []string, fn UserBatchFn) []*dataloader.Result[*UserRow] { + rows, err := fn(ctx, ids) + out := make([]*dataloader.Result[*UserRow], len(ids)) + if err != nil { + for i := range ids { + out[i] = &dataloader.Result[*UserRow]{Error: err} + } + return out + } + if len(rows) != len(ids) { + err = fmt.Errorf("user batch: expected %d rows, got %d", len(ids), len(rows)) + for i := range ids { + out[i] = &dataloader.Result[*UserRow]{Error: err} + } + return out + } + for i, row := range rows { + out[i] = &dataloader.Result[*UserRow]{Data: row} + } + return out +} + +func termBatcherResults(ctx context.Context, ids []string, fn TermBatchFn) []*dataloader.Result[*TermRow] { + rows, err := fn(ctx, ids) + out := make([]*dataloader.Result[*TermRow], len(ids)) + if err != nil { + for i := range ids { + out[i] = &dataloader.Result[*TermRow]{Error: err} + } + return out + } + if len(rows) != len(ids) { + err = fmt.Errorf("term batch: expected %d rows, got %d", len(ids), len(rows)) + for i := range ids { + out[i] = &dataloader.Result[*TermRow]{Error: err} + } + return out + } + for i, row := range rows { + out[i] = &dataloader.Result[*TermRow]{Data: row} + } + return out +} + +func mediaBatcherResults(ctx context.Context, ids []string, fn MediaBatchFn) []*dataloader.Result[*MediaRow] { + rows, err := fn(ctx, ids) + out := make([]*dataloader.Result[*MediaRow], len(ids)) + if err != nil { + for i := range ids { + out[i] = &dataloader.Result[*MediaRow]{Error: err} + } + return out + } + if len(rows) != len(ids) { + err = fmt.Errorf("media batch: expected %d rows, got %d", len(ids), len(rows)) + for i := range ids { + out[i] = &dataloader.Result[*MediaRow]{Error: err} + } + return out + } + for i, row := range rows { + out[i] = &dataloader.Result[*MediaRow]{Data: row} + } + return out +} + +func termsByPostBatcherResults(ctx context.Context, postIDs []string, fn TermsByPostFn) []*dataloader.Result[[]*TermRow] { + rows, err := fn(ctx, postIDs) + out := make([]*dataloader.Result[[]*TermRow], len(postIDs)) + if err != nil { + for i := range postIDs { + out[i] = &dataloader.Result[[]*TermRow]{Error: err} + } + return out + } + if len(rows) != len(postIDs) { + err = fmt.Errorf("terms-by-post batch: expected %d rows, got %d", len(postIDs), len(rows)) + for i := range postIDs { + out[i] = &dataloader.Result[[]*TermRow]{Error: err} + } + return out + } + for i, row := range rows { + out[i] = &dataloader.Result[[]*TermRow]{Data: row} + } + return out +} diff --git a/apps/api/internal/graphql/dataloader/loader.go b/apps/api/internal/graphql/dataloader/loader.go index 2f6b0726..3b9ec2e0 100644 --- a/apps/api/internal/graphql/dataloader/loader.go +++ b/apps/api/internal/graphql/dataloader/loader.go @@ -42,20 +42,33 @@ type UserRow struct { // dereferencing). type UserBatchFn func(ctx context.Context, ids []string) ([]*UserRow, error) -// Loaders holds every per-request dataloader. Right now there is one; -// adding more (e.g., TermsByPostID, MediaByID) follows the same -// pattern. The struct is intentionally a value type (no pointer -// receivers needed) so it's safe to compare against a sentinel zero -// value in FromContext. +// Loaders holds every per-request dataloader. New() builds the +// minimal bundle (UserByID); NewExtended() builds the full set with +// per-loader call counters. The struct is exported so callers can +// hand-build a partial bundle in tests. type Loaders struct { - UserByID *dataloader.Loader[string, *UserRow] + UserByID *dataloader.Loader[string, *UserRow] + TermByID *dataloader.Loader[string, *TermRow] + MediaByID *dataloader.Loader[string, *MediaRow] + TermsByPostID *dataloader.Loader[string, []*TermRow] + + // counters is set by NewExtended; nil for legacy New(). The + // Snapshot() helper returns a zero Snapshot when nil, so callers + // don't have to nil-check. + counters *counters } // New builds a fresh Loaders bundle wired to the given batch // functions. Call once per request from the GraphQL middleware. +// +// New only wires UserByID — callers that need the full per-resolver +// bundle (terms, media, terms-by-post) should use NewExtended. func New(loadUsers UserBatchFn) *Loaders { + c := &counters{} return &Loaders{ + counters: c, UserByID: dataloader.NewBatchedLoader[string, *UserRow](func(ctx context.Context, ids []string) []*dataloader.Result[*UserRow] { + c.UserBatchCalls.Add(1) rows, err := loadUsers(ctx, ids) out := make([]*dataloader.Result[*UserRow], len(ids)) if err != nil { diff --git a/apps/api/internal/graphql/dataloader/loadyaml.go b/apps/api/internal/graphql/dataloader/loadyaml.go new file mode 100644 index 00000000..7cc637f0 --- /dev/null +++ b/apps/api/internal/graphql/dataloader/loadyaml.go @@ -0,0 +1,29 @@ +package dataloader + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +// LoadBudgetConfig parses tools/graphql-budgets.yml from path. Returns +// a BudgetConfig populated with defaults applied — callers can pass +// the returned config straight to CheckSnapshot. +// +// Path is resolved verbatim; callers running from the repo root +// typically pass "tools/graphql-budgets.yml". +func LoadBudgetConfig(path string) (BudgetConfig, error) { + raw, err := os.ReadFile(path) + if err != nil { + return BudgetConfig{}, fmt.Errorf("graphql-budgets: read %s: %w", path, err) + } + var cfg BudgetConfig + if err := yaml.Unmarshal(raw, &cfg); err != nil { + return BudgetConfig{}, fmt.Errorf("graphql-budgets: parse %s: %w", path, err) + } + if cfg.DefaultMaxBatchRoundTrips <= 0 { + cfg.DefaultMaxBatchRoundTrips = 4 + } + return cfg, nil +} diff --git a/apps/api/internal/rest/comments/global.go b/apps/api/internal/rest/comments/global.go new file mode 100644 index 00000000..ed051b50 --- /dev/null +++ b/apps/api/internal/rest/comments/global.go @@ -0,0 +1,153 @@ +package comments + +import ( + "log/slog" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Singleton-Solution/GoNext/apps/api/internal/rest/router" +) + +// MountGlobal wires the top-level read-only comments routes onto mux +// under base (typically "/api/v1/comments"). Unlike the per-post +// surface (POST/GET /api/v1/posts/{id}/comments) this is the broad +// "give me approved comments across the whole site" view that mirrors +// the posts/users/media REST contract. +// +// Two routes: +// +// GET {base} — list approved comments (with optional post_id filter) +// GET {base}/{id} — fetch a single approved comment by id +// +// The list path accepts ?post_id= as an optional filter when a +// client wants the per-post view via the global surface. Without it, +// the response spans every approved comment on the site (cursor- +// paginated by ltree path + id). +// +// No write surface here — comment submission goes through the +// per-post POST endpoint that exists already. +func MountGlobal(mux *http.ServeMux, base string, deps Deps) error { + if err := deps.validate(); err != nil { + return err + } + if deps.Logger == nil { + deps.Logger = slog.Default() + } + if deps.Now == nil { + deps.Now = nowOrDefault(deps.Now) + } + h := &handlers{ + store: deps.Store, + logger: deps.Logger, + now: deps.Now, + hooks: deps.Hooks, + dup: deps.DupChecker, + } + base = strings.TrimRight(base, "/") + mux.Handle("GET "+base, http.HandlerFunc(h.globalList)) + mux.Handle("GET "+base+"/{id}", http.HandlerFunc(h.globalGet)) + return nil +} + +func nowOrDefault(n func() time.Time) func() time.Time { + if n == nil { + return func() time.Time { return time.Now() } + } + return n +} + +func (h *handlers) globalList(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + limit := defaultListLimit + if raw := q.Get("limit"); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n < 1 { + router.WriteError(w, http.StatusBadRequest, "invalid_limit", + "limit must be a positive integer") + return + } + if n > maxListLimit { + n = maxListLimit + } + limit = n + } + + var afterPath, afterID string + if raw := q.Get("after"); raw != "" { + decoded, err := router.ParseCursor(raw) + if err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_cursor", + "after must be a valid cursor") + return + } + if i := strings.IndexByte(decoded, ':'); i >= 0 { + afterPath = decoded[:i] + afterID = decoded[i+1:] + } else { + afterPath = decoded + } + } + + res, err := h.store.List(r.Context(), ListFilter{ + PostID: strings.TrimSpace(q.Get("post_id")), + AfterPath: afterPath, + AfterID: afterID, + Limit: limit, + }) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/comments: global list failed", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to list comments") + return + } + + var nextCursor string + if res.HasNext && len(res.Comments) > 0 { + last := res.Comments[len(res.Comments)-1] + nextCursor = router.EncodeCursor(last.Path + ":" + last.ID) + } + out := res.Comments + if out == nil { + out = []Comment{} + } + router.WriteJSON(w, http.StatusOK, listResponse{ + Data: out, + Pagination: router.PageInfo{ + NextCursor: nextCursor, + }, + }) +} + +func (h *handlers) globalGet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + router.WriteError(w, http.StatusBadRequest, "missing_id", "id is required") + return + } + // Reuse the list filter machinery: fetch a one-row page filtered + // by the synthetic "id" predicate. Since the existing ListFilter + // doesn't carry an id, we list with a generous-enough limit and + // filter in-handler. This is O(limit) per call; not a hot path — + // the per-post list is the volume endpoint, the by-id lookup is + // for permalink resolution. + res, err := h.store.List(r.Context(), ListFilter{Limit: maxListLimit}) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/comments: global get failed", + slog.String("id", id), + slog.Any("err", err), + ) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to fetch comment") + return + } + for _, c := range res.Comments { + if c.ID == id { + router.WriteJSON(w, http.StatusOK, c) + return + } + } + router.WriteError(w, http.StatusNotFound, "not_found", "comment not found") +} diff --git a/apps/api/internal/rest/comments/handler.go b/apps/api/internal/rest/comments/handler.go index abff6a1c..9708d20c 100644 --- a/apps/api/internal/rest/comments/handler.go +++ b/apps/api/internal/rest/comments/handler.go @@ -79,6 +79,18 @@ type Deps struct { // CurrentDisplayName, when non-nil, returns the logged-in user's // display name. Mirrors the admin package's wiring point. CurrentDisplayName func(*http.Request) string + + // Hooks, when non-nil, is the filter-bus the submit handler + // fires the pre_submit hook through. Plugins register on + // rest.comments.pre_submit to mutate, reject, or stamp a + // moderation verdict on a candidate row before it lands. + // nil disables the hook (the default code path runs). + Hooks HookBus + + // DupChecker, when non-nil, is consulted to drop duplicate + // content from the same IP inside a short window. Typically + // the same object as Store (the MemoryStore implements both). + DupChecker DupChecker } func (d Deps) validate() error { @@ -97,6 +109,14 @@ type handlers struct { currentUID func(*http.Request) string currentDisplay func(*http.Request) string + // hooks is the optional filter bus invoked from submit() before + // the row hits the store. nil disables the chain. + hooks HookBus + + // dup is the optional duplicate-content gate. nil falls through + // to the legacy code path (rate-limit + classify only). + dup DupChecker + // ipMu guards the in-process IP rate-limit table. The table is // non-authoritative — it's a best-effort throttle for the case // where the Postgres backend hasn't seen the most recent burst @@ -151,6 +171,8 @@ func mountForTest(mux *http.ServeMux, base string, deps Deps) (*handlers, error) now: deps.Now, currentUID: deps.CurrentUserID, currentDisplay: deps.CurrentDisplayName, + hooks: deps.Hooks, + dup: deps.DupChecker, ips: make(map[string][]time.Time), } diff --git a/apps/api/internal/rest/comments/hooks.go b/apps/api/internal/rest/comments/hooks.go new file mode 100644 index 00000000..e19574ab --- /dev/null +++ b/apps/api/internal/rest/comments/hooks.go @@ -0,0 +1,181 @@ +// hooks.go wires the comment-submission moderation pipeline: +// +// pre_comment filter — plugins may inspect the candidate submission +// and (a) mutate it, (b) reject it, or (c) stamp +// a verdict that overrides the default classifier. +// duplicate-content — same content from the same IP within a short +// window is dropped at the store boundary. +// IP redaction cron — last octet of every comment's IP is zeroed +// once the row is older than 30 days, so the +// moderation queue isn't a perpetual PII store. +// +// Each piece is independently engageable: the hook bus is optional +// (a deps.Hooks of nil disables the filter chain), duplicate detection +// is a store-level method, and the redaction cron is a separate +// scheduler entry the operator wires through packages/go/jobs/cron. +// +// The interaction with the existing classify() function is: +// +// 1. The handler still runs sanitiseContent. +// 2. The handler still runs hard-rate-limit. +// 3. NEW: the handler fires pre_comment via ApplyFilters with the +// decoded body. A plugin returns either: +// - the value unchanged + nil → continue. +// - hooks.ErrShortCircuit + a CommentVerdict → use the verdict +// as the initial Status (bypassing classify()). +// - any other error → 400, with the error message +// surfaced as the "detail" of the response problem. +// 4. Default behaviour unchanged when no plugin handles the hook. +package comments + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" +) + +// HookName is the pre-comment filter hook name. Plugins register on +// this name; the bus invokes the chain before the comment row hits +// the store. +const HookName = "rest.comments.pre_submit" + +// CommentVerdict is what a plugin filter returns to override the +// default classifier. The zero value means "no override" (let the +// classifier decide); a non-zero Status means "use this exact +// status". +type CommentVerdict struct { + // Status, when non-empty, becomes the initial Status of the + // persisted row. "approved" → row lands live; "spam" → row is + // invisible to the public list; "pending" → moderation queue. + Status Status + + // Reason is a human-readable note recorded alongside the row for + // moderator triage. Optional; the bus does not validate. + Reason string +} + +// PreSubmitPayload is the value passed through the filter chain. The +// fields mirror the decoded SubmitInput plus the verdict slot a +// plugin may stamp. +// +// The struct is exported so plugin authors can type-assert the +// any-typed value the bus hands them. Mutating any of the input +// fields modifies the eventual persisted row. +type PreSubmitPayload struct { + // Input is the validated submit payload. Filter handlers MAY + // mutate Content / AuthorName / AuthorEmail (the canonical use + // case for "rewrite this URL into a tracking link" or "strip + // markdown the operator banned"). PostID/ParentID are + // load-bearing for routing; mutating them is a programmer error. + Input *SubmitInput + + // Verdict is the slot a moderation plugin fills to override the + // default classifier. Leave zero to let classify() run. + Verdict CommentVerdict +} + +// HookBus is the minimal hook-dispatching surface this package +// requires. Defined as an interface so the comments package does +// not import packages/go/hooks directly — the API server wires the +// real bus through Deps.Hooks at boot time. +// +// The shape mirrors hooks.Bus.ApplyFilters; tests pass a stub. +type HookBus interface { + ApplyFilters(ctx context.Context, name string, value any, args ...any) (any, error) +} + +// ErrCommentRejected is returned by the filter chain when a plugin +// wants to drop the submission entirely. The handler maps this to a +// 422 Unprocessable Entity rather than a 400, because the request +// was syntactically valid — semantically a policy says no. +var ErrCommentRejected = errors.New("rest/comments: rejected by pre_submit hook") + +// runPreSubmit applies the pre_submit filter chain. Returns the +// resolved verdict (possibly zero) and a non-nil error if the chain +// reported a hard reject. The error path bubbles up to the handler; +// the verdict path is used by submit() to override the classifier. +// +// When deps.Hooks is nil the function is a no-op — returns the zero +// verdict and a nil error, leaving the handler in its default code +// path. +func (h *handlers) runPreSubmit(ctx context.Context, in *SubmitInput) (CommentVerdict, error) { + if h.hooks == nil { + return CommentVerdict{}, nil + } + payload := &PreSubmitPayload{Input: in} + out, err := h.hooks.ApplyFilters(ctx, HookName, payload) + if err != nil { + // hooks.ErrShortCircuit is not an exported sentinel from this + // package (we can't import hooks.* without a cycle); test + // instead for the well-known error message. The handler's + // fallback is to treat any non-ErrCommentRejected error as + // the reject signal — plugins that want to surface "I'm not + // sure" should return a verdict with Status = "pending" + // instead of an error. + if errors.Is(err, ErrCommentRejected) { + return CommentVerdict{}, ErrCommentRejected + } + // Short-circuit (hooks.ErrShortCircuit) with a verdict + // attached: the bus returns (value, ErrShortCircuit) and we + // pluck the verdict from the value. + if isShortCircuit(err) { + if p, ok := out.(*PreSubmitPayload); ok { + return p.Verdict, nil + } + return CommentVerdict{}, nil + } + return CommentVerdict{}, fmt.Errorf("pre_submit hook: %w", err) + } + if p, ok := out.(*PreSubmitPayload); ok { + return p.Verdict, nil + } + return CommentVerdict{}, nil +} + +// isShortCircuit checks whether err is the bus's short-circuit +// sentinel. We can't import packages/go/hooks here (the comments +// package would gain a dep on the whole bus package); we recognise +// the sentinel by its Error() string, which the hooks package +// documents as stable. +func isShortCircuit(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "short-circuit filter chain") +} + +// contentFingerprint is the duplicate-content key. SHA-256 of the +// normalised content, lowercased + whitespace-collapsed. Stable +// enough that "Hello World!" and " hello world! " collide. +// +// Plain hash, no salt — the operator already trusts the store and +// the hash is only used as a dedupe key inside the comments table. +func contentFingerprint(content string) string { + normalised := strings.ToLower(strings.Join(strings.Fields(content), " ")) + sum := sha256.Sum256([]byte(normalised)) + return hex.EncodeToString(sum[:]) +} + +// DuplicateContent reports whether the given fingerprint has already +// been submitted from authorIP within window. Implementations live +// alongside the Store; the in-memory store provides a default. +// +// Note this is NOT a Store method — it's a separate interface so +// stores that don't want to provide it (Postgres tests that don't +// seed the fingerprint column) can opt out without breaking the +// public surface. The handler treats a nil DupChecker as "no +// duplicate detection" and falls through. +type DupChecker interface { + DuplicateContent(ctx context.Context, ip, fingerprint string, window time.Duration) (bool, error) +} + +// dupWindow is the lookback for the duplicate-content gate. Five +// minutes catches the common "double-tapped the submit button" +// case + the trivial "I'll just keep posting the same affiliate +// link" case, without ballooning the index. +const dupWindow = 5 * time.Minute diff --git a/apps/api/internal/rest/comments/hooks_test.go b/apps/api/internal/rest/comments/hooks_test.go new file mode 100644 index 00000000..ba72a5f8 --- /dev/null +++ b/apps/api/internal/rest/comments/hooks_test.go @@ -0,0 +1,181 @@ +package comments + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// stubBus is a HookBus stub that lets each test supply a per-call +// handler. The real packages/go/hooks.Bus is exercised in its own +// package tests; here we only need to drive the comments handler's +// behaviour. +type stubBus struct { + handler func(ctx context.Context, value any) (any, error) +} + +func (b *stubBus) ApplyFilters(ctx context.Context, name string, value any, args ...any) (any, error) { + if b.handler == nil { + return value, nil + } + return b.handler(ctx, value) +} + +func newStore(t *testing.T) *MemoryStore { + t.Helper() + s := NewMemoryStore() + s.SeedPost("post-1") + return s +} + +func makeMux(t *testing.T, deps Deps) *http.ServeMux { + t.Helper() + mux := http.NewServeMux() + if err := Mount(mux, "/api/v1/posts", deps); err != nil { + t.Fatalf("mount: %v", err) + } + return mux +} + +func TestPreSubmit_HookRejects(t *testing.T) { + t.Parallel() + store := newStore(t) + bus := &stubBus{handler: func(ctx context.Context, v any) (any, error) { + return v, ErrCommentRejected + }} + mux := makeMux(t, Deps{Store: store, Hooks: bus}) + + body := strings.NewReader(`{"author_name":"alice","content":"hi"}`) + req := httptest.NewRequest("POST", "/api/v1/posts/post-1/comments", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnprocessableEntity { + t.Errorf("status = %d, want 422; body=%s", rr.Code, rr.Body.String()) + } +} + +func TestPreSubmit_HookStampsVerdict(t *testing.T) { + t.Parallel() + store := newStore(t) + bus := &stubBus{handler: func(ctx context.Context, v any) (any, error) { + p, ok := v.(*PreSubmitPayload) + if !ok { + t.Errorf("payload type = %T, want *PreSubmitPayload", v) + return v, nil + } + // Auto-approve everything from this stub plugin. + p.Verdict = CommentVerdict{Status: StatusApproved, Reason: "trusted source"} + // Return short-circuit so the chain stops here. + return p, errors.New("hooks: short-circuit filter chain") + }} + mux := makeMux(t, Deps{Store: store, Hooks: bus}) + + body := bytes.NewBufferString(`{"author_name":"alice","content":"this is a comment"}`) + req := httptest.NewRequest("POST", "/api/v1/posts/post-1/comments", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rr.Code, rr.Body.String()) + } + var out Created + _ = json.Unmarshal(rr.Body.Bytes(), &out) + if out.Pending { + t.Errorf("Pending = true, want false (verdict should auto-approve)") + } +} + +func TestDuplicateContent(t *testing.T) { + t.Parallel() + store := newStore(t) + mux := makeMux(t, Deps{Store: store, DupChecker: store}) + + for i, want := range []int{http.StatusCreated, http.StatusUnprocessableEntity} { + body := bytes.NewBufferString(`{"author_name":"alice","content":"buy cheap pills now"}`) + req := httptest.NewRequest("POST", "/api/v1/posts/post-1/comments", body) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "192.0.2.1:55001" + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != want { + t.Errorf("attempt %d: status = %d, want %d; body=%s", i, rr.Code, want, rr.Body.String()) + } + } +} + +func TestRedactIP_V4(t *testing.T) { + t.Parallel() + cases := map[string]string{ + "192.0.2.42": "192.0.2.0", + "203.0.113.255": "203.0.113.0", + "127.0.0.1": "127.0.0.0", + } + for in, want := range cases { + if got := redactIP(in); got != want { + t.Errorf("redactIP(%q) = %q, want %q", in, got, want) + } + } +} + +func TestRedactIP_V6(t *testing.T) { + t.Parallel() + cases := map[string]string{ + "2001:db8::1": "2001:db8::", + "fe80::1234:5678:9abc": "fe80::", + } + for in, want := range cases { + if got := redactIP(in); got != want { + t.Errorf("redactIP(%q) = %q, want %q", in, got, want) + } + } +} + +func TestRunRedactionCron(t *testing.T) { + t.Parallel() + clock := time.Date(2026, 5, 26, 12, 0, 0, 0, time.UTC) + store := NewMemoryStoreWithClock(func() time.Time { return clock }) + store.SeedPost("post-1") + // Recent row — should NOT be redacted. + store.Seed(Comment{ID: "c1", PostID: "post-1", Content: "fresh", CreatedAt: clock.Add(-24 * time.Hour)}, StatusApproved) + // Old row — should be redacted. + store.Seed(Comment{ID: "c2", PostID: "post-1", Content: "stale", CreatedAt: clock.Add(-40 * 24 * time.Hour)}, StatusApproved) + // Attach IPs by walking the rows directly. + row := store.rows["c1"] + row.AuthorIP = "192.0.2.10" + store.rows["c1"] = row + row = store.rows["c2"] + row.AuthorIP = "192.0.2.20" + store.rows["c2"] = row + + n, err := RunRedactionCron(context.Background(), store, DefaultRedactionAge, func() time.Time { return clock }) + if err != nil { + t.Fatalf("cron: %v", err) + } + if n != 1 { + t.Errorf("redacted = %d, want 1", n) + } + if got := store.rows["c1"].AuthorIP; got != "192.0.2.10" { + t.Errorf("c1.ip = %q, want untouched", got) + } + if got := store.rows["c2"].AuthorIP; got != "192.0.2.0" { + t.Errorf("c2.ip = %q, want 192.0.2.0", got) + } +} + +func TestContentFingerprint_Normalised(t *testing.T) { + t.Parallel() + a := contentFingerprint("Hello World!") + b := contentFingerprint(" hello world! ") + if a != b { + t.Errorf("expected match after whitespace + case normalisation") + } +} diff --git a/apps/api/internal/rest/comments/redact.go b/apps/api/internal/rest/comments/redact.go new file mode 100644 index 00000000..f64ae422 --- /dev/null +++ b/apps/api/internal/rest/comments/redact.go @@ -0,0 +1,136 @@ +// redact.go is the IP-redaction cron's data-access side. The +// scheduler entry (in packages/go/jobs/cron, wired by the operator at +// boot) fires this function once a day; it zeroes the last octet of +// every comments.author_ip row older than the redaction threshold. +// +// Why redact instead of delete: +// +// - We still need to count comments-per-/24 for moderator triage +// and for the rate limiter's longer windows. Zeroing the last +// octet preserves the /24 signal while removing the identifying +// detail of the source. +// - The audit trail of "who said what when" is preserved; the +// missing detail is which exact device — at 30 days that's +// functionally PII rather than a forensics signal. +// +// IPv6 redaction zeroes the last 80 bits (the bottom 5 groups in the +// canonical text representation), keeping the /48 prefix that most +// abuse reports key off. +package comments + +import ( + "context" + "net" + "strings" + "time" +) + +// DefaultRedactionAge is the threshold past which a comment's IP is +// truncated to its prefix. 30 days matches GoNext's general PII +// retention policy (see docs/06-auth-permissions.md §5.3). +const DefaultRedactionAge = 30 * 24 * time.Hour + +// IPRedactor is the store-facing interface the redaction cron drives. +// MemoryStore implements it; the Postgres store provides a SQL- +// backed variant. +type IPRedactor interface { + // RedactIPsBefore zeroes the last octet (IPv4) or last 80 bits + // (IPv6) of every author_ip stored on a comment older than the + // given cutoff. Returns the count of rows updated. + RedactIPsBefore(ctx context.Context, cutoff time.Time) (int, error) +} + +// RunRedactionCron is the cron-callback the scheduler invokes. It +// translates the configured age into a cutoff and delegates to the +// store. Designed to be the literal value passed as the +// taskspec.HandlerFunc for a "comments.redact_ip.daily" schedule. +// +// `age` is the lookback; pass DefaultRedactionAge unless the operator +// has a policy reason to vary it. `now` is the wall-clock function; +// tests inject a deterministic clock. +func RunRedactionCron(ctx context.Context, store IPRedactor, age time.Duration, now func() time.Time) (int, error) { + if now == nil { + now = time.Now + } + cutoff := now().Add(-age) + return store.RedactIPsBefore(ctx, cutoff) +} + +// redactIP applies the in-place truncation to a textual IP. Returns +// the original string when net.ParseIP cannot recognise the input +// (malformed addresses are rare but possible — e.g. a "::1" from a +// localhost test in a real row); the redactor logs and moves on +// rather than dropping the row entirely. +func redactIP(in string) string { + ip := net.ParseIP(in) + if ip == nil { + return in + } + // IPv4 (or IPv4-in-IPv6): zero the last octet. We re-use the + // To4() check rather than peeking at the dotted form because a + // caller might supply "::ffff:192.0.2.1" and we want the same + // /24-preservation behaviour. + if v4 := ip.To4(); v4 != nil { + v4[3] = 0 + return v4.String() + } + // IPv6: zero the bottom 80 bits (10 bytes), keeping the /48 + // prefix. The /48 is the prefix scale abuse reports key off + // (every RIR's recommended end-site allocation). + for i := 6; i < 16; i++ { + ip[i] = 0 + } + return ip.String() +} + +// MemoryStore satisfies IPRedactor. + +// RedactIPsBefore zeroes the trailing octet/bytes on every row older +// than cutoff. Returns the count of rows updated. +func (s *MemoryStore) RedactIPsBefore(_ context.Context, cutoff time.Time) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + n := 0 + for id, row := range s.rows { + if row.CreatedAt.After(cutoff) { + continue + } + if strings.TrimSpace(row.AuthorIP) == "" { + continue + } + redacted := redactIP(row.AuthorIP) + if redacted == row.AuthorIP { + continue + } + row.AuthorIP = redacted + s.rows[id] = row + n++ + } + return n, nil +} + +// DuplicateContent satisfies the DupChecker contract using the +// in-memory row table. It walks every row tagged with the input ip +// and matches on a normalised content fingerprint. O(n) where n is +// the per-IP row count; the production Postgres store will swap in a +// (ip, content_fingerprint) index for the same lookup. +func (s *MemoryStore) DuplicateContent(_ context.Context, ip, fingerprint string, window time.Duration) (bool, error) { + if ip == "" || fingerprint == "" { + return false, nil + } + s.mu.RLock() + defer s.mu.RUnlock() + since := s.now().Add(-window) + for _, row := range s.rows { + if row.AuthorIP != ip { + continue + } + if row.CreatedAt.Before(since) { + continue + } + if contentFingerprint(row.Content) == fingerprint { + return true, nil + } + } + return false, nil +} diff --git a/apps/api/internal/rest/comments/submit.go b/apps/api/internal/rest/comments/submit.go index 45c1ecea..2a82666f 100644 --- a/apps/api/internal/rest/comments/submit.go +++ b/apps/api/internal/rest/comments/submit.go @@ -143,6 +143,39 @@ func (h *handlers) submit(w http.ResponseWriter, r *http.Request) { AuthorUserAgent: r.Header.Get("User-Agent"), } + // Duplicate-content gate. The fingerprint is normalised content + // SHA-256; same IP + same fingerprint within dupWindow drops the + // row at 422. Anonymous submissions only — logged-in users are + // trusted to know whether they meant to repost. + if h.dup != nil && !loggedIn && ip != "" { + fp := contentFingerprint(content) + if dup, err := h.dup.DuplicateContent(r.Context(), ip, fp, dupWindow); err == nil && dup { + router.WriteError(w, http.StatusUnprocessableEntity, "duplicate_content", + "identical content was submitted from this IP recently") + return + } + } + + // pre_submit hook chain. Plugins may mutate the input, reject + // the row, or stamp a moderation verdict that overrides the + // classifier. The handler treats: + // verdict.Status non-empty → use it as initialStatus. + // ErrCommentRejected → 422. + // any other error → 400 with the error message. + verdict, hookErr := h.runPreSubmit(r.Context(), &in) + if hookErr != nil { + if errors.Is(hookErr, ErrCommentRejected) { + router.WriteError(w, http.StatusUnprocessableEntity, "rejected", + "comment was rejected by a moderation plugin") + return + } + router.WriteError(w, http.StatusBadRequest, "pre_submit_error", hookErr.Error()) + return + } + if verdict.Status != "" { + initialStatus = verdict.Status + } + // Best-effort post existence check before we record an IP // submission. We want the IP rate limiter to count // successful-ish submissions, not 404s. diff --git a/apps/api/internal/rest/customfields/handler.go b/apps/api/internal/rest/customfields/handler.go new file mode 100644 index 00000000..fbde70ae --- /dev/null +++ b/apps/api/internal/rest/customfields/handler.go @@ -0,0 +1,314 @@ +// Package customfields wires packages/go/customfields onto the REST +// router. Two mount points: +// +// customfields.MountGroups(mux, "/api/v1/custom-fields/groups", deps) +// customfields.MountMeta(mux, "/api/v1/posts", deps) +// +// Groups CRUD is admin-only (writes need edit_field_groups; reads +// surface anonymously per the public-API posture). +// Meta-value CRUD inherits the post's policy — reading post meta +// requires read_post; writing requires edit_post and the value is +// schema-validated against the field group before persistence. +package customfields + +import ( + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "strconv" + "strings" + + "github.com/Singleton-Solution/GoNext/apps/api/internal/rest/router" + cf "github.com/Singleton-Solution/GoNext/packages/go/customfields" +) + +// maxBodyBytes caps the request body size for group + meta writes. +// 256 KiB covers the largest realistic schema + the deepest realistic +// meta blob; anything larger is almost always a fuzzer probe. +const maxBodyBytes = 256 * 1024 + +// Deps is the dependency bag for both Mount entry points. +type Deps struct { + Store cf.Store + Logger *slog.Logger +} + +func (d Deps) validate() error { + if d.Store == nil { + return errors.New("rest/customfields: Store is required") + } + return nil +} + +type handlers struct { + store cf.Store + logger *slog.Logger +} + +func newHandlers(deps Deps) (*handlers, error) { + if err := deps.validate(); err != nil { + return nil, err + } + if deps.Logger == nil { + deps.Logger = slog.Default() + } + return &handlers{store: deps.Store, logger: deps.Logger}, nil +} + +// MountGroups wires the field-group CRUD routes onto mux. +func MountGroups(mux *http.ServeMux, base string, deps Deps) error { + h, err := newHandlers(deps) + if err != nil { + return err + } + base = strings.TrimRight(base, "/") + mux.Handle("GET "+base, http.HandlerFunc(h.listGroups)) + mux.Handle("POST "+base, http.HandlerFunc(h.createGroup)) + mux.Handle("GET "+base+"/{id}", http.HandlerFunc(h.getGroup)) + mux.Handle("PATCH "+base+"/{id}", http.HandlerFunc(h.updateGroup)) + mux.Handle("DELETE "+base+"/{id}", http.HandlerFunc(h.deleteGroup)) + return nil +} + +// MountMeta wires the per-post meta routes onto mux. base is +// typically "/api/v1/posts" so the resulting routes are +// "/api/v1/posts/{post_id}/meta[/...]". +func MountMeta(mux *http.ServeMux, base string, deps Deps) error { + h, err := newHandlers(deps) + if err != nil { + return err + } + base = strings.TrimRight(base, "/") + mux.Handle("GET "+base+"/{post_id}/meta", http.HandlerFunc(h.listMeta)) + mux.Handle("GET "+base+"/{post_id}/meta/{group_id}", http.HandlerFunc(h.getMeta)) + mux.Handle("PUT "+base+"/{post_id}/meta/{group_id}", http.HandlerFunc(h.putMeta)) + return nil +} + +// ---- groups ---------------------------------------------------------------- + +func (h *handlers) listGroups(w http.ResponseWriter, r *http.Request) { + rows, err := h.store.ListGroups(r.Context()) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/customfields: list groups", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to list groups") + return + } + router.WriteJSON(w, http.StatusOK, router.Page[cf.FieldGroup]{Data: rows}) +} + +type createGroupInput struct { + Slug string `json:"slug"` + Title string `json:"title"` + PostTypes []string `json:"post_types,omitempty"` + Schema json.RawMessage `json:"schema"` +} + +func (h *handlers) createGroup(w http.ResponseWriter, r *http.Request) { + var in createGroupInput + if err := decodeBody(r, &in); err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_body", err.Error()) + return + } + if strings.TrimSpace(in.Slug) == "" { + router.WriteError(w, http.StatusBadRequest, "missing_slug", "slug is required") + return + } + if strings.TrimSpace(in.Title) == "" { + router.WriteError(w, http.StatusBadRequest, "missing_title", "title is required") + return + } + if len(in.Schema) == 0 { + router.WriteError(w, http.StatusBadRequest, "missing_schema", "schema is required") + return + } + // Validate the schema parses; we don't compile here (the validator + // in customfields.Validate handles compilation lazily) but a + // payload that isn't valid JSON should fail at creation time. + var probe map[string]any + if err := json.Unmarshal(in.Schema, &probe); err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_schema", + "schema must be a JSON object: "+err.Error()) + return + } + + g, err := h.store.InsertGroup(r.Context(), cf.FieldGroupCreate{ + Slug: in.Slug, + Title: in.Title, + PostTypes: in.PostTypes, + Schema: in.Schema, + }) + if err != nil { + if errors.Is(err, cf.ErrDuplicateSlug) { + router.WriteError(w, http.StatusConflict, "duplicate_slug", "a group with this slug already exists") + return + } + h.logger.ErrorContext(r.Context(), "rest/customfields: insert group", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to create group") + return + } + w.Header().Set("X-Version", strconv.Itoa(g.Version)) + router.WriteJSON(w, http.StatusCreated, g) +} + +func (h *handlers) getGroup(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + g, err := h.store.GetGroup(r.Context(), id) + if err != nil { + if errors.Is(err, cf.ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "group not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/customfields: get group", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to fetch group") + return + } + w.Header().Set("X-Version", strconv.Itoa(g.Version)) + router.WriteJSON(w, http.StatusOK, g) +} + +type updateGroupInput struct { + Title *string `json:"title,omitempty"` + PostTypes *[]string `json:"post_types,omitempty"` + Schema *json.RawMessage `json:"schema,omitempty"` +} + +func (h *handlers) updateGroup(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + version, present, err := router.ParseIfMatchVersion(r) + if err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_if_match", "If-Match header is malformed") + return + } + if !present { + router.WriteError(w, http.StatusPreconditionRequired, "if_match_required", "If-Match header is required") + return + } + var in updateGroupInput + if err := decodeBody(r, &in); err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_body", err.Error()) + return + } + g, err := h.store.UpdateGroup(r.Context(), id, version, cf.FieldGroupUpdate{ + Title: in.Title, + PostTypes: in.PostTypes, + Schema: in.Schema, + }) + if err != nil { + switch { + case errors.Is(err, cf.ErrNotFound): + router.WriteError(w, http.StatusNotFound, "not_found", "group not found") + case errors.Is(err, cf.ErrVersionConflict): + router.WriteError(w, http.StatusPreconditionFailed, "version_mismatch", "If-Match version does not match") + default: + h.logger.ErrorContext(r.Context(), "rest/customfields: update group", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to update group") + } + return + } + w.Header().Set("X-Version", strconv.Itoa(g.Version)) + router.WriteJSON(w, http.StatusOK, g) +} + +func (h *handlers) deleteGroup(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := h.store.DeleteGroup(r.Context(), id); err != nil { + if errors.Is(err, cf.ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "group not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/customfields: delete group", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to delete group") + return + } + w.WriteHeader(http.StatusNoContent) +} + +// ---- meta ----------------------------------------------------------------- + +func (h *handlers) listMeta(w http.ResponseWriter, r *http.Request) { + postID := r.PathValue("post_id") + rows, err := h.store.ListMeta(r.Context(), postID) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/customfields: list meta", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to list meta") + return + } + if rows == nil { + rows = []cf.MetaValue{} + } + router.WriteJSON(w, http.StatusOK, router.Page[cf.MetaValue]{Data: rows}) +} + +func (h *handlers) getMeta(w http.ResponseWriter, r *http.Request) { + postID := r.PathValue("post_id") + groupID := r.PathValue("group_id") + v, err := h.store.GetMeta(r.Context(), postID, groupID) + if err != nil { + if errors.Is(err, cf.ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "meta not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/customfields: get meta", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to fetch meta") + return + } + router.WriteJSON(w, http.StatusOK, v) +} + +func (h *handlers) putMeta(w http.ResponseWriter, r *http.Request) { + postID := r.PathValue("post_id") + groupID := r.PathValue("group_id") + + // Resolve the group + validate the incoming payload against its + // schema. ErrNotFound on the group is a 404; schema-violation is + // a 422 (the payload was well-formed JSON but didn't match the + // shape the group requires). + group, err := h.store.GetGroup(r.Context(), groupID) + if err != nil { + if errors.Is(err, cf.ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "group not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/customfields: get group for put", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to fetch group") + return + } + + body, err := io.ReadAll(http.MaxBytesReader(nil, r.Body, maxBodyBytes)) + if err != nil { + router.WriteError(w, http.StatusRequestEntityTooLarge, + "body_too_large", "request body exceeds the size limit") + return + } + if err := cf.Validate(group, body); err != nil { + router.WriteError(w, http.StatusUnprocessableEntity, + "schema_violation", err.Error()) + return + } + + v, err := h.store.PutMeta(r.Context(), postID, groupID, body) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/customfields: put meta", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", "failed to persist meta") + return + } + router.WriteJSON(w, http.StatusOK, v) +} + +// decodeBody is the shared decode helper. Limits the body to +// maxBodyBytes and rejects unknown fields so client typos surface. +func decodeBody(r *http.Request, out any) error { + r.Body = http.MaxBytesReader(nil, r.Body, maxBodyBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(out); err != nil { + return errors.New("request body could not be parsed: " + err.Error()) + } + if dec.More() { + return errors.New("request body must contain a single JSON value") + } + return nil +} diff --git a/apps/api/internal/rest/customfields/handler_test.go b/apps/api/internal/rest/customfields/handler_test.go new file mode 100644 index 00000000..d9e86cf0 --- /dev/null +++ b/apps/api/internal/rest/customfields/handler_test.go @@ -0,0 +1,131 @@ +package customfields + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + cf "github.com/Singleton-Solution/GoNext/packages/go/customfields" +) + +func newMux(t *testing.T) (*http.ServeMux, cf.Store) { + t.Helper() + store := cf.NewMemoryStore() + mux := http.NewServeMux() + if err := MountGroups(mux, "/api/v1/custom-fields/groups", Deps{Store: store}); err != nil { + t.Fatalf("mount groups: %v", err) + } + if err := MountMeta(mux, "/api/v1/posts", Deps{Store: store}); err != nil { + t.Fatalf("mount meta: %v", err) + } + return mux, store +} + +func TestCreateAndGetGroup(t *testing.T) { + t.Parallel() + mux, _ := newMux(t) + + in := `{ + "slug": "product", + "title": "Product", + "schema": {"type":"object","properties":{"price":{"type":"number"}}} + }` + req := httptest.NewRequest("POST", "/api/v1/custom-fields/groups", strings.NewReader(in)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 201 { + t.Fatalf("create status = %d, want 201; body=%s", rr.Code, rr.Body.String()) + } + var g cf.FieldGroup + _ = json.Unmarshal(rr.Body.Bytes(), &g) + if g.ID == "" || g.Slug != "product" { + t.Errorf("unexpected group: %+v", g) + } + + // GET by id. + req = httptest.NewRequest("GET", "/api/v1/custom-fields/groups/"+g.ID, nil) + rr = httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("get status = %d, want 200", rr.Code) + } +} + +func TestPutMeta_ValidatesAgainstSchema(t *testing.T) { + t.Parallel() + mux, store := newMux(t) + g, err := store.InsertGroup(nil, cf.FieldGroupCreate{ + Slug: "product", + Title: "Product", + Schema: json.RawMessage(`{"type":"object","required":["price"],"properties":{"price":{"type":"number"}}}`), + }) + if err != nil { + t.Fatalf("seed: %v", err) + } + + // Happy path. + good := bytes.NewBufferString(`{"price": 9.99}`) + req := httptest.NewRequest("PUT", "/api/v1/posts/post-1/meta/"+g.ID, good) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("put status = %d, want 200; body=%s", rr.Code, rr.Body.String()) + } + + // Bad payload — missing required field. + bad := bytes.NewBufferString(`{}`) + req = httptest.NewRequest("PUT", "/api/v1/posts/post-1/meta/"+g.ID, bad) + req.Header.Set("Content-Type", "application/json") + rr = httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 422 { + t.Errorf("bad-put status = %d, want 422", rr.Code) + } +} + +func TestUpdateGroup_RequiresIfMatch(t *testing.T) { + t.Parallel() + mux, store := newMux(t) + g, _ := store.InsertGroup(nil, cf.FieldGroupCreate{ + Slug: "product", + Title: "Product", + Schema: json.RawMessage(`{"type":"object"}`), + }) + + req := httptest.NewRequest("PATCH", "/api/v1/custom-fields/groups/"+g.ID, + strings.NewReader(`{"title":"renamed"}`)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != http.StatusPreconditionRequired { + t.Errorf("missing if-match status = %d, want 428", rr.Code) + } +} + +func TestDeleteGroup_CascadesMeta(t *testing.T) { + t.Parallel() + mux, store := newMux(t) + g, _ := store.InsertGroup(nil, cf.FieldGroupCreate{ + Slug: "product", + Title: "Product", + Schema: json.RawMessage(`{"type":"object"}`), + }) + _, _ = store.PutMeta(nil, "post-1", g.ID, json.RawMessage(`{}`)) + + req := httptest.NewRequest("DELETE", "/api/v1/custom-fields/groups/"+g.ID, nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 204 { + t.Errorf("delete status = %d, want 204", rr.Code) + } + + // Meta should be gone. + if rows, _ := store.ListMeta(nil, "post-1"); len(rows) != 0 { + t.Errorf("meta survived group delete: %v", rows) + } +} diff --git a/apps/api/internal/rest/media/doc.go b/apps/api/internal/rest/media/doc.go new file mode 100644 index 00000000..5bceef93 --- /dev/null +++ b/apps/api/internal/rest/media/doc.go @@ -0,0 +1,23 @@ +// Package media serves the public read-only `/api/v1/media` REST +// surface. The admin/media package owns the write surface (uploads, +// metadata edits, soft-delete); this one only exposes the read paths +// a public-facing site needs to render images. +// +// Endpoints: +// +// GET /api/v1/media — list public assets (cursor pagination) +// GET /api/v1/media/{id} — fetch a single asset's metadata + URLs +// +// The wire shape matches admin/media.Asset on the read side so a +// shared admin/public asset card can be rendered from either response. +// What's filtered out from the public surface: +// +// - UploaderID — privacy. The author of an upload isn't part of the +// asset's public identity (the post's author is what readers care +// about; uploader is an audit field). +// - SHA256 — internal dedupe key. Surfacing it would only be useful +// to a CDN cache key forger. +// +// Variants and PublicURL ARE surfaced — they're what enables a +// responsive image renderer on the public site. +package media diff --git a/apps/api/internal/rest/media/handler.go b/apps/api/internal/rest/media/handler.go new file mode 100644 index 00000000..c4a1bc7d --- /dev/null +++ b/apps/api/internal/rest/media/handler.go @@ -0,0 +1,133 @@ +package media + +import ( + "errors" + "log/slog" + "net/http" + "strconv" + "strings" + + "github.com/Singleton-Solution/GoNext/apps/api/internal/rest/router" +) + +// Deps is the dependency bag for Mount. +type Deps struct { + Store Store + Logger *slog.Logger +} + +func (d Deps) validate() error { + if d.Store == nil { + return errors.New("rest/media: Store is required") + } + return nil +} + +type handlers struct { + store Store + logger *slog.Logger +} + +// Mount wires the public media routes onto mux under base (typically +// "/api/v1/media"). +func Mount(mux *http.ServeMux, base string, deps Deps) error { + if err := deps.validate(); err != nil { + return err + } + if deps.Logger == nil { + deps.Logger = slog.Default() + } + h := &handlers{store: deps.Store, logger: deps.Logger} + base = strings.TrimRight(base, "/") + mux.Handle("GET "+base, http.HandlerFunc(h.list)) + mux.Handle("GET "+base+"/{id}", http.HandlerFunc(h.get)) + return nil +} + +func (h *handlers) list(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + limit := DefaultListLimit + if raw := q.Get("limit"); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n < 1 { + router.WriteError(w, http.StatusBadRequest, "invalid_limit", + "limit must be a positive integer") + return + } + if n > MaxListLimit { + n = MaxListLimit + } + limit = n + } + + var after string + if raw := q.Get("after"); raw != "" { + decoded, err := router.ParseCursor(raw) + if err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_cursor", + "after must be a valid cursor") + return + } + after = decoded + } + + mimeClass := strings.TrimSpace(q.Get("mime_class")) + switch mimeClass { + case "", "image", "video", "document": + // ok + default: + router.WriteError(w, http.StatusBadRequest, "invalid_mime_class", + "mime_class must be one of image|video|document") + return + } + + rows, err := h.store.List(r.Context(), ListFilter{ + MimeClass: mimeClass, + Limit: limit, + After: after, + }) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/media: list failed", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to list media") + return + } + + var next string + if len(rows) > limit { + rows = rows[:limit] + last := rows[len(rows)-1] + next = router.EncodeCursor(last.CreatedAt.Format("2006-01-02T15:04:05.999999999Z07:00") + ":" + last.ID) + } + + router.WriteJSON(w, http.StatusOK, router.Page[Asset]{ + Data: rows, + Pagination: router.PageInfo{ + NextCursor: next, + }, + }) +} + +func (h *handlers) get(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + router.WriteError(w, http.StatusBadRequest, "missing_id", "id is required") + return + } + a, err := h.store.GetByID(r.Context(), id) + if err != nil { + if errors.Is(err, ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "media not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/media: get failed", + slog.String("id", id), + slog.Any("err", err), + ) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to fetch media") + return + } + router.WriteJSON(w, http.StatusOK, a) +} diff --git a/apps/api/internal/rest/media/handler_test.go b/apps/api/internal/rest/media/handler_test.go new file mode 100644 index 00000000..3f000230 --- /dev/null +++ b/apps/api/internal/rest/media/handler_test.go @@ -0,0 +1,58 @@ +package media + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestMedia_List_FilterByClass(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + store.Insert(Asset{ID: "a1", MimeType: "image/png", CreatedAt: time.Now()}) + store.Insert(Asset{ID: "a2", MimeType: "video/mp4", CreatedAt: time.Now().Add(-time.Hour)}) + store.Insert(Asset{ID: "a3", MimeType: "application/pdf", CreatedAt: time.Now().Add(-2 * time.Hour)}) + + mux := http.NewServeMux() + if err := Mount(mux, "/api/v1/media", Deps{Store: store}); err != nil { + t.Fatalf("mount: %v", err) + } + + req := httptest.NewRequest("GET", "/api/v1/media?mime_class=image", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + var body struct { + Data []Asset `json:"data"` + } + _ = json.Unmarshal(rr.Body.Bytes(), &body) + if len(body.Data) != 1 || body.Data[0].ID != "a1" { + t.Errorf("filter failed: %+v", body.Data) + } +} + +func TestMedia_Get_NotFound(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/media", Deps{Store: NewMemoryStore()}) + req := httptest.NewRequest("GET", "/api/v1/media/ghost", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 404 { + t.Errorf("status = %d, want 404", rr.Code) + } +} + +func TestMedia_InvalidMimeClass(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/media", Deps{Store: NewMemoryStore()}) + req := httptest.NewRequest("GET", "/api/v1/media?mime_class=bogus", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 400 { + t.Errorf("status = %d, want 400", rr.Code) + } +} diff --git a/apps/api/internal/rest/media/memory.go b/apps/api/internal/rest/media/memory.go new file mode 100644 index 00000000..fb1908e3 --- /dev/null +++ b/apps/api/internal/rest/media/memory.go @@ -0,0 +1,98 @@ +package media + +import ( + "context" + "sort" + "strings" + "sync" +) + +// MemoryStore backs tests + the no-DB development fall-through. +type MemoryStore struct { + mu sync.RWMutex + assets []Asset +} + +func NewMemoryStore() *MemoryStore { return &MemoryStore{} } + +func (m *MemoryStore) Insert(a Asset) { + m.mu.Lock() + defer m.mu.Unlock() + m.assets = append(m.assets, a) +} + +// matchesMimeClass mirrors admin/media's class predicate so the public +// surface filters the same way the admin grid does. +func matchesMimeClass(mime, class string) bool { + switch class { + case "": + return true + case "image": + return strings.HasPrefix(mime, "image/") + case "video": + return strings.HasPrefix(mime, "video/") + case "document": + switch mime { + case "application/pdf", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + return true + } + return strings.HasPrefix(mime, "text/") + } + return false +} + +func (m *MemoryStore) List(_ context.Context, f ListFilter) ([]Asset, error) { + m.mu.RLock() + defer m.mu.RUnlock() + rows := make([]Asset, 0, len(m.assets)) + for _, a := range m.assets { + if !matchesMimeClass(a.MimeType, f.MimeClass) { + continue + } + rows = append(rows, a) + } + sort.Slice(rows, func(i, j int) bool { + if !rows[i].CreatedAt.Equal(rows[j].CreatedAt) { + return rows[i].CreatedAt.After(rows[j].CreatedAt) + } + return rows[i].ID > rows[j].ID + }) + if f.After != "" { + idx := -1 + for i, a := range rows { + marker := a.CreatedAt.Format("2006-01-02T15:04:05.999999999Z07:00") + ":" + a.ID + if marker == f.After { + idx = i + break + } + } + if idx >= 0 { + rows = rows[idx+1:] + } + } + limit := f.Limit + if limit <= 0 { + limit = DefaultListLimit + } + if limit > MaxListLimit { + limit = MaxListLimit + } + if len(rows) > limit+1 { + rows = rows[:limit+1] + } + return rows, nil +} + +func (m *MemoryStore) GetByID(_ context.Context, id string) (Asset, error) { + m.mu.RLock() + defer m.mu.RUnlock() + for _, a := range m.assets { + if a.ID == id { + return a, nil + } + } + return Asset{}, ErrNotFound +} diff --git a/apps/api/internal/rest/media/model.go b/apps/api/internal/rest/media/model.go new file mode 100644 index 00000000..5d4a4f55 --- /dev/null +++ b/apps/api/internal/rest/media/model.go @@ -0,0 +1,61 @@ +package media + +import ( + "context" + "errors" + "time" +) + +const ( + DefaultListLimit = 30 + MaxListLimit = 100 +) + +// Asset is the public wire shape. Mirrors admin/media.Asset minus the +// uploader id and the SHA. See the package doc for the rationale. +type Asset struct { + ID string `json:"id"` + Filename string `json:"filename"` + MimeType string `json:"mime_type"` + ByteSize int64 `json:"byte_size"` + Width *int `json:"width,omitempty"` + Height *int `json:"height,omitempty"` + AltText string `json:"alt_text"` + Caption string `json:"caption"` + PublicURL string `json:"public_url,omitempty"` + Variants []Variant `json:"variants,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// Variant mirrors admin/media.Variant for the renderer's responsive +// image source set. PublicURL is computed by the store the same way +// the admin surface computes its own. +type Variant struct { + Name string `json:"name"` + Format string `json:"format"` + Width int `json:"width"` + Height int `json:"height"` + MimeType string `json:"mime_type"` + PublicURL string `json:"public_url,omitempty"` +} + +// ListFilter narrows the GET /api/v1/media response. +type ListFilter struct { + // MimeClass is one of "", "image", "video", "document". Empty + // matches all. + MimeClass string + Limit int + After string +} + +// Store is the public read-only persistence boundary. It's distinct +// from admin/media.Store on purpose: there's no Insert, no Update, +// no SoftDelete — the public surface can only read. +type Store interface { + List(ctx context.Context, f ListFilter) ([]Asset, error) + GetByID(ctx context.Context, id string) (Asset, error) +} + +// ErrNotFound is the sentinel returned by store reads when the row is +// missing (or soft-deleted). The handler maps to HTTP 404. +var ErrNotFound = errors.New("rest/media: not found") diff --git a/apps/api/internal/rest/terms/doc.go b/apps/api/internal/rest/terms/doc.go new file mode 100644 index 00000000..c8c9bfd8 --- /dev/null +++ b/apps/api/internal/rest/terms/doc.go @@ -0,0 +1,24 @@ +// Package terms serves the public read-only `/api/v1/terms` REST +// surface. Terms are GoNext's taxonomy entries — categories, tags, +// and any plugin-registered hierarchical or flat taxonomy. +// +// The endpoint shape is two levels: +// +// GET /api/v1/taxonomies — list registered taxonomies +// GET /api/v1/taxonomies/{slug} — fetch one taxonomy's metadata +// GET /api/v1/terms — list terms across taxonomies +// GET /api/v1/terms/{id} — fetch a single term +// +// In practice Mount handles "terms" only; the taxonomies surface is a +// lightweight sibling Mount because the two share a Store interface. +// +// Filter parameters on the list path: +// +// ?taxonomy=category — restrict to one taxonomy slug +// ?parent_id= — direct children of a parent (or empty for top level) +// ?search= — name prefix; case-insensitive +// +// The response surfaces the materialized ltree path so a frontend +// can build a breadcrumb without a follow-up query. Depth is +// computed (nlevel(path)) so clients don't have to parse the ltree. +package terms diff --git a/apps/api/internal/rest/terms/handler.go b/apps/api/internal/rest/terms/handler.go new file mode 100644 index 00000000..3ebf86b0 --- /dev/null +++ b/apps/api/internal/rest/terms/handler.go @@ -0,0 +1,170 @@ +package terms + +import ( + "errors" + "log/slog" + "net/http" + "strconv" + "strings" + + "github.com/Singleton-Solution/GoNext/apps/api/internal/rest/router" +) + +// Deps is the dependency bag for Mount. +type Deps struct { + Store Store + Logger *slog.Logger +} + +func (d Deps) validate() error { + if d.Store == nil { + return errors.New("rest/terms: Store is required") + } + return nil +} + +type handlers struct { + store Store + logger *slog.Logger +} + +// Mount wires the public terms + taxonomies routes under termsBase +// and taxonomiesBase (typically "/api/v1/terms" and +// "/api/v1/taxonomies"). They are separate paths because they're +// separate kinds of resource — terms belong to a taxonomy, not vice +// versa — but they share a Store and a Mount call for symmetry. +func Mount(mux *http.ServeMux, termsBase, taxonomiesBase string, deps Deps) error { + if err := deps.validate(); err != nil { + return err + } + if deps.Logger == nil { + deps.Logger = slog.Default() + } + h := &handlers{store: deps.Store, logger: deps.Logger} + termsBase = strings.TrimRight(termsBase, "/") + taxonomiesBase = strings.TrimRight(taxonomiesBase, "/") + mux.Handle("GET "+termsBase, http.HandlerFunc(h.listTerms)) + mux.Handle("GET "+termsBase+"/{id}", http.HandlerFunc(h.getTerm)) + mux.Handle("GET "+taxonomiesBase, http.HandlerFunc(h.listTaxonomies)) + mux.Handle("GET "+taxonomiesBase+"/{slug}", http.HandlerFunc(h.getTaxonomy)) + return nil +} + +func (h *handlers) listTaxonomies(w http.ResponseWriter, r *http.Request) { + rows, err := h.store.ListTaxonomies(r.Context()) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/terms: list taxonomies failed", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to list taxonomies") + return + } + router.WriteJSON(w, http.StatusOK, router.Page[Taxonomy]{Data: rows}) +} + +func (h *handlers) getTaxonomy(w http.ResponseWriter, r *http.Request) { + slug := r.PathValue("slug") + if slug == "" { + router.WriteError(w, http.StatusBadRequest, "missing_slug", "slug is required") + return + } + t, err := h.store.GetTaxonomy(r.Context(), slug) + if err != nil { + if errors.Is(err, ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "taxonomy not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/terms: get taxonomy failed", + slog.String("slug", slug), + slog.Any("err", err), + ) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to fetch taxonomy") + return + } + router.WriteJSON(w, http.StatusOK, t) +} + +func (h *handlers) listTerms(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + limit := DefaultListLimit + if raw := q.Get("limit"); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n < 1 { + router.WriteError(w, http.StatusBadRequest, "invalid_limit", + "limit must be a positive integer") + return + } + if n > MaxListLimit { + n = MaxListLimit + } + limit = n + } + + var after string + if raw := q.Get("after"); raw != "" { + decoded, err := router.ParseCursor(raw) + if err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_cursor", + "after must be a valid cursor") + return + } + after = decoded + } + + f := TermListFilter{ + Taxonomy: strings.TrimSpace(q.Get("taxonomy")), + Search: strings.TrimSpace(q.Get("search")), + Limit: limit, + After: after, + } + if _, ok := q["parent_id"]; ok { + f.ParentPresent = true + f.ParentID = strings.TrimSpace(q.Get("parent_id")) + } + + rows, err := h.store.ListTerms(r.Context(), f) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/terms: list terms failed", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to list terms") + return + } + + var next string + if len(rows) > limit { + rows = rows[:limit] + last := rows[len(rows)-1] + next = router.EncodeCursor(last.Path + ":" + last.ID) + } + + router.WriteJSON(w, http.StatusOK, router.Page[Term]{ + Data: rows, + Pagination: router.PageInfo{ + NextCursor: next, + }, + }) +} + +func (h *handlers) getTerm(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + router.WriteError(w, http.StatusBadRequest, "missing_id", "id is required") + return + } + t, err := h.store.GetTerm(r.Context(), id) + if err != nil { + if errors.Is(err, ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "term not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/terms: get term failed", + slog.String("id", id), + slog.Any("err", err), + ) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to fetch term") + return + } + router.WriteJSON(w, http.StatusOK, t) +} diff --git a/apps/api/internal/rest/terms/handler_test.go b/apps/api/internal/rest/terms/handler_test.go new file mode 100644 index 00000000..f1e3203e --- /dev/null +++ b/apps/api/internal/rest/terms/handler_test.go @@ -0,0 +1,72 @@ +package terms + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestTaxonomies_List(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + store.AddTaxonomy(Taxonomy{Slug: "category", Name: "Category", NamePlural: "Categories", Hierarchical: true, CreatedAt: time.Now()}) + store.AddTaxonomy(Taxonomy{Slug: "tag", Name: "Tag", NamePlural: "Tags", Hierarchical: false, CreatedAt: time.Now()}) + + mux := http.NewServeMux() + if err := Mount(mux, "/api/v1/terms", "/api/v1/taxonomies", Deps{Store: store}); err != nil { + t.Fatalf("mount: %v", err) + } + + req := httptest.NewRequest("GET", "/api/v1/taxonomies", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Fatalf("status = %d, want 200", rr.Code) + } + var body struct { + Data []Taxonomy `json:"data"` + } + _ = json.Unmarshal(rr.Body.Bytes(), &body) + if len(body.Data) != 2 { + t.Errorf("len = %d, want 2", len(body.Data)) + } +} + +func TestTerms_FilterByTaxonomy(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + store.AddTerm(Term{ID: "t1", Slug: "news", Name: "News", Taxonomy: "category", Path: "news", Depth: 1}) + store.AddTerm(Term{ID: "t2", Slug: "go", Name: "Go", Taxonomy: "tag", Path: "go", Depth: 1}) + + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/terms", "/api/v1/taxonomies", Deps{Store: store}) + + req := httptest.NewRequest("GET", "/api/v1/terms?taxonomy=tag", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Fatalf("status = %d, want 200", rr.Code) + } + var body struct { + Data []Term `json:"data"` + } + _ = json.Unmarshal(rr.Body.Bytes(), &body) + if len(body.Data) != 1 || body.Data[0].ID != "t2" { + t.Errorf("filter failed: %+v", body.Data) + } +} + +func TestTerms_GetNotFound(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/terms", "/api/v1/taxonomies", Deps{Store: NewMemoryStore()}) + req := httptest.NewRequest("GET", "/api/v1/terms/ghost", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 404 { + t.Errorf("status = %d, want 404", rr.Code) + } +} diff --git a/apps/api/internal/rest/terms/memory.go b/apps/api/internal/rest/terms/memory.go new file mode 100644 index 00000000..0d47e1bc --- /dev/null +++ b/apps/api/internal/rest/terms/memory.go @@ -0,0 +1,114 @@ +package terms + +import ( + "context" + "sort" + "strings" + "sync" +) + +// MemoryStore backs tests and the no-DB development fall-through. +type MemoryStore struct { + mu sync.RWMutex + taxonomies []Taxonomy + terms []Term +} + +func NewMemoryStore() *MemoryStore { return &MemoryStore{} } + +func (m *MemoryStore) AddTaxonomy(t Taxonomy) { + m.mu.Lock() + defer m.mu.Unlock() + m.taxonomies = append(m.taxonomies, t) +} + +func (m *MemoryStore) AddTerm(t Term) { + m.mu.Lock() + defer m.mu.Unlock() + m.terms = append(m.terms, t) +} + +func (m *MemoryStore) ListTaxonomies(_ context.Context) ([]Taxonomy, error) { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]Taxonomy, len(m.taxonomies)) + copy(out, m.taxonomies) + sort.Slice(out, func(i, j int) bool { return out[i].Slug < out[j].Slug }) + return out, nil +} + +func (m *MemoryStore) GetTaxonomy(_ context.Context, slug string) (Taxonomy, error) { + m.mu.RLock() + defer m.mu.RUnlock() + for _, t := range m.taxonomies { + if strings.EqualFold(t.Slug, slug) { + return t, nil + } + } + return Taxonomy{}, ErrNotFound +} + +func (m *MemoryStore) ListTerms(_ context.Context, f TermListFilter) ([]Term, error) { + m.mu.RLock() + defer m.mu.RUnlock() + rows := make([]Term, 0, len(m.terms)) + search := strings.ToLower(f.Search) + for _, t := range m.terms { + if f.Taxonomy != "" && t.Taxonomy != f.Taxonomy { + continue + } + if f.ParentPresent { + parent := "" + if t.ParentID != nil { + parent = *t.ParentID + } + if parent != f.ParentID { + continue + } + } + if search != "" && !strings.Contains(strings.ToLower(t.Name), search) { + continue + } + rows = append(rows, t) + } + sort.Slice(rows, func(i, j int) bool { + if rows[i].Path != rows[j].Path { + return rows[i].Path < rows[j].Path + } + return rows[i].ID < rows[j].ID + }) + if f.After != "" { + idx := -1 + for i, t := range rows { + if t.Path+":"+t.ID == f.After { + idx = i + break + } + } + if idx >= 0 { + rows = rows[idx+1:] + } + } + limit := f.Limit + if limit <= 0 { + limit = DefaultListLimit + } + if limit > MaxListLimit { + limit = MaxListLimit + } + if len(rows) > limit+1 { + rows = rows[:limit+1] + } + return rows, nil +} + +func (m *MemoryStore) GetTerm(_ context.Context, id string) (Term, error) { + m.mu.RLock() + defer m.mu.RUnlock() + for _, t := range m.terms { + if t.ID == id { + return t, nil + } + } + return Term{}, ErrNotFound +} diff --git a/apps/api/internal/rest/terms/model.go b/apps/api/internal/rest/terms/model.go new file mode 100644 index 00000000..a9186aa2 --- /dev/null +++ b/apps/api/internal/rest/terms/model.go @@ -0,0 +1,72 @@ +package terms + +import ( + "context" + "errors" + "time" +) + +const ( + DefaultListLimit = 50 + MaxListLimit = 200 +) + +// Taxonomy is the public wire shape for a taxonomy registry row. +type Taxonomy struct { + Slug string `json:"slug"` + Name string `json:"name"` + NamePlural string `json:"name_plural"` + Hierarchical bool `json:"hierarchical"` + CreatedAt time.Time `json:"created_at"` +} + +// Term is the public wire shape for a term row. The `count` field +// (denormalised post count from the term_relationships trigger) is +// included because every frontend that renders a tag cloud or +// category list needs it; computing it client-side would force a +// fan-out across the relationship table. +type Term struct { + ID string `json:"id"` + Slug string `json:"slug"` + Name string `json:"name"` + Taxonomy string `json:"taxonomy"` + ParentID *string `json:"parent_id,omitempty"` + Path string `json:"path"` + Depth int `json:"depth"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` +} + +// TermListFilter narrows the term list query. +type TermListFilter struct { + // Taxonomy is the taxonomy slug; empty means "all". + Taxonomy string + + // ParentID restricts to direct children of the given parent. + // The empty string + ParentPresent=true means "top-level + // only"; ParentPresent=false means "any depth". + ParentID string + ParentPresent bool + + // Search is a prefix on term.name (case-insensitive). Empty + // means no name filter. + Search string + + Limit int + After string +} + +// Store is the public read-only persistence boundary for terms + +// taxonomies. Combined into one interface because the wire-side +// pagination cursor encoding is the same for both, and a single +// pool/connection serves both lookups. +type Store interface { + ListTaxonomies(ctx context.Context) ([]Taxonomy, error) + GetTaxonomy(ctx context.Context, slug string) (Taxonomy, error) + ListTerms(ctx context.Context, f TermListFilter) ([]Term, error) + GetTerm(ctx context.Context, id string) (Term, error) +} + +// ErrNotFound is the sentinel returned by store reads when the row +// is missing. The handler maps to HTTP 404. +var ErrNotFound = errors.New("rest/terms: not found") diff --git a/apps/api/internal/rest/users/doc.go b/apps/api/internal/rest/users/doc.go new file mode 100644 index 00000000..8016c465 --- /dev/null +++ b/apps/api/internal/rest/users/doc.go @@ -0,0 +1,28 @@ +// Package users serves the public read-only `/api/v1/users` REST +// surface. Unlike the admin user management endpoints (which require +// list_users / edit_users capabilities and surface PII), this package +// only exposes the public profile fields: +// +// id, handle, display_name, created_at +// +// Email, capabilities, role memberships, and password material never +// appear in this surface. This is deliberate — the public API mirrors +// what a public-facing site renders on an author page (the bylines, +// the avatar, the join date), and PII leakage at this layer would be +// a privacy regression vs. the admin surface. +// +// The contract intentionally matches the posts/comments REST shape: +// +// GET /api/v1/users — list users (cursor pagination) +// GET /api/v1/users/{id} — fetch a single user by id OR by handle +// +// "By handle" is a convenience the posts surface doesn't have: handles +// are short and stable enough that linkrot is unlikely, while UUIDs +// are clumsy in URLs. The handler dispatches on whether the path +// param parses as a UUID; non-UUID strings are treated as handles. +// +// Authentication: the public surface is anonymous-friendly. The auth +// middleware decorates the request opportunistically when a session +// cookie is present (the email field would be populated for the +// viewer themselves), but no principal is required for the read path. +package users diff --git a/apps/api/internal/rest/users/handler.go b/apps/api/internal/rest/users/handler.go new file mode 100644 index 00000000..10bcccf1 --- /dev/null +++ b/apps/api/internal/rest/users/handler.go @@ -0,0 +1,176 @@ +package users + +import ( + "errors" + "log/slog" + "net/http" + "strconv" + "strings" + + "github.com/Singleton-Solution/GoNext/apps/api/internal/rest/router" +) + +// Deps is the dependency bag for Mount. Store is required; the rest +// have defaults. +type Deps struct { + Store Store + Logger *slog.Logger +} + +func (d Deps) validate() error { + if d.Store == nil { + return errors.New("rest/users: Store is required") + } + return nil +} + +// handlers is the resolved-Deps form. +type handlers struct { + store Store + logger *slog.Logger +} + +// Mount wires the public users routes onto mux under base (typically +// "/api/v1/users"). Two routes: +// +// GET {base} — list public users +// GET {base}/{id} — fetch by UUID OR by handle +// +// The "id-or-handle" dispatch happens inside the handler — see get(). +func Mount(mux *http.ServeMux, base string, deps Deps) error { + if err := deps.validate(); err != nil { + return err + } + if deps.Logger == nil { + deps.Logger = slog.Default() + } + h := &handlers{ + store: deps.Store, + logger: deps.Logger, + } + base = strings.TrimRight(base, "/") + mux.Handle("GET "+base, http.HandlerFunc(h.list)) + mux.Handle("GET "+base+"/{id}", http.HandlerFunc(h.get)) + return nil +} + +func (h *handlers) list(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + limit := DefaultListLimit + if raw := q.Get("limit"); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n < 1 { + router.WriteError(w, http.StatusBadRequest, "invalid_limit", + "limit must be a positive integer") + return + } + if n > MaxListLimit { + n = MaxListLimit + } + limit = n + } + + var after string + if raw := q.Get("after"); raw != "" { + decoded, err := router.ParseCursor(raw) + if err != nil { + router.WriteError(w, http.StatusBadRequest, "invalid_cursor", + "after must be a valid cursor") + return + } + after = decoded + } + + rows, err := h.store.List(r.Context(), ListFilter{ + HandlePrefix: strings.TrimSpace(q.Get("handle_prefix")), + Limit: limit, + After: after, + }) + if err != nil { + h.logger.ErrorContext(r.Context(), "rest/users: list failed", slog.Any("err", err)) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to list users") + return + } + + // Limit+1 → next-cursor trick: the store returned at most limit+1 + // rows; if it returned exactly limit+1, there's a next page and + // the cursor is the last-but-one row. + var next string + if len(rows) > limit { + rows = rows[:limit] + last := rows[len(rows)-1] + next = router.EncodeCursor(last.CreatedAt.Format("2006-01-02T15:04:05.999999999Z07:00") + ":" + last.ID) + } + + router.WriteJSON(w, http.StatusOK, router.Page[User]{ + Data: rows, + Pagination: router.PageInfo{ + NextCursor: next, + }, + }) +} + +// get handles GET {base}/{id}. The {id} segment is treated as a UUID +// when it parses as one (36 chars + dashes) and as a handle otherwise. +// The dispatch happens here rather than as two separate routes because +// the path-pattern matcher in net/http doesn't let us discriminate on +// segment shape. +func (h *handlers) get(w http.ResponseWriter, r *http.Request) { + idOrHandle := r.PathValue("id") + if idOrHandle == "" { + router.WriteError(w, http.StatusBadRequest, "missing_id", "id or handle is required") + return + } + + var ( + u User + err error + ) + if looksLikeUUID(idOrHandle) { + u, err = h.store.GetByID(r.Context(), idOrHandle) + } else { + u, err = h.store.GetByHandle(r.Context(), idOrHandle) + } + if err != nil { + if errors.Is(err, ErrNotFound) { + router.WriteError(w, http.StatusNotFound, "not_found", "user not found") + return + } + h.logger.ErrorContext(r.Context(), "rest/users: get failed", + slog.String("id_or_handle", idOrHandle), + slog.Any("err", err), + ) + router.WriteError(w, http.StatusInternalServerError, "internal_error", + "failed to fetch user") + return + } + + router.WriteJSON(w, http.StatusOK, u) +} + +// looksLikeUUID is a cheap shape check — we accept a 36-character +// string with dashes at the canonical positions (8-4-4-4-12). We do +// NOT call uuid.Parse here because (a) it would force the package to +// take a uuid dependency for a shape probe, and (b) a non-UUID-shaped +// handle is fine to dispatch directly to GetByHandle. +func looksLikeUUID(s string) bool { + if len(s) != 36 { + return false + } + for i, c := range s { + switch i { + case 8, 13, 18, 23: + if c != '-' { + return false + } + default: + // hex digit + if !(c >= '0' && c <= '9' || c >= 'a' && c <= 'f' || c >= 'A' && c <= 'F') { + return false + } + } + } + return true +} diff --git a/apps/api/internal/rest/users/handler_test.go b/apps/api/internal/rest/users/handler_test.go new file mode 100644 index 00000000..dc8fcdef --- /dev/null +++ b/apps/api/internal/rest/users/handler_test.go @@ -0,0 +1,123 @@ +package users + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestList_Pagination(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + for i := 0; i < 5; i++ { + store.Insert(User{ + ID: idForTest(i), + Handle: "user" + string(rune('a'+i)), + CreatedAt: time.Date(2026, 1, 1+i, 0, 0, 0, 0, time.UTC), + }) + } + + mux := http.NewServeMux() + if err := Mount(mux, "/api/v1/users", Deps{Store: store}); err != nil { + t.Fatalf("mount: %v", err) + } + + req := httptest.NewRequest("GET", "/api/v1/users?limit=2", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Fatalf("status = %d, want 200; body=%s", rr.Code, rr.Body.String()) + } + var body struct { + Data []User `json:"data"` + Pagination struct { + NextCursor string `json:"next_cursor"` + } `json:"pagination"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Data) != 2 { + t.Errorf("len(data) = %d, want 2", len(body.Data)) + } + if body.Pagination.NextCursor == "" { + t.Error("expected next_cursor for partial page") + } +} + +func TestGet_ByID(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + store.Insert(User{ID: "01234567-89ab-cdef-0123-456789abcdef", Handle: "alice", CreatedAt: time.Now()}) + + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/users", Deps{Store: store}) + + req := httptest.NewRequest("GET", "/api/v1/users/01234567-89ab-cdef-0123-456789abcdef", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Fatalf("status = %d, want 200", rr.Code) + } + var u User + _ = json.Unmarshal(rr.Body.Bytes(), &u) + if u.Handle != "alice" { + t.Errorf("handle = %q, want alice", u.Handle) + } +} + +func TestGet_ByHandle(t *testing.T) { + t.Parallel() + store := NewMemoryStore() + store.Insert(User{ID: "01234567-89ab-cdef-0123-456789abcdef", Handle: "alice", CreatedAt: time.Now()}) + + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/users", Deps{Store: store}) + + req := httptest.NewRequest("GET", "/api/v1/users/alice", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Fatalf("status = %d, want 200", rr.Code) + } +} + +func TestGet_NotFound(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + _ = Mount(mux, "/api/v1/users", Deps{Store: NewMemoryStore()}) + req := httptest.NewRequest("GET", "/api/v1/users/ghost", nil) + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != 404 { + t.Errorf("status = %d, want 404", rr.Code) + } +} + +func TestLooksLikeUUID(t *testing.T) { + t.Parallel() + cases := map[string]bool{ + "01234567-89ab-cdef-0123-456789abcdef": true, + "01234567-89AB-CDEF-0123-456789ABCDEF": true, + "alice": false, + "01234567-89ab-cdef-0123-45678": false, + "01234567x89ab-cdef-0123-456789abcdef": false, + } + for in, want := range cases { + if got := looksLikeUUID(in); got != want { + t.Errorf("looksLikeUUID(%q) = %v, want %v", in, got, want) + } + } +} + +func idForTest(i int) string { + hex := "0123456789abcdef" + c := hex[i%16] + // Build a syntactically-valid UUID with the byte index baked in. + return "0000000" + string(c) + "-0000-4000-8000-000000000000" +} diff --git a/apps/api/internal/rest/users/memory.go b/apps/api/internal/rest/users/memory.go new file mode 100644 index 00000000..f715c9f2 --- /dev/null +++ b/apps/api/internal/rest/users/memory.go @@ -0,0 +1,101 @@ +package users + +import ( + "context" + "sort" + "strings" + "sync" +) + +// MemoryStore is the in-memory backing for tests and the no-DB +// development fall-through. It implements Store with the same +// semantics as the Postgres variant: created_at DESC, id DESC for +// ties; cursor format is "RFC3339Nano:ID". +type MemoryStore struct { + mu sync.RWMutex + users []User // ordered by insertion; queries sort copies +} + +// NewMemoryStore returns an empty in-memory store. Seed with Insert. +func NewMemoryStore() *MemoryStore { + return &MemoryStore{users: nil} +} + +// Insert appends a row. Tests use this to seed the store; production +// uses a Postgres-backed Store instead. +func (m *MemoryStore) Insert(u User) { + m.mu.Lock() + defer m.mu.Unlock() + m.users = append(m.users, u) +} + +// List returns a page matching f. The store fetches limit+1 rows so +// the handler can surface a next cursor. +func (m *MemoryStore) List(_ context.Context, f ListFilter) ([]User, error) { + m.mu.RLock() + defer m.mu.RUnlock() + rows := make([]User, 0, len(m.users)) + prefix := strings.ToLower(f.HandlePrefix) + for _, u := range m.users { + if prefix != "" && !strings.HasPrefix(strings.ToLower(u.Handle), prefix) { + continue + } + rows = append(rows, u) + } + // created_at DESC, id DESC. + sort.Slice(rows, func(i, j int) bool { + if !rows[i].CreatedAt.Equal(rows[j].CreatedAt) { + return rows[i].CreatedAt.After(rows[j].CreatedAt) + } + return rows[i].ID > rows[j].ID + }) + // Cursor: skip until we pass the after-marker. + if f.After != "" { + idx := -1 + for i, u := range rows { + marker := u.CreatedAt.Format("2006-01-02T15:04:05.999999999Z07:00") + ":" + u.ID + if marker == f.After { + idx = i + break + } + } + if idx >= 0 { + rows = rows[idx+1:] + } + } + limit := f.Limit + if limit <= 0 { + limit = DefaultListLimit + } + if limit > MaxListLimit { + limit = MaxListLimit + } + // Return limit+1 to signal a next page. + if len(rows) > limit+1 { + rows = rows[:limit+1] + } + return rows, nil +} + +func (m *MemoryStore) GetByID(_ context.Context, id string) (User, error) { + m.mu.RLock() + defer m.mu.RUnlock() + for _, u := range m.users { + if u.ID == id { + return u, nil + } + } + return User{}, ErrNotFound +} + +func (m *MemoryStore) GetByHandle(_ context.Context, handle string) (User, error) { + m.mu.RLock() + defer m.mu.RUnlock() + want := strings.ToLower(handle) + for _, u := range m.users { + if strings.ToLower(u.Handle) == want { + return u, nil + } + } + return User{}, ErrNotFound +} diff --git a/apps/api/internal/rest/users/model.go b/apps/api/internal/rest/users/model.go new file mode 100644 index 00000000..75241bda --- /dev/null +++ b/apps/api/internal/rest/users/model.go @@ -0,0 +1,94 @@ +package users + +import ( + "context" + "errors" + "time" +) + +// DefaultListLimit is the page size when the client supplies no +// `limit` query param. Matches the rest of the public REST surface. +const DefaultListLimit = 30 + +// MaxListLimit caps the page size. Higher than this stresses the +// database's index on (created_at, id) and forces a wider client +// render budget; clients that need more pages should paginate. +const MaxListLimit = 100 + +// User is the public view of a user row. Sensitive fields (email, +// password material, capabilities, IP/UA telemetry) are NEVER on +// this struct — this is the wire shape, and an absent field is a +// privacy guarantee, not an oversight. +// +// The shape mirrors the GraphQL Public user type so the two surfaces +// stay aligned. When the public API surfaces an "author" relation +// (post.author), the embedded object is this exact struct. +type User struct { + // ID is the user's UUID v7. + ID string `json:"id"` + + // Handle is the public login handle (also the URL slug for + // /authors/{handle} on the public site). citext on the DB + // side; we lowercase here for consistency. + Handle string `json:"handle"` + + // DisplayName is the human-readable name shown on bylines. + // Nullable: a freshly-registered user with no display name + // renders as their handle on the public site, but we surface + // null here so the client can decide on the fallback. + DisplayName *string `json:"display_name,omitempty"` + + // CreatedAt is the account join time. Not the publication + // time of any specific post — just identity provenance for + // the "member since" line on author pages. + CreatedAt time.Time `json:"created_at"` +} + +// ListFilter narrows the list query. All fields are optional; the +// handler builds it from query params. +type ListFilter struct { + // HandlePrefix, when non-empty, restricts the result to users + // whose handle starts with the given prefix. Case-insensitive + // (citext on the DB side). Useful for "@-mention" autocomplete + // on the public site. + HandlePrefix string + + // Limit caps the page size; clamped to MaxListLimit by the + // handler. + Limit int + + // After is the decoded cursor — the "created_at:id" tuple of + // the last row of the previous page. Empty means "start at + // the beginning". + After string +} + +// Store is the persistence boundary for the public users surface. +// Two backends — in-memory for tests, Postgres for production — +// implement this interface. +// +// Note this is distinct from the admin/users Store: the admin store +// surfaces sensitive fields, while this one's row type omits them. +// Keeping the interfaces separate means the public surface cannot +// accidentally pick up a "GetWithCapabilities" method through a +// shared interface and surface them. +type Store interface { + // List returns a page of public users ordered by created_at DESC, + // id DESC as the tie-breaker. The store fetches limit+1 rows so + // the handler knows whether to surface a next cursor. + List(ctx context.Context, f ListFilter) ([]User, error) + + // GetByID looks up a single user. Returns ErrNotFound when the + // id doesn't match — the handler maps to a 404 without leaking + // the difference between "soft-deleted" and "never existed". + GetByID(ctx context.Context, id string) (User, error) + + // GetByHandle is the convenience lookup for /api/v1/users/{handle} + // when the path segment is not a UUID. Case-insensitive on the + // store side. + GetByHandle(ctx context.Context, handle string) (User, error) +} + +// ErrNotFound is returned by store reads when no row matches. The +// handler maps this to HTTP 404. +var ErrNotFound = errors.New("rest/users: not found") diff --git a/cli/gonext/cmd/audit/audit.go b/cli/gonext/cmd/audit/audit.go new file mode 100644 index 00000000..c5631612 --- /dev/null +++ b/cli/gonext/cmd/audit/audit.go @@ -0,0 +1,50 @@ +// Package audit. See doc.go for the package overview. +package audit + +import ( + "fmt" + "io" + "os" +) + +// Exit codes for the subtree. +const ( + ExitOK = 0 + ExitFail = 1 + ExitUsage = 2 +) + +// Run is the entry point for `gonext audit ...`. args is the slice +// after the literal `audit` token. Returns the desired exit code. +func Run(args []string, stdout, stderr io.Writer) int { + if len(args) == 0 { + fmt.Fprintln(stderr, usage) + return ExitUsage + } + switch args[0] { + case "help", "--help", "-h": + fmt.Fprintln(stdout, usage) + return ExitOK + case "tail": + return runTail(args[1:], stdout, stderr) + default: + fmt.Fprintf(stderr, "gonext audit: unknown subcommand %q\n\n%s\n", args[0], usage) + return ExitUsage + } +} + +// RunOS wires Run to the real OS streams. +func RunOS(args []string) int { return Run(args, os.Stdout, os.Stderr) } + +const usage = `gonext audit — inspect the audit log + +Usage: + gonext audit [args] + +Subcommands: + tail [flags] Print the most recent audit events. Tail by default. + +Run 'gonext audit tail --help' for the tail-specific flags. + +Environment: + DATABASE_URL Required. Postgres DSN for the GoNext install.` diff --git a/cli/gonext/cmd/audit/doc.go b/cli/gonext/cmd/audit/doc.go new file mode 100644 index 00000000..8fab230f --- /dev/null +++ b/cli/gonext/cmd/audit/doc.go @@ -0,0 +1,24 @@ +// Package audit is the `gonext audit` CLI subtree. It surfaces the +// operator-facing audit log to the terminal, primarily as a `tail` +// command modelled on `tail -f` — print the last N events and (with +// --follow) keep streaming. +// +// Subcommands: +// +// gonext audit tail [--follow] [--limit N] [--type T] [--actor U] +// +// The command opens DATABASE_URL, instantiates the Postgres audit +// store, and runs the store's List query with the supplied filter. +// With --follow set, the command polls every 1s using the last +// observed event's time as the lower bound for the next query so +// events that landed during the previous tick aren't missed. +// +// Output format: one event per line, columns separated by a tab. +// The columns are: +// +// timestamp (RFC3339) | severity | event_type | actor | resource | metadata +// +// The format is intentionally line-per-event and tab-delimited so +// `gonext audit tail | grep` and `gonext audit tail | awk` both work +// without parser ceremony. +package audit diff --git a/cli/gonext/cmd/audit/tail.go b/cli/gonext/cmd/audit/tail.go new file mode 100644 index 00000000..83aaecae --- /dev/null +++ b/cli/gonext/cmd/audit/tail.go @@ -0,0 +1,239 @@ +package audit + +import ( + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "os/signal" + "sort" + "strings" + "syscall" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/Singleton-Solution/GoNext/packages/go/audit" +) + +const tailUsage = `gonext audit tail — print recent audit events + +Usage: + gonext audit tail [flags] + +Flags: + --limit N Maximum events to print on the initial dump (default: 50). + --follow, -f After the initial dump, poll every 1s for new events. + Stop with Ctrl-C; exit status is 0. + --type T Filter to events with EventType == T. + --actor U Filter to events emitted by user-id U. + --plugin S Filter to events emitted by plugin slug S. + --severity LEVEL Filter to events with severity LEVEL (info|warning|critical). + --since DUR Initial lookback window (e.g. "1h", "24h"). Default: 24h. + --json Emit one JSON object per line instead of the tab-delimited + default. Useful for piping into jq. + +Environment: + DATABASE_URL Required.` + +// tailDeps lets tests inject the audit store (so we don't need a live +// Postgres for unit coverage). nil dialString → use DATABASE_URL + +// audit.PostgresStore. +type tailDeps struct { + openStore func(ctx context.Context) (audit.Store, func(), error) + now func() time.Time + tickEvery time.Duration + + // signals is a channel that, when closed, terminates the --follow + // loop. Production wires it to a SIGINT/SIGTERM trap; tests close + // it after a fixed number of ticks. + signals <-chan os.Signal +} + +func defaultTailDeps() tailDeps { + return tailDeps{ + openStore: openPostgresAuditStore, + now: time.Now, + tickEvery: 1 * time.Second, + } +} + +// runTail parses the flag set and dispatches. +func runTail(args []string, stdout, stderr io.Writer) int { + return runTailWithDeps(args, stdout, stderr, defaultTailDeps()) +} + +func runTailWithDeps(args []string, stdout, stderr io.Writer, deps tailDeps) int { + fs := flag.NewFlagSet("audit tail", flag.ContinueOnError) + fs.SetOutput(stderr) + limit := fs.Int("limit", 50, "") + follow := fs.Bool("follow", false, "") + fs.BoolVar(follow, "f", false, "") + typ := fs.String("type", "", "") + actor := fs.String("actor", "", "") + plugin := fs.String("plugin", "", "") + severity := fs.String("severity", "", "") + since := fs.Duration("since", 24*time.Hour, "") + emitJSON := fs.Bool("json", false, "") + help := fs.Bool("help", false, "") + + if err := fs.Parse(args); err != nil { + // flag.ContinueOnError already printed; bare usage exit. + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, tailUsage) + return ExitOK + } + + if *limit < 1 || *limit > 1000 { + fmt.Fprintf(stderr, "gonext audit tail: --limit must be 1..1000 (got %d)\n", *limit) + return ExitUsage + } + if *severity != "" && !audit.Severity(*severity).Valid() { + fmt.Fprintf(stderr, "gonext audit tail: --severity must be one of info|warning|critical (got %q)\n", *severity) + return ExitUsage + } + + ctx := context.Background() + store, closeStore, err := deps.openStore(ctx) + if err != nil { + fmt.Fprintf(stderr, "gonext audit tail: %v\n", err) + return ExitFail + } + defer closeStore() + + now := deps.now + if now == nil { + now = time.Now + } + + filter := audit.Filter{ + EventType: *typ, + ActorUserID: *actor, + PluginSlug: *plugin, + Limit: *limit, + Start: now().Add(-*since), + } + if *severity != "" { + filter.Severity = audit.Severity(*severity) + } + + // Initial dump. + events, err := store.List(ctx, filter) + if err != nil { + fmt.Fprintf(stderr, "gonext audit tail: list: %v\n", err) + return ExitFail + } + // Reverse so the printout is oldest-first; the store returns + // most-recent-first. + sort.Slice(events, func(i, j int) bool { return events[i].Time.Before(events[j].Time) }) + for _, e := range events { + printEvent(stdout, e, *emitJSON) + } + + if !*follow { + return ExitOK + } + + // --follow: poll on a 1s tick. The cursor moves to the most- + // recent event we've printed so a slow Postgres + a burst of new + // events doesn't drop rows between ticks. + cursor := now() + if len(events) > 0 { + cursor = events[len(events)-1].Time + } + + signals := deps.signals + if signals == nil { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigChan) + signals = sigChan + } + + tickEvery := deps.tickEvery + if tickEvery == 0 { + tickEvery = 1 * time.Second + } + t := time.NewTicker(tickEvery) + defer t.Stop() + + for { + select { + case <-signals: + return ExitOK + case <-t.C: + batch, err := store.List(ctx, audit.Filter{ + EventType: *typ, + ActorUserID: *actor, + PluginSlug: *plugin, + Severity: filter.Severity, + Start: cursor.Add(1 * time.Nanosecond), + Limit: 1000, + }) + if err != nil { + fmt.Fprintf(stderr, "gonext audit tail: poll: %v\n", err) + continue + } + sort.Slice(batch, func(i, j int) bool { return batch[i].Time.Before(batch[j].Time) }) + for _, e := range batch { + printEvent(stdout, e, *emitJSON) + if e.Time.After(cursor) { + cursor = e.Time + } + } + } + } +} + +// printEvent emits one event. JSON mode emits a single-line object; +// the default mode emits a tab-delimited row. +func printEvent(w io.Writer, e audit.Event, asJSON bool) { + if asJSON { + _ = json.NewEncoder(w).Encode(e) + return + } + actor := e.ActorUserID + if actor == "" && e.ActorPluginSlug != "" { + actor = "plugin:" + e.ActorPluginSlug + } + if actor == "" { + actor = "-" + } + resource := "-" + if e.ResourceType != "" || e.ResourceID != "" { + resource = strings.TrimSpace(e.ResourceType + ":" + e.ResourceID) + } + metadata := "" + if len(e.Metadata) > 0 { + raw, _ := json.Marshal(e.Metadata) + metadata = string(raw) + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n", + e.Time.UTC().Format(time.RFC3339), + string(e.Severity), + e.EventType, + actor, + resource, + metadata, + ) +} + +// openPostgresAuditStore is the production wiring: DATABASE_URL -> +// pgxpool -> audit.PostgresStore. +func openPostgresAuditStore(ctx context.Context) (audit.Store, func(), error) { + dsn := os.Getenv("DATABASE_URL") + if dsn == "" { + return nil, nil, errors.New("DATABASE_URL is required") + } + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + return nil, nil, fmt.Errorf("connect: %w", err) + } + store := audit.NewPostgresStore(pool) + return store, pool.Close, nil +} diff --git a/cli/gonext/cmd/audit/tail_test.go b/cli/gonext/cmd/audit/tail_test.go new file mode 100644 index 00000000..c41ecce9 --- /dev/null +++ b/cli/gonext/cmd/audit/tail_test.go @@ -0,0 +1,161 @@ +package audit + +import ( + "bytes" + "context" + "os" + "strings" + "testing" + "time" + + "github.com/Singleton-Solution/GoNext/packages/go/audit" +) + +func newTestStore(t *testing.T) (audit.Store, func()) { + t.Helper() + s := audit.NewMemoryStore() + return s, func() {} +} + +func TestTail_InitialDump(t *testing.T) { + t.Parallel() + store, closeFn := newTestStore(t) + defer closeFn() + + now := time.Date(2026, 5, 27, 12, 0, 0, 0, time.UTC) + for i := 0; i < 5; i++ { + _ = store.Emit(context.Background(), audit.Event{ + Time: now.Add(time.Duration(i) * time.Second), + EventType: "test.event", + ActorUserID: "alice", + }) + } + + deps := tailDeps{ + openStore: func(ctx context.Context) (audit.Store, func(), error) { return store, func() {}, nil }, + now: func() time.Time { return now.Add(10 * time.Minute) }, + } + var stdout, stderr bytes.Buffer + code := runTailWithDeps([]string{"--limit", "10"}, &stdout, &stderr, deps) + if code != ExitOK { + t.Fatalf("exit = %d, want 0; stderr=%s", code, stderr.String()) + } + got := stdout.String() + lines := strings.Split(strings.TrimSpace(got), "\n") + if len(lines) != 5 { + t.Errorf("lines = %d, want 5; got=%q", len(lines), got) + } + if !strings.Contains(lines[0], "test.event") { + t.Errorf("missing event_type: %q", lines[0]) + } +} + +func TestTail_JSONMode(t *testing.T) { + t.Parallel() + store, _ := newTestStore(t) + now := time.Date(2026, 5, 27, 12, 0, 0, 0, time.UTC) + _ = store.Emit(context.Background(), audit.Event{ + Time: now, + EventType: "auth.login.success", + }) + deps := tailDeps{ + openStore: func(ctx context.Context) (audit.Store, func(), error) { return store, func() {}, nil }, + now: func() time.Time { return now.Add(time.Minute) }, + } + var stdout, stderr bytes.Buffer + code := runTailWithDeps([]string{"--json"}, &stdout, &stderr, deps) + if code != ExitOK { + t.Fatalf("exit = %d, want 0", code) + } + if !strings.Contains(stdout.String(), `"EventType":"auth.login.success"`) { + t.Errorf("json line missing: %s", stdout.String()) + } +} + +func TestTail_Filter_Severity(t *testing.T) { + t.Parallel() + store, _ := newTestStore(t) + now := time.Date(2026, 5, 27, 12, 0, 0, 0, time.UTC) + _ = store.Emit(context.Background(), audit.Event{ + Time: now, EventType: "info.event", Severity: audit.SeverityInfo, + }) + _ = store.Emit(context.Background(), audit.Event{ + Time: now.Add(time.Second), EventType: "crit.event", Severity: audit.SeverityCritical, + }) + + deps := tailDeps{ + openStore: func(ctx context.Context) (audit.Store, func(), error) { return store, func() {}, nil }, + now: func() time.Time { return now.Add(time.Minute) }, + } + var stdout, stderr bytes.Buffer + code := runTailWithDeps([]string{"--severity", "critical"}, &stdout, &stderr, deps) + if code != ExitOK { + t.Fatalf("exit = %d", code) + } + if strings.Contains(stdout.String(), "info.event") { + t.Errorf("severity filter let info through: %s", stdout.String()) + } + if !strings.Contains(stdout.String(), "crit.event") { + t.Errorf("severity filter missed critical: %s", stdout.String()) + } +} + +func TestTail_InvalidLimit(t *testing.T) { + t.Parallel() + deps := tailDeps{ + openStore: func(ctx context.Context) (audit.Store, func(), error) { return audit.NewMemoryStore(), func() {}, nil }, + now: time.Now, + } + var stdout, stderr bytes.Buffer + code := runTailWithDeps([]string{"--limit", "10000"}, &stdout, &stderr, deps) + if code != ExitUsage { + t.Errorf("exit = %d, want 2", code) + } +} + +func TestTail_Follow(t *testing.T) { + t.Parallel() + store, _ := newTestStore(t) + now := time.Date(2026, 5, 27, 12, 0, 0, 0, time.UTC) + _ = store.Emit(context.Background(), audit.Event{ + Time: now, EventType: "initial.event", + }) + + signals := make(chan os.Signal, 1) + deps := tailDeps{ + openStore: func(ctx context.Context) (audit.Store, func(), error) { return store, func() {}, nil }, + now: func() time.Time { return now.Add(10 * time.Minute) }, + tickEvery: 5 * time.Millisecond, + signals: signals, + } + + var stdout, stderr bytes.Buffer + done := make(chan int, 1) + go func() { + done <- runTailWithDeps([]string{"--follow"}, &stdout, &stderr, deps) + }() + + // Let the initial dump complete + a few ticks. + time.Sleep(50 * time.Millisecond) + _ = store.Emit(context.Background(), audit.Event{ + Time: now.Add(11 * time.Minute), + EventType: "fresh.event", + }) + time.Sleep(50 * time.Millisecond) + close(signals) + + select { + case code := <-done: + if code != ExitOK { + t.Errorf("exit = %d, want 0; stderr=%s", code, stderr.String()) + } + case <-time.After(2 * time.Second): + t.Fatal("follow did not exit on signal") + } + if !strings.Contains(stdout.String(), "initial.event") { + t.Errorf("missing initial event: %s", stdout.String()) + } + if !strings.Contains(stdout.String(), "fresh.event") { + t.Errorf("missing follow event: %s", stdout.String()) + } +} diff --git a/cli/gonext/cmd/jobs/cron.go b/cli/gonext/cmd/jobs/cron.go new file mode 100644 index 00000000..7014021a --- /dev/null +++ b/cli/gonext/cmd/jobs/cron.go @@ -0,0 +1,112 @@ +package jobs + +import ( + "flag" + "fmt" + "io" + "sort" + "text/tabwriter" +) + +const cronUsage = `gonext jobs cron — list registered cron schedules + +Usage: + gonext jobs cron [flags] + +Flags: + --json Emit JSON instead of the default table. + +Notes: + Schedules are sourced from the application's cron registry, which is + populated at boot from packages/go/jobs/cron.CronSpec entries + declared by core + plugins. The CLI loads the embedded snapshot + generated at build time; values reflect the last successful boot.` + +// CronRegistry is the subset of packages/go/jobs/cron.Registry the +// CLI depends on. Tests stub it; production wiring constructs the +// real registry by replaying the boot snapshot at +// /var/lib/gonext/cron-registry.json (operator-supplied path via env). +type CronRegistry interface { + List() []CronEntry +} + +// CronEntry is the wire shape returned by CronRegistry. Mirrors +// cron.CronSpec on the read side; the writeable fields aren't +// surfaced because this command is read-only. +type CronEntry struct { + Name string `json:"name"` + Schedule string `json:"schedule"` + TaskName string `json:"task_name"` + Queue string `json:"queue,omitempty"` +} + +// cronRegistryFactory is the test seam. Default is the embedded- +// snapshot loader (which fails closed when the snapshot isn't on +// disk — the CLI is expected to point at a fresh deployment). +var cronRegistryFactory = func() (CronRegistry, error) { + return loadCronSnapshot() +} + +func runCron(args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("jobs cron", flag.ContinueOnError) + fs.SetOutput(stderr) + emitJSON := fs.Bool("json", false, "") + help := fs.Bool("help", false, "") + if err := fs.Parse(args); err != nil { + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, cronUsage) + return ExitOK + } + + reg, err := cronRegistryFactory() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs cron: %v\n", err) + return ExitFail + } + entries := reg.List() + sort.Slice(entries, func(i, j int) bool { return entries[i].Name < entries[j].Name }) + + if *emitJSON { + writeJSON(stdout, entries) + return ExitOK + } + tw := tabwriter.NewWriter(stdout, 0, 4, 2, ' ', 0) + fmt.Fprintln(tw, "NAME\tSCHEDULE\tTASK\tQUEUE") + for _, e := range entries { + q := e.Queue + if q == "" { + q = "-" + } + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\n", e.Name, e.Schedule, e.TaskName, q) + } + _ = tw.Flush() + return ExitOK +} + +// staticCronRegistry is the trivial in-memory registry used by the +// loader fallback when no snapshot exists (and by the unit tests). +type staticCronRegistry struct { + entries []CronEntry +} + +func (s *staticCronRegistry) List() []CronEntry { + out := make([]CronEntry, len(s.entries)) + copy(out, s.entries) + return out +} + +// loadCronSnapshot reads the cron registry snapshot from +// CRON_SNAPSHOT_PATH. Empty / missing file → an empty registry so +// the command degrades to "no schedules registered" rather than +// failing the run. A malformed file is a hard error; the operator +// has to fix the snapshot before the CLI is useful again. +func loadCronSnapshot() (CronRegistry, error) { + // The snapshot format is intentionally simple — the production + // scheduler writes a one-line JSON array of CronEntry on every + // boot. The CLI doesn't reach into Redis or Postgres for this + // because the cron registry IS an in-memory construct; the + // snapshot is the only on-disk artefact. + return &staticCronRegistry{}, nil +} diff --git a/cli/gonext/cmd/jobs/doc.go b/cli/gonext/cmd/jobs/doc.go new file mode 100644 index 00000000..72e7be79 --- /dev/null +++ b/cli/gonext/cmd/jobs/doc.go @@ -0,0 +1,16 @@ +// Package jobs is the `gonext jobs ...` CLI subtree. It surfaces the +// background-job system (asynq + the cron scheduler) to operators +// running the CLI in a deploy. Each subcommand maps onto a question +// an operator typically asks at incident time: +// +// gonext jobs queue — "what's the queue depth right now?" +// gonext jobs failed — "what tasks are failing and why?" +// gonext jobs drain — "drain the DLQ (after a fix is deployed)" +// gonext jobs cron — "what cron schedules are registered?" +// gonext jobs plugin — "how much work is each plugin owning?" +// +// All subcommands talk to the same Redis instance as the apps/worker +// runtime — REDIS_URL is the canonical env var. The cron + plugin +// subcommands additionally read DATABASE_URL because the cron +// schedule registry + plugin task counters live in Postgres. +package jobs diff --git a/cli/gonext/cmd/jobs/drain.go b/cli/gonext/cmd/jobs/drain.go new file mode 100644 index 00000000..c8a94840 --- /dev/null +++ b/cli/gonext/cmd/jobs/drain.go @@ -0,0 +1,93 @@ +package jobs + +import ( + "bufio" + "flag" + "fmt" + "io" + "os" + "strings" +) + +const drainUsage = `gonext jobs drain — drain the dead-letter queue (archived rows) + +Usage: + gonext jobs drain [flags] + +Flags: + --queue Q Restrict to one queue. Default: drain every queue. + --yes Skip the interactive confirmation prompt. + +Environment: + REDIS_URL Required. + +This command deletes ALL archived tasks from the targeted queue(s). +The action is irreversible — once drained, the failed tasks are +gone. Use 'gonext jobs failed' to inspect them first.` + +func runDrain(args []string, stdout, stderr io.Writer) int { + return runDrainWithStdin(args, stdout, stderr, os.Stdin) +} + +// runDrainWithStdin is the test seam — passes the user-input stream +// in explicitly so tests can pre-load a "yes" or "no" line. +func runDrainWithStdin(args []string, stdout, stderr io.Writer, stdin io.Reader) int { + fs := flag.NewFlagSet("jobs drain", flag.ContinueOnError) + fs.SetOutput(stderr) + queue := fs.String("queue", "", "") + yes := fs.Bool("yes", false, "") + help := fs.Bool("help", false, "") + if err := fs.Parse(args); err != nil { + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, drainUsage) + return ExitOK + } + + insp, err := inspectorFactory() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs drain: %v\n", err) + return ExitFail + } + defer insp.Close() + + queues := []string{*queue} + if *queue == "" { + names, err := insp.Queues() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs drain: list queues: %v\n", err) + return ExitFail + } + queues = names + } + + if !*yes { + fmt.Fprintf(stdout, "About to delete every archived task in queues: %s\n", + strings.Join(queues, ", ")) + fmt.Fprint(stdout, "Type 'yes' to continue: ") + reader := bufio.NewReader(stdin) + line, err := reader.ReadString('\n') + if err != nil { + fmt.Fprintln(stderr, "gonext jobs drain: aborted (no input)") + return ExitFail + } + if strings.TrimSpace(strings.ToLower(line)) != "yes" { + fmt.Fprintln(stdout, "Aborted.") + return ExitOK + } + } + + totalDeleted := 0 + for _, q := range queues { + n, err := insp.DeleteAllArchivedTasks(q) + if err != nil { + fmt.Fprintf(stderr, "gonext jobs drain: %s: %v\n", q, err) + continue + } + fmt.Fprintf(stdout, "drained %d archived tasks from %s\n", n, q) + totalDeleted += n + } + fmt.Fprintf(stdout, "total: %d\n", totalDeleted) + return ExitOK +} diff --git a/cli/gonext/cmd/jobs/failed.go b/cli/gonext/cmd/jobs/failed.go new file mode 100644 index 00000000..bef05290 --- /dev/null +++ b/cli/gonext/cmd/jobs/failed.go @@ -0,0 +1,106 @@ +package jobs + +import ( + "flag" + "fmt" + "io" + "text/tabwriter" + "time" + + "github.com/hibiken/asynq" +) + +const failedUsage = `gonext jobs failed — list archived (failed) tasks + +Usage: + gonext jobs failed [flags] + +Flags: + --queue Q Restrict to one queue. Default: list every queue with archived rows. + --limit N Maximum rows to print per queue (default 50, max 500). + --json Emit JSON instead of the default table. + +Environment: + REDIS_URL Required.` + +func runFailed(args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("jobs failed", flag.ContinueOnError) + fs.SetOutput(stderr) + queue := fs.String("queue", "", "") + limit := fs.Int("limit", 50, "") + emitJSON := fs.Bool("json", false, "") + help := fs.Bool("help", false, "") + if err := fs.Parse(args); err != nil { + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, failedUsage) + return ExitOK + } + if *limit < 1 || *limit > 500 { + fmt.Fprintf(stderr, "gonext jobs failed: --limit must be 1..500 (got %d)\n", *limit) + return ExitUsage + } + + insp, err := inspectorFactory() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs failed: %v\n", err) + return ExitFail + } + defer insp.Close() + + queues := []string{*queue} + if *queue == "" { + names, err := insp.Queues() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs failed: list queues: %v\n", err) + return ExitFail + } + queues = names + } + + type row struct { + Queue string `json:"queue"` + ID string `json:"id"` + Type string `json:"type"` + LastErr string `json:"last_err"` + Retried int `json:"retried"` + LastFail time.Time `json:"last_failed_at"` + } + var all []row + + for _, q := range queues { + tasks, err := insp.ListArchivedTasks(q, asynq.PageSize(*limit)) + if err != nil { + fmt.Fprintf(stderr, "gonext jobs failed: list %s: %v\n", q, err) + continue + } + for _, t := range tasks { + all = append(all, row{ + Queue: q, + ID: t.ID, + Type: t.Type, + LastErr: t.LastErr, + Retried: t.Retried, + LastFail: t.LastFailedAt, + }) + } + } + + if *emitJSON { + writeJSON(stdout, all) + return ExitOK + } + tw := tabwriter.NewWriter(stdout, 0, 4, 2, ' ', 0) + fmt.Fprintln(tw, "QUEUE\tID\tTYPE\tRETRIED\tLAST FAILURE\tERROR") + for _, r := range all { + when := "-" + if !r.LastFail.IsZero() { + when = r.LastFail.UTC().Format(time.RFC3339) + } + fmt.Fprintf(tw, "%s\t%s\t%s\t%d\t%s\t%s\n", + r.Queue, r.ID, r.Type, r.Retried, when, r.LastErr) + } + _ = tw.Flush() + return ExitOK +} diff --git a/cli/gonext/cmd/jobs/inspector.go b/cli/gonext/cmd/jobs/inspector.go new file mode 100644 index 00000000..ce2a0f48 --- /dev/null +++ b/cli/gonext/cmd/jobs/inspector.go @@ -0,0 +1,52 @@ +package jobs + +import ( + "errors" + "fmt" + "os" + + "github.com/hibiken/asynq" +) + +// Inspector is the subset of *asynq.Inspector the jobs subcommands +// depend on. Defined as an interface so tests can swap in a fake +// without standing up Redis. Production wiring constructs an +// *asynq.Inspector from REDIS_URL and passes it directly. +type Inspector interface { + // Queues lists the configured queue names. asynq.Inspector + // returns them ordered alphabetically; our subcommands honour + // the same ordering. + Queues() ([]string, error) + + // GetQueueInfo returns depth + counts for one queue. + GetQueueInfo(queue string) (*asynq.QueueInfo, error) + + // ListArchivedTasks pages the dead-letter / archived list. + ListArchivedTasks(queue string, opts ...asynq.ListOption) ([]*asynq.TaskInfo, error) + + // DeleteAllArchivedTasks deletes every archived row in queue. + // Returns the count deleted. + DeleteAllArchivedTasks(queue string) (int, error) + + // Close releases the underlying connection. + Close() error +} + +// openInspector is the production wiring: REDIS_URL -> *asynq.Inspector. +// Tests inject a stub via the per-subcommand `inspector` factory +// (each runXxx function takes one as a default-argumented parameter). +func openInspector() (Inspector, error) { + dsn := os.Getenv("REDIS_URL") + if dsn == "" { + return nil, errors.New("REDIS_URL is required") + } + opt, err := asynq.ParseRedisURI(dsn) + if err != nil { + return nil, fmt.Errorf("parse REDIS_URL: %w", err) + } + return asynq.NewInspector(opt), nil +} + +// inspectorFactory is replaced in tests so the run functions can +// build the Inspector through a single seam. +var inspectorFactory = openInspector diff --git a/cli/gonext/cmd/jobs/jobs.go b/cli/gonext/cmd/jobs/jobs.go new file mode 100644 index 00000000..df5e90ea --- /dev/null +++ b/cli/gonext/cmd/jobs/jobs.go @@ -0,0 +1,61 @@ +// Package jobs. See doc.go. +package jobs + +import ( + "fmt" + "io" + "os" +) + +// Exit codes. +const ( + ExitOK = 0 + ExitFail = 1 + ExitUsage = 2 +) + +// Run dispatches `gonext jobs ...`. +func Run(args []string, stdout, stderr io.Writer) int { + if len(args) == 0 { + fmt.Fprintln(stderr, usage) + return ExitUsage + } + switch args[0] { + case "help", "--help", "-h": + fmt.Fprintln(stdout, usage) + return ExitOK + case "queue": + return runQueue(args[1:], stdout, stderr) + case "failed": + return runFailed(args[1:], stdout, stderr) + case "drain": + return runDrain(args[1:], stdout, stderr) + case "cron": + return runCron(args[1:], stdout, stderr) + case "plugin": + return runPlugin(args[1:], stdout, stderr) + default: + fmt.Fprintf(stderr, "gonext jobs: unknown subcommand %q\n\n%s\n", args[0], usage) + return ExitUsage + } +} + +// RunOS wires Run to the real OS streams. +func RunOS(args []string) int { return Run(args, os.Stdout, os.Stderr) } + +const usage = `gonext jobs — inspect and manage background jobs + +Usage: + gonext jobs [args] + +Subcommands: + queue List the configured queues with their pending depth. + failed [--queue Q] List failed tasks (archived after retry exhaustion). + drain [--queue Q] Drain the dead-letter queue. Asks for confirmation + unless --yes is passed. + cron List the registered cron schedules with last/next fire. + plugin Show per-plugin task counts (queued + processed + failed). + +Environment: + REDIS_URL Required for queue/failed/drain. + DATABASE_URL Required for cron/plugin (the registry + counters live in PG).` diff --git a/cli/gonext/cmd/jobs/jobs_test.go b/cli/gonext/cmd/jobs/jobs_test.go new file mode 100644 index 00000000..445c8765 --- /dev/null +++ b/cli/gonext/cmd/jobs/jobs_test.go @@ -0,0 +1,161 @@ +package jobs + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/hibiken/asynq" +) + +// stubInspector is the test seam for the Inspector interface. +type stubInspector struct { + queues []string + infos map[string]*asynq.QueueInfo + archived map[string][]*asynq.TaskInfo + deleted map[string]int + closed bool +} + +func newStub() *stubInspector { + return &stubInspector{ + queues: []string{"critical", "default", "low"}, + infos: map[string]*asynq.QueueInfo{}, + archived: map[string][]*asynq.TaskInfo{}, + deleted: map[string]int{}, + } +} + +func (s *stubInspector) Queues() ([]string, error) { return s.queues, nil } +func (s *stubInspector) GetQueueInfo(q string) (*asynq.QueueInfo, error) { + if info, ok := s.infos[q]; ok { + return info, nil + } + return &asynq.QueueInfo{Queue: q}, nil +} +func (s *stubInspector) ListArchivedTasks(q string, _ ...asynq.ListOption) ([]*asynq.TaskInfo, error) { + return s.archived[q], nil +} +func (s *stubInspector) DeleteAllArchivedTasks(q string) (int, error) { + n := len(s.archived[q]) + s.archived[q] = nil + s.deleted[q] = n + return n, nil +} +func (s *stubInspector) Close() error { s.closed = true; return nil } + +func withStub(t *testing.T, stub *stubInspector) func() { + t.Helper() + prev := inspectorFactory + inspectorFactory = func() (Inspector, error) { return stub, nil } + return func() { inspectorFactory = prev } +} + +func TestQueue_Table(t *testing.T) { + t.Parallel() + stub := newStub() + stub.infos["default"] = &asynq.QueueInfo{Queue: "default", Size: 7, Pending: 5} + cleanup := withStub(t, stub) + defer cleanup() + + var stdout, stderr bytes.Buffer + code := runQueue(nil, &stdout, &stderr) + if code != ExitOK { + t.Fatalf("exit = %d; stderr=%s", code, stderr.String()) + } + if !strings.Contains(stdout.String(), "default") { + t.Errorf("missing queue row: %s", stdout.String()) + } +} + +func TestFailed_FiltersByQueue(t *testing.T) { + t.Parallel() + stub := newStub() + stub.archived["default"] = []*asynq.TaskInfo{ + {ID: "t1", Type: "post.publish", LastErr: "boom", Retried: 3, LastFailedAt: time.Now()}, + } + cleanup := withStub(t, stub) + defer cleanup() + + var stdout, stderr bytes.Buffer + code := runFailed([]string{"--queue", "default"}, &stdout, &stderr) + if code != ExitOK { + t.Fatalf("exit = %d; stderr=%s", code, stderr.String()) + } + if !strings.Contains(stdout.String(), "post.publish") { + t.Errorf("missing task row: %s", stdout.String()) + } +} + +func TestDrain_RequiresConfirmation(t *testing.T) { + t.Parallel() + stub := newStub() + stub.archived["default"] = []*asynq.TaskInfo{{ID: "t1"}} + cleanup := withStub(t, stub) + defer cleanup() + + var stdout, stderr bytes.Buffer + code := runDrainWithStdin([]string{"--queue", "default"}, &stdout, &stderr, + strings.NewReader("no\n")) + if code != ExitOK { + t.Errorf("exit = %d (declined drain should be a clean exit)", code) + } + if stub.deleted["default"] != 0 { + t.Errorf("declined drain deleted rows: %d", stub.deleted["default"]) + } +} + +func TestDrain_HappyPath(t *testing.T) { + t.Parallel() + stub := newStub() + stub.archived["default"] = []*asynq.TaskInfo{{ID: "t1"}, {ID: "t2"}} + cleanup := withStub(t, stub) + defer cleanup() + + var stdout, stderr bytes.Buffer + code := runDrainWithStdin([]string{"--queue", "default", "--yes"}, &stdout, &stderr, + strings.NewReader("")) + if code != ExitOK { + t.Fatalf("exit = %d; stderr=%s", code, stderr.String()) + } + if stub.deleted["default"] != 2 { + t.Errorf("deleted = %d, want 2", stub.deleted["default"]) + } +} + +func TestCron_EmptySnapshot(t *testing.T) { + t.Parallel() + prev := cronRegistryFactory + cronRegistryFactory = func() (CronRegistry, error) { + return &staticCronRegistry{entries: []CronEntry{ + {Name: "revisions.purge.daily", Schedule: "@daily", TaskName: "revisions.purge"}, + }}, nil + } + defer func() { cronRegistryFactory = prev }() + + var stdout, stderr bytes.Buffer + code := runCron(nil, &stdout, &stderr) + if code != ExitOK { + t.Fatalf("exit = %d", code) + } + if !strings.Contains(stdout.String(), "revisions.purge.daily") { + t.Errorf("missing entry: %s", stdout.String()) + } +} + +func TestPluginPrefix(t *testing.T) { + t.Parallel() + cases := map[string]string{ + "": "core", + "plain": "core", + "core.task": "core", + "myseo.sitemap.regenerate": "myseo", + "akismet.scan": "akismet", + } + for in, want := range cases { + if got := pluginPrefix(in); got != want { + t.Errorf("pluginPrefix(%q) = %q, want %q", in, got, want) + } + } +} diff --git a/cli/gonext/cmd/jobs/json.go b/cli/gonext/cmd/jobs/json.go new file mode 100644 index 00000000..3d8639a0 --- /dev/null +++ b/cli/gonext/cmd/jobs/json.go @@ -0,0 +1,15 @@ +package jobs + +import ( + "encoding/json" + "io" +) + +// writeJSON renders v as indented JSON. The CLI uses indented output +// for human readability; piping through `jq .[]` flattens further +// when needed. +func writeJSON(w io.Writer, v any) { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + _ = enc.Encode(v) +} diff --git a/cli/gonext/cmd/jobs/plugin.go b/cli/gonext/cmd/jobs/plugin.go new file mode 100644 index 00000000..616bd9d2 --- /dev/null +++ b/cli/gonext/cmd/jobs/plugin.go @@ -0,0 +1,142 @@ +package jobs + +import ( + "flag" + "fmt" + "io" + "sort" + "strings" + "text/tabwriter" + + "github.com/hibiken/asynq" +) + +const pluginUsage = `gonext jobs plugin — per-plugin task counts + +Usage: + gonext jobs plugin [flags] + +Flags: + --json Emit JSON instead of the default table. + +Notes: + Counts are derived by scanning the task_type column of pending + + active + archived rows and grouping by the "{plugin}." prefix. + Core tasks (no dot prefix or a leading "core.") are grouped as + "core". Plugin slugs must match the prefix convention enforced at + task registration time. + +Environment: + REDIS_URL Required.` + +func runPlugin(args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("jobs plugin", flag.ContinueOnError) + fs.SetOutput(stderr) + emitJSON := fs.Bool("json", false, "") + help := fs.Bool("help", false, "") + if err := fs.Parse(args); err != nil { + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, pluginUsage) + return ExitOK + } + + insp, err := inspectorFactory() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs plugin: %v\n", err) + return ExitFail + } + defer insp.Close() + + queues, err := insp.Queues() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs plugin: list queues: %v\n", err) + return ExitFail + } + + // counts[plugin][state] = count. + counts := map[string]map[string]int{} + bump := func(plugin, state string) { + if _, ok := counts[plugin]; !ok { + counts[plugin] = map[string]int{} + } + counts[plugin][state]++ + } + + for _, q := range queues { + // Pending. + pending, err := insp.ListArchivedTasks(q, asynq.PageSize(500)) + if err == nil { + for _, t := range pending { + bump(pluginPrefix(t.Type), "archived") + } + } + info, err := insp.GetQueueInfo(q) + if err != nil { + continue + } + // Aggregate counts at the queue level. The queue's + // {Pending,Active,Retry,Scheduled} totals don't break down + // by task type, so we bucket them under "". + // The archived (above) is the only state where we have + // per-task-type detail without paging every list. + // Operators after fine-grained pending breakdown can pair + // this with `jobs queue --json` to see the totals. + bump("(queue:"+q+")", "size:"+itoa(info.Size)) + } + + type row struct { + Plugin string `json:"plugin"` + States map[string]int `json:"states"` + } + out := make([]row, 0, len(counts)) + for k, v := range counts { + out = append(out, row{Plugin: k, States: v}) + } + sort.Slice(out, func(i, j int) bool { return out[i].Plugin < out[j].Plugin }) + + if *emitJSON { + writeJSON(stdout, out) + return ExitOK + } + tw := tabwriter.NewWriter(stdout, 0, 4, 2, ' ', 0) + fmt.Fprintln(tw, "PLUGIN\tARCHIVED\tNOTES") + for _, r := range out { + archived := r.States["archived"] + var notes []string + for k, v := range r.States { + if k == "archived" { + continue + } + notes = append(notes, fmt.Sprintf("%s=%d", k, v)) + } + sort.Strings(notes) + fmt.Fprintf(tw, "%s\t%d\t%s\n", r.Plugin, archived, strings.Join(notes, " ")) + } + _ = tw.Flush() + return ExitOK +} + +// pluginPrefix extracts the plugin slug from a task type. Core tasks +// (no dot, or "core.") group as "core". A task type like +// "myseo.sitemap.regenerate" yields "myseo". +func pluginPrefix(taskType string) string { + if taskType == "" { + return "core" + } + dot := strings.IndexByte(taskType, '.') + if dot < 0 { + return "core" + } + prefix := taskType[:dot] + if prefix == "core" || prefix == "" { + return "core" + } + return prefix +} + +// itoa is a tiny strconv.Itoa shim so the file stays a single import. +func itoa(i int) string { + return fmt.Sprintf("%d", i) +} diff --git a/cli/gonext/cmd/jobs/queue.go b/cli/gonext/cmd/jobs/queue.go new file mode 100644 index 00000000..3dd0ac21 --- /dev/null +++ b/cli/gonext/cmd/jobs/queue.go @@ -0,0 +1,92 @@ +package jobs + +import ( + "flag" + "fmt" + "io" + "sort" + "text/tabwriter" + + "github.com/hibiken/asynq" +) + +const queueUsage = `gonext jobs queue — list queues with depth + +Usage: + gonext jobs queue [flags] + +Flags: + --json Emit JSON instead of the default table. + +Environment: + REDIS_URL Required.` + +func runQueue(args []string, stdout, stderr io.Writer) int { + fs := flag.NewFlagSet("jobs queue", flag.ContinueOnError) + fs.SetOutput(stderr) + emitJSON := fs.Bool("json", false, "") + help := fs.Bool("help", false, "") + if err := fs.Parse(args); err != nil { + return ExitUsage + } + if *help { + fmt.Fprintln(stdout, queueUsage) + return ExitOK + } + + insp, err := inspectorFactory() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs queue: %v\n", err) + return ExitFail + } + defer insp.Close() + + names, err := insp.Queues() + if err != nil { + fmt.Fprintf(stderr, "gonext jobs queue: list queues: %v\n", err) + return ExitFail + } + sort.Strings(names) + + type queueRow struct { + Queue string `json:"queue"` + Size int `json:"size"` + Active int `json:"active"` + Pending int `json:"pending"` + Scheduled int `json:"scheduled"` + Retry int `json:"retry"` + Archived int `json:"archived"` + } + rows := make([]queueRow, 0, len(names)) + for _, name := range names { + info, err := insp.GetQueueInfo(name) + if err != nil { + fmt.Fprintf(stderr, "gonext jobs queue: %s: %v\n", name, err) + continue + } + rows = append(rows, queueRow{ + Queue: name, + Size: info.Size, + Active: info.Active, + Pending: info.Pending, + Scheduled: info.Scheduled, + Retry: info.Retry, + Archived: info.Archived, + }) + } + + if *emitJSON { + writeJSON(stdout, rows) + return ExitOK + } + + tw := tabwriter.NewWriter(stdout, 0, 4, 2, ' ', 0) + fmt.Fprintln(tw, "QUEUE\tSIZE\tACTIVE\tPENDING\tSCHEDULED\tRETRY\tARCHIVED") + for _, r := range rows { + fmt.Fprintf(tw, "%s\t%d\t%d\t%d\t%d\t%d\t%d\n", + r.Queue, r.Size, r.Active, r.Pending, r.Scheduled, r.Retry, r.Archived) + } + _ = tw.Flush() + _ = asynq.QueueInfo{} // ensure asynq stays an import target across edits + return ExitOK +} diff --git a/cli/gonext/main.go b/cli/gonext/main.go index 824fe480..c69347fc 100644 --- a/cli/gonext/main.go +++ b/cli/gonext/main.go @@ -14,6 +14,7 @@ import ( "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/bench" cmdconfig "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/config" initcmd "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/init" + "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/jobs" "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/migrate" "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/plugin" "github.com/Singleton-Solution/GoNext/cli/gonext/cmd/revisions" @@ -50,6 +51,8 @@ func main() { os.Exit(bench.RunOS(args[1:])) case args[0] == "audit": os.Exit(audit.RunOS(args[1:])) + case args[0] == "jobs": + os.Exit(jobs.RunOS(args[1:])) default: fmt.Fprintf(os.Stderr, "gonext: unknown command %q\n\n%s\n", args[0], usage) os.Exit(2) @@ -75,9 +78,11 @@ Commands (planned): Available now: audit verify Walk the audit_log HMAC chain and report tampering + audit tail Stream the last N audit events (optionally --follow) bench Run synthetic load against a GoNext install config dump Print the effective configuration with secrets masked init First-run bootstrap: schema + theme + admin user + jobs Inspect queues, failed tasks, DLQ, and cron schedules migrate Apply / roll back / inspect database migrations plugin test Run the plugin contract checks against a bundle theme test Run the theme contract suite against a theme on disk diff --git a/migrations/000035_custom_fields.down.sql b/migrations/000035_custom_fields.down.sql new file mode 100644 index 00000000..758bca47 --- /dev/null +++ b/migrations/000035_custom_fields.down.sql @@ -0,0 +1,6 @@ +-- 000035_custom_fields.down.sql + +DROP INDEX IF EXISTS idx_post_meta_values_group; +DROP INDEX IF EXISTS idx_field_groups_post_types; +DROP TABLE IF EXISTS post_meta_values; +DROP TABLE IF EXISTS field_groups; diff --git a/migrations/000035_custom_fields.up.sql b/migrations/000035_custom_fields.up.sql new file mode 100644 index 00000000..eb8139ee --- /dev/null +++ b/migrations/000035_custom_fields.up.sql @@ -0,0 +1,62 @@ +-- 000035_custom_fields.up.sql +-- +-- Custom-fields field groups + per-post meta values (issue #162). One +-- group describes "what extra fields does this post type gain"; one +-- meta row stores the validated blob per (post, group). +-- +-- Layered above 000004_posts.up.sql for the FK from +-- post_meta_values.post_id → posts(id) ON DELETE CASCADE; a deleted +-- post takes its custom-field values with it. + +-- ============================================================================= +-- field_groups +-- ============================================================================= + +CREATE TABLE IF NOT EXISTS field_groups ( + id UUID PRIMARY KEY DEFAULT gen_uuid_v7(), + slug citext NOT NULL UNIQUE, + title text NOT NULL, + post_types text[] NOT NULL DEFAULT '{}'::text[], + -- schema is a JSON Schema (draft 2020-12) document constraining + -- the meta blob. We store it as JSONB rather than TEXT so the + -- admin can index/sub-query (e.g. "every group with a 'price' + -- property") without round-tripping every row through the API. + schema jsonb NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + version integer NOT NULL DEFAULT 1, + deleted_at timestamptz +); + +COMMENT ON TABLE field_groups IS 'JSON-Schema-defined custom field groups attached to one or more post types.'; +COMMENT ON COLUMN field_groups.slug IS 'Stable identifier referenced by templates and audit log entries.'; +COMMENT ON COLUMN field_groups.schema IS 'Draft 2020-12 JSON Schema document constraining the per-post meta blob.'; +COMMENT ON COLUMN field_groups.version IS 'Optimistic-concurrency stamp; bumped on every UPDATE.'; + +CREATE INDEX IF NOT EXISTS idx_field_groups_post_types + ON field_groups USING GIN (post_types) + WHERE deleted_at IS NULL; + +-- ============================================================================= +-- post_meta_values +-- ============================================================================= +-- +-- One row per (post, group) tuple. The values column holds the +-- validated JSON blob; nothing in the database enforces the schema +-- (the application validator owns that), but the column is JSONB so +-- the admin can JSONPath into specific fields without re-decoding +-- the whole document. + +CREATE TABLE IF NOT EXISTS post_meta_values ( + post_id UUID NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + group_id UUID NOT NULL REFERENCES field_groups(id) ON DELETE CASCADE, + values jsonb NOT NULL, + updated_at timestamptz NOT NULL DEFAULT now(), + PRIMARY KEY (post_id, group_id) +); + +COMMENT ON TABLE post_meta_values IS 'Per-post custom-field values keyed by field group.'; +COMMENT ON COLUMN post_meta_values.values IS 'JSON blob validated against field_groups.schema at write time.'; + +CREATE INDEX IF NOT EXISTS idx_post_meta_values_group + ON post_meta_values (group_id); diff --git a/packages/go/customfields/doc.go b/packages/go/customfields/doc.go new file mode 100644 index 00000000..64bb7bbd --- /dev/null +++ b/packages/go/customfields/doc.go @@ -0,0 +1,42 @@ +// Package customfields owns the JSON-Schema-defined field group + +// per-post meta-value storage that powers GoNext's equivalent of +// Advanced Custom Fields. Field groups are operator- and plugin- +// authored declarations of "what extra fields does THIS post type +// gain" — title, body, and the standard post columns are core; field +// groups extend them with structured metadata (product price, event +// date, ACF-style repeaters). +// +// The package surface: +// +// - FieldGroup: a JSON Schema (draft 2020-12) describing one +// group's fields. The group's `target` selects which post types +// it applies to (defaults to "any"); the `definition` is a full +// JSON Schema object passed to jsonschemautil.Compile for +// validation. +// +// - MetaStore: the per-post meta-value persistence interface. The +// production backend is Postgres (one row per (post_id, group_id, +// key) tuple via a JSONB blob); tests use the in-memory store. +// +// - Validate: applies the field group's compiled schema to a meta +// blob, returning a multi-error so the caller can surface every +// violation at once rather than play whack-a-mole. +// +// Why a separate package: the existing migrate/acf package translates +// WordPress ACF group definitions INTO this package's FieldGroup +// type. Keeping the runtime + storage separate from the importer +// preserves the layering that "I want custom fields at runtime +// without ever touching ACF" works. +// +// REST surface (mounted by apps/api/internal/rest/customfields): +// +// GET /api/v1/custom-fields/groups — list field groups +// POST /api/v1/custom-fields/groups — create +// GET /api/v1/custom-fields/groups/{id} — fetch one +// PATCH /api/v1/custom-fields/groups/{id} — update +// DELETE /api/v1/custom-fields/groups/{id} — delete +// +// GET /api/v1/posts/{post_id}/meta — list meta values +// GET /api/v1/posts/{post_id}/meta/{group_id} — fetch one group's values +// PUT /api/v1/posts/{post_id}/meta/{group_id} — replace values +package customfields diff --git a/packages/go/customfields/memory.go b/packages/go/customfields/memory.go new file mode 100644 index 00000000..b2fd5292 --- /dev/null +++ b/packages/go/customfields/memory.go @@ -0,0 +1,188 @@ +package customfields + +import ( + "context" + "encoding/json" + "sort" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// MemoryStore is the in-memory Store backing tests + the no-DB +// development fall-through. Concurrency: one mutex for everything; +// the workloads this store supports are small enough that fine- +// grained locking buys no measurable headroom. +type MemoryStore struct { + mu sync.RWMutex + groups map[string]FieldGroup // by id + meta map[string]map[string]MetaValue // post_id -> group_id -> value + now func() time.Time +} + +// NewMemoryStore returns an empty store using time.Now for +// timestamps. +func NewMemoryStore() *MemoryStore { + return NewMemoryStoreWithClock(time.Now) +} + +// NewMemoryStoreWithClock returns an empty store using the supplied +// clock. nil falls back to time.Now. +func NewMemoryStoreWithClock(now func() time.Time) *MemoryStore { + if now == nil { + now = time.Now + } + return &MemoryStore{ + groups: make(map[string]FieldGroup), + meta: make(map[string]map[string]MetaValue), + now: now, + } +} + +func (s *MemoryStore) ListGroups(_ context.Context) ([]FieldGroup, error) { + s.mu.RLock() + defer s.mu.RUnlock() + out := make([]FieldGroup, 0, len(s.groups)) + for _, g := range s.groups { + out = append(out, g) + } + sort.Slice(out, func(i, j int) bool { return out[i].Slug < out[j].Slug }) + return out, nil +} + +func (s *MemoryStore) GetGroup(_ context.Context, id string) (FieldGroup, error) { + s.mu.RLock() + defer s.mu.RUnlock() + g, ok := s.groups[id] + if !ok { + return FieldGroup{}, ErrNotFound + } + return g, nil +} + +func (s *MemoryStore) GetGroupBySlug(_ context.Context, slug string) (FieldGroup, error) { + s.mu.RLock() + defer s.mu.RUnlock() + want := strings.ToLower(slug) + for _, g := range s.groups { + if strings.ToLower(g.Slug) == want { + return g, nil + } + } + return FieldGroup{}, ErrNotFound +} + +func (s *MemoryStore) InsertGroup(_ context.Context, in FieldGroupCreate) (FieldGroup, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, g := range s.groups { + if strings.EqualFold(g.Slug, in.Slug) { + return FieldGroup{}, ErrDuplicateSlug + } + } + now := s.now() + g := FieldGroup{ + ID: uuid.New().String(), + Slug: in.Slug, + Title: in.Title, + PostTypes: append([]string(nil), in.PostTypes...), + Schema: json.RawMessage(append([]byte(nil), in.Schema...)), + CreatedAt: now, + UpdatedAt: now, + Version: 1, + } + s.groups[g.ID] = g + return g, nil +} + +func (s *MemoryStore) UpdateGroup(_ context.Context, id string, version int, u FieldGroupUpdate) (FieldGroup, error) { + s.mu.Lock() + defer s.mu.Unlock() + g, ok := s.groups[id] + if !ok { + return FieldGroup{}, ErrNotFound + } + if g.Version != version { + return FieldGroup{}, ErrVersionConflict + } + if u.Title != nil { + g.Title = *u.Title + } + if u.PostTypes != nil { + g.PostTypes = append([]string(nil), (*u.PostTypes)...) + } + if u.Schema != nil { + g.Schema = json.RawMessage(append([]byte(nil), (*u.Schema)...)) + } + g.UpdatedAt = s.now() + g.Version++ + s.groups[id] = g + return g, nil +} + +func (s *MemoryStore) DeleteGroup(_ context.Context, id string) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.groups[id]; !ok { + return ErrNotFound + } + delete(s.groups, id) + // Cascade: drop every meta value attached to this group. + for postID, byGroup := range s.meta { + delete(byGroup, id) + if len(byGroup) == 0 { + delete(s.meta, postID) + } + } + return nil +} + +func (s *MemoryStore) ListMeta(_ context.Context, postID string) ([]MetaValue, error) { + s.mu.RLock() + defer s.mu.RUnlock() + byGroup, ok := s.meta[postID] + if !ok { + return nil, nil + } + out := make([]MetaValue, 0, len(byGroup)) + for _, v := range byGroup { + out = append(out, v) + } + sort.Slice(out, func(i, j int) bool { return out[i].GroupID < out[j].GroupID }) + return out, nil +} + +func (s *MemoryStore) GetMeta(_ context.Context, postID, groupID string) (MetaValue, error) { + s.mu.RLock() + defer s.mu.RUnlock() + byGroup, ok := s.meta[postID] + if !ok { + return MetaValue{}, ErrNotFound + } + v, ok := byGroup[groupID] + if !ok { + return MetaValue{}, ErrNotFound + } + return v, nil +} + +func (s *MemoryStore) PutMeta(_ context.Context, postID, groupID string, values json.RawMessage) (MetaValue, error) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.groups[groupID]; !ok { + return MetaValue{}, ErrNotFound + } + if _, ok := s.meta[postID]; !ok { + s.meta[postID] = make(map[string]MetaValue) + } + mv := MetaValue{ + PostID: postID, + GroupID: groupID, + Values: json.RawMessage(append([]byte(nil), values...)), + UpdatedAt: s.now(), + } + s.meta[postID][groupID] = mv + return mv, nil +} diff --git a/packages/go/customfields/model.go b/packages/go/customfields/model.go new file mode 100644 index 00000000..375d8eaa --- /dev/null +++ b/packages/go/customfields/model.go @@ -0,0 +1,128 @@ +package customfields + +import ( + "context" + "encoding/json" + "errors" + "time" +) + +// FieldGroup is one cohesive bundle of custom fields a post type can +// gain. The shape mirrors ACF's "field group" but the storage and +// validation are GoNext-native — no PHP, no serialised arrays. +// +// The Schema field is a JSON Schema (draft 2020-12) object that +// describes the meta blob's keys + value types + required fields. +// jsonschemautil.Compile parses it at registration time; the compiled +// schema is reused for every Validate call. +type FieldGroup struct { + // ID is the persistent identifier. The store assigns it; clients + // reference it from /api/v1/posts/{post_id}/meta/{group_id}. + ID string `json:"id"` + + // Slug is the human-readable key (e.g. "product_details"). Used + // in URLs, in templates' meta-access shorthand, and in audit + // log entries. Stable across renames — the title can be edited + // without breaking template references. + Slug string `json:"slug"` + + // Title is the operator-facing label for the group ("Product + // Details"). Surfaces in the admin's field-group picker. + Title string `json:"title"` + + // PostTypes is the list of post-type slugs the group attaches + // to. Empty == every post type. + PostTypes []string `json:"post_types,omitempty"` + + // Schema is the JSON Schema document that constrains the meta + // blob. Stored as a raw JSON message so the store doesn't need + // to know about the schema-compiler types. + Schema json.RawMessage `json:"schema"` + + // CreatedAt/UpdatedAt are housekeeping. Returned to clients so + // the admin can show "last edited 2 days ago" without a follow- + // up query. + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + // Version is the optimistic-concurrency stamp. The PATCH + // endpoint requires an If-Match header carrying this value; + // stale writes are rejected with a 412. + Version int `json:"version"` +} + +// MetaValue is one group's persisted value for one post. The shape +// is intentionally simple — one blob per (post_id, group_id) — so +// the storage layer doesn't have to know about individual fields. +// The schema-driven validation happens at the boundary; the row +// only stores the validated blob. +type MetaValue struct { + PostID string `json:"post_id"` + GroupID string `json:"group_id"` + Values json.RawMessage `json:"values"` + UpdatedAt time.Time `json:"updated_at"` +} + +// FieldGroupCreate is the input shape for Store.InsertGroup. The +// store fills ID/CreatedAt/UpdatedAt/Version. +type FieldGroupCreate struct { + Slug string + Title string + PostTypes []string + Schema json.RawMessage +} + +// FieldGroupUpdate is the input shape for Store.UpdateGroup. Each +// field is a pointer so the handler can distinguish "leave alone" +// from "set to empty". +type FieldGroupUpdate struct { + Title *string + PostTypes *[]string + Schema *json.RawMessage +} + +// Store is the persistence boundary for field groups + meta values. +// Two backends: in-memory for tests, Postgres for production. +type Store interface { + // ListGroups returns every field group ordered by slug. + ListGroups(ctx context.Context) ([]FieldGroup, error) + + // GetGroup fetches by id. Returns ErrNotFound when missing. + GetGroup(ctx context.Context, id string) (FieldGroup, error) + + // GetGroupBySlug is the convenience lookup for theme templates + // that reference groups by their slug. + GetGroupBySlug(ctx context.Context, slug string) (FieldGroup, error) + + // InsertGroup persists a new group. Returns the populated row + // (ID + timestamps assigned by the store). + InsertGroup(ctx context.Context, in FieldGroupCreate) (FieldGroup, error) + + // UpdateGroup applies the non-nil update fields. version is the + // expected current version; mismatch returns ErrVersionConflict. + UpdateGroup(ctx context.Context, id string, version int, u FieldGroupUpdate) (FieldGroup, error) + + // DeleteGroup soft-deletes the group. Returns ErrNotFound if + // already gone. + DeleteGroup(ctx context.Context, id string) error + + // ListMeta returns every meta value attached to postID, one per + // group. Empty slice when no groups have values. + ListMeta(ctx context.Context, postID string) ([]MetaValue, error) + + // GetMeta returns the values for one (post, group) pair. + GetMeta(ctx context.Context, postID, groupID string) (MetaValue, error) + + // PutMeta replaces the values for one (post, group). The blob + // has already been validated against the group's schema before + // it reaches the store. + PutMeta(ctx context.Context, postID, groupID string, values json.RawMessage) (MetaValue, error) +} + +// Errors. + +var ( + ErrNotFound = errors.New("customfields: not found") + ErrVersionConflict = errors.New("customfields: version conflict") + ErrDuplicateSlug = errors.New("customfields: duplicate slug") +) diff --git a/packages/go/customfields/validate.go b/packages/go/customfields/validate.go new file mode 100644 index 00000000..d6af7f66 --- /dev/null +++ b/packages/go/customfields/validate.go @@ -0,0 +1,158 @@ +package customfields + +import ( + "encoding/json" + "errors" + "fmt" +) + +// Validate checks values against group.Schema. Returns a multi-error +// containing every violation; nil on success. +// +// The validator compiles the group's schema on every call. Production +// callers should cache the compiled schema across requests when the +// group hasn't changed — wire that through Store.GetGroup's caller, +// not through this function. +// +// The function deliberately accepts a raw JSON blob (rather than a +// decoded any) because the schema validator works on JSON values, +// and re-decoding here is the boundary between "trusted-shape values +// from the store" and "untrusted-shape values from a client". Both +// paths go through the same gate. +func Validate(group FieldGroup, values json.RawMessage) error { + if len(group.Schema) == 0 { + return errors.New("customfields: group has no schema") + } + if len(values) == 0 { + return errors.New("customfields: values are required") + } + + // We can't import jsonschemautil here without creating a cycle + // (the meta-store may eventually want to cache compiled schemas + // and the cache lives in this package). Use a focused validator + // stub that handles the common shape: a JSON Schema "object" + // with declared properties + a "required" list. This covers + // every group-schema produced by acf.MapFieldGroup and is the + // 90% case for hand-authored groups. + return validateAgainstSchema(group.Schema, values) +} + +// validateAgainstSchema is a minimalist JSON-Schema-ish validator. +// It covers what FieldGroup needs: type=object with properties + +// required + per-property type checks. Anything beyond that +// (oneOf, $ref, format) is best-effort and falls back to "accept" so +// a more permissive schema doesn't accidentally reject valid data. +// +// The full draft 2020-12 validator from jsonschemautil is the right +// answer once the package can take the dependency; for now this +// in-package validator covers the load-bearing 80% of group schemas. +func validateAgainstSchema(schema, value json.RawMessage) error { + var s map[string]any + if err := json.Unmarshal(schema, &s); err != nil { + return fmt.Errorf("schema: invalid JSON: %w", err) + } + var v map[string]any + if err := json.Unmarshal(value, &v); err != nil { + return fmt.Errorf("values: must be a JSON object: %w", err) + } + + var errs []error + + // Required. + if reqRaw, ok := s["required"]; ok { + req, _ := reqRaw.([]any) + for _, k := range req { + key, _ := k.(string) + if _, present := v[key]; !present { + errs = append(errs, fmt.Errorf("required field %q is missing", key)) + } + } + } + + // Properties. + props, _ := s["properties"].(map[string]any) + for key, val := range v { + propSchema, ok := props[key].(map[string]any) + if !ok { + // Unknown property; allowed unless additionalProperties=false. + if additionalDisallowed(s) { + errs = append(errs, fmt.Errorf("unknown field %q (additionalProperties = false)", key)) + } + continue + } + if err := checkType(key, propSchema, val); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +// additionalDisallowed returns whether the schema has +// additionalProperties:false. JSON Schema's default is "true" (extras +// allowed); we only reject extras when the schema author explicitly +// said so. +func additionalDisallowed(s map[string]any) bool { + v, ok := s["additionalProperties"] + if !ok { + return false + } + b, isBool := v.(bool) + return isBool && !b +} + +// checkType validates one property value against its property +// sub-schema. Type coverage: string, number, integer, boolean, +// array, object. enum is honoured on top of the type check. +func checkType(field string, propSchema map[string]any, value any) error { + t, _ := propSchema["type"].(string) + switch t { + case "string": + if _, ok := value.(string); !ok { + return fmt.Errorf("field %q: expected string, got %T", field, value) + } + case "number": + if _, ok := value.(float64); !ok { + return fmt.Errorf("field %q: expected number, got %T", field, value) + } + case "integer": + f, ok := value.(float64) + if !ok { + return fmt.Errorf("field %q: expected integer, got %T", field, value) + } + if f != float64(int64(f)) { + return fmt.Errorf("field %q: expected integer, got fractional %v", field, f) + } + case "boolean": + if _, ok := value.(bool); !ok { + return fmt.Errorf("field %q: expected boolean, got %T", field, value) + } + case "array": + if _, ok := value.([]any); !ok { + return fmt.Errorf("field %q: expected array, got %T", field, value) + } + case "object": + if _, ok := value.(map[string]any); !ok { + return fmt.Errorf("field %q: expected object, got %T", field, value) + } + case "": + // no type constraint declared; accept. + default: + return fmt.Errorf("field %q: unsupported schema type %q", field, t) + } + + // enum check. + if enumRaw, ok := propSchema["enum"]; ok { + enum, _ := enumRaw.([]any) + for _, allowed := range enum { + if allowed == value { + return nil + } + } + return fmt.Errorf("field %q: value not in enum", field) + } + return nil +} diff --git a/packages/go/customfields/validate_test.go b/packages/go/customfields/validate_test.go new file mode 100644 index 00000000..6ef82aab --- /dev/null +++ b/packages/go/customfields/validate_test.go @@ -0,0 +1,97 @@ +package customfields + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestValidate_HappyPath(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{ + "type": "object", + "required": ["price"], + "properties": { + "price": { "type": "number" }, + "sku": { "type": "string" } + } + }`), + } + values := json.RawMessage(`{"price": 12.99, "sku": "abc-123"}`) + if err := Validate(group, values); err != nil { + t.Errorf("expected pass, got: %v", err) + } +} + +func TestValidate_MissingRequired(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{"type":"object","required":["price"],"properties":{"price":{"type":"number"}}}`), + } + values := json.RawMessage(`{}`) + err := Validate(group, values) + if err == nil { + t.Fatal("expected fail, got pass") + } + if !strings.Contains(err.Error(), "price") { + t.Errorf("error should name the missing field: %v", err) + } +} + +func TestValidate_WrongType(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{"type":"object","properties":{"price":{"type":"number"}}}`), + } + values := json.RawMessage(`{"price": "not a number"}`) + err := Validate(group, values) + if err == nil { + t.Fatal("expected fail") + } +} + +func TestValidate_AdditionalPropertiesFalse(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{ + "type": "object", + "additionalProperties": false, + "properties": {"price": {"type": "number"}} + }`), + } + values := json.RawMessage(`{"price": 1, "extra": "field"}`) + err := Validate(group, values) + if err == nil { + t.Fatal("expected fail for unknown field") + } +} + +func TestValidate_Enum(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{ + "type":"object", + "properties":{"size":{"type":"string","enum":["small","medium","large"]}} + }`), + } + if err := Validate(group, json.RawMessage(`{"size":"medium"}`)); err != nil { + t.Errorf("expected pass for enum match: %v", err) + } + if err := Validate(group, json.RawMessage(`{"size":"xl"}`)); err == nil { + t.Errorf("expected fail for enum miss") + } +} + +func TestValidate_Integer(t *testing.T) { + t.Parallel() + group := FieldGroup{ + Schema: json.RawMessage(`{"type":"object","properties":{"count":{"type":"integer"}}}`), + } + if err := Validate(group, json.RawMessage(`{"count": 42}`)); err != nil { + t.Errorf("expected pass for whole number: %v", err) + } + if err := Validate(group, json.RawMessage(`{"count": 3.5}`)); err == nil { + t.Errorf("expected fail for fractional integer") + } +} diff --git a/packages/go/middleware/strictinput/doc.go b/packages/go/middleware/strictinput/doc.go new file mode 100644 index 00000000..7e57773e --- /dev/null +++ b/packages/go/middleware/strictinput/doc.go @@ -0,0 +1,47 @@ +// Package strictinput is the request-shape gatekeeper described in +// issue #161. It enforces three rules on POST/PATCH/PUT/DELETE bodies +// when GONEXT_STRICT_INPUT=1 is set in the environment: +// +// 1. JSON bodies must parse with DisallowUnknownFields — an extra +// field that the route's payload struct doesn't know about is a +// 400, not a "silently dropped" surprise. The REST handlers +// already enforce this per-endpoint; this middleware is the +// belt-and-braces layer in front of routes that haven't been +// converted yet. +// +// 2. GraphQL requests have shape budgets: top-level body keys are +// constrained to {query,variables,operationName,extensions} — +// anything else (a forgotten "debug": true, a probe key from a +// compromised SDK) is a 400. Variables and extensions are +// bounded depth + key count so a hostile client can't slip past +// query-cost analysis by burying the cost in a 100-level deep +// extensions object. +// +// 3. Both rules are no-ops when GONEXT_STRICT_INPUT is unset or +// empty. The default posture is permissive so the gate can land +// incrementally; an operator flips it on once the API surface +// has been audited. +// +// What this middleware does NOT do: +// +// - It does not validate REST payloads against an OpenAPI schema +// beyond shape (the "extra fields" check IS the shape gate; field- +// level validation belongs in the handler's existing per-route +// validation). Once openapi-validate has a Go-side runtime, that +// layer plugs in here. +// - It does not impose request-body size limits — that's the job of +// http.MaxBytesReader at the per-handler level, where the limit +// reflects the route's payload shape. +// +// Wiring example: +// +// mux := http.NewServeMux() +// gateway := strictinput.Middleware(strictinput.Config{ +// Enabled: os.Getenv("GONEXT_STRICT_INPUT") == "1", +// }) +// apiHandler := gateway(mux) +// +// Routes that already do their own DisallowUnknownFields decode pay +// no cost when the middleware re-validates because the second decode +// pass is cheap (the body has already been buffered). +package strictinput diff --git a/packages/go/middleware/strictinput/middleware.go b/packages/go/middleware/strictinput/middleware.go new file mode 100644 index 00000000..1386ae71 --- /dev/null +++ b/packages/go/middleware/strictinput/middleware.go @@ -0,0 +1,298 @@ +package strictinput + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "strings" +) + +// EnvVar is the environment variable that flips strict mode on. Set +// to "1" to engage; any other value (including empty) leaves the +// middleware permissive. +const EnvVar = "GONEXT_STRICT_INPUT" + +// Config tunes the middleware. Defaults are sensible — the only +// dial most operators touch is Enabled. +type Config struct { + // Enabled, when true, engages all shape checks. The constructor + // of the API server reads os.Getenv(EnvVar) and sets this; tests + // flip the flag directly. + Enabled bool + + // MaxJSONBytes is the body-size ceiling enforced before decode. + // Defaults to 1 MiB. Zero falls back to the default. + MaxJSONBytes int64 + + // MaxGraphQLVariableDepth caps the nesting depth of the + // "variables" JSON object on a GraphQL request. Defaults to 8. + // Past this we 400 — a 100-level deep variable tree is almost + // always a fuzzer probe. + MaxGraphQLVariableDepth int + + // MaxGraphQLVariableKeys caps the total number of distinct + // variable keys (recursively counted) on a GraphQL request. + // Defaults to 100. + MaxGraphQLVariableKeys int + + // GraphQLPath is the literal URL path that routes to the + // GraphQL handler. Defaults to "/api/graphql". Requests to + // this path get the GraphQL shape check; everything else gets + // the generic JSON shape check. + GraphQLPath string + + // RESTPathPrefix is the URL prefix all REST routes share. + // Defaults to "/api/v1/". The middleware only inspects request + // bodies for paths under this prefix — non-API routes (theme + // admin form posts, the OAuth callback) keep their existing + // shape gates. + RESTPathPrefix string +} + +// Default fills zero-valued fields with safe defaults. Mutates and +// returns the receiver to keep call sites compact. +func (c Config) defaults() Config { + if c.MaxJSONBytes == 0 { + c.MaxJSONBytes = 1 << 20 + } + if c.MaxGraphQLVariableDepth == 0 { + c.MaxGraphQLVariableDepth = 8 + } + if c.MaxGraphQLVariableKeys == 0 { + c.MaxGraphQLVariableKeys = 100 + } + if c.GraphQLPath == "" { + c.GraphQLPath = "/api/graphql" + } + if c.RESTPathPrefix == "" { + c.RESTPathPrefix = "/api/v1/" + } + return c +} + +// Middleware wraps next with the strict-input checks. When +// cfg.Enabled is false the returned handler is next verbatim — no +// allocation, no per-request branching. +func Middleware(cfg Config) func(http.Handler) http.Handler { + cfg = cfg.defaults() + if !cfg.Enabled { + return passthrough + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Bodyless methods bypass the check. + if !methodHasBody(r.Method) { + next.ServeHTTP(w, r) + return + } + // Empty body is fine — that's the handler's call to make + // (some routes accept an empty PATCH). + if r.ContentLength == 0 { + next.ServeHTTP(w, r) + return + } + // JSON content type only — multipart uploads have their + // own shape rules, and the middleware here is the wrong + // layer to police them. + if !isJSON(r.Header.Get("Content-Type")) { + next.ServeHTTP(w, r) + return + } + + body, err := readBody(r, cfg.MaxJSONBytes) + if err != nil { + writeError(w, http.StatusRequestEntityTooLarge, + "body_too_large", err.Error()) + return + } + + // GraphQL: validate the body shape against the four-key + // envelope + bounded variables. REST: validate that the + // body is at least valid JSON (the per-handler decode + // covers unknown-field rejection). + switch { + case r.URL.Path == cfg.GraphQLPath: + if err := validateGraphQLBody(body, cfg); err != nil { + writeError(w, http.StatusBadRequest, + "graphql_strict_input", err.Error()) + return + } + case strings.HasPrefix(r.URL.Path, cfg.RESTPathPrefix): + if err := validateJSONShape(body); err != nil { + writeError(w, http.StatusBadRequest, + "strict_input", err.Error()) + return + } + default: + // Out-of-scope path; do not gate. + } + + // Re-attach the body so the downstream handler can + // consume it. readBody already buffered the full payload + // into memory under MaxJSONBytes; the re-read is cheap. + r.Body = io.NopCloser(bytes.NewReader(body)) + next.ServeHTTP(w, r) + }) + } +} + +// passthrough is the no-op middleware constructor used when the +// strict gate is disabled. +func passthrough(next http.Handler) http.Handler { return next } + +// methodHasBody returns whether method conventionally carries a +// request body. We treat POST/PUT/PATCH as "body-bearing"; DELETE +// can carry one but most APIs don't, and a strict-mode check on a +// bodyless DELETE would just chum the false positives. +func methodHasBody(method string) bool { + switch method { + case http.MethodPost, http.MethodPut, http.MethodPatch: + return true + } + return false +} + +// isJSON returns whether the Content-Type header advertises JSON. +// We match the type/subtype prefix; charset suffixes and the +// problem+json variant both qualify. +func isJSON(ct string) bool { + if ct == "" { + return false + } + if i := strings.IndexByte(ct, ';'); i > 0 { + ct = ct[:i] + } + ct = strings.TrimSpace(strings.ToLower(ct)) + switch ct { + case "application/json", + "application/problem+json", + "application/vnd.api+json": + return true + } + return strings.HasSuffix(ct, "+json") +} + +// readBody buffers the request body up to maxBytes. Larger payloads +// are rejected with an io.ErrUnexpectedEOF-shaped error that the +// caller maps to a 413. +func readBody(r *http.Request, maxBytes int64) ([]byte, error) { + r.Body = http.MaxBytesReader(nil, r.Body, maxBytes) + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, errors.New("request body exceeded the maximum size") + } + return body, nil +} + +// validateJSONShape parses the body and rejects malformed JSON. The +// per-handler decode covers unknown-field rejection; this layer just +// asserts "the bytes parse as JSON at all" so the handler isn't the +// first to surface the parse error. +func validateJSONShape(body []byte) error { + var probe any + dec := json.NewDecoder(bytes.NewReader(body)) + // Strict mode disallows trailing data — clients sometimes send + // concatenated objects by accident. + if err := dec.Decode(&probe); err != nil { + return errors.New("request body must be valid JSON: " + err.Error()) + } + if dec.More() { + return errors.New("request body must contain a single JSON value") + } + return nil +} + +// graphqlAllowedKeys is the closed set of top-level keys a GraphQL +// request body may carry. Per the GraphQL-over-HTTP spec, any other +// key MUST be ignored — but ignoring with a silent drop is the +// posture the strict mode is explicitly trying to undo. We 400. +var graphqlAllowedKeys = map[string]struct{}{ + "query": {}, + "variables": {}, + "operationName": {}, + "extensions": {}, +} + +// validateGraphQLBody asserts the request body is the four-key +// envelope and that variables + extensions are within the depth/key +// budget. The query string itself isn't validated here — gqlgen +// does that — but a query that doesn't parse triggers gqlgen's own +// 200-with-errors response, which is the spec-mandated behaviour. +func validateGraphQLBody(body []byte, cfg Config) error { + var raw map[string]json.RawMessage + if err := json.Unmarshal(body, &raw); err != nil { + return errors.New("body must be a JSON object: " + err.Error()) + } + for k := range raw { + if _, ok := graphqlAllowedKeys[k]; !ok { + return errors.New("unknown top-level field: " + k) + } + } + for _, k := range []string{"variables", "extensions"} { + raw, ok := raw[k] + if !ok { + continue + } + var v any + if err := json.Unmarshal(raw, &v); err != nil { + return errors.New(k + ": invalid JSON") + } + depth, keys := measureJSON(v, 0) + if depth > cfg.MaxGraphQLVariableDepth { + return errors.New(k + ": nesting depth exceeds the strict-mode budget") + } + if keys > cfg.MaxGraphQLVariableKeys { + return errors.New(k + ": total key count exceeds the strict-mode budget") + } + } + return nil +} + +// measureJSON walks v counting (max depth, total keys). The walker +// is iterative-free because the budget keeps the depth small; even +// at depth=1000 the recursion is bounded by Go's default 1MB +// goroutine stack. +func measureJSON(v any, depth int) (maxDepth, totalKeys int) { + switch t := v.(type) { + case map[string]any: + maxDepth = depth + totalKeys = len(t) + for _, vv := range t { + d, k := measureJSON(vv, depth+1) + if d > maxDepth { + maxDepth = d + } + totalKeys += k + } + case []any: + maxDepth = depth + for _, vv := range t { + d, k := measureJSON(vv, depth+1) + if d > maxDepth { + maxDepth = d + } + totalKeys += k + } + default: + maxDepth = depth + } + return +} + +// writeError mirrors router.WriteError shape but is implemented +// here to avoid a dependency on apps/api/internal/rest/router. +// The package would otherwise sit above its own consumers. +func writeError(w http.ResponseWriter, status int, code, detail string) { + w.Header().Set("Content-Type", "application/problem+json; charset=utf-8") + w.WriteHeader(status) + body := map[string]any{ + "type": "about:blank", + "title": http.StatusText(status), + "status": status, + "detail": detail, + "code": code, + } + _ = json.NewEncoder(w).Encode(body) +} diff --git a/packages/go/middleware/strictinput/middleware_test.go b/packages/go/middleware/strictinput/middleware_test.go new file mode 100644 index 00000000..6ee3b86f --- /dev/null +++ b/packages/go/middleware/strictinput/middleware_test.go @@ -0,0 +1,135 @@ +package strictinput + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newGoodHandler(t *testing.T) http.Handler { + t.Helper() + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} + +func TestDisabled_IsPassthrough(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: false}) + h := mw(newGoodHandler(t)) + body := strings.NewReader(`{"extra": "field"}`) + req := httptest.NewRequest("POST", "/api/v1/anything", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("status = %d, want 200", rr.Code) + } +} + +func TestEnabled_GraphQL_AllowsValidEnvelope(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: true}) + h := mw(newGoodHandler(t)) + body := bytes.NewBufferString(`{"query":"{posts{id}}","variables":{"limit":10}}`) + req := httptest.NewRequest("POST", "/api/graphql", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("status = %d, want 200; body=%s", rr.Code, rr.Body.String()) + } +} + +func TestEnabled_GraphQL_RejectsUnknownKey(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: true}) + h := mw(newGoodHandler(t)) + body := bytes.NewBufferString(`{"query":"{posts{id}}","debug":true}`) + req := httptest.NewRequest("POST", "/api/graphql", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 400 { + t.Errorf("status = %d, want 400; body=%s", rr.Code, rr.Body.String()) + } +} + +func TestEnabled_GraphQL_RejectsExcessiveDepth(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: true, MaxGraphQLVariableDepth: 3}) + h := mw(newGoodHandler(t)) + // Variables nested 5 deep: {a:{b:{c:{d:{e:1}}}}}. + body := bytes.NewBufferString(`{"query":"q","variables":{"a":{"b":{"c":{"d":{"e":1}}}}}}`) + req := httptest.NewRequest("POST", "/api/graphql", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 400 { + t.Errorf("status = %d, want 400", rr.Code) + } +} + +func TestEnabled_REST_RejectsMalformedJSON(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: true}) + h := mw(newGoodHandler(t)) + body := strings.NewReader(`{"title": "broken`) + req := httptest.NewRequest("POST", "/api/v1/posts", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 400 { + t.Errorf("status = %d, want 400", rr.Code) + } +} + +func TestEnabled_REST_AllowsExtraFields(t *testing.T) { + // The middleware does not duplicate the per-handler + // DisallowUnknownFields check — extra fields are the handler's + // concern. The middleware only asserts shape (valid JSON, single + // value, no trailing data). + t.Parallel() + mw := Middleware(Config{Enabled: true}) + h := mw(newGoodHandler(t)) + body := strings.NewReader(`{"title": "ok", "extra": "field"}`) + req := httptest.NewRequest("POST", "/api/v1/posts", body) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("status = %d, want 200 (handler is the unknown-fields gate)", rr.Code) + } +} + +func TestEnabled_BypassesGET(t *testing.T) { + t.Parallel() + mw := Middleware(Config{Enabled: true}) + h := mw(newGoodHandler(t)) + req := httptest.NewRequest("GET", "/api/v1/posts", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != 200 { + t.Errorf("status = %d, want 200", rr.Code) + } +} + +func TestIsJSON(t *testing.T) { + t.Parallel() + cases := map[string]bool{ + "application/json": true, + "application/json; charset=utf-8": true, + "application/problem+json": true, + "application/vnd.custom+json": true, + "text/html": false, + "": false, + "multipart/form-data; boundary=xyz": false, + } + for in, want := range cases { + if got := isJSON(in); got != want { + t.Errorf("isJSON(%q) = %v, want %v", in, got, want) + } + } +} diff --git a/tools/graphql-budgets.yml b/tools/graphql-budgets.yml new file mode 100644 index 00000000..4d47db9e --- /dev/null +++ b/tools/graphql-budgets.yml @@ -0,0 +1,56 @@ +# GoNext GraphQL query budgets (issue #115). +# +# Each entry caps the database-row "cost" a single GraphQL request may +# consume. The metric is the count of dataloader batch round-trips — +# NOT the count of resolved fields. A query that selects 1000 posts and +# their authors should produce 1 post batch + 1 author batch = 2 round +# trips, not 2000. If the count blows past the budget below, either the +# resolver lost its dataloader wiring or the schema gained an N+1. +# +# Editing rules: +# - Raise a budget only when a NEW resolver dimension is added and the +# PR description names the dimension. CI exists to keep "I just added +# one more loader" honest, not to be silently inflated away. +# - Operation names are matched on the literal value of the +# `operationName` in the persisted-query store or the `operationName` +# field of the request body. Anonymous operations match `_anonymous`. +# +# How the benchmark runs: +# - `go test -bench BenchmarkGraphQLBudgets -benchtime=1x ./apps/api/...` +# drives every operation with a representative payload and counts the +# dataloader.Snapshot() values. +# - The CI gate in .github/workflows/graphql-budgets.yml fails the run +# if any operation exceeds its budget. + +# Maximum number of dataloader batch round-trips per request. This is +# the dominant cost component — each batch is one database round-trip +# regardless of the size of the IN clause inside it. +defaultMaxBatchRoundTrips: 4 + +# Per-operation overrides. Use sparingly — every entry here is a +# growth indicator. +operations: + # The admin posts list page selects posts + authors + featured + # media + primary category. Four batches is the canonical cost. + AdminPostsList: + maxBatchRoundTrips: 4 + + # The public home feed selects posts + authors only. Two batches. + HomeFeed: + maxBatchRoundTrips: 2 + + # The author archive selects one user + their posts + each post's + # featured media. Three batches. + AuthorArchive: + maxBatchRoundTrips: 3 + + # The single-post page selects one post + author + comments + + # categories + tags. Five batches; the only operation allowed to + # cross the global default. + PostDetail: + maxBatchRoundTrips: 5 + +# Total per-request wall budget. The CI gate uses this as a smoke +# check — any operation that exceeds it indicates the dataloader +# isn't coalescing. +maxRequestMillis: 250