/
transaction.go
126 lines (112 loc) · 2.46 KB
/
transaction.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package databases
import (
"context"
"database/sql"
"github.com/aacfactory/fns/commons/bytex"
)
type TransactionOptions struct {
Isolation Isolation
Readonly bool
}
type TransactionOption func(options *TransactionOptions)
type Transaction interface {
Commit() error
Rollback() error
Query(ctx context.Context, query []byte, args []any) (rows Rows, err error)
Execute(ctx context.Context, query []byte, args []any) (result Result, err error)
}
func NewTransactionWithStatements(tx *sql.Tx, statements *Statements) Transaction {
return &DefaultTransaction{
core: tx,
prepare: statements != nil,
statements: statements,
}
}
func NewTransaction(tx *sql.Tx) Transaction {
return &DefaultTransaction{
core: tx,
prepare: false,
statements: nil,
}
}
type DefaultTransaction struct {
core *sql.Tx
prepare bool
statements *Statements
}
func (tx *DefaultTransaction) Commit() error {
return tx.core.Commit()
}
func (tx *DefaultTransaction) Rollback() error {
return tx.core.Rollback()
}
func (tx *DefaultTransaction) Query(ctx context.Context, query []byte, args []any) (rows Rows, err error) {
var r *sql.Rows
if tx.prepare {
stmt, prepareErr := tx.statements.Get(query)
if prepareErr != nil {
err = prepareErr
return
}
st, release, closed := stmt.Stmt()
if closed {
rows, err = tx.Query(ctx, query, args)
return
}
st = tx.core.Stmt(st)
r, err = st.Query(args...)
release()
if err != nil {
return
}
} else {
r, err = tx.core.Query(bytex.ToString(query), args...)
if err != nil {
return
}
}
rows = &DefaultRows{
core: r,
}
return
}
func (tx *DefaultTransaction) Execute(ctx context.Context, query []byte, args []any) (result Result, err error) {
var r sql.Result
if tx.prepare {
stmt, prepareErr := tx.statements.Get(query)
if prepareErr != nil {
err = prepareErr
return
}
st, release, closed := stmt.Stmt()
if closed {
result, err = tx.Execute(ctx, query, args)
return
}
st = tx.core.Stmt(st)
r, err = st.Exec(args...)
release()
if err != nil {
return
}
} else {
r, err = tx.core.Exec(bytex.ToString(query), args...)
if err != nil {
return
}
}
rowsAffected, rowsAffectedErr := r.RowsAffected()
if rowsAffectedErr != nil {
err = rowsAffectedErr
return
}
lastInsertId, lastInsertIdErr := r.LastInsertId()
if lastInsertIdErr != nil {
lastInsertId = -1
}
result = Result{
LastInsertId: lastInsertId,
RowsAffected: rowsAffected,
}
return
}