Skip to content

Commit 2a3318c

Browse files
committed
conn.Raw not working
1 parent 90d6148 commit 2a3318c

File tree

1 file changed

+45
-21
lines changed

1 file changed

+45
-21
lines changed

Diff for: go/logic/applier.go

+45-21
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ import (
1414

1515
"github.com/github/gh-ost/go/base"
1616
"github.com/github/gh-ost/go/binlog"
17-
"github.com/github/gh-ost/go/mysql"
1817
"github.com/github/gh-ost/go/sql"
1918

20-
"github.com/openark/golib/log"
19+
"context"
20+
"database/sql/driver"
21+
22+
"github.com/github/gh-ost/go/mysql"
23+
drivermysql "github.com/go-sql-driver/mysql"
2124
"github.com/openark/golib/sqlutils"
2225
)
2326

@@ -1207,13 +1210,19 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
12071210
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
12081211
func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error {
12091212
var totalDelta int64
1213+
ctx := context.TODO()
12101214

12111215
err := func() error {
1212-
tx, err := this.db.Begin()
1216+
conn, err := this.db.Conn(ctx)
12131217
if err != nil {
12141218
return err
12151219
}
1220+
defer conn.Close()
12161221

1222+
tx, err := conn.BeginTx(ctx, nil)
1223+
if err != nil {
1224+
return err
1225+
}
12171226
rollback := func(err error) error {
12181227
tx.Rollback()
12191228
return err
@@ -1225,34 +1234,49 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
12251234
if _, err := tx.Exec(sessionQuery); err != nil {
12261235
return rollback(err)
12271236
}
1228-
multiArgs := []interface{}{}
1237+
rowDeltas := make([]int64, 0, len(dmlEvents))
1238+
multiArgs := []driver.NamedValue{}
12291239
var multiQueryBuilder strings.Builder
12301240
for _, dmlEvent := range dmlEvents {
12311241
for _, buildResult := range this.buildDMLEventQuery(dmlEvent) {
12321242
if buildResult.err != nil {
1233-
return buildResult.err
1243+
return rollback(buildResult.err)
12341244
}
1235-
multiArgs = append(multiArgs, buildResult.args...)
1245+
for _, arg := range buildResult.args {
1246+
multiArgs = append(multiArgs, driver.NamedValue{Value: driver.Value(arg)})
1247+
}
1248+
rowDeltas = append(rowDeltas, buildResult.rowsDelta)
12361249
multiQueryBuilder.WriteString(buildResult.query)
12371250
multiQueryBuilder.WriteString(";\n")
12381251
}
12391252
}
1240-
// TODO: get rows affected from each query in multi statement
1241-
log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
1242-
_, err = tx.Exec(multiQueryBuilder.String(), multiArgs...)
1243-
if err != nil {
1244-
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
1245-
return rollback(err)
1246-
}
1247-
// rowsAffected, err := result.RowsAffected()
1248-
// if err != nil {
1249-
// log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
1250-
// rowsAffected = 1
1251-
// }
1252-
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1253-
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1254-
// totalDelta += buildResult.rowsDelta * rowsAffected
12551253

1254+
//this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs)
1255+
execErr := conn.Raw(func(driverConn any) error {
1256+
ex, ok := driverConn.(driver.ExecerContext)
1257+
if !ok {
1258+
return fmt.Errorf("could not cast driverConn to ExecerContext")
1259+
}
1260+
res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs)
1261+
if err != nil {
1262+
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
1263+
this.migrationContext.Log.Errorf("Error exec: %+v", err)
1264+
return err
1265+
}
1266+
mysqlRes, ok := res.(drivermysql.Result)
1267+
if !ok {
1268+
return fmt.Errorf("Could not cast %+v to mysql.Result", res)
1269+
}
1270+
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1271+
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1272+
for i, rowsAffected := range mysqlRes.AllRowsAffected() {
1273+
totalDelta += rowDeltas[i] * rowsAffected
1274+
}
1275+
return nil
1276+
})
1277+
if execErr != nil {
1278+
return rollback(execErr)
1279+
}
12561280
if err := tx.Commit(); err != nil {
12571281
return err
12581282
}

0 commit comments

Comments
 (0)