From ce06834c8e6925fe8f009150bfa56395d2ebb497 Mon Sep 17 00:00:00 2001 From: James Greenhill Date: Wed, 3 Dec 2025 17:12:50 -0800 Subject: [PATCH] Add COPY protocol, graceful shutdown, and enhanced pg_catalog support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Features: - COPY protocol: Support COPY TO STDOUT and COPY FROM STDIN for bulk data - Graceful shutdown: Wait for in-flight queries with configurable timeout - Enhanced pg_catalog: Full support for psql \d tablename command pg_catalog additions: - pg_class_full view with relforcerowsecurity column - pg_collation, pg_policy, pg_roles, pg_statistic_ext views - pg_publication, pg_publication_rel, pg_publication_tables views - pg_inherits with inhdetachpending column - pg_get_expr macro with default pretty parameter - Query rewrites for ::regclass and ::regnamespace casts Also adds CLAUDE.md for AI assistant context. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- CLAUDE.md | 124 ++++++++++++++++++++ README.md | 45 +++++++- TODO.md | 10 +- server/catalog.go | 248 ++++++++++++++++++++++++++++++++++++++++ server/conn.go | 278 +++++++++++++++++++++++++++++++++++++++++++++ server/protocol.go | 78 +++++++++++++ server/server.go | 85 +++++++++++++- 7 files changed, 861 insertions(+), 7 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..e1fdc98c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,124 @@ +# Claude Code Context for Duckgres + +This file provides context for Claude Code sessions working on this codebase. + +## Project Overview + +Duckgres is a PostgreSQL wire protocol server backed by DuckDB. It allows any PostgreSQL client (psql, pgAdmin, lib/pq, psycopg2, JDBC, etc.) to connect and execute queries against DuckDB databases. + +## Architecture + +``` +PostgreSQL Client → TLS → Duckgres Server → DuckDB (per-user database) +``` + +### Key Components + +- **main.go**: Entry point, configuration loading (CLI flags, env vars, YAML) +- **server/server.go**: Server struct, connection handling, graceful shutdown +- **server/conn.go**: Client connection handling, query execution, COPY protocol +- **server/protocol.go**: PostgreSQL wire protocol message encoding/decoding +- **server/catalog.go**: pg_catalog compatibility (views, functions, query rewriting) +- **server/types.go**: Type OID mapping between DuckDB and PostgreSQL +- **server/ratelimit.go**: Rate limiting for brute-force protection +- **server/tls.go**: Auto-generation of self-signed TLS certificates + +## PostgreSQL Wire Protocol + +The server implements the PostgreSQL v3 protocol: + +### Message Types (server/protocol.go) +- **Frontend (client→server)**: Query, Parse, Bind, Describe, Execute, Sync, Close, CopyData, CopyDone +- **Backend (server→client)**: AuthOK, RowDescription, DataRow, CommandComplete, ReadyForQuery, CopyInResponse, CopyOutResponse + +### Query Flow +1. Client sends Query message ('Q') +2. Server parses SQL, rewrites pg_catalog references +3. Server executes via DuckDB's database/sql driver +4. Server sends RowDescription + DataRow messages +5. Server sends CommandComplete + ReadyForQuery + +### Extended Query Protocol +Supports prepared statements (Parse/Bind/Execute) for parameterized queries and binary result formats. + +## pg_catalog Compatibility (server/catalog.go) + +psql and other clients expect PostgreSQL system catalogs. We provide compatibility by: + +1. **Creating views** in main schema that mirror pg_catalog tables: + - `pg_database`, `pg_class_full`, `pg_collation`, `pg_policy`, `pg_roles` + - `pg_statistic_ext`, `pg_publication`, `pg_publication_rel`, `pg_inherits`, etc. + +2. **Creating macros** for PostgreSQL functions: + - `pg_get_userbyid`, `pg_table_is_visible`, `format_type`, `pg_get_expr` + - `obj_description`, `col_description`, `pg_get_indexdef`, etc. + +3. **Query rewriting** to replace PostgreSQL-specific syntax: + - `pg_catalog.pg_class` → `pg_class_full` + - `OPERATOR(pg_catalog.~)` → `~` + - `::pg_catalog.regtype` → `::VARCHAR` + +## COPY Protocol (server/conn.go) + +Supports bulk data transfer: +- **COPY TO STDOUT**: Streams query results to client +- **COPY FROM STDIN**: Receives data from client, inserts row by row +- Supports CSV format with HEADER option + +## Configuration + +Three-tier configuration (highest to lowest priority): +1. CLI flags (`--port`, `--config`, etc.) +2. Environment variables (`DUCKGRES_PORT`, etc.) +3. YAML config file +4. Built-in defaults + +## Testing + +```bash +# Build +go build -o duckgres . + +# Run on non-standard port +./duckgres --port 35437 + +# Connect with psql +PGPASSWORD=postgres psql "host=127.0.0.1 port=35437 user=postgres sslmode=require" + +# Test commands +\dt # List tables +\d tablename # Describe table +\l # List databases +``` + +## Common Development Tasks + +### Adding a new pg_catalog view +1. Add view creation SQL in `initPgCatalog()` in `catalog.go` +2. Add regex pattern to rewrite `pg_catalog.viewname` to `viewname` +3. Add the replacement in `rewritePgCatalogQuery()` + +### Adding a new PostgreSQL function +1. Add `CREATE MACRO` in the `functions` slice in `initPgCatalog()` +2. Add function name to `pgCatalogFunctions` slice for query rewriting + +### Adding protocol support +1. Add message type constant in `protocol.go` +2. Add write function (e.g., `writeCopyData()`) +3. Handle in message loop in `conn.go` + +## Dependencies + +- `github.com/duckdb/duckdb-go/v2` - DuckDB Go driver +- `gopkg.in/yaml.v3` - YAML config parsing + +## Known Limitations + +- Single process (all users share one process) +- No replication +- Some pg_catalog tables are stubs (return empty) +- Type OID mapping is incomplete (some types show as "unknown") + +## TODO Reference + +See `TODO.md` for the full feature roadmap and known issues. diff --git a/README.md b/README.md index 2af75c36..7e7091d8 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,13 @@ A PostgreSQL wire protocol compatible server backed by DuckDB. Connect with any - **PostgreSQL Wire Protocol**: Full compatibility with PostgreSQL clients - **TLS Encryption**: Required TLS connections with auto-generated self-signed certificates - **Per-User Databases**: Each authenticated user gets their own isolated DuckDB database file -- **Password Authentication**: MD5 password authentication +- **Password Authentication**: Cleartext password authentication over TLS - **Extended Query Protocol**: Support for prepared statements, binary format, and parameterized queries +- **COPY Protocol**: Bulk data import/export with `COPY FROM STDIN` and `COPY TO STDOUT` - **DuckDB Extensions**: Configurable extension loading (ducklake enabled by default) - **DuckLake Integration**: Auto-attach DuckLake catalogs for lakehouse workflows - **Rate Limiting**: Built-in protection against brute-force attacks +- **Graceful Shutdown**: Waits for in-flight queries before exiting - **Flexible Configuration**: YAML config files, environment variables, and CLI flags ## Quick Start @@ -144,6 +146,46 @@ ATTACH 'ducklake:postgres:host=localhost dbname=ducklake' (DATA_PATH 's3://my-bu See [DuckLake documentation](https://ducklake.select/docs/stable/duckdb/usage/connecting) for more details. +## COPY Protocol + +Duckgres supports PostgreSQL's COPY protocol for efficient bulk data import and export: + +```sql +-- Export data to stdout (tab-separated) +COPY tablename TO STDOUT; + +-- Export as CSV with headers +COPY tablename TO STDOUT WITH CSV HEADER; + +-- Export query results +COPY (SELECT * FROM tablename WHERE id > 100) TO STDOUT WITH CSV; + +-- Import data from stdin +COPY tablename FROM STDIN; + +-- Import CSV with headers +COPY tablename FROM STDIN WITH CSV HEADER; +``` + +This works with psql's `\copy` command and programmatic COPY operations from PostgreSQL drivers. + +## Graceful Shutdown + +Duckgres handles shutdown signals (SIGINT, SIGTERM) gracefully: + +- Stops accepting new connections immediately +- Waits for in-flight queries to complete (default 30s timeout) +- Logs active connection count during shutdown +- Closes all database connections cleanly + +The shutdown timeout can be configured: + +```go +cfg := server.Config{ + ShutdownTimeout: 60 * time.Second, +} +``` + ## Rate Limiting Built-in rate limiting protects against brute-force authentication attacks: @@ -219,6 +261,7 @@ GROUP BY name; - `DROP TABLE/INDEX/VIEW` - `ALTER TABLE` - `BEGIN/COMMIT/ROLLBACK` (DuckDB transaction support) +- `COPY` - Bulk data loading and export (see below) ### PostgreSQL Compatibility - Extended query protocol (prepared statements) diff --git a/TODO.md b/TODO.md index b091f222..c222a179 100644 --- a/TODO.md +++ b/TODO.md @@ -16,26 +16,26 @@ ### Protocol Compatibility - [ ] **Binary Format Support**: Encode results in binary format for better performance with some clients -- [ ] **COPY Protocol**: Support `COPY FROM`/`COPY TO` for bulk data loading +- [x] **COPY Protocol**: Support `COPY FROM`/`COPY TO` for bulk data loading - [ ] **Cancel Request Handling**: Properly cancel long-running queries ### Compatibility - [x] **System Catalog Emulation**: Basic `pg_catalog` compatibility for psql - [x] `\dt` (list tables) - working - [x] `\l` (list databases) - working - - [ ] `\d ` (describe table) - needs more pg_class columns + - [x] `\d
` (describe table) - working - [ ] **Information Schema**: Emulate PostgreSQL's `information_schema` - [ ] **Session Variables**: Support `SET` commands (timezone, search_path, etc.) - [ ] **Type OID Mapping**: Proper PostgreSQL OID mapping for all DuckDB types ### Features -- [ ] **Extensions**: Load DuckDB extensions on startup +- [x] **Extensions**: Load DuckDB extensions on startup ### Operations - [ ] **Hot Reload**: Reload config without restart - [ ] **Admin Commands**: `\duckgres status`, `\duckgres users`, etc. - [ ] **Docker Image**: Official container image -- [ ] **Graceful Shutdown**: Finish in-flight queries before shutdown +- [x] **Graceful Shutdown**: Finish in-flight queries before shutdown ## Medium Priority @@ -73,7 +73,7 @@ ## Known Issues - [ ] Some PostgreSQL drivers may fail with unsupported OIDs -- [ ] `\d` commands in psql don't work (need system catalog) +- [x] `\d` commands in psql don't work (need system catalog) - fixed - [ ] Transaction isolation may differ from PostgreSQL behavior - [ ] Large result sets may cause memory issues (no streaming) diff --git a/server/catalog.go b/server/catalog.go index 8ab46074..fcf597c5 100644 --- a/server/catalog.go +++ b/server/catalog.go @@ -27,6 +27,184 @@ func initPgCatalog(db *sql.DB) error { ` db.Exec(pgDatabaseSQL) + // Create pg_class wrapper that adds missing columns psql expects + // DuckDB's pg_catalog.pg_class is missing relforcerowsecurity + pgClassSQL := ` + CREATE OR REPLACE VIEW pg_class_full AS + SELECT + oid, + relname, + relnamespace, + reltype, + reloftype, + relowner, + relam, + relfilenode, + reltablespace, + relpages, + reltuples, + relallvisible, + reltoastrelid, + reltoastidxid, + relhasindex, + relisshared, + relpersistence, + relkind, + relnatts, + relchecks, + relhasoids, + relhaspkey, + relhasrules, + relhastriggers, + relhassubclass, + relrowsecurity, + false AS relforcerowsecurity, + relispopulated, + relreplident, + relispartition, + relrewrite, + relfrozenxid, + relminmxid, + relacl, + reloptions, + relpartbound + FROM pg_catalog.pg_class + ` + db.Exec(pgClassSQL) + + // Create pg_collation view (DuckDB doesn't have this) + pgCollationSQL := ` + CREATE OR REPLACE VIEW pg_collation AS + SELECT + 0::BIGINT AS oid, + 'default' AS collname, + 0::BIGINT AS collnamespace, + 0::INTEGER AS collowner, + 'c' AS collprovider, + true AS collisdeterministic, + 0::INTEGER AS collencoding, + 'C' AS collcollate, + 'C' AS collctype, + NULL AS collversion + WHERE false + ` + db.Exec(pgCollationSQL) + + // Create pg_policy view for row-level security (empty, DuckDB doesn't support RLS) + pgPolicySQL := ` + CREATE OR REPLACE VIEW pg_policy AS + SELECT + 0::BIGINT AS oid, + '' AS polname, + 0::BIGINT AS polrelid, + '*' AS polcmd, + true AS polpermissive, + ARRAY[]::BIGINT[] AS polroles, + NULL AS polqual, + NULL AS polwithcheck + WHERE false + ` + db.Exec(pgPolicySQL) + + // Create pg_roles view (minimal for psql compatibility) + pgRolesSQL := ` + CREATE OR REPLACE VIEW pg_roles AS + SELECT + 0::BIGINT AS oid, + 'duckdb' AS rolname, + true AS rolsuper, + true AS rolinherit, + true AS rolcreaterole, + true AS rolcreatedb, + true AS rolcanlogin, + false AS rolreplication, + false AS rolbypassrls, + -1::INTEGER AS rolconnlimit, + NULL AS rolpassword, + NULL AS rolvaliduntil, + ARRAY[]::VARCHAR[] AS rolconfig + ` + db.Exec(pgRolesSQL) + + // Create pg_statistic_ext view (extended statistics, empty) + pgStatisticExtSQL := ` + CREATE OR REPLACE VIEW pg_statistic_ext AS + SELECT + 0::BIGINT AS oid, + 0::BIGINT AS stxrelid, + 0::BIGINT AS stxnamespace, + '' AS stxname, + 0::INTEGER AS stxowner, + 0::INTEGER AS stxstattarget, + ARRAY[]::VARCHAR[] AS stxkeys, + ARRAY[]::VARCHAR[] AS stxkind + WHERE false + ` + db.Exec(pgStatisticExtSQL) + + // Create pg_publication_tables view (logical replication, empty) + pgPublicationTablesSQL := ` + CREATE OR REPLACE VIEW pg_publication_tables AS + SELECT + '' AS pubname, + '' AS schemaname, + '' AS tablename + WHERE false + ` + db.Exec(pgPublicationTablesSQL) + + // Create pg_rules view (empty, DuckDB doesn't have rules) + pgRulesSQL := ` + CREATE OR REPLACE VIEW pg_rules AS + SELECT + '' AS schemaname, + '' AS tablename, + '' AS rulename, + '' AS definition + WHERE false + ` + db.Exec(pgRulesSQL) + + // Create pg_publication view (logical replication, empty) + pgPublicationSQL := ` + CREATE OR REPLACE VIEW pg_publication AS + SELECT + 0::BIGINT AS oid, + '' AS pubname, + 0::INTEGER AS pubowner, + false AS puballtables, + false AS pubinsert, + false AS pubupdate, + false AS pubdelete, + false AS pubtruncate, + false AS pubviaroot + WHERE false + ` + db.Exec(pgPublicationSQL) + + // Create pg_publication_rel view (publication-relation mapping, empty) + pgPublicationRelSQL := ` + CREATE OR REPLACE VIEW pg_publication_rel AS + SELECT + 0::BIGINT AS oid, + 0::BIGINT AS prpubid, + 0::BIGINT AS prrelid + WHERE false + ` + db.Exec(pgPublicationRelSQL) + + // Create pg_inherits view (table inheritance, empty - DuckDB doesn't support inheritance) + pgInheritsSQL := ` + CREATE OR REPLACE VIEW pg_inherits AS + SELECT + 0::BIGINT AS inhrelid, + 0::BIGINT AS inhparent, + 0::INTEGER AS inhseqno, + false AS inhdetachpending + WHERE false + ` + db.Exec(pgInheritsSQL) + // Create helper macros/functions that psql expects but DuckDB doesn't have // These need to be created without schema prefix so DuckDB finds them functions := []string{ @@ -69,12 +247,20 @@ func initPgCatalog(db *sql.DB) error { `CREATE OR REPLACE MACRO col_description(table_oid, col_num) AS NULL`, // shobj_description - get shared object comment `CREATE OR REPLACE MACRO shobj_description(oid, catalog) AS NULL`, + // pg_get_expr - deparse an expression (used for defaults, etc.) + // Use default parameter so it works with both 2 and 3 args + `DROP MACRO IF EXISTS pg_get_expr`, + `CREATE MACRO pg_get_expr(expr, relid, pretty := false) AS NULL`, // pg_get_indexdef - get index definition `CREATE OR REPLACE MACRO pg_get_indexdef(index_oid) AS ''`, `CREATE OR REPLACE MACRO pg_get_indexdef(index_oid, col, pretty) AS ''`, // pg_get_constraintdef - get constraint definition `CREATE OR REPLACE MACRO pg_get_constraintdef(constraint_oid) AS ''`, `CREATE OR REPLACE MACRO pg_get_constraintdef(constraint_oid, pretty) AS ''`, + // pg_get_statisticsobjdef_columns - get column list for extended statistics + `CREATE OR REPLACE MACRO pg_get_statisticsobjdef_columns(stat_oid) AS ''`, + // pg_relation_is_publishable - check if relation can be published + `CREATE OR REPLACE MACRO pg_relation_is_publishable(rel_oid) AS false`, // current_setting - get config setting `CREATE OR REPLACE MACRO current_setting(name) AS CASE name @@ -108,6 +294,9 @@ var pgCatalogFunctions = []string{ "shobj_description", "pg_get_indexdef", "pg_get_constraintdef", + "pg_get_partkeydef", + "pg_get_statisticsobjdef_columns", + "pg_relation_is_publishable", "current_setting", "pg_is_in_recovery", "has_schema_privilege", @@ -132,10 +321,36 @@ var ( collateRegex = regexp.MustCompile(`(?i)\s+COLLATE\s+pg_catalog\."?default"?`) // pg_catalog.pg_database -> pg_database (use our view) pgDatabaseRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_database`) + // pg_catalog.pg_class -> pg_class_full (use our wrapper view with extra columns) + pgClassRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_class\b`) + // pg_catalog.pg_collation -> pg_collation (use our view) + pgCollationRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_collation\b`) + // pg_catalog.pg_policy -> pg_policy (use our view) + pgPolicyRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_policy\b`) + // pg_catalog.pg_roles -> pg_roles (use our view) + pgRolesRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_roles\b`) + // pg_catalog.pg_statistic_ext -> pg_statistic_ext (use our view) + pgStatisticExtRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_statistic_ext\b`) + // pg_catalog.pg_publication_tables -> pg_publication_tables (use our view) + pgPublicationTablesRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_publication_tables\b`) + // pg_catalog.pg_rules -> pg_rules (use our view) + pgRulesRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_rules\b`) + // pg_catalog.pg_publication -> pg_publication (use our view) + pgPublicationRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_publication\b`) + // pg_catalog.pg_publication_rel -> pg_publication_rel (use our view) + pgPublicationRelRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_publication_rel\b`) + // pg_catalog.pg_inherits -> pg_inherits (use our view) + pgInheritsRegex = regexp.MustCompile(`(?i)pg_catalog\.pg_inherits\b`) // ::pg_catalog.regtype::pg_catalog.text -> ::VARCHAR (PostgreSQL type cast) regtypeTextCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.regtype::pg_catalog\.text`) // ::pg_catalog.regtype -> ::VARCHAR regtypeCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.regtype`) + // ::pg_catalog.regclass -> ::VARCHAR + regclassCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.regclass`) + // ::pg_catalog.regnamespace::pg_catalog.text -> ::VARCHAR + regnamespaceTextCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.regnamespace::pg_catalog\.text`) + // ::pg_catalog.regnamespace -> ::VARCHAR + regnamespaceCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.regnamespace`) // ::pg_catalog.text -> ::VARCHAR textCastRegex = regexp.MustCompile(`(?i)::pg_catalog\.text`) ) @@ -154,9 +369,42 @@ func rewritePgCatalogQuery(query string) string { // Replace pg_catalog.pg_database with pg_database (our view in main schema) query = pgDatabaseRegex.ReplaceAllString(query, "pg_database") + // Replace pg_catalog.pg_class with pg_class_full (our wrapper view with extra columns) + query = pgClassRegex.ReplaceAllString(query, "pg_class_full") + + // Replace pg_catalog.pg_collation with pg_collation (our empty view) + query = pgCollationRegex.ReplaceAllString(query, "pg_collation") + + // Replace pg_catalog.pg_policy with pg_policy (our empty view) + query = pgPolicyRegex.ReplaceAllString(query, "pg_policy") + + // Replace pg_catalog.pg_roles with pg_roles (our view) + query = pgRolesRegex.ReplaceAllString(query, "pg_roles") + + // Replace pg_catalog.pg_statistic_ext with pg_statistic_ext (our view) + query = pgStatisticExtRegex.ReplaceAllString(query, "pg_statistic_ext") + + // Replace pg_catalog.pg_publication_tables with pg_publication_tables (our view) + query = pgPublicationTablesRegex.ReplaceAllString(query, "pg_publication_tables") + + // Replace pg_catalog.pg_rules with pg_rules (our view) + query = pgRulesRegex.ReplaceAllString(query, "pg_rules") + + // Replace pg_catalog.pg_publication with pg_publication (our view) + query = pgPublicationRegex.ReplaceAllString(query, "pg_publication") + + // Replace pg_catalog.pg_publication_rel with pg_publication_rel (our view) + query = pgPublicationRelRegex.ReplaceAllString(query, "pg_publication_rel") + + // Replace pg_catalog.pg_inherits with pg_inherits (our view) + query = pgInheritsRegex.ReplaceAllString(query, "pg_inherits") + // Replace PostgreSQL type casts (order matters - most specific first) query = regtypeTextCastRegex.ReplaceAllString(query, "::VARCHAR") query = regtypeCastRegex.ReplaceAllString(query, "::VARCHAR") + query = regclassCastRegex.ReplaceAllString(query, "::VARCHAR") + query = regnamespaceTextCastRegex.ReplaceAllString(query, "::VARCHAR") + query = regnamespaceCastRegex.ReplaceAllString(query, "::VARCHAR") query = textCastRegex.ReplaceAllString(query, "::VARCHAR") return query diff --git a/server/conn.go b/server/conn.go index c502472b..2985f795 100644 --- a/server/conn.go +++ b/server/conn.go @@ -11,6 +11,7 @@ import ( "log" "net" "os" + "regexp" "strings" ) @@ -266,6 +267,11 @@ func (c *clientConn) handleQuery(body []byte) error { upperQuery := strings.ToUpper(query) cmdType := c.getCommandType(upperQuery) + // Handle COPY commands specially + if cmdType == "COPY" { + return c.handleCopy(query, upperQuery) + } + // For non-SELECT queries, use Exec if cmdType != "SELECT" { result, err := c.db.Exec(query) @@ -405,6 +411,278 @@ func (c *clientConn) buildCommandTag(cmdType string, result sql.Result) string { } } +// Regular expressions for parsing COPY commands +var ( + copyToStdoutRegex = regexp.MustCompile(`(?i)COPY\s+(.+?)\s+TO\s+STDOUT`) + copyFromStdinRegex = regexp.MustCompile(`(?i)COPY\s+(\S+)\s+(?:\(([^)]+)\)\s+)?FROM\s+STDIN`) + copyWithCSVRegex = regexp.MustCompile(`(?i)\bCSV\b`) + copyWithHeaderRegex = regexp.MustCompile(`(?i)\bHEADER\b`) + copyDelimiterRegex = regexp.MustCompile(`(?i)\bDELIMITER\s+['"](.)['"]\b`) +) + +// handleCopy handles COPY TO STDOUT and COPY FROM STDIN commands +func (c *clientConn) handleCopy(query, upperQuery string) error { + // Check if it's COPY TO STDOUT + if copyToStdoutRegex.MatchString(upperQuery) { + return c.handleCopyOut(query, upperQuery) + } + + // Check if it's COPY FROM STDIN + if copyFromStdinRegex.MatchString(upperQuery) { + return c.handleCopyIn(query, upperQuery) + } + + // For other COPY commands (e.g., COPY TO file), pass through to DuckDB + result, err := c.db.Exec(query) + if err != nil { + c.sendError("ERROR", "42000", err.Error()) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + + rowsAffected, _ := result.RowsAffected() + writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowsAffected)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil +} + +// handleCopyOut handles COPY ... TO STDOUT +func (c *clientConn) handleCopyOut(query, upperQuery string) error { + matches := copyToStdoutRegex.FindStringSubmatch(query) + if len(matches) < 2 { + c.sendError("ERROR", "42601", "Invalid COPY TO STDOUT syntax") + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + + // Parse options + delimiter := "\t" + if m := copyDelimiterRegex.FindStringSubmatch(query); len(m) > 1 { + delimiter = m[1] + } else if copyWithCSVRegex.MatchString(upperQuery) { + delimiter = "," + } + + // The source can be a table name or a query in parentheses + source := strings.TrimSpace(matches[1]) + var selectQuery string + if strings.HasPrefix(source, "(") && strings.HasSuffix(source, ")") { + selectQuery = source[1 : len(source)-1] + } else { + selectQuery = fmt.Sprintf("SELECT * FROM %s", source) + } + + // Execute the query + rows, err := c.db.Query(selectQuery) + if err != nil { + c.sendError("ERROR", "42000", err.Error()) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + c.sendError("ERROR", "42000", err.Error()) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + + // Send CopyOutResponse + if err := writeCopyOutResponse(c.writer, int16(len(cols)), true); err != nil { + return err + } + c.writer.Flush() + + // Send header if CSV with HEADER + if copyWithCSVRegex.MatchString(upperQuery) && copyWithHeaderRegex.MatchString(upperQuery) { + header := strings.Join(cols, delimiter) + "\n" + if err := writeCopyData(c.writer, []byte(header)); err != nil { + return err + } + } + + // Send data rows + rowCount := 0 + for rows.Next() { + values := make([]interface{}, len(cols)) + valuePtrs := make([]interface{}, len(cols)) + for i := range values { + valuePtrs[i] = &values[i] + } + + if err := rows.Scan(valuePtrs...); err != nil { + c.sendError("ERROR", "42000", err.Error()) + break + } + + // Format row as tab/comma separated values + var rowData []string + for _, v := range values { + rowData = append(rowData, c.formatCopyValue(v)) + } + line := strings.Join(rowData, delimiter) + "\n" + if err := writeCopyData(c.writer, []byte(line)); err != nil { + return err + } + rowCount++ + } + + // Send CopyDone + if err := writeCopyDone(c.writer); err != nil { + return err + } + + writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowCount)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil +} + +// handleCopyIn handles COPY ... FROM STDIN +func (c *clientConn) handleCopyIn(query, upperQuery string) error { + matches := copyFromStdinRegex.FindStringSubmatch(query) + if len(matches) < 2 { + c.sendError("ERROR", "42601", "Invalid COPY FROM STDIN syntax") + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + + tableName := matches[1] + columnList := "" + if len(matches) > 2 && matches[2] != "" { + columnList = fmt.Sprintf("(%s)", matches[2]) + } + + // Parse options + delimiter := "\t" + if m := copyDelimiterRegex.FindStringSubmatch(query); len(m) > 1 { + delimiter = m[1] + } else if copyWithCSVRegex.MatchString(upperQuery) { + delimiter = "," + } + hasHeader := copyWithCSVRegex.MatchString(upperQuery) && copyWithHeaderRegex.MatchString(upperQuery) + + // Get column count for the table + colQuery := fmt.Sprintf("SELECT * FROM %s LIMIT 0", tableName) + testRows, err := c.db.Query(colQuery) + if err != nil { + c.sendError("ERROR", "42P01", fmt.Sprintf("relation \"%s\" does not exist", tableName)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + cols, _ := testRows.Columns() + testRows.Close() + + // Send CopyInResponse + if err := writeCopyInResponse(c.writer, int16(len(cols)), true); err != nil { + return err + } + c.writer.Flush() + + // Read COPY data from client + var allData bytes.Buffer + rowCount := 0 + headerSkipped := false + + for { + msgType, body, err := readMessage(c.reader) + if err != nil { + return err + } + + switch msgType { + case msgCopyData: + allData.Write(body) + + case msgCopyDone: + // Process all data + lines := strings.Split(allData.String(), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Skip header if needed + if hasHeader && !headerSkipped { + headerSkipped = true + continue + } + + // Parse values and insert + values := c.parseCopyLine(line, delimiter) + if len(values) == 0 { + continue + } + + // Build INSERT statement + placeholders := make([]string, len(values)) + args := make([]interface{}, len(values)) + for i, v := range values { + placeholders[i] = "?" + if v == "\\N" || v == "" { + args[i] = nil + } else { + args[i] = v + } + } + + insertSQL := fmt.Sprintf("INSERT INTO %s %s VALUES (%s)", + tableName, columnList, strings.Join(placeholders, ", ")) + + if _, err := c.db.Exec(insertSQL, args...); err != nil { + c.sendError("ERROR", "22P02", fmt.Sprintf("invalid input: %v", err)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + rowCount++ + } + + writeCommandComplete(c.writer, fmt.Sprintf("COPY %d", rowCount)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + + case msgCopyFail: + // Client cancelled COPY + errMsg := string(bytes.TrimRight(body, "\x00")) + c.sendError("ERROR", "57014", fmt.Sprintf("COPY failed: %s", errMsg)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + + default: + c.sendError("ERROR", "08P01", fmt.Sprintf("unexpected message type during COPY: %c", msgType)) + writeReadyForQuery(c.writer, 'I') + c.writer.Flush() + return nil + } + } +} + +// formatCopyValue formats a value for COPY output +func (c *clientConn) formatCopyValue(v interface{}) string { + if v == nil { + return "\\N" + } + return fmt.Sprintf("%v", v) +} + +// parseCopyLine parses a line of COPY input +func (c *clientConn) parseCopyLine(line, delimiter string) []string { + // Simple split - doesn't handle quoted values yet + return strings.Split(line, delimiter) +} + func (c *clientConn) sendRowDescription(cols []string, colTypes []*sql.ColumnType) error { var buf bytes.Buffer diff --git a/server/protocol.go b/server/protocol.go index 170c8f14..4e2415cd 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -35,6 +35,13 @@ const ( msgBindComplete = '2' msgCloseComplete = '3' msgNoData = 'n' + + // COPY messages (both directions) + msgCopyData = 'd' // Contains COPY data + msgCopyDone = 'c' // COPY completed + msgCopyFail = 'f' // COPY failed (frontend only) + msgCopyInResponse = 'G' // Server ready to receive COPY data + msgCopyOutResponse = 'H' // Server about to send COPY data ) // Authentication types @@ -247,3 +254,74 @@ func writeCloseComplete(w io.Writer) error { func writeNoData(w io.Writer) error { return writeMessage(w, msgNoData, nil) } + +// writeCopyOutResponse tells client we're about to send COPY data +// Format: overall format (0=text, 1=binary), num columns, format per column +func writeCopyOutResponse(w io.Writer, numColumns int16, textFormat bool) error { + var data []byte + + // Overall format (0=text) + if textFormat { + data = append(data, 0) + } else { + data = append(data, 1) + } + + // Number of columns + colBytes := make([]byte, 2) + binary.BigEndian.PutUint16(colBytes, uint16(numColumns)) + data = append(data, colBytes...) + + // Format for each column (0=text) + for i := int16(0); i < numColumns; i++ { + formatBytes := make([]byte, 2) + if textFormat { + binary.BigEndian.PutUint16(formatBytes, 0) + } else { + binary.BigEndian.PutUint16(formatBytes, 1) + } + data = append(data, formatBytes...) + } + + return writeMessage(w, msgCopyOutResponse, data) +} + +// writeCopyInResponse tells client to send COPY data +func writeCopyInResponse(w io.Writer, numColumns int16, textFormat bool) error { + var data []byte + + // Overall format (0=text) + if textFormat { + data = append(data, 0) + } else { + data = append(data, 1) + } + + // Number of columns + colBytes := make([]byte, 2) + binary.BigEndian.PutUint16(colBytes, uint16(numColumns)) + data = append(data, colBytes...) + + // Format for each column (0=text) + for i := int16(0); i < numColumns; i++ { + formatBytes := make([]byte, 2) + if textFormat { + binary.BigEndian.PutUint16(formatBytes, 0) + } else { + binary.BigEndian.PutUint16(formatBytes, 1) + } + data = append(data, formatBytes...) + } + + return writeMessage(w, msgCopyInResponse, data) +} + +// writeCopyData sends a row of COPY data +func writeCopyData(w io.Writer, data []byte) error { + return writeMessage(w, msgCopyData, data) +} + +// writeCopyDone signals the end of COPY data +func writeCopyDone(w io.Writer) error { + return writeMessage(w, msgCopyDone, nil) +} diff --git a/server/server.go b/server/server.go index adafa031..22ece042 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,15 @@ package server import ( + "context" "crypto/tls" "database/sql" "fmt" "log" "net" "sync" + "sync/atomic" + "time" _ "github.com/duckdb/duckdb-go/v2" ) @@ -29,6 +32,9 @@ type Config struct { // DuckLake configuration DuckLake DuckLakeConfig + + // Graceful shutdown timeout (default: 30s) + ShutdownTimeout time.Duration } // DuckLakeConfig configures DuckLake metadata store and data path @@ -47,6 +53,7 @@ type Server struct { wg sync.WaitGroup closed bool closeMu sync.Mutex + activeConns int64 // atomic counter for active connections } func New(cfg Config) (*Server, error) { @@ -65,6 +72,11 @@ func New(cfg Config) (*Server, error) { cfg.RateLimit = DefaultRateLimitConfig() } + // Use default shutdown timeout if not specified + if cfg.ShutdownTimeout == 0 { + cfg.ShutdownTimeout = 30 * time.Second + } + s := &Server{ cfg: cfg, dbs: make(map[string]*sql.DB), @@ -114,20 +126,87 @@ func (s *Server) Close() error { s.closed = true s.closeMu.Unlock() + // Stop accepting new connections if s.listener != nil { s.listener.Close() } - s.wg.Wait() + // Check if there are active connections + activeConns := atomic.LoadInt64(&s.activeConns) + if activeConns > 0 { + log.Printf("Waiting for %d active connection(s) to finish...", activeConns) + } + // Wait for connections with timeout + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("All connections closed gracefully") + case <-time.After(s.cfg.ShutdownTimeout): + log.Printf("Shutdown timeout (%v) exceeded, force closing remaining connections", s.cfg.ShutdownTimeout) + } + + // Close all database connections s.dbsMu.Lock() defer s.dbsMu.Unlock() for _, db := range s.dbs { db.Close() } + log.Println("Shutdown complete") return nil } +// Shutdown performs a graceful shutdown with the given context +func (s *Server) Shutdown(ctx context.Context) error { + s.closeMu.Lock() + s.closed = true + s.closeMu.Unlock() + + // Stop accepting new connections + if s.listener != nil { + s.listener.Close() + } + + // Check if there are active connections + activeConns := atomic.LoadInt64(&s.activeConns) + if activeConns > 0 { + log.Printf("Waiting for %d active connection(s) to finish...", activeConns) + } + + // Wait for connections with context + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("All connections closed gracefully") + case <-ctx.Done(): + log.Printf("Shutdown context cancelled, force closing remaining connections") + } + + // Close all database connections + s.dbsMu.Lock() + defer s.dbsMu.Unlock() + for _, db := range s.dbs { + db.Close() + } + log.Println("Shutdown complete") + return nil +} + +// ActiveConnections returns the number of active connections +func (s *Server) ActiveConnections() int64 { + return atomic.LoadInt64(&s.activeConns) +} + func (s *Server) getOrCreateDB(username string) (*sql.DB, error) { s.dbsMu.RLock() db, ok := s.dbs[username] @@ -231,6 +310,10 @@ func (s *Server) attachDuckLake(db *sql.DB) error { func (s *Server) handleConnection(conn net.Conn) { remoteAddr := conn.RemoteAddr() + // Track active connections + atomic.AddInt64(&s.activeConns, 1) + defer atomic.AddInt64(&s.activeConns, -1) + // Check rate limiting before doing anything if msg := s.rateLimiter.CheckConnection(remoteAddr); msg != "" { // Send PostgreSQL error and close