diff --git a/controlplane/control.go b/controlplane/control.go index 71d4a3d4..baf7c563 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -23,6 +23,7 @@ import ( "github.com/cloudflare/tableflip" "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "github.com/posthog/duckgres/server/flightsqlingress" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -630,8 +631,8 @@ func createSessionWithRegisteredCancel( srv *server.Server, timeout time.Duration, key server.BackendKey, - createFn func(context.Context) (int32, *server.FlightExecutor, error), -) (int32, *server.FlightExecutor, error) { + createFn func(context.Context) (int32, *flightclient.FlightExecutor, error), +) (int32, *flightclient.FlightExecutor, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -982,7 +983,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) { cp.srv, cp.cfg.WorkerQueueTimeout, server.BackendKey{Pid: pid, SecretKey: secretKey}, - func(ctx context.Context) (int32, *server.FlightExecutor, error) { + func(ctx context.Context) (int32, *flightclient.FlightExecutor, error) { return sessions.CreateSession(ctx, username, pid, memLimit, threads) }, ) diff --git a/controlplane/control_cancel_test.go b/controlplane/control_cancel_test.go index 5c507aa7..2809bbfb 100644 --- a/controlplane/control_cancel_test.go +++ b/controlplane/control_cancel_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" ) func TestCreateSessionWithRegisteredCancel_CancelQueryCancelsWait(t *testing.T) { @@ -24,7 +25,7 @@ func TestCreateSessionWithRegisteredCancel_CancelQueryCancelsWait(t *testing.T) srv, 200*time.Millisecond, key, - func(ctx context.Context) (int32, *server.FlightExecutor, error) { + func(ctx context.Context) (int32, *flightclient.FlightExecutor, error) { close(started) <-ctx.Done() return 0, nil, ctx.Err() diff --git a/controlplane/flight_ingress.go b/controlplane/flight_ingress.go index 9c444b97..12216399 100644 --- a/controlplane/flight_ingress.go +++ b/controlplane/flight_ingress.go @@ -12,6 +12,7 @@ import ( "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "github.com/posthog/duckgres/server/flightsqlingress" ) @@ -45,7 +46,7 @@ type flightSessionProvider struct { sm *SessionManager } -func (p *flightSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { +func (p *flightSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { workerPID, executor, err := p.sm.CreateSession(ctx, username, pid, memoryLimit, threads) if err != nil { return 0, nil, err @@ -74,7 +75,7 @@ type orgRoutedSessionProvider struct { userOrg map[string]string // username → orgID (populated during auth) } -func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { +func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { p.mu.RLock() orgID := p.userOrg[username] p.mu.RUnlock() @@ -141,7 +142,7 @@ func (p *orgRoutedSessionProvider) DurableSessionMetadata(pid int32, username st }, nil } -func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record flightsqlingress.DurableSessionRecord) (int32, *server.FlightExecutor, error) { +func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record flightsqlingress.DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { _, sessions, _, ok := p.orgRouter.StackForOrg(record.OrgID) if !ok { return 0, nil, fmt.Errorf("no org stack for org %q", record.OrgID) diff --git a/controlplane/k8s_pool.go b/controlplane/k8s_pool.go index 97af46d1..29b7ea7b 100644 --- a/controlplane/k8s_pool.go +++ b/controlplane/k8s_pool.go @@ -21,6 +21,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "google.golang.org/grpc" "google.golang.org/grpc/credentials" corev1 "k8s.io/api/core/v1" @@ -932,8 +933,8 @@ func waitForWorkerTCPWithMetadata(addr, bearerToken string, serverCertPEM []byte var dialOpts []grpc.DialOption dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) dialOpts = append(dialOpts, grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize), )) dialOpts = append(dialOpts, server.OTELGRPCClientHandler()) if bearerToken != "" { @@ -1640,8 +1641,8 @@ func (p *K8sWorkerPool) connectWorkerDirect(ctx context.Context, podName, podIP, var dialOpts []grpc.DialOption dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) dialOpts = append(dialOpts, grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize), )) dialOpts = append(dialOpts, server.OTELGRPCClientHandler()) if bearerToken != "" { diff --git a/controlplane/session_mgr.go b/controlplane/session_mgr.go index c8c659d5..0f06ab10 100644 --- a/controlplane/session_mgr.go +++ b/controlplane/session_mgr.go @@ -10,6 +10,7 @@ import ( "time" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" ) // SessionProgress holds cached query progress from a worker health check. @@ -26,7 +27,7 @@ type ManagedSession struct { WorkerID int Protocol string // "postgres" or "flight" SessionToken string - Executor *server.FlightExecutor + Executor *flightclient.FlightExecutor connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop // Cached query progress from worker health checks. @@ -68,7 +69,7 @@ func (sm *SessionManager) ReservePID() int32 { // CreateSession acquires a worker (reusing an idle one or spawning a new one), // creates a session on it, and rebalances memory/thread limits across all active sessions. // If pid is 0, a new one is generated. -func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { +func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { memoryLimit, threads = sm.resolveSessionLimits(memoryLimit, threads) // Acquire a worker: reuses idle pre-warmed workers or spawns a new one. @@ -102,7 +103,7 @@ func (sm *SessionManager) resolveSessionLimits(memoryLimit string, threads int) return memoryLimit, threads } -func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username string, workerID int, ownerEpoch int64) (int32, *server.FlightExecutor, error) { +func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username string, workerID int, ownerEpoch int64) (int32, *flightclient.FlightExecutor, error) { reconnector, ok := sm.pool.(flightReconnectPool) if !ok { return 0, nil, fmt.Errorf("worker pool does not support flight reconnect") @@ -114,7 +115,7 @@ func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username s return sm.createSessionOnWorker(ctx, username, 0, "", 0, worker, "flight", false) } -func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool) (int32, *server.FlightExecutor, error) { +func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool) (int32, *flightclient.FlightExecutor, error) { createStart := time.Now() sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, threads) if err != nil { @@ -124,7 +125,7 @@ func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username st return 0, nil, fmt.Errorf("create session on worker %d: %w", worker.ID, err) } - executor := server.NewFlightExecutorFromClient(worker.client, sessionToken) + executor := flightclient.NewFlightExecutorFromClient(worker.client, sessionToken) executor.SetControlMetadata(worker.ID, worker.OwnerCPInstanceID(), worker.OwnerEpoch()) if pid == 0 { diff --git a/controlplane/session_mgr_test.go b/controlplane/session_mgr_test.go index 1afc55d1..5488a13e 100644 --- a/controlplane/session_mgr_test.go +++ b/controlplane/session_mgr_test.go @@ -8,7 +8,7 @@ import ( "sync/atomic" "testing" - "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" ) // mockCloser tracks whether Close was called. @@ -27,7 +27,7 @@ func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) { } sm := NewSessionManager(pool, nil) - executor := &server.FlightExecutor{} + executor := &flightclient.FlightExecutor{} pid := int32(1001) sm.mu.Lock() @@ -67,7 +67,7 @@ func TestOnWorkerCrash_ClosesConnections(t *testing.T) { sm := NewSessionManager(pool, nil) conn := &mockCloser{} - executor := &server.FlightExecutor{} + executor := &flightclient.FlightExecutor{} pid := int32(1002) sm.mu.Lock() @@ -93,8 +93,8 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) { } sm := NewSessionManager(pool, nil) - exec1 := &server.FlightExecutor{} - exec2 := &server.FlightExecutor{} + exec1 := &flightclient.FlightExecutor{} + exec2 := &flightclient.FlightExecutor{} conn1 := &mockCloser{} conn2 := &mockCloser{} @@ -228,7 +228,7 @@ func TestDestroySessionAfterOnWorkerCrash(t *testing.T) { sm := NewSessionManager(pool, nil) conn := &mockCloser{} - executor := &server.FlightExecutor{} + executor := &flightclient.FlightExecutor{} pid := int32(1010) sm.mu.Lock() diff --git a/controlplane/worker_mgr.go b/controlplane/worker_mgr.go index d4039027..eed7845e 100644 --- a/controlplane/worker_mgr.go +++ b/controlplane/worker_mgr.go @@ -19,6 +19,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -454,8 +455,8 @@ func waitForWorker(socketPath, bearerToken string, timeout time.Duration) (*flig var dialOpts []grpc.DialOption dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) dialOpts = append(dialOpts, grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize), )) if bearerToken != "" { diff --git a/duckdbservice/arrowmap/arrowmap.go b/duckdbservice/arrowmap/arrowmap.go index 845410f5..1580e805 100644 --- a/duckdbservice/arrowmap/arrowmap.go +++ b/duckdbservice/arrowmap/arrowmap.go @@ -305,14 +305,24 @@ func QuoteIdent(ident string) string { // and preserves the source MAP ordering. // // Lives in arrowmap so AppendValue can switch on it without depending on -// the server package (which transitively links libduckdb). The flight -// executor in the server package re-exports it as server.OrderedMapValue -// via a type alias for backward compatibility. +// the server package, and the flight client + result formatters in +// server/ can both reference it without creating an import cycle. type OrderedMapValue struct { Keys []any Values []any } +// IntervalValue is the duckdb-free representation of an Arrow +// MonthDayNanoInterval as decoded by the Flight client. It stores the +// component fields directly (Months/Days/Micros) so result formatters +// in the server package can switch on the type without importing the +// flight subpackage (which would create a cycle). +type IntervalValue struct { + Months int32 + Days int32 + Micros int64 +} + // Appender is a hook that handles append for value types arrowmap doesn't // know about (typically driver-specific types like duckdb.Interval). It // reports whether it handled the value; arrowmap.AppendValue falls back to diff --git a/duckdbservice/service.go b/duckdbservice/service.go index eb95c1c3..549fc76e 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -22,6 +22,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -289,8 +290,8 @@ func (svc *DuckDBService) Serve(listener net.Listener) error { var opts []grpc.ServerOption opts = append(opts, - grpc.MaxRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxSendMsgSize(flightclient.MaxGRPCMessageSize), ) if svc.cfg.BearerToken != "" { opts = append(opts, diff --git a/server/conn.go b/server/conn.go index 540e8bc4..e5b0f8fe 100644 --- a/server/conn.go +++ b/server/conn.go @@ -26,6 +26,7 @@ import ( duckdb "github.com/duckdb/duckdb-go/v2" pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/posthog/duckgres/duckdbservice/arrowmap" "github.com/posthog/duckgres/server/auth" "github.com/posthog/duckgres/transpiler" "go.opentelemetry.io/otel/attribute" @@ -4308,7 +4309,7 @@ func (c *clientConn) formatCopyValue(v interface{}) string { return formatArrayValue(val) case map[string]any: return formatMapValue(val) - case OrderedMapValue: + case arrowmap.OrderedMapValue: return formatOrderedMapValue(val) default: return fmt.Sprintf("%v", val) @@ -4512,13 +4513,13 @@ func formatValue(v interface{}) string { case duckdb.Interval: // PostgreSQL interval text format: "1 year 2 mons 3 days 04:05:06.123456" return formatInterval(val) - case intervalValue: - // Arrow Flight returns intervalValue instead of duckdb.Interval + case arrowmap.IntervalValue: + // Arrow Flight returns arrowmap.IntervalValue instead of duckdb.Interval return formatInterval(duckdb.Interval{Months: val.Months, Days: val.Days, Micros: val.Micros}) case map[string]any: // STRUCT text format: {"key1": val1, "key2": val2} return formatMapValue(val) - case OrderedMapValue: + case arrowmap.OrderedMapValue: return formatOrderedMapValue(val) default: // For other types, try to convert to string @@ -4647,7 +4648,7 @@ func formatMapValue(m map[string]any) string { // formatOrderedMapValue formats an OrderedMapValue as a key-value text // representation, preserving the original insertion order from the Arrow array. -func formatOrderedMapValue(m OrderedMapValue) string { +func formatOrderedMapValue(m arrowmap.OrderedMapValue) string { var buf strings.Builder buf.WriteByte('{') for i, k := range m.Keys { diff --git a/server/flight_executor.go b/server/flightclient/flight_executor.go similarity index 93% rename from server/flight_executor.go rename to server/flightclient/flight_executor.go index f1846df6..b5e69848 100644 --- a/server/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -1,4 +1,4 @@ -package server +package flightclient import ( "context" @@ -20,6 +20,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/duckdbservice/arrowmap" + "github.com/posthog/duckgres/server" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -32,12 +33,6 @@ const MaxGRPCMessageSize = 1 << 30 // 1GB // ErrWorkerDead is returned when the backing worker process has crashed. var ErrWorkerDead = errors.New("flight worker is dead") -// OrderedMapValue is an alias for arrowmap.OrderedMapValue. The type was -// moved into arrowmap so AppendValue's MAP branch can switch on it without -// arrowmap depending on the server package. The alias preserves the -// existing server.OrderedMapValue spelling for current call sites. -type OrderedMapValue = arrowmap.OrderedMapValue - // FlightExecutor implements QueryExecutor backed by an Arrow Flight SQL client. // It routes queries to a duckdb-service worker process over a Unix socket. type FlightExecutor struct { @@ -75,7 +70,7 @@ func NewFlightExecutor(addr, bearerToken, sessionToken string) (*FlightExecutor, // Propagate OTEL trace context across gRPC to worker pods. // Filtered to query RPCs only (GetFlightInfo, DoGet). - dialOpts = append(dialOpts, OTELGRPCClientHandler()) + dialOpts = append(dialOpts, server.OTELGRPCClientHandler()) if bearerToken != "" { dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(&bearerCreds{token: bearerToken})) @@ -161,14 +156,14 @@ func recoverClientPanic(err *error) { } } -func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args ...any) (rs RowSet, err error) { +func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args ...any) (rs server.RowSet, err error) { if e.dead.Load() { return nil, ErrWorkerDead } // Return empty results for queries that are only semicolons, whitespace, // and/or comments. These represent PostgreSQL client pings (e.g., pgx sends "-- ping"). - if IsEmptyQuery(query) { + if server.IsEmptyQuery(query) { return &emptyRowSet{}, nil } @@ -226,13 +221,13 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. }, nil } -func (e *FlightExecutor) ExecContext(ctx context.Context, query string, args ...any) (result ExecResult, err error) { +func (e *FlightExecutor) ExecContext(ctx context.Context, query string, args ...any) (result server.ExecResult, err error) { if e.dead.Load() { return nil, ErrWorkerDead } // Return zero rows affected for empty/comment-only queries. - if IsEmptyQuery(query) { + if server.IsEmptyQuery(query) { return &flightExecResult{rowsAffected: 0}, nil } @@ -257,15 +252,15 @@ func (e *FlightExecutor) ExecContext(ctx context.Context, query string, args ... return &flightExecResult{rowsAffected: affected}, nil } -func (e *FlightExecutor) Query(query string, args ...any) (RowSet, error) { +func (e *FlightExecutor) Query(query string, args ...any) (server.RowSet, error) { return e.QueryContext(context.Background(), query, args...) } -func (e *FlightExecutor) Exec(query string, args ...any) (ExecResult, error) { +func (e *FlightExecutor) Exec(query string, args ...any) (server.ExecResult, error) { return e.ExecContext(context.Background(), query, args...) } -func (e *FlightExecutor) ConnContext(ctx context.Context) (RawConn, error) { +func (e *FlightExecutor) ConnContext(ctx context.Context) (server.RawConn, error) { return nil, fmt.Errorf("ConnContext not supported in Flight mode (use batched INSERT for COPY FROM)") } @@ -346,8 +341,8 @@ func (r *FlightRowSet) Columns() ([]string, error) { return names, nil } -func (r *FlightRowSet) ColumnTypes() ([]ColumnTyper, error) { - types := make([]ColumnTyper, r.schema.NumFields()) +func (r *FlightRowSet) ColumnTypes() ([]server.ColumnTyper, error) { + types := make([]server.ColumnTyper, r.schema.NumFields()) for i := 0; i < r.schema.NumFields(); i++ { types[i] = &arrowColumnType{dt: r.schema.Field(i).Type} } @@ -433,7 +428,7 @@ func (r *FlightRowSet) Err() error { type emptyRowSet struct{} func (e *emptyRowSet) Columns() ([]string, error) { return nil, nil } -func (e *emptyRowSet) ColumnTypes() ([]ColumnTyper, error) { return nil, nil } +func (e *emptyRowSet) ColumnTypes() ([]server.ColumnTyper, error) { return nil, nil } func (e *emptyRowSet) Next() bool { return false } func (e *emptyRowSet) Scan(dest ...any) error { return fmt.Errorf("no rows") } func (e *emptyRowSet) Close() error { return nil } @@ -455,8 +450,8 @@ func (e *emptySchemaRowSet) Columns() ([]string, error) { return cols, nil } -func (e *emptySchemaRowSet) ColumnTypes() ([]ColumnTyper, error) { - types := make([]ColumnTyper, e.schema.NumFields()) +func (e *emptySchemaRowSet) ColumnTypes() ([]server.ColumnTyper, error) { + types := make([]server.ColumnTyper, e.schema.NumFields()) for i := 0; i < e.schema.NumFields(); i++ { types[i] = &arrowColumnType{dt: e.schema.Field(i).Type} } @@ -634,7 +629,7 @@ func extractArrowValue(col arrow.Array, row int) interface{} { ks = append(ks, extractArrowValue(keys, i)) vs = append(vs, extractArrowValue(items, i)) } - return OrderedMapValue{Keys: ks, Values: vs} + return arrowmap.OrderedMapValue{Keys: ks, Values: vs} default: // Fallback: use String representation return arr.ValueStr(row) @@ -683,14 +678,8 @@ func decimalToBigInt(val decimal128.Num, dt *arrow.Decimal128Type) interface{} { return result } -type intervalValue struct { - Months int32 - Days int32 - Micros int64 -} - -func monthDayNanoToInterval(val arrow.MonthDayNanoInterval) intervalValue { - return intervalValue{ +func monthDayNanoToInterval(val arrow.MonthDayNanoInterval) arrowmap.IntervalValue { + return arrowmap.IntervalValue{ Months: val.Months, Days: val.Days, Micros: val.Nanoseconds / 1000, diff --git a/server/flight_executor_arrow_test.go b/server/flightclient/flight_executor_arrow_test.go similarity index 92% rename from server/flight_executor_arrow_test.go rename to server/flightclient/flight_executor_arrow_test.go index de4aa8bd..dbfd9dc8 100644 --- a/server/flight_executor_arrow_test.go +++ b/server/flightclient/flight_executor_arrow_test.go @@ -1,4 +1,4 @@ -package server +package flightclient import ( "fmt" @@ -7,6 +7,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/posthog/duckgres/duckdbservice/arrowmap" ) // Tests for arrowTypeToDuckDB — verifies that Arrow types are mapped back to @@ -80,7 +81,7 @@ func TestArrowTypeToDuckDB_ListOfStruct(t *testing.T) { // // Expected return types (matching what conn.go formatValue consumes): // STRUCT → map[string]interface{} -// MAP → OrderedMapValue (keys preserve original Arrow types and insertion order) +// MAP → arrowmap.OrderedMapValue (keys preserve original Arrow types and insertion order) func TestExtractArrowValue_Struct(t *testing.T) { alloc := memory.NewGoAllocator() @@ -211,9 +212,9 @@ func TestExtractArrowValue_Map(t *testing.T) { // Row 0: single-entry map val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP) row 0 returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP) row 0 returned %T, want arrowmap.OrderedMapValue", val) } if len(m.Keys) != 1 { t.Fatalf("expected 1 map entry, got %d", len(m.Keys)) @@ -227,9 +228,9 @@ func TestExtractArrowValue_Map(t *testing.T) { // Row 1: two-entry map val = extractArrowValue(rec.Column(0), 1) - m, ok = val.(OrderedMapValue) + m, ok = val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP) row 1 returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP) row 1 returned %T, want arrowmap.OrderedMapValue", val) } if len(m.Keys) != 2 { t.Fatalf("expected 2 map entries, got %d", len(m.Keys)) @@ -378,9 +379,9 @@ func TestExtractArrowValue_MapEmpty(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(empty MAP) returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(empty MAP) returned %T, want arrowmap.OrderedMapValue", val) } if len(m.Keys) != 0 { t.Errorf("expected empty map, got %d entries", len(m.Keys)) @@ -410,9 +411,9 @@ func TestExtractArrowValue_MapIntegerKeys(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP int keys) returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP int keys) returned %T, want arrowmap.OrderedMapValue", val) } if m.Keys[0] != int32(1) { t.Errorf("Keys[0] = %v (%T), want int32(1)", m.Keys[0], m.Keys[0]) @@ -450,9 +451,9 @@ func TestExtractArrowValue_MapWithNullValues(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP with null values) returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP with null values) returned %T, want arrowmap.OrderedMapValue", val) } if m.Values[0] != int32(42) { t.Errorf("m[\"present\"] = %v, want int32(42)", m.Values[0]) @@ -529,9 +530,9 @@ func TestExtractArrowValue_MapMultipleRows(t *testing.T) { // Row 0 val := extractArrowValue(rec.Column(0), 0) - m0, ok := val.(OrderedMapValue) + m0, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("row 0: returned %T, want OrderedMapValue", val) + t.Fatalf("row 0: returned %T, want arrowmap.OrderedMapValue", val) } if len(m0.Keys) != 1 || m0.Values[0] != int32(1) { t.Errorf("row 0: got %v, want {a:1}", m0) @@ -539,9 +540,9 @@ func TestExtractArrowValue_MapMultipleRows(t *testing.T) { // Row 1 (empty) val = extractArrowValue(rec.Column(0), 1) - m1, ok := val.(OrderedMapValue) + m1, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("row 1: returned %T, want OrderedMapValue", val) + t.Fatalf("row 1: returned %T, want arrowmap.OrderedMapValue", val) } if len(m1.Keys) != 0 { t.Errorf("row 1: expected empty map, got %v", m1) @@ -549,9 +550,9 @@ func TestExtractArrowValue_MapMultipleRows(t *testing.T) { // Row 2 val = extractArrowValue(rec.Column(0), 2) - m2, ok := val.(OrderedMapValue) + m2, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("row 2: returned %T, want OrderedMapValue", val) + t.Fatalf("row 2: returned %T, want arrowmap.OrderedMapValue", val) } if len(m2.Keys) != 3 || m2.Values[2] != int32(30) { t.Errorf("row 2: got %v, want {x:10,y:20,z:30}", m2) @@ -631,16 +632,16 @@ func TestExtractArrowValue_ListOfMap(t *testing.T) { if len(elems) != 2 { t.Fatalf("expected 2 elements, got %d", len(elems)) } - m0, ok := elems[0].(OrderedMapValue) + m0, ok := elems[0].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("element 0 returned %T, want OrderedMapValue", elems[0]) + t.Fatalf("element 0 returned %T, want arrowmap.OrderedMapValue", elems[0]) } if m0.Values[0] != int32(1) { t.Errorf("elem[0][\"a\"] = %v, want int32(1)", m0.Values[0]) } - m1, ok := elems[1].(OrderedMapValue) + m1, ok := elems[1].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("element 1 returned %T, want OrderedMapValue", elems[1]) + t.Fatalf("element 1 returned %T, want arrowmap.OrderedMapValue", elems[1]) } if len(m1.Keys) != 2 { t.Errorf("elem[1] has %d entries, want 2", len(m1.Keys)) @@ -671,13 +672,13 @@ func TestExtractArrowValue_MapOfMapValues(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP of MAP) returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP of MAP) returned %T, want arrowmap.OrderedMapValue", val) } - inner_val, ok := m.Values[0].(OrderedMapValue) + inner_val, ok := m.Values[0].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("inner map returned %T, want OrderedMapValue", m.Values[0]) + t.Fatalf("inner map returned %T, want arrowmap.OrderedMapValue", m.Values[0]) } if inner_val.Values[0] != int32(99) { t.Errorf("inner[\"inner_key\"] = %v, want int32(99)", inner_val.Values[0]) @@ -708,9 +709,9 @@ func TestExtractArrowValue_MapOfListValues(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("extractArrowValue(MAP of LIST) returned %T, want OrderedMapValue", val) + t.Fatalf("extractArrowValue(MAP of LIST) returned %T, want arrowmap.OrderedMapValue", val) } nums, ok := m.Values[0].([]any) if !ok { @@ -754,9 +755,9 @@ func TestExtractArrowValue_StructContainingMap(t *testing.T) { if m["id"] != int32(1) { t.Errorf("id = %v, want int32(1)", m["id"]) } - meta, ok := m["meta"].(OrderedMapValue) + meta, ok := m["meta"].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("meta field returned %T, want OrderedMapValue", m["meta"]) + t.Fatalf("meta field returned %T, want arrowmap.OrderedMapValue", m["meta"]) } if meta.Values[0] != "red" { t.Errorf("meta[\"color\"] = %v, want \"red\"", meta.Values[0]) @@ -967,9 +968,9 @@ func TestExtractArrowValue_MapMixedNullNonNullRows(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("row 0: returned %T, want OrderedMapValue", val) + t.Fatalf("row 0: returned %T, want arrowmap.OrderedMapValue", val) } if m.Values[0] != int32(1) { t.Errorf("row 0: a = %v, want 1", m.Values[0]) @@ -980,9 +981,9 @@ func TestExtractArrowValue_MapMixedNullNonNullRows(t *testing.T) { } val = extractArrowValue(rec.Column(0), 2) - m, ok = val.(OrderedMapValue) + m, ok = val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("row 2: returned %T, want OrderedMapValue", val) + t.Fatalf("row 2: returned %T, want arrowmap.OrderedMapValue", val) } if m.Values[0] != int32(2) { t.Errorf("row 2: b = %v, want 2", m.Values[0]) @@ -1170,9 +1171,9 @@ func TestExtractArrowValue_MapWithStructValues(t *testing.T) { defer rec.Release() val := extractArrowValue(rec.Column(0), 0) - m, ok := val.(OrderedMapValue) + m, ok := val.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("returned %T, want OrderedMapValue", val) + t.Fatalf("returned %T, want arrowmap.OrderedMapValue", val) } // MAP values that are STRUCTs come back as map[string]interface{} from extractArrowValue point, ok := m.Values[0].(map[string]interface{}) @@ -1362,13 +1363,13 @@ func TestExtractThenAppend_MapBasic(t *testing.T) { // Verify the rebuilt MAP is not null and has correct values col := dst.Column(0).(*array.Map) if col.IsNull(0) { - t.Fatal("rebuilt MAP is null — AppendValue did not recognize extractArrowValue's OrderedMapValue") + t.Fatal("rebuilt MAP is null — AppendValue did not recognize extractArrowValue's arrowmap.OrderedMapValue") } // Re-extract from rebuilt record to verify data rebuilt := extractArrowValue(col, 0) - rm, ok := rebuilt.(OrderedMapValue) + rm, ok := rebuilt.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("re-extracted value is %T, want OrderedMapValue", rebuilt) + t.Fatalf("re-extracted value is %T, want arrowmap.OrderedMapValue", rebuilt) } if rm.Keys[0] != "a" || rm.Keys[1] != "b" { t.Errorf("rebuilt Keys = %v, want [a, b]", rm.Keys) @@ -1411,9 +1412,9 @@ func TestExtractThenAppend_MapIntegerKeys(t *testing.T) { t.Fatal("rebuilt MAP(INT,VARCHAR) is null") } rebuilt := extractArrowValue(col, 0) - rm, ok := rebuilt.(OrderedMapValue) + rm, ok := rebuilt.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("re-extracted value is %T, want OrderedMapValue", rebuilt) + t.Fatalf("re-extracted value is %T, want arrowmap.OrderedMapValue", rebuilt) } if rm.Keys[0] != int32(1) { t.Errorf("rebuilt Keys[0] = %v (%T), want int32(1)", rm.Keys[0], rm.Keys[0]) @@ -1483,9 +1484,9 @@ func TestExtractThenAppend_MapEmpty(t *testing.T) { t.Fatal("rebuilt empty MAP should not be null") } rebuilt := extractArrowValue(col, 0) - rm, ok := rebuilt.(OrderedMapValue) + rm, ok := rebuilt.(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("re-extracted value is %T, want OrderedMapValue", rebuilt) + t.Fatalf("re-extracted value is %T, want arrowmap.OrderedMapValue", rebuilt) } if len(rm.Keys) != 0 { t.Errorf("rebuilt MAP has %d entries, want 0", len(rm.Keys)) @@ -1539,7 +1540,7 @@ func TestExtractThenAppend_MapMultipleRows(t *testing.T) { if col.IsNull(0) { t.Fatal("row 0: should not be null") } - r0 := extractArrowValue(col, 0).(OrderedMapValue) + r0 := extractArrowValue(col, 0).(arrowmap.OrderedMapValue) if r0.Values[0] != int32(10) { t.Errorf("row 0: x = %v, want 10", r0.Values[0]) } @@ -1553,7 +1554,7 @@ func TestExtractThenAppend_MapMultipleRows(t *testing.T) { if col.IsNull(2) { t.Fatal("row 2: should not be null") } - r2 := extractArrowValue(col, 2).(OrderedMapValue) + r2 := extractArrowValue(col, 2).(arrowmap.OrderedMapValue) if len(r2.Keys) != 0 { t.Errorf("row 2: expected empty, got %v", r2) } @@ -1562,7 +1563,7 @@ func TestExtractThenAppend_MapMultipleRows(t *testing.T) { if col.IsNull(3) { t.Fatal("row 3: should not be null") } - r3 := extractArrowValue(col, 3).(OrderedMapValue) + r3 := extractArrowValue(col, 3).(arrowmap.OrderedMapValue) if len(r3.Keys) != 3 || r3.Values[0] != int32(1) || r3.Values[1] != int32(2) || r3.Values[2] != int32(3) { t.Errorf("row 3: got %v, want {a:1,b:2,c:3}", r3) } @@ -1600,7 +1601,7 @@ func TestExtractThenAppend_MapWithNullValues(t *testing.T) { if col.IsNull(0) { t.Fatal("rebuilt MAP with null values should not itself be null") } - rebuilt := extractArrowValue(col, 0).(OrderedMapValue) + rebuilt := extractArrowValue(col, 0).(arrowmap.OrderedMapValue) if rebuilt.Values[0] != int32(42) { t.Errorf("present = %v, want 42", rebuilt.Values[0]) } @@ -1643,10 +1644,10 @@ func TestExtractThenAppend_MapOfMap(t *testing.T) { if col.IsNull(0) { t.Fatal("rebuilt nested MAP is null") } - rebuilt := extractArrowValue(col, 0).(OrderedMapValue) - innerMap, ok := rebuilt.Values[0].(OrderedMapValue) + rebuilt := extractArrowValue(col, 0).(arrowmap.OrderedMapValue) + innerMap, ok := rebuilt.Values[0].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("inner value is %T, want OrderedMapValue", rebuilt.Values[0]) + t.Fatalf("inner value is %T, want arrowmap.OrderedMapValue", rebuilt.Values[0]) } if innerMap.Values[0] != int32(99) { t.Errorf("inner[\"inner\"] = %v, want 99", innerMap.Values[0]) @@ -1693,9 +1694,9 @@ func TestExtractThenAppend_StructContainingMap(t *testing.T) { if rebuilt["id"] != int32(1) { t.Errorf("id = %v, want 1", rebuilt["id"]) } - meta, ok := rebuilt["meta"].(OrderedMapValue) + meta, ok := rebuilt["meta"].(arrowmap.OrderedMapValue) if !ok { - t.Fatalf("meta is %T, want OrderedMapValue", rebuilt["meta"]) + t.Fatalf("meta is %T, want arrowmap.OrderedMapValue", rebuilt["meta"]) } if meta.Values[0] != "red" { t.Errorf("meta[color] = %v, want red", meta.Values[0]) @@ -1739,7 +1740,7 @@ func TestExtractThenAppend_MapWithStructValues(t *testing.T) { if col.IsNull(0) { t.Fatal("rebuilt MAP(VARCHAR,STRUCT) is null") } - rebuilt := extractArrowValue(col, 0).(OrderedMapValue) + rebuilt := extractArrowValue(col, 0).(arrowmap.OrderedMapValue) point, ok := rebuilt.Values[0].(map[string]interface{}) if !ok { t.Fatalf("point value is %T, want map[string]interface{}", rebuilt.Values[0]) @@ -1788,11 +1789,11 @@ func TestExtractThenAppend_ListOfMap(t *testing.T) { if len(rebuilt) != 2 { t.Fatalf("expected 2 elements, got %d", len(rebuilt)) } - e0 := rebuilt[0].(OrderedMapValue) + e0 := rebuilt[0].(arrowmap.OrderedMapValue) if e0.Values[0] != int32(1) { t.Errorf("elem[0] = %v, want {a:1}", e0) } - e1 := rebuilt[1].(OrderedMapValue) + e1 := rebuilt[1].(arrowmap.OrderedMapValue) if e1.Values[0] != int32(2) { t.Errorf("elem[1] = %v, want {b:2}", e1) } @@ -1836,7 +1837,7 @@ func appendValue(builder array.Builder, val interface{}) { } case *array.MapBuilder: switch v := val.(type) { - case OrderedMapValue: + case arrowmap.OrderedMapValue: b.Append(true) for i, k := range v.Keys { appendValue(b.KeyBuilder(), k) @@ -1855,49 +1856,3 @@ func appendValue(builder array.Builder, val interface{}) { } } -// --- formatOrderedMapValue tests --- - -func TestFormatOrderedMapValue_Basic(t *testing.T) { - m := OrderedMapValue{Keys: []any{"a"}, Values: []any{int32(1)}} - got := formatOrderedMapValue(m) - if got != "{a=1}" { - t.Errorf("formatOrderedMapValue = %q, want %q", got, "{a=1}") - } -} - -func TestFormatOrderedMapValue_IntegerKeys(t *testing.T) { - m := OrderedMapValue{Keys: []any{int32(1)}, Values: []any{"one"}} - got := formatOrderedMapValue(m) - if got != "{1=one}" { - t.Errorf("formatOrderedMapValue = %q, want %q", got, "{1=one}") - } -} - -func TestFormatOrderedMapValue_Empty(t *testing.T) { - m := OrderedMapValue{Keys: []any{}, Values: []any{}} - got := formatOrderedMapValue(m) - if got != "{}" { - t.Errorf("formatOrderedMapValue = %q, want %q", got, "{}") - } -} - -func TestFormatOrderedMapValue_NilValue(t *testing.T) { - m := OrderedMapValue{Keys: []any{"k"}, Values: []any{nil}} - got := formatOrderedMapValue(m) - if got != "{k=}" { - t.Errorf("formatOrderedMapValue = %q, want %q", got, "{k=}") - } -} - -func TestFormatOrderedMapValue_PreservesOrder(t *testing.T) { - // Verifies that key order in output matches Keys slice. - m := OrderedMapValue{ - Keys: []any{"z", "a", "m"}, - Values: []any{int32(1), int32(2), int32(3)}, - } - got := formatOrderedMapValue(m) - expected := "{z=1, a=2, m=3}" - if got != expected { - t.Errorf("formatOrderedMapValue = %q, want %q", got, expected) - } -} diff --git a/server/flight_executor_test.go b/server/flightclient/flight_executor_test.go similarity index 97% rename from server/flight_executor_test.go rename to server/flightclient/flight_executor_test.go index a60ddeae..fc2e1705 100644 --- a/server/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -1,4 +1,4 @@ -package server +package flightclient import ( "context" diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index 6d0d84ce..a3169f25 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -24,6 +24,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/duckdbservice/arrowmap" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -64,7 +65,7 @@ type Config struct { } type SessionProvider interface { - CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) + CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) DestroySession(int32) } @@ -108,7 +109,7 @@ type sessionMetadataProvider interface { } type sessionReconnector interface { - ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) + ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) } type durableSessionStoreProvider interface { @@ -225,8 +226,8 @@ func NewFlightIngressFromListener(baseListener net.Listener, tlsConfig *tls.Conf handler.rateLimiter = opts.RateLimiter grpcOpts := []grpc.ServerOption{ - grpc.MaxRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxSendMsgSize(flightclient.MaxGRPCMessageSize), } srv := flight.NewServerWithMiddleware(nil, grpcOpts...) @@ -1118,7 +1119,7 @@ type flightClientSession struct { pid int32 token string username string - executor *server.FlightExecutor + executor *flightclient.FlightExecutor queryFn func(context.Context, string, ...any) (server.RowSet, error) execFn func(context.Context, string, ...any) (server.ExecResult, error) @@ -1141,7 +1142,7 @@ type flightClientSession struct { afterTxnControlExecHook func(string) } -func newFlightClientSession(pid int32, username string, executor *server.FlightExecutor) *flightClientSession { +func newFlightClientSession(pid int32, username string, executor *flightclient.FlightExecutor) *flightClientSession { s := &flightClientSession{ pid: pid, username: username, @@ -1408,7 +1409,7 @@ type flightAuthSessionStore struct { workerQueueTimeout time.Duration hooks Hooks - createSessionFn func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) + createSessionFn func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) destroySessionFn func(int32) metadataProvider sessionMetadataProvider reconnector sessionReconnector @@ -1438,7 +1439,7 @@ func (r *lockedRowSet) Close() error { } func newFlightAuthSessionStore(provider SessionProvider, idleTTL, reapInterval, handleIdleTTL, tokenTTL, workerQueueTimeout time.Duration, opts Options) *flightAuthSessionStore { - createFn := func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createFn := func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("session provider is not configured") } destroyFn := func(int32) {} diff --git a/server/flightsqlingress/ingress_test.go b/server/flightsqlingress/ingress_test.go index 2f10e128..4db0d5da 100644 --- a/server/flightsqlingress/ingress_test.go +++ b/server/flightsqlingress/ingress_test.go @@ -15,6 +15,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" "google.golang.org/grpc" @@ -163,14 +164,14 @@ func (s *captureDurableSessionStore) CloseSession(sessionToken string, closedAt } type testDurableSessionProvider struct { - createSessionFn func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) + createSessionFn func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) destroySessionFn func(int32) metadataFn func(pid int32, username string) (DurableSessionMetadata, error) - reconnectSessionFn func(context.Context, DurableSessionRecord) (int32, *server.FlightExecutor, error) + reconnectSessionFn func(context.Context, DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) durableStore DurableSessionStore } -func (p *testDurableSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { +func (p *testDurableSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { return p.createSessionFn(ctx, username, pid, memoryLimit, threads) } @@ -187,7 +188,7 @@ func (p *testDurableSessionProvider) DurableSessionMetadata(pid int32, username return p.metadataFn(pid, username) } -func (p *testDurableSessionProvider) ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { +func (p *testDurableSessionProvider) ReconnectSession(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { return p.reconnectSessionFn(ctx, record) } @@ -324,7 +325,7 @@ func testFlightHandlerWithStoreAndRateLimiter(t *testing.T, users map[string]str sessions: make(map[string]*flightClientSession), stopCh: make(chan struct{}), doneCh: make(chan struct{}), - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 1234, nil, nil }, destroySessionFn: func(int32) {}, @@ -579,7 +580,7 @@ func TestSessionFromContextAcceptsServerIssuedSessionTokenWithoutBasicAuth(t *te func TestSessionFromContextRejectsUnknownSessionTokenEvenWithBasicAuth(t *testing.T) { store := &flightAuthSessionStore{ - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 9876, nil, nil }, destroySessionFn: func(int32) {}, @@ -682,7 +683,7 @@ func TestSessionFromContextTokenPathDoesNotClearRateLimiterFailures(t *testing.T func TestSessionFromContextWithoutTokenCreatesDistinctSessions(t *testing.T) { var createCalls atomic.Int32 store := &flightAuthSessionStore{ - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return createCalls.Add(1), nil, nil }, destroySessionFn: func(int32) {}, @@ -834,7 +835,7 @@ func TestFlightSessionTokenLifecycleIssueValidateRevokeExpiryMatrix(t *testing.T idleTTL: time.Minute, handleIdleTTL: time.Minute, tokenTTL: time.Hour, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 1234, nil, nil }, destroySessionFn: func(pid int32) { @@ -947,7 +948,7 @@ func TestFlightAuthSessionStorePersistsDurableSessionRecordOnCreate(t *testing.T durable := &captureDurableSessionStore{} provider := &testDurableSessionProvider{ durableStore: durable, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 4321, nil, nil }, metadataFn: func(pid int32, username string) (DurableSessionMetadata, error) { @@ -1026,10 +1027,10 @@ func TestFlightAuthSessionStoreReconnectsDurableSessionByToken(t *testing.T) { CPInstanceID: "cp-old:boot-a", }, nil }, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("unexpected create path") }, - reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { reconnected = record return 9876, nil, nil }, @@ -1081,10 +1082,10 @@ func TestFlightAuthSessionStoreRejectsClosedDurableSessionToken(t *testing.T) { reconnectCalls := 0 provider := &testDurableSessionProvider{ durableStore: durable, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("unexpected create path") }, - reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { reconnectCalls++ return 9876, nil, nil }, @@ -1129,10 +1130,10 @@ func TestFlightAuthSessionStoreReconnectRefreshesDurableSessionMetadata(t *testi CPInstanceID: "cp-new:boot-b", }, nil }, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("unexpected create path") }, - reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { return 9876, nil, nil }, } @@ -1203,7 +1204,7 @@ func TestFlightAuthSessionStoreReconnectFailureUpdatesDurableSessionState(t *tes reconnectCalls := 0 provider := &testDurableSessionProvider{ durableStore: durable, - reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *server.FlightExecutor, error) { + reconnectSessionFn: func(ctx context.Context, record DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) { reconnectCalls++ return 0, nil, tt.reconnectErr }, @@ -1237,7 +1238,7 @@ func TestFlightAuthSessionStoreReconnectFailureUpdatesDurableSessionState(t *tes func TestFlightAuthSessionStoreRejectsNewSessionsWhileDraining(t *testing.T) { provider := &testDurableSessionProvider{ - createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { return 321, nil, nil }, } @@ -1265,7 +1266,7 @@ func TestFlightAuthSessionStoreRejectsNewSessionsWhileDraining(t *testing.T) { func TestFlightAuthSessionStoreWaitForZeroSessions(t *testing.T) { provider := &testDurableSessionProvider{ - createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { return 654, nil, nil }, } @@ -1357,7 +1358,7 @@ func TestCloseSessionRevokesTokenAndDestroysWorker(t *testing.T) { func TestCloseSessionMissingTokenDoesNotBootstrap(t *testing.T) { var createCalls atomic.Int32 store := &flightAuthSessionStore{ - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { createCalls.Add(1) return 1234, nil, nil }, @@ -1389,7 +1390,7 @@ func TestCloseSessionTokenOnlyRevokesTokenAndDoesNotBootstrap(t *testing.T) { var createCalls atomic.Int32 var destroyed []int32 store := &flightAuthSessionStore{ - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { createCalls.Add(1) return 9876, nil, nil }, @@ -1631,7 +1632,7 @@ func TestFlightAuthSessionStoreReapHookReceivesTrigger(t *testing.T) { reapedCount = count }, }, - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("not used") }, destroySessionFn: func(int32) {}, @@ -1666,7 +1667,7 @@ func TestFlightAuthSessionStoreReapKeepsSessionWithFreshHandle(t *testing.T) { }, stopCh: make(chan struct{}), doneCh: make(chan struct{}), - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("not used") }, destroySessionFn: func(pid int32) { @@ -1703,7 +1704,7 @@ func TestFlightAuthSessionStoreReapStaleHandleAllowsSessionReap(t *testing.T) { }, stopCh: make(chan struct{}), doneCh: make(chan struct{}), - createSessionFn: func(context.Context, string, int32, string, int) (int32, *server.FlightExecutor, error) { + createSessionFn: func(context.Context, string, int32, string, int) (int32, *flightclient.FlightExecutor, error) { return 0, nil, fmt.Errorf("not used") }, destroySessionFn: func(pid int32) { diff --git a/server/format_ordered_map_test.go b/server/format_ordered_map_test.go new file mode 100644 index 00000000..9e9d1f2f --- /dev/null +++ b/server/format_ordered_map_test.go @@ -0,0 +1,56 @@ +package server + +import ( + "testing" + + "github.com/posthog/duckgres/duckdbservice/arrowmap" +) + +// formatOrderedMapValue tests. Live here in package server because the +// function under test calls formatValue (defined in conn.go) which switches +// on the duckdb-go driver's value types — it can't move to a duckdb-free +// subpackage. + +func TestFormatOrderedMapValue_Basic(t *testing.T) { + m := arrowmap.OrderedMapValue{Keys: []any{"a"}, Values: []any{int32(1)}} + got := formatOrderedMapValue(m) + if got != "{a=1}" { + t.Errorf("formatOrderedMapValue = %q, want %q", got, "{a=1}") + } +} + +func TestFormatOrderedMapValue_IntegerKeys(t *testing.T) { + m := arrowmap.OrderedMapValue{Keys: []any{int32(1)}, Values: []any{"one"}} + got := formatOrderedMapValue(m) + if got != "{1=one}" { + t.Errorf("formatOrderedMapValue = %q, want %q", got, "{1=one}") + } +} + +func TestFormatOrderedMapValue_Empty(t *testing.T) { + m := arrowmap.OrderedMapValue{Keys: []any{}, Values: []any{}} + got := formatOrderedMapValue(m) + if got != "{}" { + t.Errorf("formatOrderedMapValue = %q, want %q", got, "{}") + } +} + +func TestFormatOrderedMapValue_NilValue(t *testing.T) { + m := arrowmap.OrderedMapValue{Keys: []any{"k"}, Values: []any{nil}} + got := formatOrderedMapValue(m) + if got != "{k=}" { + t.Errorf("formatOrderedMapValue = %q, want %q", got, "{k=}") + } +} + +func TestFormatOrderedMapValue_PreservesOrder(t *testing.T) { + m := arrowmap.OrderedMapValue{ + Keys: []any{"z", "a", "m"}, + Values: []any{int32(1), int32(2), int32(3)}, + } + got := formatOrderedMapValue(m) + expected := "{z=1, a=2, m=3}" + if got != expected { + t.Errorf("formatOrderedMapValue = %q, want %q", got, expected) + } +} diff --git a/server/types.go b/server/types.go index 9d158a6f..f59dd238 100644 --- a/server/types.go +++ b/server/types.go @@ -12,6 +12,7 @@ import ( "time" duckdb "github.com/duckdb/duckdb-go/v2" + "github.com/posthog/duckgres/duckdbservice/arrowmap" ) // PostgreSQL type OIDs @@ -616,8 +617,8 @@ func encodeInterval(v interface{}) []byte { binary.BigEndian.PutUint64(buf[0:8], uint64(val.Micros)) binary.BigEndian.PutUint32(buf[8:12], uint32(val.Days)) binary.BigEndian.PutUint32(buf[12:16], uint32(val.Months)) - case intervalValue: - // Arrow Flight returns intervalValue instead of duckdb.Interval + case arrowmap.IntervalValue: + // Arrow Flight returns arrowmap.IntervalValue instead of duckdb.Interval binary.BigEndian.PutUint64(buf[0:8], uint64(val.Micros)) binary.BigEndian.PutUint32(buf[8:12], uint32(val.Days)) binary.BigEndian.PutUint32(buf[12:16], uint32(val.Months)) diff --git a/tests/controlplane/flight_ingress_test.go b/tests/controlplane/flight_ingress_test.go index 8db8034d..f5ec7990 100644 --- a/tests/controlplane/flight_ingress_test.go +++ b/tests/controlplane/flight_ingress_test.go @@ -17,7 +17,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -35,8 +35,8 @@ func newFlightClient(t *testing.T, port int) *flightsql.Client { client, err := flightsql.NewClient(addr, nil, nil, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize), ), ) if err != nil { diff --git a/tests/perf/drivers/flight/driver.go b/tests/perf/drivers/flight/driver.go index 80caf680..0e535d1e 100644 --- a/tests/perf/drivers/flight/driver.go +++ b/tests/perf/drivers/flight/driver.go @@ -8,7 +8,7 @@ import ( "time" "github.com/apache/arrow-go/v18/arrow/flight/flightsql" - "github.com/posthog/duckgres/server" + "github.com/posthog/duckgres/server/flightclient" "github.com/posthog/duckgres/tests/perf/core" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -36,8 +36,8 @@ func NewFromAddress(addr, username, password string, insecureSkipVerify bool) (* nil, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize), - grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize), + grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize), + grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize), ), ) if err != nil { @@ -49,7 +49,7 @@ func NewFromAddress(addr, username, password string, insecureSkipVerify bool) (* _ = client.Close() return nil, err } - exec := server.NewFlightExecutorFromClient(client, token) + exec := flightclient.NewFlightExecutorFromClient(client, token) return &Driver{ exec: &flightExecutor{ client: client, @@ -87,7 +87,7 @@ func (d *Driver) Close() error { type flightExecutor struct { client *flightsql.Client - exec *server.FlightExecutor + exec *flightclient.FlightExecutor } func (e *flightExecutor) Execute(ctx context.Context, query string, args []any) (int64, error) {