Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions controlplane/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/posthog/duckgres/controlplane/configstore"
"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/ducklake"
"github.com/posthog/duckgres/server/flightclient"
"github.com/posthog/duckgres/server/flightsqlingress"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
Expand Down Expand Up @@ -631,8 +632,8 @@ func createSessionWithRegisteredCancel(
srv *server.Server,
timeout time.Duration,
key server.BackendKey,
createFn func(context.Context) (int32, *server.FlightExecutor, error),
) (int32, *server.FlightExecutor, error) {
createFn func(context.Context) (int32, *flightclient.FlightExecutor, error),
) (int32, *flightclient.FlightExecutor, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

Expand Down Expand Up @@ -983,7 +984,7 @@ func (cp *ControlPlane) handleConnection(conn net.Conn) {
cp.srv,
cp.cfg.WorkerQueueTimeout,
server.BackendKey{Pid: pid, SecretKey: secretKey},
func(ctx context.Context) (int32, *server.FlightExecutor, error) {
func(ctx context.Context) (int32, *flightclient.FlightExecutor, error) {
return sessions.CreateSession(ctx, username, pid, memLimit, threads)
},
)
Expand Down
3 changes: 2 additions & 1 deletion controlplane/control_cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
)

func TestCreateSessionWithRegisteredCancel_CancelQueryCancelsWait(t *testing.T) {
Expand All @@ -24,7 +25,7 @@ func TestCreateSessionWithRegisteredCancel_CancelQueryCancelsWait(t *testing.T)
srv,
200*time.Millisecond,
key,
func(ctx context.Context) (int32, *server.FlightExecutor, error) {
func(ctx context.Context) (int32, *flightclient.FlightExecutor, error) {
close(started)
<-ctx.Done()
return 0, nil, ctx.Err()
Expand Down
7 changes: 4 additions & 3 deletions controlplane/flight_ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/posthog/duckgres/controlplane/configstore"
"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
"github.com/posthog/duckgres/server/flightsqlingress"
)

Expand Down Expand Up @@ -45,7 +46,7 @@ type flightSessionProvider struct {
sm *SessionManager
}

func (p *flightSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) {
func (p *flightSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
workerPID, executor, err := p.sm.CreateSession(ctx, username, pid, memoryLimit, threads)
if err != nil {
return 0, nil, err
Expand Down Expand Up @@ -74,7 +75,7 @@ type orgRoutedSessionProvider struct {
userOrg map[string]string // username → orgID (populated during auth)
}

func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) {
func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
p.mu.RLock()
orgID := p.userOrg[username]
p.mu.RUnlock()
Expand Down Expand Up @@ -141,7 +142,7 @@ func (p *orgRoutedSessionProvider) DurableSessionMetadata(pid int32, username st
}, nil
}

func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record flightsqlingress.DurableSessionRecord) (int32, *server.FlightExecutor, error) {
func (p *orgRoutedSessionProvider) ReconnectSession(ctx context.Context, record flightsqlingress.DurableSessionRecord) (int32, *flightclient.FlightExecutor, error) {
_, sessions, _, ok := p.orgRouter.StackForOrg(record.OrgID)
if !ok {
return 0, nil, fmt.Errorf("no org stack for org %q", record.OrgID)
Expand Down
9 changes: 5 additions & 4 deletions controlplane/k8s_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/flight/flightsql"
"github.com/posthog/duckgres/controlplane/configstore"
"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -932,8 +933,8 @@ func waitForWorkerTCPWithMetadata(addr, bearerToken string, serverCertPEM []byte
var dialOpts []grpc.DialOption
dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize),
))
dialOpts = append(dialOpts, server.OTELGRPCClientHandler())
if bearerToken != "" {
Expand Down Expand Up @@ -1640,8 +1641,8 @@ func (p *K8sWorkerPool) connectWorkerDirect(ctx context.Context, podName, podIP,
var dialOpts []grpc.DialOption
dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize),
))
dialOpts = append(dialOpts, server.OTELGRPCClientHandler())
if bearerToken != "" {
Expand Down
11 changes: 6 additions & 5 deletions controlplane/session_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
)

// SessionProgress holds cached query progress from a worker health check.
Expand All @@ -26,7 +27,7 @@ type ManagedSession struct {
WorkerID int
Protocol string // "postgres" or "flight"
SessionToken string
Executor *server.FlightExecutor
Executor *flightclient.FlightExecutor
connCloser io.Closer // TCP connection, closed on worker crash to unblock the message loop

// Cached query progress from worker health checks.
Expand Down Expand Up @@ -68,7 +69,7 @@ func (sm *SessionManager) ReservePID() int32 {
// CreateSession acquires a worker (reusing an idle one or spawning a new one),
// creates a session on it, and rebalances memory/thread limits across all active sessions.
// If pid is 0, a new one is generated.
func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *server.FlightExecutor, error) {
func (sm *SessionManager) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
memoryLimit, threads = sm.resolveSessionLimits(memoryLimit, threads)

// Acquire a worker: reuses idle pre-warmed workers or spawns a new one.
Expand Down Expand Up @@ -102,7 +103,7 @@ func (sm *SessionManager) resolveSessionLimits(memoryLimit string, threads int)
return memoryLimit, threads
}

func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username string, workerID int, ownerEpoch int64) (int32, *server.FlightExecutor, error) {
func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username string, workerID int, ownerEpoch int64) (int32, *flightclient.FlightExecutor, error) {
reconnector, ok := sm.pool.(flightReconnectPool)
if !ok {
return 0, nil, fmt.Errorf("worker pool does not support flight reconnect")
Expand All @@ -114,7 +115,7 @@ func (sm *SessionManager) ReconnectFlightSession(ctx context.Context, username s
return sm.createSessionOnWorker(ctx, username, 0, "", 0, worker, "flight", false)
}

func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool) (int32, *server.FlightExecutor, error) {
func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username string, pid int32, memoryLimit string, threads int, worker *ManagedWorker, protocol string, retireOnFailure bool) (int32, *flightclient.FlightExecutor, error) {
createStart := time.Now()
sessionToken, err := worker.CreateSession(ctx, username, memoryLimit, threads)
if err != nil {
Expand All @@ -124,7 +125,7 @@ func (sm *SessionManager) createSessionOnWorker(ctx context.Context, username st
return 0, nil, fmt.Errorf("create session on worker %d: %w", worker.ID, err)
}

executor := server.NewFlightExecutorFromClient(worker.client, sessionToken)
executor := flightclient.NewFlightExecutorFromClient(worker.client, sessionToken)
executor.SetControlMetadata(worker.ID, worker.OwnerCPInstanceID(), worker.OwnerEpoch())

if pid == 0 {
Expand Down
12 changes: 6 additions & 6 deletions controlplane/session_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"sync/atomic"
"testing"

"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
)

// mockCloser tracks whether Close was called.
Expand All @@ -27,7 +27,7 @@ func TestOnWorkerCrash_MarksExecutorsDead(t *testing.T) {
}
sm := NewSessionManager(pool, nil)

executor := &server.FlightExecutor{}
executor := &flightclient.FlightExecutor{}
pid := int32(1001)

sm.mu.Lock()
Expand Down Expand Up @@ -67,7 +67,7 @@ func TestOnWorkerCrash_ClosesConnections(t *testing.T) {
sm := NewSessionManager(pool, nil)

conn := &mockCloser{}
executor := &server.FlightExecutor{}
executor := &flightclient.FlightExecutor{}
pid := int32(1002)

sm.mu.Lock()
Expand All @@ -93,8 +93,8 @@ func TestOnWorkerCrash_MultipleSessions(t *testing.T) {
}
sm := NewSessionManager(pool, nil)

exec1 := &server.FlightExecutor{}
exec2 := &server.FlightExecutor{}
exec1 := &flightclient.FlightExecutor{}
exec2 := &flightclient.FlightExecutor{}
conn1 := &mockCloser{}
conn2 := &mockCloser{}

Expand Down Expand Up @@ -228,7 +228,7 @@ func TestDestroySessionAfterOnWorkerCrash(t *testing.T) {
sm := NewSessionManager(pool, nil)

conn := &mockCloser{}
executor := &server.FlightExecutor{}
executor := &flightclient.FlightExecutor{}
pid := int32(1010)

sm.mu.Lock()
Expand Down
5 changes: 3 additions & 2 deletions controlplane/worker_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/flight"
"github.com/apache/arrow-go/v18/arrow/flight/flightsql"
"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
Expand Down Expand Up @@ -454,8 +455,8 @@ func waitForWorker(socketPath, bearerToken string, timeout time.Duration) (*flig
var dialOpts []grpc.DialOption
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(server.MaxGRPCMessageSize),
grpc.MaxCallRecvMsgSize(flightclient.MaxGRPCMessageSize),
grpc.MaxCallSendMsgSize(flightclient.MaxGRPCMessageSize),
))

if bearerToken != "" {
Expand Down
16 changes: 13 additions & 3 deletions duckdbservice/arrowmap/arrowmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,24 @@ func QuoteIdent(ident string) string {
// and preserves the source MAP ordering.
//
// Lives in arrowmap so AppendValue can switch on it without depending on
// the server package (which transitively links libduckdb). The flight
// executor in the server package re-exports it as server.OrderedMapValue
// via a type alias for backward compatibility.
// the server package, and the flight client + result formatters in
// server/ can both reference it without creating an import cycle.
type OrderedMapValue struct {
Keys []any
Values []any
}

// IntervalValue is the duckdb-free representation of an Arrow
// MonthDayNanoInterval as decoded by the Flight client. It stores the
// component fields directly (Months/Days/Micros) so result formatters
// in the server package can switch on the type without importing the
// flight subpackage (which would create a cycle).
type IntervalValue struct {
Months int32
Days int32
Micros int64
}

// Appender is a hook that handles append for value types arrowmap doesn't
// know about (typically driver-specific types like duckdb.Interval). It
// reports whether it handled the value; arrowmap.AppendValue falls back to
Expand Down
5 changes: 3 additions & 2 deletions duckdbservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/flight/flightsql"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/posthog/duckgres/server"
"github.com/posthog/duckgres/server/flightclient"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
Expand Down Expand Up @@ -289,8 +290,8 @@ func (svc *DuckDBService) Serve(listener net.Listener) error {

var opts []grpc.ServerOption
opts = append(opts,
grpc.MaxRecvMsgSize(server.MaxGRPCMessageSize),
grpc.MaxSendMsgSize(server.MaxGRPCMessageSize),
grpc.MaxRecvMsgSize(flightclient.MaxGRPCMessageSize),
grpc.MaxSendMsgSize(flightclient.MaxGRPCMessageSize),
)
if svc.cfg.BearerToken != "" {
opts = append(opts,
Expand Down
52 changes: 12 additions & 40 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (

duckdb "github.com/duckdb/duckdb-go/v2"
pg_query "github.com/pganalyze/pg_query_go/v6"
"github.com/posthog/duckgres/duckdbservice/arrowmap"
"github.com/posthog/duckgres/server/auth"
"github.com/posthog/duckgres/server/sqlcore"
"github.com/posthog/duckgres/transpiler"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -2475,41 +2477,11 @@ func countDollarParams(query string) int {
return max
}

// isEmptyQuery checks if a query contains only semicolons, whitespace, and/or comments.
// PostgreSQL returns EmptyQueryResponse for queries like ";", ";;;", "-- ping", etc.
func isEmptyQuery(query string) bool {
// Strip SQL comments first (e.g., pgx sends "-- ping" for Ping())
stripped := stripLeadingComments(query)
for _, r := range stripped {
if r != ';' && r != ' ' && r != '\t' && r != '\n' && r != '\r' {
return false
}
}
return true
}

// stripLeadingComments removes leading SQL comments from a query.
// Handles both block comments /* ... */ and line comments -- ...
func stripLeadingComments(query string) string {
for {
query = strings.TrimSpace(query)
if strings.HasPrefix(query, "/*") {
end := strings.Index(query, "*/")
if end == -1 {
return query
}
query = query[end+2:]
} else if strings.HasPrefix(query, "--") {
end := strings.Index(query, "\n")
if end == -1 {
return ""
}
query = query[end+1:]
} else {
return query
}
}
}
// isEmptyQuery and stripLeadingComments moved to server/sqlcore so the
// Flight client can call them without importing server. Local thin wrappers
// preserve the unexported call-site spellings used throughout this file.
func isEmptyQuery(query string) bool { return sqlcore.IsEmptyQuery(query) }
func stripLeadingComments(query string) string { return sqlcore.StripLeadingComments(query) }

// stripLeadingNoise strips leading whitespace, comments, and parentheses from
// a query string in a loop until none remain. This handles cases like
Expand Down Expand Up @@ -4308,7 +4280,7 @@ func (c *clientConn) formatCopyValue(v interface{}) string {
return formatArrayValue(val)
case map[string]any:
return formatMapValue(val)
case OrderedMapValue:
case arrowmap.OrderedMapValue:
return formatOrderedMapValue(val)
default:
return fmt.Sprintf("%v", val)
Expand Down Expand Up @@ -4512,13 +4484,13 @@ func formatValue(v interface{}) string {
case duckdb.Interval:
// PostgreSQL interval text format: "1 year 2 mons 3 days 04:05:06.123456"
return formatInterval(val)
case intervalValue:
// Arrow Flight returns intervalValue instead of duckdb.Interval
case arrowmap.IntervalValue:
// Arrow Flight returns arrowmap.IntervalValue instead of duckdb.Interval
return formatInterval(duckdb.Interval{Months: val.Months, Days: val.Days, Micros: val.Micros})
case map[string]any:
// STRUCT text format: {"key1": val1, "key2": val2}
return formatMapValue(val)
case OrderedMapValue:
case arrowmap.OrderedMapValue:
return formatOrderedMapValue(val)
default:
// For other types, try to convert to string
Expand Down Expand Up @@ -4647,7 +4619,7 @@ func formatMapValue(m map[string]any) string {

// formatOrderedMapValue formats an OrderedMapValue as a key-value text
// representation, preserving the original insertion order from the Arrow array.
func formatOrderedMapValue(m OrderedMapValue) string {
func formatOrderedMapValue(m arrowmap.OrderedMapValue) string {
var buf strings.Builder
buf.WriteByte('{')
for i, k := range m.Keys {
Expand Down
Loading
Loading