/
client.go
65 lines (53 loc) · 1.65 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package jmongo
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
type Client struct {
client *mongo.Client
}
func NewClient(opts ...*options.ClientOptions) (*Client, error) {
c, err := mongo.NewClient(opts...)
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
func (c *Client) Client() *mongo.Client {
return c.client
}
func (c *Client) Connect(ctx context.Context) error {
return c.client.Connect(ctx)
}
func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
return c.client.Ping(ctx, rp)
}
// Database returns a handle for a database with the given name configured with the given DatabaseOptions.
func (c *Client) Database(name string, opts ...*options.DatabaseOptions) *Database {
return NewDatabase(c.client.Database(name, opts...), c)
}
// WithTransaction open transaction
func (c *Client) WithTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
return c.client.UseSession(ctx, func(sessionContext mongo.SessionContext) error {
_, err := sessionContext.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (any, error) {
return nil, fn(sessCtx)
})
return err
})
}
func WithTransaction[T any](ctx context.Context, c *Client, fn func(ctx context.Context) (T, error)) (T, error) {
var res T
var err error
err = c.client.UseSession(ctx, func(sessionContext mongo.SessionContext) error {
a, err := sessionContext.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (any, error) {
return fn(sessCtx)
})
if a != nil {
res = a.(T)
}
return err
})
return res, err
}