Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/pgbouncer/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ func sqlAuthenticationQuery(sqlFunctionName string) string {
// No replicators.
`NOT pg_authid.rolreplication`,
// Not the PgBouncer role itself.
`pg_authid.rolname <> ` + util.SQLQuoteLiteral(postgresqlUser),
`pg_authid.rolname <> ` + postgres.QuoteLiteral(postgresqlUser),
// Those without a password expiration or an expiration in the future.
`(pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)`,
}, "\n AND ")

return strings.TrimSpace(`
CREATE OR REPLACE FUNCTION ` + sqlFunctionName + `(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS ` + util.SQLQuoteLiteral(`
RETURNS TABLE(username TEXT, password TEXT) AS ` + postgres.QuoteLiteral(`
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
Expand Down
8 changes: 4 additions & 4 deletions internal/pgbouncer/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import (
func TestSQLAuthenticationQuery(t *testing.T) {
assert.Equal(t, sqlAuthenticationQuery("some.fn_name"),
`CREATE OR REPLACE FUNCTION some.fn_name(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS '
RETURNS TABLE(username TEXT, password TEXT) AS E'
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
AND pg_authid.rolcanlogin
AND NOT pg_authid.rolsuper
AND NOT pg_authid.rolreplication
AND pg_authid.rolname <> ''_crunchypgbouncer''
AND pg_authid.rolname <> E''_crunchypgbouncer''
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
LANGUAGE SQL STABLE SECURITY DEFINER;`)
}
Expand Down Expand Up @@ -150,14 +150,14 @@ REVOKE ALL PRIVILEGES
GRANT USAGE
ON SCHEMA :"namespace" TO :"username";
CREATE OR REPLACE FUNCTION :"namespace".get_auth(username TEXT)
RETURNS TABLE(username TEXT, password TEXT) AS '
RETURNS TABLE(username TEXT, password TEXT) AS E'
SELECT rolname::TEXT, rolpassword::TEXT
FROM pg_catalog.pg_authid
WHERE pg_authid.rolname = $1
AND pg_authid.rolcanlogin
AND NOT pg_authid.rolsuper
AND NOT pg_authid.rolreplication
AND pg_authid.rolname <> ''_crunchypgbouncer''
AND pg_authid.rolname <> E''_crunchypgbouncer''
AND (pg_authid.rolvaliduntil IS NULL OR pg_authid.rolvaliduntil >= CURRENT_TIMESTAMP)'
LANGUAGE SQL STABLE SECURITY DEFINER;
REVOKE ALL PRIVILEGES
Expand Down
22 changes: 22 additions & 0 deletions internal/postgres/sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
//
// SPDX-License-Identifier: Apache-2.0

package postgres

import "strings"

// escapeLiteral is called by QuoteLiteral to add backslashes before special
// characters of the "escape" string syntax. Double quote marks to escape them
// regardless of the "backslash_quote" parameter.
var escapeLiteral = strings.NewReplacer(`'`, `''`, `\`, `\\`).Replace

// QuoteLiteral escapes v so it can be safely used as a literal (or constant)
// in an SQL statement.
func QuoteLiteral(v string) string {
// Use the "escape" syntax to ensure that backslashes behave consistently regardless
// of the "standard_conforming_strings" parameter. Include a space before so
// the "E" cannot change the meaning of an adjacent SQL keyword or identifier.
// - https://www.postgresql.org/docs/current/sql-syntax-lexical.html
return ` E'` + escapeLiteral(v) + `'`
}
16 changes: 16 additions & 0 deletions internal/postgres/sql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2021 - 2024 Crunchy Data Solutions, Inc.
//
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"testing"

"gotest.tools/v3/assert"
)

func TestQuoteLiteral(t *testing.T) {
assert.Equal(t, QuoteLiteral(``), ` E''`)
assert.Equal(t, QuoteLiteral(`ab"cd\ef'gh`), ` E'ab"cd\\ef''gh'`)
}
62 changes: 0 additions & 62 deletions internal/util/util.go

This file was deleted.