Skip to content

Commit

Permalink
Merge pull request #18 from 0x-buidl/fix/transaction_write_operations
Browse files Browse the repository at this point in the history
fix: acknowledge write operations coming from transactions
  • Loading branch information
lxnre-codes committed Sep 22, 2023
2 parents 8acad27 + 868e80e commit ca9733d
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 20 deletions.
4 changes: 2 additions & 2 deletions document.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (doc *Document[T, P]) Save(ctx context.Context) error {
return nil, err
}

_, err := withTransaction(ctx, doc.Collection(), callback)
_, err := withAtomicity(ctx, doc.Collection(), callback)
return err
}

Expand All @@ -168,7 +168,7 @@ func (doc *Document[T, P]) Delete(ctx context.Context) error {
return nil, err
}

_, err = withTransaction(ctx, doc.Collection(), callback)
_, err = withAtomicity(ctx, doc.Collection(), callback)
if err != nil {
return err
}
Expand Down
60 changes: 42 additions & 18 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

type sessFn func(sessCtx mongo.SessionContext) (interface{}, error)
type SessionFunc func(sessCtx mongo.SessionContext) (interface{}, error)

type SessionLike interface {
*mongo.Database | *mongo.Collection | *mongo.SessionContext | *mongo.Client | *mongo.Session
}

type Model[T Schema, P IDefaultSchema] struct {
collection *mongo.Collection
Expand Down Expand Up @@ -91,7 +95,7 @@ func (model *Model[T, P]) CreateOne(ctx context.Context, doc T, opts ...*mopt.In
return nil, err
}

_, err := withTransaction(ctx, model.collection, callback)
_, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -123,7 +127,7 @@ func (model *Model[T, P]) CreateMany(ctx context.Context, docs []T, opts ...*mop
return newDocs, err
}

newDocs, err := withTransaction(ctx, model.collection, callback)
newDocs, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -153,7 +157,7 @@ func (model *Model[T, P]) DeleteOne(ctx context.Context, query bson.M, opts ...*
return res, err
}

res, err := withTransaction(ctx, model.collection, callback)
res, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -183,7 +187,7 @@ func (model *Model[T, P]) DeleteMany(ctx context.Context, query bson.M, opts ...
return res, err
}

res, err := withTransaction(ctx, model.collection, callback)
res, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -375,7 +379,7 @@ func (model *Model[T, P]) UpdateOne(ctx context.Context, query bson.M, update bs
err = runAfterUpdateHooks(sessCtx, ds, newHookArg[T](res, UpdateOne))
return res, err
}
res, err := withTransaction(ctx, model.collection, callback)
res, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -411,13 +415,44 @@ func (model *Model[T, P]) UpdateMany(ctx context.Context, query bson.M, update b
return res, err
}

res, err := withTransaction(ctx, model.collection, callback)
res, err := withAtomicity(ctx, model.collection, callback)
if err != nil {
return nil, err
}
return res.(*mongo.UpdateResult), nil
}

// WithTransaction executes the callback function in a transaction.
// When a transaction is started with [mongo.SessionContext] options are ignored because the session is already created.
func WithTransaction[T SessionLike](ctx context.Context, sess T, fn SessionFunc, opts ...*options.TransactionOptions) (any, error) {
var session mongo.Session
var err error
switch sess := any(sess).(type) {
case *mongo.SessionContext:
return fn(*sess)
case *mongo.Session:
return (*sess).WithTransaction(ctx, fn, opts...)
case *mongo.Client:
session, err = sess.StartSession()
case *mongo.Database:
session, err = sess.Client().StartSession()
case *mongo.Collection:
session, err = sess.Database().Client().StartSession()
}
if err != nil {
return nil, err
}
defer session.EndSession(ctx)
return session.WithTransaction(ctx, fn, opts...)
}

func withAtomicity(ctx context.Context, coll *mongo.Collection, callback SessionFunc) (interface{}, error) {
if ctx, ok := ctx.(mongo.SessionContext); ok {
return callback(ctx)
}
return WithTransaction(ctx, coll, callback)
}

// func (model *Model[T, P]) CountDocuments(ctx context.Context,
// query bson.M, opts ...*options.CountOptions,
// ) (int64, error) {
Expand Down Expand Up @@ -472,17 +507,6 @@ func findWithPopulate[U int.UnionFindOpts, T Schema, P IDefaultSchema](ctx conte
return docs, nil
}

func withTransaction(ctx context.Context, coll *mongo.Collection, fn sessFn, opts ...*options.TransactionOptions) (interface{}, error) {
session, err := coll.Database().Client().StartSession()
if err != nil {
return nil, err
}
defer session.EndSession(ctx)

res, err := session.WithTransaction(ctx, fn, opts...)
return res, err
}

func getObjectId(id any) (*primitive.ObjectID, error) {
var oid primitive.ObjectID
switch id := id.(type) {
Expand Down
90 changes: 90 additions & 0 deletions model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mgs_test

import (
"context"
"errors"
"fmt"
"math/rand"
"testing"
Expand Down Expand Up @@ -287,6 +288,95 @@ func TestModel_Populate(t *testing.T) {
})
}

func TestWithTransaction(t *testing.T) {
ctx := context.Background()
db, cleanup := getDb(ctx)
defer cleanup(ctx)

bookModel := mgs.NewModel[Book, *mgs.DefaultSchema](db.Collection("books"))
genBooks := generateBooks(ctx, db)

t.Run("Should run transaction with mongo.Client", func(t *testing.T) {
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
err := genBooks[0].Delete(sessCtx)
return nil, err
}
_, err := mgs.WithTransaction(ctx, db.Client(), callback)
assert.NoError(t, err, "WithTransaction should not return error")
})

t.Run("Should run transaction with mongo.Database", func(t *testing.T) {
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
err := genBooks[1].Delete(sessCtx)
return nil, err
}

_, err := mgs.WithTransaction(ctx, db, callback)
assert.NoError(t, err, "WithTransaction should not return error")
})

t.Run("Should run transaction with mongo.Collection", func(t *testing.T) {
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
err := genBooks[2].Delete(sessCtx)
return nil, err
}
_, err := mgs.WithTransaction(ctx, db.Collection("books"), callback)
assert.NoError(t, err, "WithTransaction should not return error")
book, err := bookModel.FindById(ctx, genBooks[2].GetID())
assert.Error(t, err, "FindById return error")
assert.Nil(t, book, "book should be nil")
})

t.Run("Should run transaction with mongo.Session", func(t *testing.T) {
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
err := genBooks[3].Delete(sessCtx)
return nil, err
}

sess, err := db.Client().StartSession()
if err != nil {
t.Fatal(err)
}
defer sess.EndSession(ctx)
_, err = mgs.WithTransaction(ctx, &sess, callback)
assert.NoError(t, err, "WithTransaction should not return error")
book, err := bookModel.FindById(ctx, genBooks[3].GetID())
assert.Error(t, err, "FindById return error")
assert.Nil(t, book, "book should be nil")
})

t.Run("Should run transaction with mongo.SessionContext", func(t *testing.T) {
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
callback2 := func(sCtx mongo.SessionContext) (interface{}, error) {
err := genBooks[4].Delete(sCtx)
return nil, err
}

if _, err := mgs.WithTransaction(context.TODO(), &sessCtx, callback2); err != nil {
return nil, err
}

genBooks[5].Doc.Title = "This is a test title"
if err = genBooks[5].Save(sessCtx); err != nil {
return nil, err
}

return nil, errors.New("this is a test error")
}
_, err := mgs.WithTransaction(ctx, db, callback)
assert.Error(t, err, "WithTransaction should return error")

book, err := bookModel.FindById(ctx, genBooks[5].GetID())
if err != nil {
t.Fatal(err)
}
assert.True(t, book.Doc.Title != "This is a test title", "WithTransaction should rollback changes")

_, err = bookModel.FindById(ctx, genBooks[4].GetID())
assert.NoError(t, err, "WithTransaction should rollback changes")
})
}

func TestModelMongodbErrors(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
Expand Down

0 comments on commit ca9733d

Please sign in to comment.