diff --git a/duckdbservice/appender_init.go b/duckdbservice/appender_init.go new file mode 100644 index 00000000..cc50ed8c --- /dev/null +++ b/duckdbservice/appender_init.go @@ -0,0 +1,82 @@ +package duckdbservice + +import ( + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal128" + duckdb "github.com/duckdb/duckdb-go/v2" + "github.com/posthog/duckgres/duckdbservice/arrowmap" +) + +// init registers handlers for the duckdb-go driver value types so that any +// caller (including arrowmap.AppendValue and the wrapper duckdbservice.AppendValue) +// gets full type coverage when this package is linked into the binary. +// +// Binaries that don't link duckdbservice (e.g., a future control-plane-only +// binary) won't see these registrations — which is correct, because they +// also won't be the ones scanning rows from a duckdb-go driver connection. +func init() { + arrowmap.RegisterAppender(handleDuckDBValue) +} + +// handleDuckDBValue implements arrowmap.Appender for duckdb-go's driver +// value types. Returns true when it claimed the value, false to fall +// through to arrowmap's built-in handling. +func handleDuckDBValue(builder array.Builder, val any) bool { + switch b := builder.(type) { + case *array.MonthDayNanoIntervalBuilder: + v, ok := val.(duckdb.Interval) + if !ok { + return false + } + b.Append(arrow.MonthDayNanoInterval{ + Months: v.Months, + Days: v.Days, + Nanoseconds: v.Micros * 1000, + }) + return true + case *array.Decimal128Builder: + v, ok := val.(duckdb.Decimal) + if !ok { + return false + } + b.Append(decimal128.FromBigInt(v.Value)) + return true + case *array.FixedSizeBinaryBuilder: + v, ok := val.(duckdb.UUID) + if !ok { + return false + } + b.Append(v[:]) + return true + case *array.MapBuilder: + switch v := val.(type) { + case duckdb.OrderedMap: + b.Append(true) + kb, ib := b.KeyBuilder(), b.ItemBuilder() + keys, values := v.Keys(), v.Values() + for i, k := range keys { + arrowmap.AppendValue(kb, k) + arrowmap.AppendValue(ib, values[i]) + } + return true + case duckdb.Map: + b.Append(true) + kb, ib := b.KeyBuilder(), b.ItemBuilder() + for k, item := range v { + arrowmap.AppendValue(kb, k) + arrowmap.AppendValue(ib, item) + } + return true + } + return false + case *array.StringBuilder: + v, ok := val.(duckdb.UUID) + if !ok { + return false + } + b.Append(v.String()) + return true + } + return false +} diff --git a/duckdbservice/arrow_helpers.go b/duckdbservice/arrow_helpers.go index 703720e1..a267e542 100644 --- a/duckdbservice/arrow_helpers.go +++ b/duckdbservice/arrow_helpers.go @@ -3,21 +3,13 @@ package duckdbservice import ( "context" "database/sql" - "encoding/hex" - "fmt" - "math/big" "reflect" "strings" - "time" - "unicode/utf8" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/memory" - duckdb "github.com/duckdb/duckdb-go/v2" "github.com/posthog/duckgres/duckdbservice/arrowmap" - "github.com/posthog/duckgres/server" ) // DuckDBTypeToArrow re-exports arrowmap.DuckDBTypeToArrow for backward @@ -25,12 +17,18 @@ import ( var DuckDBTypeToArrow = arrowmap.DuckDBTypeToArrow // QualifyTableName re-exports arrowmap.QualifyTableName for backward -// compatibility with existing callers in this package. +// compatibility with existing callers. var QualifyTableName = arrowmap.QualifyTableName -// AppendValue is the only helper in this package that still needs to live with -// the duckdb-go import (because it switches on duckdb.Interval / Decimal / -// UUID / OrderedMap / Map types). +// QuoteIdent re-exports arrowmap.QuoteIdent for backward compatibility. +var QuoteIdent = arrowmap.QuoteIdent + +// AppendValue re-exports arrowmap.AppendValue for backward compatibility. +// The duckdb-go-specific value types (duckdb.Interval, Decimal, UUID, +// OrderedMap, Map) are handled via an arrowmap.Appender registered from +// duckdbservice/appender_init.go, so callers that import duckdbservice get +// full type coverage automatically. +var AppendValue = arrowmap.AppendValue // RowsToRecord converts sql.Rows into an Arrow RecordBatch of up to batchSize rows. // Returns nil when there are no more rows. @@ -66,267 +64,6 @@ func RowsToRecord(alloc memory.Allocator, rows *sql.Rows, schema *arrow.Schema, return builder.NewRecordBatch(), nil } -// AppendValue appends a value to an Arrow array builder with type coercion. -func AppendValue(builder array.Builder, val interface{}) { - if val == nil { - builder.AppendNull() - return - } - - switch b := builder.(type) { - case *array.Int64Builder: - switch v := val.(type) { - case int64: - b.Append(v) - case int32: - b.Append(int64(v)) - case int: - b.Append(int64(v)) - default: - b.AppendNull() - } - case *array.Int32Builder: - switch v := val.(type) { - case int32: - b.Append(v) - case int64: - b.Append(int32(v)) - case int: - b.Append(int32(v)) - default: - b.AppendNull() - } - case *array.Int16Builder: - switch v := val.(type) { - case int16: - b.Append(v) - case int32: - b.Append(int16(v)) - default: - b.AppendNull() - } - case *array.Int8Builder: - switch v := val.(type) { - case int8: - b.Append(v) - case int32: - b.Append(int8(v)) - default: - b.AppendNull() - } - case *array.Uint8Builder: - switch v := val.(type) { - case uint8: - b.Append(v) - case uint16: - b.Append(uint8(v)) - default: - b.AppendNull() - } - case *array.Uint16Builder: - switch v := val.(type) { - case uint16: - b.Append(v) - case uint32: - b.Append(uint16(v)) - default: - b.AppendNull() - } - case *array.Uint32Builder: - switch v := val.(type) { - case uint32: - b.Append(v) - case uint64: - b.Append(uint32(v)) - default: - b.AppendNull() - } - case *array.Uint64Builder: - switch v := val.(type) { - case uint64: - b.Append(v) - default: - b.AppendNull() - } - case *array.Float64Builder: - switch v := val.(type) { - case float64: - b.Append(v) - case float32: - b.Append(float64(v)) - default: - b.AppendNull() - } - case *array.Float32Builder: - switch v := val.(type) { - case float32: - b.Append(v) - case float64: - b.Append(float32(v)) - default: - b.AppendNull() - } - case *array.BooleanBuilder: - if v, ok := val.(bool); ok { - b.Append(v) - } else { - b.AppendNull() - } - case *array.Date32Builder: - switch v := val.(type) { - case time.Time: - // Floor division to handle pre-epoch dates correctly. - // Go's integer division truncates toward zero, but Date32 - // needs days since epoch rounded toward negative infinity. - unix := v.Unix() - days := unix / 86400 - if unix%86400 < 0 { - days-- - } - b.Append(arrow.Date32(days)) - default: - b.AppendNull() - } - case *array.TimestampBuilder: - switch v := val.(type) { - case time.Time: - b.AppendTime(v) - default: - b.AppendNull() - } - case *array.Time64Builder: - switch v := val.(type) { - case time.Time: - micros := int64(v.Hour())*3600000000 + int64(v.Minute())*60000000 + - int64(v.Second())*1000000 + int64(v.Nanosecond())/1000 - b.Append(arrow.Time64(micros)) - default: - b.AppendNull() - } - case *array.MonthDayNanoIntervalBuilder: - switch v := val.(type) { - case duckdb.Interval: - b.Append(arrow.MonthDayNanoInterval{ - Months: v.Months, - Days: v.Days, - Nanoseconds: v.Micros * 1000, - }) - default: - b.AppendNull() - } - case *array.Decimal128Builder: - switch v := val.(type) { - case duckdb.Decimal: - b.Append(decimal128.FromBigInt(v.Value)) - case *big.Int: - b.Append(decimal128.FromBigInt(v)) - default: - b.AppendNull() - } - case *array.FixedSizeBinaryBuilder: - switch v := val.(type) { - case duckdb.UUID: - b.Append(v[:]) - case []byte: - b.Append(v) - default: - b.AppendNull() - } - case *array.ListBuilder: - switch v := val.(type) { - case []any: - b.Append(true) - vb := b.ValueBuilder() - for _, elem := range v { - AppendValue(vb, elem) - } - default: - b.AppendNull() - } - case *array.StructBuilder: - switch v := val.(type) { - case map[string]any: - b.Append(true) - st := b.Type().(*arrow.StructType) - for i := 0; i < st.NumFields(); i++ { - fieldVal, ok := v[st.Field(i).Name] - if !ok { - b.FieldBuilder(i).AppendNull() - } else { - AppendValue(b.FieldBuilder(i), fieldVal) - } - } - default: - b.AppendNull() - } - case *array.MapBuilder: - switch v := val.(type) { - case duckdb.OrderedMap: - b.Append(true) - kb, ib := b.KeyBuilder(), b.ItemBuilder() - keys, values := v.Keys(), v.Values() - for i, k := range keys { - AppendValue(kb, k) - AppendValue(ib, values[i]) - } - case duckdb.Map: - b.Append(true) - kb, ib := b.KeyBuilder(), b.ItemBuilder() - for k, item := range v { - AppendValue(kb, k) - AppendValue(ib, item) - } - case map[any]any: - b.Append(true) - kb, ib := b.KeyBuilder(), b.ItemBuilder() - for k, item := range v { - AppendValue(kb, k) - AppendValue(ib, item) - } - case server.OrderedMapValue: - b.Append(true) - kb, ib := b.KeyBuilder(), b.ItemBuilder() - for i, k := range v.Keys { - AppendValue(kb, k) - AppendValue(ib, v.Values[i]) - } - default: - b.AppendNull() - } - case *array.StringBuilder: - switch v := val.(type) { - case string: - b.Append(v) - case duckdb.UUID: - b.Append(v.String()) - case []byte: - // TODO: This heuristic (16 bytes + invalid UTF-8 → UUID) is coupled to - // DuckDBTypeToArrow("UUID") returning String. If UUID mapping changes, - // update this branch accordingly. The Go driver returns []byte (not - // duckdb.UUID) when scanning UUID columns into interface{}. - if len(v) == 16 && !utf8.Valid(v) { - s := hex.EncodeToString(v) - b.Append(s[0:8] + "-" + s[8:12] + "-" + s[12:16] + "-" + s[16:20] + "-" + s[20:32]) - } else { - b.Append(string(v)) - } - default: - b.Append(fmt.Sprintf("%v", v)) - } - case *array.BinaryBuilder: - switch v := val.(type) { - case []byte: - b.Append(v) - case string: - b.Append([]byte(v)) - default: - b.AppendNull() - } - default: - builder.AppendNull() - } -} - type contextQueryer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } @@ -374,6 +111,3 @@ func GetQuerySchema(ctx context.Context, db contextQueryer, query string, tx con } return arrow.NewSchema(fields, nil), nil } - -// QuoteIdent re-exports arrowmap.QuoteIdent for backward compatibility. -var QuoteIdent = arrowmap.QuoteIdent diff --git a/duckdbservice/arrowmap/arrowmap.go b/duckdbservice/arrowmap/arrowmap.go index 900e49ec..845410f5 100644 --- a/duckdbservice/arrowmap/arrowmap.go +++ b/duckdbservice/arrowmap/arrowmap.go @@ -1,17 +1,28 @@ // Package arrowmap provides DuckDB-free helpers for translating DuckDB type -// strings into Arrow types and for quoting/qualifying SQL identifiers. +// strings into Arrow types, quoting/qualifying SQL identifiers, and appending +// scanned values into Arrow array builders. // // These helpers are kept in their own package (with no dependency on // github.com/duckdb/duckdb-go) so that the control plane can use them -// without linking libduckdb. +// without linking libduckdb. The DuckDB driver-specific value types +// (duckdb.Interval, duckdb.Decimal, duckdb.UUID, duckdb.OrderedMap, +// duckdb.Map) are handled via the RegisterAppender hook so duckdbservice +// can register them at init time without arrowmap depending on duckdb-go. package arrowmap import ( "database/sql" + "encoding/hex" "fmt" + "math/big" "strings" + "sync/atomic" + "time" + "unicode/utf8" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal128" ) // DuckDBTypeToArrow maps a DuckDB type name to an Arrow DataType. @@ -287,3 +298,299 @@ func QuoteIdent(ident string) string { escaped := strings.ReplaceAll(ident, `"`, `""`) return `"` + escaped + `"` } + +// OrderedMapValue represents an Arrow MAP as parallel key/value slices, +// preserving insertion order. Using parallel slices instead of a Go map +// avoids panics on non-comparable key types (e.g., []byte from BLOB keys) +// and preserves the source MAP ordering. +// +// Lives in arrowmap so AppendValue can switch on it without depending on +// the server package (which transitively links libduckdb). The flight +// executor in the server package re-exports it as server.OrderedMapValue +// via a type alias for backward compatibility. +type OrderedMapValue struct { + Keys []any + Values []any +} + +// Appender is a hook that handles append for value types arrowmap doesn't +// know about (typically driver-specific types like duckdb.Interval). It +// reports whether it handled the value; arrowmap.AppendValue falls back to +// its built-in handling when no registered Appender claims the value. +// +// Hooks must be safe to call concurrently and must not panic. They run in +// registration order; the first one to return true wins. +type Appender func(builder array.Builder, val any) (handled bool) + +// appenders is loaded once into an atomic.Value as []Appender. Reads on the +// hot path (AppendValue) are lock-free; registrations rebuild the slice. +// Registrations are expected to happen at init time so contention is rare. +var appenders atomic.Value // []Appender + +// RegisterAppender adds a hook that AppendValue will consult before falling +// back to its built-in value-type handling. Intended for use from package +// init() functions in importers that own driver-specific value types +// (e.g., duckdbservice registers handlers for duckdb.Interval, Decimal, +// UUID, OrderedMap, and Map). +func RegisterAppender(a Appender) { + if a == nil { + return + } + cur, _ := appenders.Load().([]Appender) + next := make([]Appender, 0, len(cur)+1) + next = append(next, cur...) + next = append(next, a) + appenders.Store(next) +} + +// AppendValue appends a value to an Arrow array builder with type coercion. +// It first asks any registered Appender hooks (see RegisterAppender), then +// falls back to handling the standard Arrow / Go value types itself. +func AppendValue(builder array.Builder, val any) { + if val == nil { + builder.AppendNull() + return + } + if hooks, _ := appenders.Load().([]Appender); len(hooks) > 0 { + for _, h := range hooks { + if h(builder, val) { + return + } + } + } + appendBuiltin(builder, val) +} + +// appendBuiltin handles the value types arrowmap knows about natively +// (everything that doesn't depend on a database driver package). +func appendBuiltin(builder array.Builder, val any) { + switch b := builder.(type) { + case *array.Int64Builder: + switch v := val.(type) { + case int64: + b.Append(v) + case int32: + b.Append(int64(v)) + case int: + b.Append(int64(v)) + default: + b.AppendNull() + } + case *array.Int32Builder: + switch v := val.(type) { + case int32: + b.Append(v) + case int64: + b.Append(int32(v)) + case int: + b.Append(int32(v)) + default: + b.AppendNull() + } + case *array.Int16Builder: + switch v := val.(type) { + case int16: + b.Append(v) + case int32: + b.Append(int16(v)) + default: + b.AppendNull() + } + case *array.Int8Builder: + switch v := val.(type) { + case int8: + b.Append(v) + case int32: + b.Append(int8(v)) + default: + b.AppendNull() + } + case *array.Uint8Builder: + switch v := val.(type) { + case uint8: + b.Append(v) + case uint16: + b.Append(uint8(v)) + default: + b.AppendNull() + } + case *array.Uint16Builder: + switch v := val.(type) { + case uint16: + b.Append(v) + case uint32: + b.Append(uint16(v)) + default: + b.AppendNull() + } + case *array.Uint32Builder: + switch v := val.(type) { + case uint32: + b.Append(v) + case uint64: + b.Append(uint32(v)) + default: + b.AppendNull() + } + case *array.Uint64Builder: + switch v := val.(type) { + case uint64: + b.Append(v) + default: + b.AppendNull() + } + case *array.Float64Builder: + switch v := val.(type) { + case float64: + b.Append(v) + case float32: + b.Append(float64(v)) + default: + b.AppendNull() + } + case *array.Float32Builder: + switch v := val.(type) { + case float32: + b.Append(v) + case float64: + b.Append(float32(v)) + default: + b.AppendNull() + } + case *array.BooleanBuilder: + if v, ok := val.(bool); ok { + b.Append(v) + } else { + b.AppendNull() + } + case *array.Date32Builder: + switch v := val.(type) { + case time.Time: + // Floor division to handle pre-epoch dates correctly. + // Go's integer division truncates toward zero, but Date32 + // needs days since epoch rounded toward negative infinity. + unix := v.Unix() + days := unix / 86400 + if unix%86400 < 0 { + days-- + } + b.Append(arrow.Date32(days)) + default: + b.AppendNull() + } + case *array.TimestampBuilder: + switch v := val.(type) { + case time.Time: + b.AppendTime(v) + default: + b.AppendNull() + } + case *array.Time64Builder: + switch v := val.(type) { + case time.Time: + micros := int64(v.Hour())*3600000000 + int64(v.Minute())*60000000 + + int64(v.Second())*1000000 + int64(v.Nanosecond())/1000 + b.Append(arrow.Time64(micros)) + default: + b.AppendNull() + } + case *array.MonthDayNanoIntervalBuilder: + // arrowmap natively handles arrow.MonthDayNanoInterval; driver-specific + // interval types (e.g., duckdb.Interval) come in via Appender hooks. + switch v := val.(type) { + case arrow.MonthDayNanoInterval: + b.Append(v) + default: + b.AppendNull() + } + case *array.Decimal128Builder: + switch v := val.(type) { + case *big.Int: + b.Append(decimal128.FromBigInt(v)) + default: + b.AppendNull() + } + case *array.FixedSizeBinaryBuilder: + switch v := val.(type) { + case []byte: + b.Append(v) + default: + b.AppendNull() + } + case *array.ListBuilder: + switch v := val.(type) { + case []any: + b.Append(true) + vb := b.ValueBuilder() + for _, elem := range v { + AppendValue(vb, elem) + } + default: + b.AppendNull() + } + case *array.StructBuilder: + switch v := val.(type) { + case map[string]any: + b.Append(true) + st := b.Type().(*arrow.StructType) + for i := 0; i < st.NumFields(); i++ { + fieldVal, ok := v[st.Field(i).Name] + if !ok { + b.FieldBuilder(i).AppendNull() + } else { + AppendValue(b.FieldBuilder(i), fieldVal) + } + } + default: + b.AppendNull() + } + case *array.MapBuilder: + switch v := val.(type) { + case OrderedMapValue: + b.Append(true) + kb, ib := b.KeyBuilder(), b.ItemBuilder() + for i, k := range v.Keys { + AppendValue(kb, k) + AppendValue(ib, v.Values[i]) + } + case map[any]any: + b.Append(true) + kb, ib := b.KeyBuilder(), b.ItemBuilder() + for k, item := range v { + AppendValue(kb, k) + AppendValue(ib, item) + } + default: + b.AppendNull() + } + case *array.StringBuilder: + switch v := val.(type) { + case string: + b.Append(v) + case []byte: + // 16-byte non-UTF-8 input is heuristically formatted as a UUID + // string. This pairs with DuckDBTypeToArrow("UUID") returning + // String — duckdb's Go driver returns []byte (not duckdb.UUID) + // when scanning UUID columns into interface{}. + if len(v) == 16 && !utf8.Valid(v) { + s := hex.EncodeToString(v) + b.Append(s[0:8] + "-" + s[8:12] + "-" + s[12:16] + "-" + s[16:20] + "-" + s[20:32]) + } else { + b.Append(string(v)) + } + default: + b.Append(fmt.Sprintf("%v", v)) + } + case *array.BinaryBuilder: + switch v := val.(type) { + case []byte: + b.Append(v) + case string: + b.Append([]byte(v)) + default: + b.AppendNull() + } + default: + builder.AppendNull() + } +} diff --git a/server/auth/metrics.go b/server/auth/metrics.go new file mode 100644 index 00000000..30794e1f --- /dev/null +++ b/server/auth/metrics.go @@ -0,0 +1,30 @@ +package auth + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Auth-related Prometheus metrics. Defined here (rather than in the larger +// server package's metrics block) so this package is self-contained — it can +// be imported and built without pulling in the rest of server/. + +// AuthFailuresCounter is exported so the wire-protocol code in server can +// also bump it when sending FATAL/Class-28 error responses to the client. +var AuthFailuresCounter = promauto.NewCounter(prometheus.CounterOpts{ + Name: "duckgres_auth_failures_total", + Help: "Total number of authentication failures", +}) + +// RateLimitRejectsCounter is exported so the connection-handling code in +// server can also bump it when rejecting a connection at the rate-limit +// gate (before any auth attempt happens). +var RateLimitRejectsCounter = promauto.NewCounter(prometheus.CounterOpts{ + Name: "duckgres_rate_limit_rejects_total", + Help: "Total number of connections rejected due to rate limiting", +}) + +var rateLimitedIPsGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "duckgres_rate_limited_ips", + Help: "Number of currently rate-limited IP addresses", +}) diff --git a/server/auth_policy.go b/server/auth/policy.go similarity index 95% rename from server/auth_policy.go rename to server/auth/policy.go index 9c56f160..a7ed7e39 100644 --- a/server/auth_policy.go +++ b/server/auth/policy.go @@ -1,4 +1,4 @@ -package server +package auth import ( "crypto/subtle" @@ -16,11 +16,11 @@ func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) } if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" { - rateLimitRejectsCounter.Inc() + RateLimitRejectsCounter.Inc() return release, msg } if !rateLimiter.RegisterConnection(remoteAddr) { - rateLimitRejectsCounter.Inc() + RateLimitRejectsCounter.Inc() if msg := rateLimiter.CheckConnection(remoteAddr); msg != "" { return release, msg } @@ -35,7 +35,7 @@ func BeginRateLimitedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) // RecordFailedAuthAttempt records auth telemetry and updates rate-limit state. // Returns true when this failure causes the source IP to be banned. func RecordFailedAuthAttempt(rateLimiter *RateLimiter, remoteAddr net.Addr) bool { - authFailuresCounter.Inc() + AuthFailuresCounter.Inc() if rateLimiter == nil { return false } diff --git a/server/auth_policy_test.go b/server/auth/policy_test.go similarity index 99% rename from server/auth_policy_test.go rename to server/auth/policy_test.go index 5739a581..c5430ae2 100644 --- a/server/auth_policy_test.go +++ b/server/auth/policy_test.go @@ -1,4 +1,4 @@ -package server +package auth import ( "net" diff --git a/server/ratelimit.go b/server/auth/ratelimit.go similarity index 81% rename from server/ratelimit.go rename to server/auth/ratelimit.go index 92865a0c..1be83552 100644 --- a/server/ratelimit.go +++ b/server/auth/ratelimit.go @@ -1,4 +1,7 @@ -package server +// Package auth holds duckgres' connection rate-limiting and password +// validation policy. It has no dependency on github.com/duckdb/duckdb-go, +// so the control plane can use it without linking libduckdb. +package auth import ( "net" @@ -6,21 +9,21 @@ import ( "time" ) -// RateLimitConfig configures rate limiting behavior +// RateLimitConfig configures rate limiting behavior. type RateLimitConfig struct { - // MaxFailedAttempts is the maximum number of failed auth attempts before banning + // MaxFailedAttempts is the maximum number of failed auth attempts before banning. MaxFailedAttempts int - // FailedAttemptWindow is the time window for counting failed attempts + // FailedAttemptWindow is the time window for counting failed attempts. FailedAttemptWindow time.Duration - // BanDuration is how long to ban an IP after exceeding max failed attempts + // BanDuration is how long to ban an IP after exceeding max failed attempts. BanDuration time.Duration - // MaxConnectionsPerIP is the max concurrent connections from a single IP (0 = unlimited) + // MaxConnectionsPerIP is the max concurrent connections from a single IP (0 = unlimited). MaxConnectionsPerIP int - // MaxConnections is the total max concurrent connections (0 = unlimited) + // MaxConnections is the total max concurrent connections (0 = unlimited). MaxConnections int } -// DefaultRateLimitConfig returns sensible defaults for rate limiting +// DefaultRateLimitConfig returns sensible defaults for rate limiting. func DefaultRateLimitConfig() RateLimitConfig { return RateLimitConfig{ MaxFailedAttempts: 5, @@ -31,14 +34,14 @@ func DefaultRateLimitConfig() RateLimitConfig { } } -// ipRecord tracks connection and authentication attempts from an IP +// ipRecord tracks connection and authentication attempts from an IP. type ipRecord struct { failedAttempts []time.Time // timestamps of failed auth attempts bannedUntil time.Time // when the ban expires (zero if not banned) activeConns int // current active connections from this IP } -// RateLimiter tracks and limits connections per IP +// RateLimiter tracks and limits connections per IP. type RateLimiter struct { mu sync.Mutex config RateLimitConfig @@ -46,18 +49,17 @@ type RateLimiter struct { totalActiveConns int } -// NewRateLimiter creates a new rate limiter with the given config +// NewRateLimiter creates a new rate limiter with the given config. func NewRateLimiter(cfg RateLimitConfig) *RateLimiter { rl := &RateLimiter{ config: cfg, records: make(map[string]*ipRecord), } - // Start cleanup goroutine go rl.cleanupLoop() return rl } -// extractIP gets the IP address from a net.Addr (strips port) +// extractIP gets the IP address from a net.Addr (strips port). func extractIP(addr net.Addr) string { if addr == nil { return "" @@ -69,8 +71,8 @@ func extractIP(addr net.Addr) string { return host } -// CheckConnection checks if a connection from the given address should be allowed -// Returns an error message if the connection should be rejected, or empty string if allowed +// CheckConnection checks if a connection from the given address should be allowed. +// Returns an error message if the connection should be rejected, or empty string if allowed. func (rl *RateLimiter) CheckConnection(addr net.Addr) string { ip := extractIP(addr) if ip == "" { @@ -80,20 +82,17 @@ func (rl *RateLimiter) CheckConnection(addr net.Addr) string { rl.mu.Lock() defer rl.mu.Unlock() - // Check global connection limit if rl.config.MaxConnections > 0 && rl.totalActiveConns >= rl.config.MaxConnections { return "too many concurrent connections" } record := rl.getOrCreateRecord(ip) - // Check if IP is banned if !record.bannedUntil.IsZero() && time.Now().Before(record.bannedUntil) { remaining := time.Until(record.bannedUntil).Round(time.Second) return "too many failed authentication attempts, try again in " + remaining.String() } - // Check concurrent connection limit if rl.config.MaxConnectionsPerIP > 0 && record.activeConns >= rl.config.MaxConnectionsPerIP { return "too many connections from your IP address" } @@ -101,8 +100,8 @@ func (rl *RateLimiter) CheckConnection(addr net.Addr) string { return "" } -// RegisterConnection records a new connection from the given address -// Returns true if the connection is allowed, false otherwise +// RegisterConnection records a new connection from the given address. +// Returns true if the connection is allowed, false otherwise. func (rl *RateLimiter) RegisterConnection(addr net.Addr) bool { ip := extractIP(addr) if ip == "" { @@ -112,19 +111,16 @@ func (rl *RateLimiter) RegisterConnection(addr net.Addr) bool { rl.mu.Lock() defer rl.mu.Unlock() - // Check global connection limit if rl.config.MaxConnections > 0 && rl.totalActiveConns >= rl.config.MaxConnections { return false } record := rl.getOrCreateRecord(ip) - // Check if banned if !record.bannedUntil.IsZero() && time.Now().Before(record.bannedUntil) { return false } - // Check concurrent connection limit if rl.config.MaxConnectionsPerIP > 0 && record.activeConns >= rl.config.MaxConnectionsPerIP { return false } @@ -134,7 +130,7 @@ func (rl *RateLimiter) RegisterConnection(addr net.Addr) bool { return true } -// UnregisterConnection decrements the active connection count for an IP +// UnregisterConnection decrements the active connection count for an IP. func (rl *RateLimiter) UnregisterConnection(addr net.Addr) { ip := extractIP(addr) if ip == "" { @@ -157,8 +153,8 @@ func (rl *RateLimiter) UnregisterConnection(addr net.Addr) { } } -// RecordFailedAuth records a failed authentication attempt -// Returns true if the IP is now banned +// RecordFailedAuth records a failed authentication attempt. +// Returns true if the IP is now banned. func (rl *RateLimiter) RecordFailedAuth(addr net.Addr) bool { ip := extractIP(addr) if ip == "" { @@ -171,10 +167,8 @@ func (rl *RateLimiter) RecordFailedAuth(addr net.Addr) bool { record := rl.getOrCreateRecord(ip) now := time.Now() - // Add this failed attempt record.failedAttempts = append(record.failedAttempts, now) - // Count recent failed attempts within the window windowStart := now.Add(-rl.config.FailedAttemptWindow) recentAttempts := 0 for _, t := range record.failedAttempts { @@ -183,21 +177,22 @@ func (rl *RateLimiter) RecordFailedAuth(addr net.Addr) bool { } } - // Ban if exceeded threshold if recentAttempts >= rl.config.MaxFailedAttempts { - // Decrement if replacing an expired ban that cleanup hasn't cleared yet - if !record.bannedUntil.IsZero() && now.After(record.bannedUntil) { - rateLimitedIPsGauge.Dec() - } + // Three cases: never banned (gauge++), already banned (no change — + // already counted), or expired ban that cleanup hasn't cleared yet + // (no change — still counted, will be decremented when cleanup runs). + alreadyCounted := !record.bannedUntil.IsZero() record.bannedUntil = now.Add(rl.config.BanDuration) - rateLimitedIPsGauge.Inc() + if !alreadyCounted { + rateLimitedIPsGauge.Inc() + } return true } return false } -// RecordSuccessfulAuth clears failed attempts for an IP after successful auth +// RecordSuccessfulAuth clears failed attempts for an IP after successful auth. func (rl *RateLimiter) RecordSuccessfulAuth(addr net.Addr) { ip := extractIP(addr) if ip == "" { @@ -216,7 +211,7 @@ func (rl *RateLimiter) RecordSuccessfulAuth(addr net.Addr) { } } -// IsBanned checks if an IP is currently banned +// IsBanned checks if an IP is currently banned. func (rl *RateLimiter) IsBanned(addr net.Addr) bool { ip := extractIP(addr) if ip == "" { @@ -234,7 +229,7 @@ func (rl *RateLimiter) IsBanned(addr net.Addr) bool { return !record.bannedUntil.IsZero() && time.Now().Before(record.bannedUntil) } -// getOrCreateRecord gets or creates a record for an IP (must hold lock) +// getOrCreateRecord gets or creates a record for an IP (must hold lock). func (rl *RateLimiter) getOrCreateRecord(ip string) *ipRecord { record, ok := rl.records[ip] if !ok { @@ -244,7 +239,7 @@ func (rl *RateLimiter) getOrCreateRecord(ip string) *ipRecord { return record } -// cleanupLoop periodically cleans up expired records +// cleanupLoop periodically cleans up expired records. func (rl *RateLimiter) cleanupLoop() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() @@ -254,7 +249,7 @@ func (rl *RateLimiter) cleanupLoop() { } } -// cleanup removes expired records to prevent memory growth +// cleanup removes expired records to prevent memory growth. func (rl *RateLimiter) cleanup() { rl.mu.Lock() defer rl.mu.Unlock() @@ -263,7 +258,6 @@ func (rl *RateLimiter) cleanup() { windowStart := now.Add(-rl.config.FailedAttemptWindow) for ip, record := range rl.records { - // Remove expired failed attempts var validAttempts []time.Time for _, t := range record.failedAttempts { if t.After(windowStart) { @@ -272,13 +266,11 @@ func (rl *RateLimiter) cleanup() { } record.failedAttempts = validAttempts - // Clear expired bans if !record.bannedUntil.IsZero() && now.After(record.bannedUntil) { record.bannedUntil = time.Time{} rateLimitedIPsGauge.Dec() } - // Remove record if it's empty and has no active connections if len(record.failedAttempts) == 0 && record.bannedUntil.IsZero() && record.activeConns == 0 { diff --git a/server/ratelimit_test.go b/server/auth/ratelimit_test.go similarity index 99% rename from server/ratelimit_test.go rename to server/auth/ratelimit_test.go index e3f95e61..3a5afab6 100644 --- a/server/ratelimit_test.go +++ b/server/auth/ratelimit_test.go @@ -1,4 +1,4 @@ -package server +package auth import ( "net" diff --git a/server/auth_aliases.go b/server/auth_aliases.go new file mode 100644 index 00000000..ed4c3fec --- /dev/null +++ b/server/auth_aliases.go @@ -0,0 +1,22 @@ +package server + +import "github.com/posthog/duckgres/server/auth" + +// Type aliases and re-exports kept here so existing references to +// server.RateLimiter / server.RateLimitConfig / server.NewRateLimiter etc. +// continue to compile after the rate-limit and auth-policy code moved into +// server/auth. New code should import server/auth and use auth.X directly. + +type ( + RateLimiter = auth.RateLimiter + RateLimitConfig = auth.RateLimitConfig +) + +var ( + NewRateLimiter = auth.NewRateLimiter + DefaultRateLimitConfig = auth.DefaultRateLimitConfig + BeginRateLimitedAuthAttempt = auth.BeginRateLimitedAuthAttempt + RecordFailedAuthAttempt = auth.RecordFailedAuthAttempt + RecordSuccessfulAuthAttempt = auth.RecordSuccessfulAuthAttempt + ValidateUserPassword = auth.ValidateUserPassword +) diff --git a/server/conn.go b/server/conn.go index 00745f59..540e8bc4 100644 --- a/server/conn.go +++ b/server/conn.go @@ -26,6 +26,7 @@ import ( duckdb "github.com/duckdb/duckdb-go/v2" pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/posthog/duckgres/server/auth" "github.com/posthog/duckgres/transpiler" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -4668,7 +4669,7 @@ func (c *clientConn) sendError(severity, code, message string) { // NOTE: If one adds a FATAL error with a non-28 code, be sure to add // a metric for it here. if strings.HasPrefix(code, "28") { - authFailuresCounter.Inc() + auth.AuthFailuresCounter.Inc() } else if severity == "ERROR" { queryErrorsCounter.WithLabelValues(c.orgID).Inc() } diff --git a/server/flight_executor.go b/server/flight_executor.go index 1c9503cb..f1846df6 100644 --- a/server/flight_executor.go +++ b/server/flight_executor.go @@ -19,6 +19,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/posthog/duckgres/duckdbservice/arrowmap" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -31,14 +32,11 @@ const MaxGRPCMessageSize = 1 << 30 // 1GB // ErrWorkerDead is returned when the backing worker process has crashed. var ErrWorkerDead = errors.New("flight worker is dead") -// OrderedMapValue represents a DuckDB MAP as parallel key/value slices, -// preserving insertion order from Arrow MAP arrays. Using parallel slices -// instead of a Go map avoids panics on non-comparable key types (e.g., -// []byte from BLOB keys) and preserves DuckDB's MAP ordering. -type OrderedMapValue struct { - Keys []any - Values []any -} +// OrderedMapValue is an alias for arrowmap.OrderedMapValue. The type was +// moved into arrowmap so AppendValue's MAP branch can switch on it without +// arrowmap depending on the server package. The alias preserves the +// existing server.OrderedMapValue spelling for current call sites. +type OrderedMapValue = arrowmap.OrderedMapValue // FlightExecutor implements QueryExecutor backed by an Arrow Flight SQL client. // It routes queries to a duckdb-service worker process over a Unix socket. diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index 911afbd1..6d0d84ce 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -22,7 +22,6 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/posthog/duckgres/duckdbservice" "github.com/posthog/duckgres/duckdbservice/arrowmap" "github.com/posthog/duckgres/server" "google.golang.org/grpc" @@ -2151,7 +2150,7 @@ func rowSetToRecord(alloc memory.Allocator, rows server.RowSet, schema *arrow.Sc } for i, val := range values { - duckdbservice.AppendValue(builder.Field(i), val) + arrowmap.AppendValue(builder.Field(i), val) } count++ } diff --git a/server/server.go b/server/server.go index 1c3faeab..cd93cba2 100644 --- a/server/server.go +++ b/server/server.go @@ -23,7 +23,9 @@ import ( awsconfig "github.com/aws/aws-sdk-go-v2/config" _ "github.com/duckdb/duckdb-go/v2" _ "github.com/jackc/pgx/v5/stdlib" // registers "pgx" driver for direct PostgreSQL connections + "github.com/posthog/duckgres/server/auth" "github.com/posthog/duckgres/server/ducklake" + "github.com/posthog/duckgres/server/sysinfo" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -130,21 +132,6 @@ var queryErrorsCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Help: "Total number of failed queries", }, []string{"org"}) -var authFailuresCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "duckgres_auth_failures_total", - Help: "Total number of authentication failures", -}) - -var rateLimitRejectsCounter = promauto.NewCounter(prometheus.CounterOpts{ - Name: "duckgres_rate_limit_rejects_total", - Help: "Total number of connections rejected due to rate limiting", -}) - -var rateLimitedIPsGauge = promauto.NewGauge(prometheus.GaugeOpts{ - Name: "duckgres_rate_limited_ips", - Help: "Number of currently rate-limited IP addresses", -}) - var queryCancellationsCounter = promauto.NewCounter(prometheus.CounterOpts{ Name: "duckgres_query_cancellations_total", Help: "Total number of queries cancelled via cancel request", @@ -850,7 +837,7 @@ func openBaseDB(cfg Config, username string) (*sql.DB, error) { // Set DuckDB memory limit memLimit := cfg.MemoryLimit if memLimit == "" { - memLimit = autoMemoryLimit() + memLimit = sysinfo.AutoMemoryLimit() } if _, err := db.Exec(fmt.Sprintf("SET memory_limit = '%s'", memLimit)); err != nil { slog.Warn("Failed to set DuckDB memory_limit.", "memory_limit", memLimit, "error", err) @@ -2085,7 +2072,7 @@ func (s *Server) handleConnection(conn net.Conn) { if msg := s.rateLimiter.CheckConnection(remoteAddr); msg != "" { // Send PostgreSQL error and close slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg) - rateLimitRejectsCounter.Inc() + auth.RateLimitRejectsCounter.Inc() _ = conn.Close() return } @@ -2093,7 +2080,7 @@ func (s *Server) handleConnection(conn net.Conn) { // Register this connection if !s.rateLimiter.RegisterConnection(remoteAddr) { slog.Warn("Connection rejected: rate limit exceeded.", "remote_addr", remoteAddr) - rateLimitRejectsCounter.Inc() + auth.RateLimitRejectsCounter.Inc() _ = conn.Close() return } diff --git a/server/sysinfo.go b/server/sysinfo/sysinfo.go similarity index 88% rename from server/sysinfo.go rename to server/sysinfo/sysinfo.go index d1ad47f6..22df4c8b 100644 --- a/server/sysinfo.go +++ b/server/sysinfo/sysinfo.go @@ -1,4 +1,7 @@ -package server +// Package sysinfo holds duckgres' system-memory detection helpers and the +// memory-limit string parser shared between the server, the control plane, +// and config resolution. No dependency on github.com/duckdb/duckdb-go. +package sysinfo import ( "bufio" @@ -52,13 +55,13 @@ var ( autoMemoryLimitValue string ) -// autoMemoryLimit computes a DuckDB memory_limit based on system memory. +// AutoMemoryLimit computes a DuckDB memory_limit based on system memory. // Formula: totalMem * 0.75, with a floor of 256MB. // Every session gets the full budget — DuckDB will spill to disk/swap if // aggregate usage exceeds physical RAM. // Returns "4GB" as a safe default if system memory cannot be detected. // The result is computed once and cached since system memory doesn't change. -func autoMemoryLimit() string { +func AutoMemoryLimit() string { autoMemoryLimitOnce.Do(func() { totalBytes := SystemMemoryBytes() if totalBytes == 0 { @@ -74,7 +77,7 @@ func autoMemoryLimit() string { limitBytes = 256 * mb } - // Format as human-readable: use GB if >= 1GB, else MB + // Format as human-readable: use GB if >= 1GB, else MB. if limitBytes >= gb { limitGB := limitBytes / gb autoMemoryLimitValue = fmt.Sprintf("%dGB", limitGB) diff --git a/server/sysinfo_test.go b/server/sysinfo/sysinfo_test.go similarity index 89% rename from server/sysinfo_test.go rename to server/sysinfo/sysinfo_test.go index 38b8dbd3..00bde27f 100644 --- a/server/sysinfo_test.go +++ b/server/sysinfo/sysinfo_test.go @@ -1,4 +1,4 @@ -package server +package sysinfo import ( "runtime" @@ -11,9 +11,9 @@ func TestAutoMemoryLimit(t *testing.T) { autoMemoryLimitOnce = sync.Once{} autoMemoryLimitValue = "" - result := autoMemoryLimit() + result := AutoMemoryLimit() if result == "" { - t.Fatal("autoMemoryLimit returned empty string") + t.Fatal("AutoMemoryLimit returned empty string") } // On Linux (CI and production), we should detect system memory @@ -31,7 +31,7 @@ func TestAutoMemoryLimitFormat(t *testing.T) { autoMemoryLimitOnce = sync.Once{} autoMemoryLimitValue = "" - result := autoMemoryLimit() + result := AutoMemoryLimit() // Should end with GB or MB validSuffix := false for _, suffix := range []string{"GB", "MB"} { @@ -41,7 +41,7 @@ func TestAutoMemoryLimitFormat(t *testing.T) { } } if !validSuffix { - t.Fatalf("autoMemoryLimit returned %q, expected suffix GB or MB", result) + t.Fatalf("AutoMemoryLimit returned %q, expected suffix GB or MB", result) } } @@ -50,10 +50,10 @@ func TestAutoMemoryLimitCached(t *testing.T) { autoMemoryLimitOnce = sync.Once{} autoMemoryLimitValue = "" - first := autoMemoryLimit() - second := autoMemoryLimit() + first := AutoMemoryLimit() + second := AutoMemoryLimit() if first != second { - t.Fatalf("autoMemoryLimit not stable: %q vs %q", first, second) + t.Fatalf("AutoMemoryLimit not stable: %q vs %q", first, second) } } diff --git a/server/sysinfo_aliases.go b/server/sysinfo_aliases.go new file mode 100644 index 00000000..094fd00e --- /dev/null +++ b/server/sysinfo_aliases.go @@ -0,0 +1,14 @@ +package server + +import "github.com/posthog/duckgres/server/sysinfo" + +// Re-exports kept here so existing references to server.SystemMemoryBytes, +// server.ValidateMemoryLimit, and server.ParseMemoryBytes continue to compile +// after the helpers moved into server/sysinfo. New code should import +// server/sysinfo and use sysinfo.X directly. + +var ( + SystemMemoryBytes = sysinfo.SystemMemoryBytes + ValidateMemoryLimit = sysinfo.ValidateMemoryLimit + ParseMemoryBytes = sysinfo.ParseMemoryBytes +) diff --git a/server/worker.go b/server/worker.go index 69c6a6d9..b63990a3 100644 --- a/server/worker.go +++ b/server/worker.go @@ -17,6 +17,8 @@ import ( "sync" "syscall" "time" + + "github.com/posthog/duckgres/server/auth" ) // Exit codes for child processes @@ -220,7 +222,7 @@ func runChildWorker(tcpConn *net.TCPConn, cfg *ChildConfig) int { expectedPassword, ok := cfg.Users[username] if !ok { slog.Warn("Unknown user", "user", username, "remote_addr", cfg.RemoteAddr) - authFailuresCounter.Inc() + auth.AuthFailuresCounter.Inc() _ = writeErrorResponse(writer, "FATAL", "28P01", "password authentication failed") _ = writer.Flush() return ExitAuthFailure @@ -254,7 +256,7 @@ func runChildWorker(tcpConn *net.TCPConn, cfg *ChildConfig) int { password := string(bytes.TrimRight(body, "\x00")) if password != expectedPassword { slog.Warn("Authentication failed", "user", username, "remote_addr", cfg.RemoteAddr) - authFailuresCounter.Inc() + auth.AuthFailuresCounter.Inc() _ = writeErrorResponse(writer, "FATAL", "28P01", "password authentication failed") _ = writer.Flush() return ExitAuthFailure