@@ -14,10 +14,13 @@ import (
14
14
15
15
"github.com/github/gh-ost/go/base"
16
16
"github.com/github/gh-ost/go/binlog"
17
- "github.com/github/gh-ost/go/mysql"
18
17
"github.com/github/gh-ost/go/sql"
19
18
20
- "github.com/openark/golib/log"
19
+ "context"
20
+ "database/sql/driver"
21
+
22
+ "github.com/github/gh-ost/go/mysql"
23
+ drivermysql "github.com/go-sql-driver/mysql"
21
24
"github.com/openark/golib/sqlutils"
22
25
)
23
26
@@ -1207,13 +1210,19 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
1207
1210
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
1208
1211
func (this * Applier ) ApplyDMLEventQueries (dmlEvents [](* binlog.BinlogDMLEvent )) error {
1209
1212
var totalDelta int64
1213
+ ctx := context .TODO ()
1210
1214
1211
1215
err := func () error {
1212
- tx , err := this .db .Begin ( )
1216
+ conn , err := this .db .Conn ( ctx )
1213
1217
if err != nil {
1214
1218
return err
1215
1219
}
1220
+ defer conn .Close ()
1216
1221
1222
+ tx , err := conn .BeginTx (ctx , nil )
1223
+ if err != nil {
1224
+ return err
1225
+ }
1217
1226
rollback := func (err error ) error {
1218
1227
tx .Rollback ()
1219
1228
return err
@@ -1225,34 +1234,49 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
1225
1234
if _ , err := tx .Exec (sessionQuery ); err != nil {
1226
1235
return rollback (err )
1227
1236
}
1228
- multiArgs := []interface {}{}
1237
+ rowDeltas := make ([]int64 , 0 , len (dmlEvents ))
1238
+ multiArgs := []driver.NamedValue {}
1229
1239
var multiQueryBuilder strings.Builder
1230
1240
for _ , dmlEvent := range dmlEvents {
1231
1241
for _ , buildResult := range this .buildDMLEventQuery (dmlEvent ) {
1232
1242
if buildResult .err != nil {
1233
- return buildResult .err
1243
+ return rollback ( buildResult .err )
1234
1244
}
1235
- multiArgs = append (multiArgs , buildResult .args ... )
1245
+ for _ , arg := range buildResult .args {
1246
+ multiArgs = append (multiArgs , driver.NamedValue {Value : driver .Value (arg )})
1247
+ }
1248
+ rowDeltas = append (rowDeltas , buildResult .rowsDelta )
1236
1249
multiQueryBuilder .WriteString (buildResult .query )
1237
1250
multiQueryBuilder .WriteString (";\n " )
1238
1251
}
1239
1252
}
1240
- // TODO: get rows affected from each query in multi statement
1241
- log .Warningf ("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics" , err )
1242
- _ , err = tx .Exec (multiQueryBuilder .String (), multiArgs ... )
1243
- if err != nil {
1244
- err = fmt .Errorf ("%w; query=%s; args=%+v" , err , multiQueryBuilder .String (), multiArgs )
1245
- return rollback (err )
1246
- }
1247
- // rowsAffected, err := result.RowsAffected()
1248
- // if err != nil {
1249
- // log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
1250
- // rowsAffected = 1
1251
- // }
1252
- // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1253
- // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1254
- // totalDelta += buildResult.rowsDelta * rowsAffected
1255
1253
1254
+ //this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs)
1255
+ execErr := conn .Raw (func (driverConn any ) error {
1256
+ ex , ok := driverConn .(driver.ExecerContext )
1257
+ if ! ok {
1258
+ return fmt .Errorf ("could not cast driverConn to ExecerContext" )
1259
+ }
1260
+ res , err := ex .ExecContext (ctx , multiQueryBuilder .String (), multiArgs )
1261
+ if err != nil {
1262
+ err = fmt .Errorf ("%w; query=%s; args=%+v" , err , multiQueryBuilder .String (), multiArgs )
1263
+ this .migrationContext .Log .Errorf ("Error exec: %+v" , err )
1264
+ return err
1265
+ }
1266
+ mysqlRes , ok := res .(drivermysql.Result )
1267
+ if ! ok {
1268
+ return fmt .Errorf ("Could not cast %+v to mysql.Result" , res )
1269
+ }
1270
+ // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1271
+ // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1272
+ for i , rowsAffected := range mysqlRes .AllRowsAffected () {
1273
+ totalDelta += rowDeltas [i ] * rowsAffected
1274
+ }
1275
+ return nil
1276
+ })
1277
+ if execErr != nil {
1278
+ return rollback (execErr )
1279
+ }
1256
1280
if err := tx .Commit (); err != nil {
1257
1281
return err
1258
1282
}
0 commit comments