Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new/manipmongo: better sharding interface #87

Merged
merged 10 commits into from
Jun 12, 2019
46 changes: 43 additions & 3 deletions manipmongo/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"context"
"fmt"
"strconv"
"strings"
"time"

"github.com/globalsign/mgo"
Expand Down Expand Up @@ -57,7 +58,7 @@ func DropDatabase(manipulator manipulate.Manipulator) error {

m, ok := manipulator.(*mongoManipulator)
if !ok {
panic("you can only pass a mongo manipulator to CreateIndex")
panic("you can only pass a mongo manipulator to DropDatabase")
}

session := m.rootSession.Copy()
Expand Down Expand Up @@ -91,6 +92,45 @@ func CreateIndex(manipulator manipulate.Manipulator, identity elemental.Identity
return nil
}

// EnsureIndex works like create index, but it will delete existing index
// if they changed before creating a new one.
func EnsureIndex(manipulator manipulate.Manipulator, identity elemental.Identity, indexes ...mgo.Index) error {

m, ok := manipulator.(*mongoManipulator)
if !ok {
panic("you can only pass a mongo manipulator to CreateIndex")
}

session := m.rootSession.Copy()
defer session.Close()

collection := session.DB(m.dbName).C(identity.Name)

for i, index := range indexes {
if index.Name == "" {
index.Name = "index_" + identity.Name + "_" + strconv.Itoa(i)
}
if err := collection.EnsureIndex(index); err != nil {

if strings.HasSuffix(err.Error(), "already exists with different options") {
if err := collection.DropIndexName(index.Name); err != nil {
return fmt.Errorf("cannot delete previous index: %s", err)
}

if err := collection.EnsureIndex(index); err != nil {
return fmt.Errorf("unable to ensure index after dropping old one '%s': %s", index.Name, err)
}

continue
}

return fmt.Errorf("unable to ensure index '%s': %s", index.Name, err)
}
}

return nil
}

// DeleteIndex deletes multiple mgo.Index for the collection.
func DeleteIndex(manipulator manipulate.Manipulator, identity elemental.Identity, indexes ...string) error {

Expand Down Expand Up @@ -135,7 +175,7 @@ func GetDatabase(manipulator manipulate.Manipulator) (*mgo.Database, func(), err

m, ok := manipulator.(*mongoManipulator)
if !ok {
panic("you can only pass a mongo manipulator to GetSession")
panic("you can only pass a mongo manipulator to GetDatabase")
}

session := m.rootSession.Copy()
Expand All @@ -148,7 +188,7 @@ func SetConsistencyMode(manipulator manipulate.Manipulator, mode mgo.Mode, refre

m, ok := manipulator.(*mongoManipulator)
if !ok {
panic("you can only pass a Mongo Manipulator to SetConsistencyMode")
panic("you can only pass a mongo manipulator to SetConsistencyMode")
}

if m.rootSession == nil {
Expand Down
132 changes: 132 additions & 0 deletions manipmongo/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,143 @@ import (
"testing"
"time"

"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
. "github.com/smartystreets/goconvey/convey"
"go.aporeto.io/elemental"
"go.aporeto.io/manipulate"
"go.aporeto.io/manipulate/maniptest"
)

func TestCompileFilter(t *testing.T) {

Convey("Given I have a filer", t, func() {

f := elemental.NewFilter().WithKey("a").Equals("b").Done()

Convey("When I call CompileFilter", func() {

cf := CompileFilter(f)

Convey("Then cf should be correct", func() {
So(cf, ShouldResemble, bson.M{"$and": []bson.M{bson.M{"a": bson.M{"$eq": "b"}}}})
})
})
})
}

func TestDoesDatabaseExists(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call DoesDatabaseExist", func() {
Convey("Then it should panic", func() {
So(func() { _, _ = DoesDatabaseExist(m) }, ShouldPanicWith, "you can only pass a mongo manipulator to DoesDatabaseExist")
})
})
})
}

func TestDropDatabase(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call DropDatabase", func() {
Convey("Then it should panic", func() {
So(func() { _ = DropDatabase(m) }, ShouldPanicWith, "you can only pass a mongo manipulator to DropDatabase")
})
})
})
}

func TestCreateIndex(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call CreateIndex", func() {
Convey("Then it should panic", func() {
So(func() { _ = CreateIndex(m, elemental.MakeIdentity("a", "a")) }, ShouldPanicWith, "you can only pass a mongo manipulator to CreateIndex")
})
})
})
}

func TestEnsureIndex(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call EnsureIndex", func() {
Convey("Then it should panic", func() {
So(func() { _ = EnsureIndex(m, elemental.MakeIdentity("a", "a")) }, ShouldPanicWith, "you can only pass a mongo manipulator to CreateIndex")
})
})
})
}

func TestDeleteIndex(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call DeleteIndex", func() {
Convey("Then it should panic", func() {
So(func() { _ = DeleteIndex(m, elemental.MakeIdentity("a", "a")) }, ShouldPanicWith, "you can only pass a mongo manipulator to DeleteIndex")
})
})
})
}

func TestCreateCollection(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call CreateCollection", func() {
Convey("Then it should panic", func() {
So(func() { _ = CreateCollection(m, elemental.MakeIdentity("a", "a"), nil) }, ShouldPanicWith, "you can only pass a mongo manipulator to CreateCollection")
})
})
})
}

func TestGetDatabase(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call GetDatabase", func() {
Convey("Then it should panic", func() {
So(func() { _, _, _ = GetDatabase(m) }, ShouldPanicWith, "you can only pass a mongo manipulator to GetDatabase")
})
})
})
}

func TestSetConsistencyMode(t *testing.T) {

Convey("Given I a test manipulator", t, func() {

m := maniptest.NewTestManipulator()

Convey("When I call SetConsistencyMode", func() {
Convey("Then it should panic", func() {
So(func() { SetConsistencyMode(m, mgo.Strong, true) }, ShouldPanicWith, "you can only pass a mongo manipulator to SetConsistencyMode")
})
})
})
}

func TestRunQuery(t *testing.T) {

testIdentity := elemental.MakeIdentity("test", "tests")
Expand Down
70 changes: 63 additions & 7 deletions manipmongo/manipulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type mongoManipulator struct {
dbName string
sharder Sharder
defaultRetryFunc manipulate.RetryFunc
forcedReadFilter bson.M
}

// New returns a new manipulator backed by MongoDB.
Expand Down Expand Up @@ -81,6 +82,7 @@ func New(url string, db string, options ...Option) (manipulate.TransactionalMani
rootSession: session,
sharder: cfg.sharder,
defaultRetryFunc: cfg.defaultRetryFunc,
forcedReadFilter: cfg.forcedReadFilter,
}, nil
}

Expand All @@ -105,12 +107,19 @@ func (s *mongoManipulator) RetrieveMany(mctx manipulate.Context, dest elemental.
}

if s.sharder != nil {
sq := s.sharder.FilterMany(dest.Identity())
sq, err := s.sharder.FilterMany(s, mctx, dest.Identity())
if err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

query := c.Find(filter)

// This makes squall returning a 500 error.
Expand Down Expand Up @@ -186,12 +195,19 @@ func (s *mongoManipulator) Retrieve(mctx manipulate.Context, object elemental.Id
filter["_id"] = object.Identifier()

if s.sharder != nil {
sq := s.sharder.FilterOne(object)
sq, err := s.sharder.FilterOne(s, mctx, object)
if err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

sp := tracing.StartTrace(mctx, fmt.Sprintf("manipmongo.retrieve.object.%s", object.Identity().Name))
sp.LogFields(log.String("object_id", object.Identifier()), log.Object("filter", filter))
defer sp.Finish()
Expand Down Expand Up @@ -249,7 +265,13 @@ func (s *mongoManipulator) Create(mctx manipulate.Context, object elemental.Iden
}

if s.sharder != nil {
s.sharder.Shard(object)
if err := s.sharder.Shard(s, mctx, object); err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("unable to execute sharder.Shard: %s", err))
}

if err := s.sharder.OnShardedWrite(s, mctx, elemental.OperationCreate, object); err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("unable to execute sharder.OnShardedWrite on create: %s", err))
}
}

if _, err := RunQuery(
Expand Down Expand Up @@ -288,12 +310,19 @@ func (s *mongoManipulator) Update(mctx manipulate.Context, object elemental.Iden

filter = bson.M{"_id": object.Identifier()}
if s.sharder != nil {
sq := s.sharder.FilterOne(object)
sq, err := s.sharder.FilterOne(s, mctx, object)
if err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

if _, err := RunQuery(
mctx,
func() (interface{}, error) { return nil, c.Update(filter, bson.M{"$set": object}) },
Expand Down Expand Up @@ -330,12 +359,19 @@ func (s *mongoManipulator) Delete(mctx manipulate.Context, object elemental.Iden

filter = bson.M{"_id": object.Identifier()}
if s.sharder != nil {
sq := s.sharder.FilterOne(object)
sq, err := s.sharder.FilterOne(s, mctx, object)
if err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

if _, err := RunQuery(
mctx,
func() (interface{}, error) { return nil, c.Remove(filter) },
Expand All @@ -350,6 +386,12 @@ func (s *mongoManipulator) Delete(mctx manipulate.Context, object elemental.Iden
return err
}

if s.sharder != nil {
if err := s.sharder.OnShardedWrite(s, mctx, elemental.OperationDelete, object); err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("unable to execute sharder.OnShardedWrite for delete: %s", err))
}
}

// backport all default values that are empty.
if a, ok := object.(elemental.AttributeSpecifiable); ok {
elemental.ResetDefaultForZeroValues(a)
Expand All @@ -374,12 +416,19 @@ func (s *mongoManipulator) DeleteMany(mctx manipulate.Context, identity elementa

filter := compiler.CompileFilter(mctx.Filter())
if s.sharder != nil {
sq := s.sharder.FilterMany(identity)
sq, err := s.sharder.FilterMany(s, mctx, identity)
if err != nil {
return manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

if _, err := RunQuery(
mctx,
func() (interface{}, error) { return c.RemoveAll(filter) },
Expand Down Expand Up @@ -415,12 +464,19 @@ func (s *mongoManipulator) Count(mctx manipulate.Context, identity elemental.Ide
}

if s.sharder != nil {
sq := s.sharder.FilterMany(identity)
sq, err := s.sharder.FilterMany(s, mctx, identity)
if err != nil {
return 0, manipulate.NewErrCannotBuildQuery(fmt.Sprintf("cannot compute sharding filter: %s", err))
}
if sq != nil {
filter = bson.M{"$and": []bson.M{sq, filter}}
}
}

if s.forcedReadFilter != nil {
filter = bson.M{"$and": []bson.M{s.forcedReadFilter, filter}}
}

sp := tracing.StartTrace(mctx, fmt.Sprintf("manipmongo.count.%s", identity.Category))
defer sp.Finish()

Expand Down
Loading