Skip to content

Commit

Permalink
Support $comment update's argument (#937)
Browse files Browse the repository at this point in the history
Closes #487.
  • Loading branch information
noisersup committed Jul 27, 2022
1 parent 9e5b193 commit 024cf92
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
45 changes: 45 additions & 0 deletions integration/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,51 @@ func TestFindCommentQuery(t *testing.T) {
assert.Contains(t, databaseNames, name)
}

func TestUpdateCommentMethod(t *testing.T) {
t.Parallel()
ctx, collection := setup.Setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)

comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "
filter := bson.D{{"_id", "string"}}
update := bson.D{{"$set", bson.D{{"v", "bar"}}}}

opts := options.Update().SetComment(comment)
res, err := collection.UpdateOne(ctx, filter, update, opts)
require.NoError(t, err)

expected := &mongo.UpdateResult{
MatchedCount: 1,
ModifiedCount: 1,
}

assert.Contains(t, databaseNames, name)
assert.Equal(t, expected, res)
}

func TestUpdateCommentQuery(t *testing.T) {
t.Parallel()
ctx, collection := setup.Setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)

comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "

res, err := collection.UpdateOne(ctx, bson.M{"_id": "string", "$comment": comment}, bson.M{"$set": bson.M{"v": "bar"}})
require.NoError(t, err)

expected := &mongo.UpdateResult{
MatchedCount: 1,
ModifiedCount: 1,
}

assert.Contains(t, databaseNames, name)
assert.Equal(t, expected, res)
}

func TestCollectionName(t *testing.T) {
t.Parallel()

Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/pg/msg_findandmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
return nil, err
}

_, err = h.update(ctx, params.sqlParam, upsert)
_, err = h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, err
}
Expand All @@ -159,7 +159,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
must.NoError(upsert.Set("_id", must.NotFail(resDocs[0].Get("_id"))))
}

_, err = h.update(ctx, params.sqlParam, upsert)
_, err = h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,7 +277,7 @@ func (h *Handler) upsert(ctx context.Context, docs []*types.Document, params *up
}
}

_, err := h.update(ctx, params.sqlParam, upsert)
_, err := h.update(ctx, &params.sqlParam, upsert)
if err != nil {
return nil, false, err
}
Expand Down
19 changes: 15 additions & 4 deletions internal/handlers/pg/msg_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
if err := common.Unimplemented(document, "let"); err != nil {
return nil, err
}
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation", "comment")
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation")

var sp pgdb.SQLParam
if sp.DB, err = common.GetRequiredParam[string](document, "$db"); err != nil {
Expand Down Expand Up @@ -103,6 +103,17 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
if u, err = common.GetOptionalParam(update, "u", u); err != nil {
return nil, err
}

// get comment from options.Update().SetComment() method
if sp.Comment, err = common.GetOptionalParam(document, "comment", sp.Comment); err != nil {
return nil, err
}

// get comment from query, e.g. db.collection.UpdateOne({"_id":"string", "$comment: "test"},{$set:{"v":"foo""}})
if sp.Comment, err = common.GetOptionalParam(q, "$comment", sp.Comment); err != nil {
return nil, err
}

if u != nil {
if err = common.ValidateUpdateOperators(u); err != nil {
return nil, err
Expand Down Expand Up @@ -191,7 +202,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
continue
}

rowsChanged, err := h.update(ctx, sp, doc)
rowsChanged, err := h.update(ctx, &sp, doc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,10 +231,10 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}

// update updates documents by _id.
func (h *Handler) update(ctx context.Context, sp pgdb.SQLParam, doc *types.Document) (int64, error) {
func (h *Handler) update(ctx context.Context, sp *pgdb.SQLParam, doc *types.Document) (int64, error) {
id := must.NotFail(doc.Get("_id"))

rowsUpdated, err := h.pgPool.SetDocumentByID(ctx, sp.DB, sp.Collection, id, doc)
rowsUpdated, err := h.pgPool.SetDocumentByID(ctx, sp, id, doc)
if err != nil {
return 0, err
}
Expand Down
13 changes: 10 additions & 3 deletions internal/handlers/pg/pgdb/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,22 @@ func (pgPool *Pool) SchemaStats(ctx context.Context, schema, collection string)
}

// SetDocumentByID sets a document by its ID.
func (pgPool *Pool) SetDocumentByID(ctx context.Context, db, collection string, id any, doc *types.Document) (int64, error) {
func (pgPool *Pool) SetDocumentByID(ctx context.Context, sp *SQLParam, id any, doc *types.Document) (int64, error) {
var tag pgconn.CommandTag
err := pgPool.InTransaction(ctx, func(tx pgx.Tx) error {
table, err := getTableName(ctx, tx, db, collection)
table, err := getTableName(ctx, tx, sp.DB, sp.Collection)
if err != nil {
return err
}

sql := "UPDATE " + pgx.Identifier{db, table}.Sanitize() +
sql := "UPDATE "
if sp.Comment != "" {
sp.Comment = strings.ReplaceAll(sp.Comment, "/*", "/ *")
sp.Comment = strings.ReplaceAll(sp.Comment, "*/", "* /")

sql += `/* ` + sp.Comment + ` */ `
}
sql += pgx.Identifier{sp.DB, table}.Sanitize() +
" SET _jsonb = $1 WHERE _jsonb->'_id' = $2"

tag, err = tx.Exec(ctx, sql, must.NotFail(fjson.Marshal(doc)), must.NotFail(fjson.Marshal(id)))
Expand Down

0 comments on commit 024cf92

Please sign in to comment.