Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support comment and $comment update's arguments #937

Merged
merged 14 commits into from
Jul 27, 2022
45 changes: 45 additions & 0 deletions integration/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,51 @@ func TestFindCommentQuery(t *testing.T) {
assert.Contains(t, databaseNames, name)
}

func TestUpdateCommentMethod(t *testing.T) {
AlekSi marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()
ctx, collection := setup.Setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)

comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "
filter := bson.D{{"_id", "string"}}
update := bson.D{{"$set", bson.D{{"v", "bar"}}}}

opts := options.Update().SetComment(comment)
res, err := collection.UpdateOne(ctx, filter, update, opts)
require.NoError(t, err)

expected := &mongo.UpdateResult{
MatchedCount: 1,
ModifiedCount: 1,
}

assert.Contains(t, databaseNames, name)
assert.Equal(t, expected, res)
}

func TestUpdateCommentQuery(t *testing.T) {
t.Parallel()
ctx, collection := setup.Setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)

comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "

res, err := collection.UpdateOne(ctx, bson.M{"_id": "string", "$comment": comment}, bson.M{"$set": bson.M{"v": "bar"}})
require.NoError(t, err)

expected := &mongo.UpdateResult{
MatchedCount: 1,
ModifiedCount: 1,
}

assert.Contains(t, databaseNames, name)
assert.Equal(t, expected, res)
}

func TestCollectionName(t *testing.T) {
t.Parallel()

Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/pg/msg_findandmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
return nil, err
}

_, err = h.update(ctx, params.sqlParam, upsert)
_, err = h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, err
}
Expand All @@ -159,7 +159,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
must.NoError(upsert.Set("_id", must.NotFail(resDocs[0].Get("_id"))))
}

_, err = h.update(ctx, params.sqlParam, upsert)
_, err = h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,7 +277,7 @@ func (h *Handler) upsert(ctx context.Context, docs []*types.Document, params *up
}
}

_, err := h.update(ctx, params.sqlParam, upsert)
_, err := h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, false, err
}
Expand Down
19 changes: 15 additions & 4 deletions internal/handlers/pg/msg_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
if err := common.Unimplemented(document, "let"); err != nil {
return nil, err
}
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation", "comment")
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation")

var sp pgdb.SQLParam
if sp.DB, err = common.GetRequiredParam[string](document, "$db"); err != nil {
Expand Down Expand Up @@ -103,6 +103,17 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
if u, err = common.GetOptionalParam(update, "u", u); err != nil {
return nil, err
}

// get comment from options.Update().SetComment() method
if sp.Comment, err = common.GetOptionalParam(document, "comment", sp.Comment); err != nil {
return nil, err
}

// get comment from query, e.g. db.collection.UpdateOne({"_id":"string", "$comment: "test"},{$set:{"v":"foo""}})
if sp.Comment, err = common.GetOptionalParam(q, "$comment", sp.Comment); err != nil {
AlekSi marked this conversation as resolved.
Show resolved Hide resolved
return nil, err
}

if u != nil {
if err = common.ValidateUpdateOperators(u); err != nil {
return nil, err
Expand Down Expand Up @@ -191,7 +202,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
continue
}

rowsChanged, err := h.update(ctx, sp, doc)
rowsChanged, err := h.update(ctx, &sp, doc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,10 +231,10 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}

// update updates documents by _id.
func (h *Handler) update(ctx context.Context, sp pgdb.SQLParam, doc *types.Document) (int64, error) {
func (h *Handler) update(ctx context.Context, sp *pgdb.SQLParam, doc *types.Document) (int64, error) {
id := must.NotFail(doc.Get("_id"))

rowsUpdated, err := h.pgPool.SetDocumentByID(ctx, sp.DB, sp.Collection, id, doc)
rowsUpdated, err := h.pgPool.SetDocumentByID(ctx, sp, id, doc)
if err != nil {
return 0, err
}
Expand Down
13 changes: 10 additions & 3 deletions internal/handlers/pg/pgdb/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,22 @@ func (pgPool *Pool) SchemaStats(ctx context.Context, schema, collection string)
}

// SetDocumentByID sets a document by its ID.
func (pgPool *Pool) SetDocumentByID(ctx context.Context, db, collection string, id any, doc *types.Document) (int64, error) {
func (pgPool *Pool) SetDocumentByID(ctx context.Context, sp *SQLParam, id any, doc *types.Document) (int64, error) {
var tag pgconn.CommandTag
err := pgPool.InTransaction(ctx, func(tx pgx.Tx) error {
table, err := getTableName(ctx, tx, db, collection)
table, err := getTableName(ctx, tx, sp.DB, sp.Collection)
if err != nil {
return err
}

sql := "UPDATE " + pgx.Identifier{db, table}.Sanitize() +
sql := "UPDATE "
if sp.Comment != "" {
sp.Comment = strings.ReplaceAll(sp.Comment, "/*", "/ *")
sp.Comment = strings.ReplaceAll(sp.Comment, "*/", "* /")

sql += `/* ` + sp.Comment + ` */ `
}
sql += pgx.Identifier{sp.DB, table}.Sanitize() +
" SET _jsonb = $1 WHERE _jsonb->'_id' = $2"

tag, err = tx.Exec(ctx, sql, must.NotFail(fjson.Marshal(doc)), must.NotFail(fjson.Marshal(id)))
Expand Down