Skip to content

Commit

Permalink
Implement sort pushdown for PostgreSQL (#3504)
Browse files Browse the repository at this point in the history
  • Loading branch information
noisersup committed Oct 17, 2023
1 parent 3d3eb82 commit a7cc841
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 15 deletions.
9 changes: 9 additions & 0 deletions internal/backends/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ func CollectionContract(c Collection) Collection {
}
}

// SortField consists of a field name and a sort order that are used in queries.
type SortField struct {
Key string
Descending bool
}

// QueryParams represents the parameters of Collection.Query method.
type QueryParams struct {
// TODO https://github.com/FerretDB/FerretDB/issues/3235
Filter *types.Document
Sort *SortField
OnlyRecordIDs bool // TODO https://github.com/FerretDB/FerretDB/issues/3490
Comment string // TODO https://github.com/FerretDB/FerretDB/issues/3573
}
Expand Down Expand Up @@ -199,13 +206,15 @@ func (cc *collectionContract) DeleteAll(ctx context.Context, params *DeleteAllPa
type ExplainParams struct {
// TODO https://github.com/FerretDB/FerretDB/issues/3235
Filter *types.Document
Sort *SortField
}

// ExplainResult represents the results of Collection.Explain method.
type ExplainResult struct {
QueryPlanner *types.Document
// TODO https://github.com/FerretDB/FerretDB/issues/3235
QueryPushdown bool
SortPushdown bool
}

// Explain return a backend-specific execution plan for the given query.
Expand Down
28 changes: 28 additions & 0 deletions internal/backends/postgresql/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ func (c *collection) Query(ctx context.Context, params *backends.QueryParams) (*

q += where

if params.Sort != nil {
var sort string
var sortArgs []any

sort, sortArgs, err = prepareOrderByClause(&placeholder, params.Sort.Key, params.Sort.Descending)
if err != nil {
return nil, lazyerrors.Error(err)
}

q += sort
args = append(args, sortArgs...)
}

rows, err := p.Query(ctx, q, args...)
if err != nil {
return nil, lazyerrors.Error(err)
Expand Down Expand Up @@ -296,6 +309,21 @@ func (c *collection) Explain(ctx context.Context, params *backends.ExplainParams

res.QueryPushdown = where != ""

if params.Sort != nil {
var sort string
var sortArgs []any

sort, sortArgs, err = prepareOrderByClause(&placeholder, params.Sort.Key, params.Sort.Descending)
if err != nil {
return nil, lazyerrors.Error(err)
}

q += sort
args = append(args, sortArgs...)

res.SortPushdown = sort != ""
}

q += where

var b []byte
Expand Down
58 changes: 44 additions & 14 deletions internal/backends/postgresql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,37 @@ func prepareWhereClause(p *metadata.Placeholder, sqlFilters *types.Document) (st
sql := `NOT ( ` +
// does document contain the key,
// it is necessary, as NOT won't work correctly if the key does not exist.
`_jsonb ? %[1]s AND ` +
`%[1]s ? %[2]s AND ` +
// does the value under the key is equal to filter value
`_jsonb->%[1]s @> %[2]s AND ` +
`%[1]s->%[2]s @> %[3]s AND ` +
// does the value type is equal to the filter's one
`_jsonb->'$s'->'p'->%[1]s->'t' = '"%[3]s"' )`
`%[1]s->'$s'->'p'->%[2]s->'t' = '"%[4]s"' )`

switch v := v.(type) {
case *types.Document, *types.Array, types.Binary,
types.NullType, types.Regex, types.Timestamp:
// type not supported for pushdown

case float64, bool, int32, int64:
filters = append(filters, fmt.Sprintf(sql, p.Next(), p.Next(), sjson.GetTypeOfValue(v)))
filters = append(filters, fmt.Sprintf(
sql,
metadata.DefaultColumn,
p.Next(),
p.Next(),
sjson.GetTypeOfValue(v),
))

args = append(args, rootKey, v)

case string, types.ObjectID, time.Time:
filters = append(filters, fmt.Sprintf(sql, p.Next(), p.Next(), sjson.GetTypeOfValue(v)))
filters = append(filters, fmt.Sprintf(
sql,
metadata.DefaultColumn,
p.Next(),
p.Next(),
sjson.GetTypeOfValue(v),
))

args = append(args, rootKey, string(must.NotFail(sjson.MarshalSingleValue(v))))

default:
Expand Down Expand Up @@ -167,11 +181,27 @@ func prepareWhereClause(p *metadata.Placeholder, sqlFilters *types.Document) (st
return filter, args, nil
}

// prepareOrderByClause adds ORDER BY clause with given sort document and returns the query and arguments.
func prepareOrderByClause(p *metadata.Placeholder, key string, descending bool) (string, []any, error) {
// Skip sorting dot notation
if strings.ContainsRune(key, '.') {
return "", nil, nil
}

sqlOrder := "ASC"

if descending {
sqlOrder = "DESC"
}

return fmt.Sprintf(" ORDER BY %s->%s %s", metadata.DefaultColumn, p.Next(), sqlOrder), []any{key}, nil
}

// filterEqual returns the proper SQL filter with arguments that filters documents
// where the value under k is equal to v.
func filterEqual(p *metadata.Placeholder, k string, v any) (filter string, args []any) {
// Select if value under the key is equal to provided value.
sql := `_jsonb->%[1]s @> %[2]s`
sql := `%[1]s->%[2]s @> %[3]s`

switch v := v.(type) {
case *types.Document, *types.Array, types.Binary,
Expand All @@ -182,27 +212,27 @@ func filterEqual(p *metadata.Placeholder, k string, v any) (filter string, args
// If value is not safe double, fetch all numbers out of safe range.
switch {
case v > types.MaxSafeDouble:
sql = `_jsonb->%[1]s > %[2]s`
sql = `%[1]s->%[2]s > %[3]s`
v = types.MaxSafeDouble

case v < -types.MaxSafeDouble:
sql = `_jsonb->%[1]s < %[2]s`
sql = `%[1]s->%[2]s < %[3]s`
v = -types.MaxSafeDouble
default:
// don't change the default eq query
}

filter = fmt.Sprintf(sql, p.Next(), p.Next())
filter = fmt.Sprintf(sql, metadata.DefaultColumn, p.Next(), p.Next())
args = append(args, k, v)

case string, types.ObjectID, time.Time:
// don't change the default eq query
filter = fmt.Sprintf(sql, p.Next(), p.Next())
filter = fmt.Sprintf(sql, metadata.DefaultColumn, p.Next(), p.Next())
args = append(args, k, string(must.NotFail(sjson.MarshalSingleValue(v))))

case bool, int32:
// don't change the default eq query
filter = fmt.Sprintf(sql, p.Next(), p.Next())
filter = fmt.Sprintf(sql, metadata.DefaultColumn, p.Next(), p.Next())
args = append(args, k, v)

case int64:
Expand All @@ -211,17 +241,17 @@ func filterEqual(p *metadata.Placeholder, k string, v any) (filter string, args
// If value cannot be safe double, fetch all numbers out of the safe range.
switch {
case v > maxSafeDouble:
sql = `_jsonb->%[1]s > %[2]s`
sql = `%[1]s->%[2]s > %[3]s`
v = maxSafeDouble

case v < -maxSafeDouble:
sql = `_jsonb->%[1]s < %[2]s`
sql = `%[1]s->%[2]s < %[3]s`
v = -maxSafeDouble
default:
// don't change the default eq query
}

filter = fmt.Sprintf(sql, p.Next(), p.Next())
filter = fmt.Sprintf(sql, metadata.DefaultColumn, p.Next(), p.Next())
args = append(args, k, v)

default:
Expand Down
20 changes: 19 additions & 1 deletion internal/handlers/sqlite/msg_explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ func (h *Handler) MsgExplain(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
qp.Filter = params.Filter
}

// Skip sorting if there are more than one sort parameters
if h.EnableSortPushdown && params.Sort.Len() == 1 {
var order types.SortType

k := params.Sort.Keys()[0]
v := params.Sort.Values()[0]

order, err = common.GetSortType(k, v)
if err != nil {
return nil, err
}

qp.Sort = &backends.SortField{
Key: k,
Descending: order == types.Descending,
}
}

res, err := coll.Explain(ctx, &qp)
if err != nil {
return nil, lazyerrors.Error(err)
Expand All @@ -99,7 +117,7 @@ func (h *Handler) MsgExplain(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
// our extensions
// TODO https://github.com/FerretDB/FerretDB/issues/3235
"pushdown", res.QueryPushdown,
"sortingPushdown", false,
"sortingPushdown", res.SortPushdown,
"limitPushdown", false,

"ok", float64(1),
Expand Down
18 changes: 18 additions & 0 deletions internal/handlers/sqlite/msg_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
qp.Filter = params.Filter
}

// Skip sorting if there are more than one sort parameters
if h.EnableSortPushdown && params.Sort.Len() == 1 {
var order types.SortType

k := params.Sort.Keys()[0]
v := params.Sort.Values()[0]

order, err = common.GetSortType(k, v)
if err != nil {
return nil, err
}

qp.Sort = &backends.SortField{
Key: k,
Descending: order == types.Descending,
}
}

cancel := func() {}
if params.MaxTimeMS != 0 {
// It is not clear if maxTimeMS affects only find, or both find and getMore (as the current code does).
Expand Down

0 comments on commit a7cc841

Please sign in to comment.