Skip to content

Commit

Permalink
sql: switch PlaceholderTypes to a slice
Browse files Browse the repository at this point in the history
We switch `PlaceholderTypes` to a slice instead of a map.

We also fix the handling of cases where some placeholders are unused
(e.g. `SELECT $2:::int`) which now error out (before they would crash
during execution). Note that PG also errors out in this case.

Fixes cockroachdb#30086.

Release note (bug fix): Preparing queries with missing placeholders
(e.g. `SELECT $2::int`) now results in an error.
  • Loading branch information
RaduBerinde committed Jan 15, 2019
1 parent 7548754 commit 651f11b
Show file tree
Hide file tree
Showing 24 changed files with 380 additions and 255 deletions.
8 changes: 5 additions & 3 deletions pkg/ccl/logictestccl/testdata/logic_test/partitioning
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,11 @@ CREATE TABLE t (a INT, b INT, c INT, PRIMARY KEY (a, b)) PARTITION BY RANGE (a,
PARTITION p1 VALUES FROM (0, 0) TO (0, (DEFAULT))
)

statement error unimplemented: placeholders are not supported in PARTITION BY
CREATE TABLE t (a INT, b INT, c INT, PRIMARY KEY (a, b)) PARTITION BY LIST (a) (
PARTITION p1 VALUES IN ($1)
# TODO(radu): we are not properly walking the expressions when
# walking CREATE TABLE.
statement error pq: could not determine data type of placeholder \$1
PREPARE a AS CREATE TABLE t (a INT, b INT, c INT, PRIMARY KEY (a, b)) PARTITION BY LIST (a) (
PARTITION p1 VALUES IN ($1:::int)
)

statement error syntax error
Expand Down
15 changes: 11 additions & 4 deletions pkg/sql/conn_executor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,16 @@ func (ex *connExecutor) execStmtInOpenState(
)
return makeErrEvent(err)
}
typeHints := make(tree.PlaceholderTypes, len(s.Types))
for i, t := range s.Types {
typeHints[types.PlaceholderIdx(i)] = coltypes.CastTargetToDatumType(t)
var typeHints tree.PlaceholderTypes
if len(s.Types) > 0 {
if len(s.Types) > stmt.NumPlaceholders {
err := pgerror.NewErrorf(pgerror.CodeSyntaxError, "too many types provided")
return makeErrEvent(err)
}
typeHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders)
for i, t := range s.Types {
typeHints[i] = coltypes.CastTargetToDatumType(t)
}
}
if _, err := ex.addPreparedStmt(
ctx, name,
Expand Down Expand Up @@ -391,7 +398,7 @@ func (ex *connExecutor) execStmtInOpenState(
}
}

p.semaCtx.Placeholders.Assign(pinfo)
p.semaCtx.Placeholders.Assign(pinfo, stmt.NumPlaceholders)
p.extendedEvalCtx.Placeholders = &p.semaCtx.Placeholders
ex.phaseTimes[plannerStartExecStmt] = timeutil.Now()
p.stmt = &stmt
Expand Down
51 changes: 19 additions & 32 deletions pkg/sql/conn_executor_prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (ex *connExecutor) execPrepare(
return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{err: err}
}

// The anonymous statement can be overwritter.
// The anonymous statement can be overwritten.
if parseCmd.Name != "" {
if _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[parseCmd.Name]; ok {
err := pgerror.NewErrorf(
Expand All @@ -61,44 +61,27 @@ func (ex *connExecutor) execPrepare(
}

// Convert the inferred SQL types back to an array of pgwire Oids.
inTypes := make([]oid.Oid, 0, len(ps.Types))
if len(ps.TypeHints) > pgwirebase.MaxPreparedStatementArgs {
return retErr(
pgwirebase.NewProtocolViolationErrorf(
"more than %d arguments to prepared statement: %d",
pgwirebase.MaxPreparedStatementArgs, len(ps.TypeHints)))
}
for k := range ps.Types {
// Placeholder names are 1-indexed; the arrays in the protocol are
// 0-indexed.
i := int(k)
// Grow inTypes to be at least as large as i. Prepopulate all
// slots with the hints provided, if any.
for j := len(inTypes); j <= i; j++ {
inTypes = append(inTypes, 0)
if j < len(parseCmd.RawTypeHints) {
inTypes[j] = parseCmd.RawTypeHints[j]
}
}
inferredTypes := make([]oid.Oid, len(ps.Types))
copy(inferredTypes, parseCmd.RawTypeHints)

for i := range ps.Types {
// OID to Datum is not a 1-1 mapping (for example, int4 and int8
// both map to TypeInt), so we need to maintain the types sent by
// the client.
if inTypes[i] != 0 {
continue
}
t, _ := ps.ValueType(k)
inTypes[i] = t.Oid()
}
for i, t := range inTypes {
if t == 0 {
return retErr(pgerror.NewErrorf(
pgerror.CodeIndeterminateDatatypeError,
"could not determine data type of placeholder %s", types.PlaceholderIdx(i)))
if inferredTypes[i] == 0 {
t, _ := ps.ValueType(types.PlaceholderIdx(i))
inferredTypes[i] = t.Oid()
}
}
// Remember the inferred placeholder types so they can be reported on
// Describe.
ps.InTypes = inTypes
ps.InferredTypes = inferredTypes
return nil, nil
}

Expand Down Expand Up @@ -139,7 +122,7 @@ func (ex *connExecutor) prepare(
ctx context.Context, stmt Statement, placeholderHints tree.PlaceholderTypes,
) (*PreparedStatement, error) {
if placeholderHints == nil {
placeholderHints = make(tree.PlaceholderTypes)
placeholderHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders)
}

prepared := &PreparedStatement{
Expand Down Expand Up @@ -204,7 +187,7 @@ func (ex *connExecutor) populatePrepared(
p *planner,
) (planFlags, error) {
prepared := stmt.Prepared
p.semaCtx.Placeholders.Reset(placeholderHints)
p.semaCtx.Placeholders.Init(stmt.NumPlaceholders, placeholderHints)
p.extendedEvalCtx.PrepareOnly = true
p.extendedEvalCtx.ActiveMemAcc = &prepared.memAcc
// constantMemAcc accounts for all constant folded values that are computed
Expand Down Expand Up @@ -271,6 +254,10 @@ func (ex *connExecutor) populatePrepared(
return 0, err
}
}
// Verify that all placeholder types have been set.
if err := p.semaCtx.Placeholders.Types.AssertAllSet(); err != nil {
return 0, err
}
prepared.Types = p.semaCtx.Placeholders.Types
return flags, nil
}
Expand Down Expand Up @@ -302,7 +289,7 @@ func (ex *connExecutor) execBind(
"unknown prepared statement %q", bindCmd.PreparedStatementName))
}

numQArgs := uint16(len(ps.InTypes))
numQArgs := uint16(len(ps.InferredTypes))

// Decode the arguments, except for internal queries for which we just verify
// that the arguments match what's expected.
Expand All @@ -314,7 +301,7 @@ func (ex *connExecutor) execBind(
"expected %d arguments, got %d", numQArgs, len(bindCmd.internalArgs)))
}
for i, datum := range bindCmd.internalArgs {
t := ps.InTypes[i]
t := ps.InferredTypes[i]
if oid := datum.ResolvedType().Oid(); datum != tree.DNull && oid != t {
return retErr(
pgwirebase.NewProtocolViolationErrorf(
Expand Down Expand Up @@ -351,7 +338,7 @@ func (ex *connExecutor) execBind(

for i, arg := range bindCmd.Args {
k := types.PlaceholderIdx(i)
t := ps.InTypes[i]
t := ps.InferredTypes[i]
if arg == nil {
// nil indicates a NULL argument value.
qargs[k] = tree.DNull
Expand Down Expand Up @@ -487,7 +474,7 @@ func (ex *connExecutor) execDescribe(
"unknown prepared statement %q", descCmd.Name))
}

res.SetInTypes(ps.InTypes)
res.SetInferredTypes(ps.InferredTypes)

if stmtHasNoData(ps.AST) {
res.SetNoDataRowDescription()
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,8 @@ type RestrictedCommandResult interface {
type DescribeResult interface {
ResultBase

// SetInTypes tells the client about the inferred placeholder types.
SetInTypes([]oid.Oid)
// SetInferredTypes tells the client about the inferred placeholder types.
SetInferredTypes([]oid.Oid)
// SetNoDataDescription is used to tell the client that the prepared statement
// or portal produces no rows.
SetNoDataRowDescription()
Expand Down Expand Up @@ -933,8 +933,8 @@ func (r *bufferedCommandResult) Discard() {
}
}

// SetInTypes is part of the DescribeResult interface.
func (r *bufferedCommandResult) SetInTypes([]oid.Oid) {}
// SetInferredTypes is part of the DescribeResult interface.
func (r *bufferedCommandResult) SetInferredTypes([]oid.Oid) {}

// SetNoDataRowDescription is part of the DescribeResult interface.
func (r *bufferedCommandResult) SetNoDataRowDescription() {}
Expand Down
3 changes: 0 additions & 3 deletions pkg/sql/logictest/testdata/logic_test/alter_table
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,6 @@ INSERT INTO add_default (a) VALUES (3)
statement error could not parse "foo" as type int
ALTER TABLE add_default ALTER COLUMN b SET DEFAULT 'foo'

statement error variable sub-expressions are not allowed in DEFAULT
ALTER TABLE add_default ALTER COLUMN b SET DEFAULT $1

statement error variable sub-expressions are not allowed in DEFAULT
ALTER TABLE add_default ALTER COLUMN b SET DEFAULT c

Expand Down
3 changes: 0 additions & 3 deletions pkg/sql/logictest/testdata/logic_test/check_constraints
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,6 @@ SELECT * FROM t5
1 10 9
2 11 9

statement error variable sub-expressions are not allowed in CHECK
CREATE TABLE t6 (x INT CHECK (x = $1))

statement error variable sub-expressions are not allowed in CHECK
CREATE TABLE t6 (x INT CHECK (x = (SELECT 1)))

Expand Down
3 changes: 0 additions & 3 deletions pkg/sql/logictest/testdata/logic_test/default
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
statement error expected DEFAULT expression to have type int, but 'false' has type bool
CREATE TABLE t (a INT PRIMARY KEY DEFAULT false)

statement error variable sub-expressions are not allowed in DEFAULT
CREATE TABLE t (a INT PRIMARY KEY DEFAULT $1)

statement error variable sub-expressions are not allowed in DEFAULT
CREATE TABLE t (a INT PRIMARY KEY DEFAULT (SELECT 1))

Expand Down
11 changes: 10 additions & 1 deletion pkg/sql/logictest/testdata/logic_test/prepare
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ PREPARE a as ()
statement error could not determine data type of placeholder \$1
PREPARE a AS SELECT $1

statement
statement error could not determine data type of placeholder \$1
PREPARE a AS SELECT $2:::int

statement error could not determine data type of placeholder \$2
PREPARE a AS SELECT $1:::int, $3:::int

statement ok
PREPARE a AS SELECT $1:::int + $2

query I
Expand Down Expand Up @@ -812,3 +818,6 @@ a d

statement ok
ROLLBACK TRANSACTION

statement error no value provided for placeholder: \$1
SELECT $1:::int
4 changes: 3 additions & 1 deletion pkg/sql/opt/bench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ func (h *harness) prepareUsingAPI(tb testing.TB) {
}
}

h.semaCtx.Placeholders.Init(len(h.query.args), nil /* typeHints */)
if h.query.prepare {
// Prepare the query by normalizing it (if it has placeholders) or exploring
// it (if it doesn't have placeholders), and cache the resulting memo so that
Expand All @@ -502,6 +503,7 @@ func (h *harness) prepareUsingAPI(tb testing.TB) {
}

// Construct placeholder values.
h.semaCtx.Placeholders.Values = make(tree.QueryArguments, len(h.query.args))
for i, arg := range h.query.args {
var parg tree.Expr
parg, err := parser.ParseExpr(fmt.Sprintf("%v", arg))
Expand All @@ -523,7 +525,7 @@ func (h *harness) prepareUsingAPI(tb testing.TB) {
tb.Fatalf("%v", err)
}

h.semaCtx.Placeholders.Values[id] = texpr
h.semaCtx.Placeholders.Values[i] = texpr
}
h.evalCtx.Placeholders = &h.semaCtx.Placeholders
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/opt/memo/memo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TestMemoInit(t *testing.T) {

ctx := context.Background()
semaCtx := tree.MakeSemaContext(false /* privileged */)
semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */)
evalCtx := tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings())

var o xform.Optimizer
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/testutils/opt_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ func (ot *OptTester) buildExpr(factory *norm.Factory) error {
if err != nil {
return err
}

ot.semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */)
b := optbuilder.New(ot.ctx, &ot.semaCtx, &ot.evalCtx, ot.catalog, factory, stmt.AST)
b.AllowUnsupportedExpr = ot.Flags.AllowUnsupportedExpr
if ot.Flags.FullyQualifyNames {
Expand Down
4 changes: 3 additions & 1 deletion pkg/sql/opt/xform/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func TestDetachMemo(t *testing.T) {

ctx := context.Background()
semaCtx := tree.MakeSemaContext(false /* privileged */)
semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */)
evalCtx := tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings())

var o xform.Optimizer
Expand All @@ -62,14 +63,15 @@ func TestDetachMemo(t *testing.T) {
t.Error("memo expression should be reinitialized by DetachMemo")
}

semaCtx.Placeholders.Clear()
semaCtx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */)
o.Init(&evalCtx)

stmt2, err := parser.ParseOne("SELECT a=$1 FROM abc")
if err != nil {
t.Fatal(err)
}

semaCtx.Placeholders.Init(stmt2.NumPlaceholders, nil /* typeHints */)
err = optbuilder.New(ctx, &semaCtx, &evalCtx, catalog, o.Factory(), stmt2.AST).Build()
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ func (r *commandResult) SetColumns(ctx context.Context, cols sqlbase.ResultColum
}
}

// SetInTypes is part of the DescribeResult interface.
func (r *commandResult) SetInTypes(types []oid.Oid) {
// SetInferredTypes is part of the DescribeResult interface.
func (r *commandResult) SetInferredTypes(types []oid.Oid) {
r.conn.writerState.fi.registerCmd(r.pos)
r.conn.bufferParamDesc(types)
}
Expand Down
50 changes: 32 additions & 18 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,34 +540,48 @@ func (c *conn) handleParse(
}
inTypeHints[i] = oid.Oid(typ)
}
// Prepare the mapping of SQL placeholder names to types. Pre-populate it with
// the type hints received from the client, if any.
sqlTypeHints := make(tree.PlaceholderTypes)
for i, t := range inTypeHints {
if t == 0 {
continue
}
v, ok := types.OidToType[t]
if !ok {
err := pgwirebase.NewProtocolViolationErrorf("unknown oid type: %v", t)
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
sqlTypeHints[types.PlaceholderIdx(i)] = v
}

startParse := timeutil.Now()
var stmt parser.Statement
stmts, err := c.parser.ParseWithInt(query, ch.GetDefaultIntSize())
if err != nil {
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
if len(stmts) > 1 {
err = pgerror.NewWrongNumberOfPreparedStatements(len(stmts))
} else if len(stmts) == 1 {
err := pgerror.NewWrongNumberOfPreparedStatements(len(stmts))
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
var stmt parser.Statement
if len(stmts) == 1 {
stmt = stmts[0]
}
// len(stmts) == 0 results in a nil (empty) statement.

if err != nil {
if len(inTypeHints) > stmt.NumPlaceholders {
err := pgwirebase.NewProtocolViolationErrorf(
"received too many type hints: %d vs %d placeholders in query",
len(inTypeHints), stmt.NumPlaceholders,
)
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}

var sqlTypeHints tree.PlaceholderTypes
if len(inTypeHints) > 0 {
// Prepare the mapping of SQL placeholder names to types. Pre-populate it with
// the type hints received from the client, if any.
sqlTypeHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders)
for i, t := range inTypeHints {
if t == 0 {
continue
}
v, ok := types.OidToType[t]
if !ok {
err := pgwirebase.NewProtocolViolationErrorf("unknown oid type: %v", t)
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
}
sqlTypeHints[i] = v
}
}

endParse := timeutil.Now()

if _, ok := stmt.AST.(*tree.CopyFrom); ok {
Expand Down
Loading

0 comments on commit 651f11b

Please sign in to comment.