/
unit-of-work.go
99 lines (86 loc) · 2.22 KB
/
unit-of-work.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package mongosvc
import (
underscore "github.com/ahl5esoft/golang-underscore"
"github.com/ahl5esoft/lite-go/model/contract"
"github.com/ahl5esoft/lite-go/service/dbsvc"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
type unitOfWork struct {
dbsvc.UnitOfWork
writeModel map[string][]mongo.WriteModel
}
func (m *unitOfWork) RegisterAdd(entry contract.IDbModel) {
model := getModelMetadata(entry)
doc := make(bson.M)
underscore.Chain(
model.FindFields(),
).Each(func(r *fieldMetadata, _ int) {
doc[r.GetColumnName()] = r.GetValue(entry)
})
m.appendWriteModel(
entry,
mongo.NewInsertOneModel().SetDocument(doc),
)
}
func (m *unitOfWork) RegisterRemove(entry contract.IDbModel) {
m.appendWriteModel(
entry,
mongo.NewDeleteOneModel().SetFilter(bson.M{
"_id": entry.GetID(),
}),
)
}
func (m *unitOfWork) RegisterSave(entry contract.IDbModel) {
model := getModelMetadata(entry)
writeModel := mongo.NewUpdateOneModel()
doc := make(bson.M)
underscore.Chain(
model.FindFields(),
).Each(func(r *fieldMetadata, _ int) {
if r.GetTableName() != "" {
writeModel.SetFilter(bson.M{
"_id": entry.GetID(),
})
} else {
doc[r.GetColumnName()] = r.GetValue(entry)
}
})
writeModel.SetUpdate(bson.M{
"$set": doc,
})
m.appendWriteModel(entry, writeModel)
}
func (m *unitOfWork) appendWriteModel(entry contract.IDbModel, writeModel mongo.WriteModel) {
table, _ := getModelMetadata(entry).GetTableName()
if _, ok := m.writeModel[table]; !ok {
m.writeModel[table] = make([]mongo.WriteModel, 0)
}
m.writeModel[table] = append(m.writeModel[table], writeModel)
}
func newUnitOfWork(dbPool *dbPool) *unitOfWork {
writeModel := make(map[string][]mongo.WriteModel)
return &unitOfWork{
UnitOfWork: dbsvc.UnitOfWork{
CommitAction: func() error {
if len(writeModel) == 0 {
return nil
}
client, db, err := dbPool.GetClientAndDb()
if err != nil {
return err
}
return client.UseSession(dbPool.Ctx, func(ctx mongo.SessionContext) (err error) {
for k, v := range writeModel {
delete(writeModel, k)
if _, err = db.Collection(k).BulkWrite(ctx, v); err != nil {
return
}
}
return
})
},
},
writeModel: writeModel,
}
}