Skip to content

Commit c1b6000

Browse files
committed
add arthurscreiber's review suggestions
1 parent 1bd2b0b commit c1b6000

File tree

3 files changed

+12
-28
lines changed

3 files changed

+12
-28
lines changed

Diff for: go/logic/applier.go

+10-25
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
8080

8181
func (this *Applier) InitDBConnections() (err error) {
8282
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
83-
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
83+
uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri)
84+
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil {
8485
return err
8586
}
8687
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
@@ -1210,7 +1211,7 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
12101211
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
12111212
func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error {
12121213
var totalDelta int64
1213-
ctx := context.TODO()
1214+
ctx := context.Background()
12141215

12151216
err := func() error {
12161217
conn, err := this.db.Conn(ctx)
@@ -1236,31 +1237,23 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
12361237
}
12371238

12381239
buildResults := make([]*dmlBuildResult, 0, len(dmlEvents))
1240+
nArgs := 0
12391241
for _, dmlEvent := range dmlEvents {
12401242
for _, buildResult := range this.buildDMLEventQuery(dmlEvent) {
12411243
if buildResult.err != nil {
12421244
return rollback(buildResult.err)
12431245
}
1244-
1246+
nArgs += len(buildResult.args)
12451247
buildResults = append(buildResults, buildResult)
12461248
}
12471249
}
12481250

12491251
execErr := conn.Raw(func(driverConn any) error {
1250-
ex, ok := driverConn.(driver.ExecerContext)
1251-
if !ok {
1252-
return fmt.Errorf("could not cast driverConn to ExecerContext")
1253-
}
1254-
1255-
nvc, ok := driverConn.(driver.NamedValueChecker)
1256-
if !ok {
1257-
return fmt.Errorf("could not cast driverConn to NamedValueChecker")
1258-
}
1252+
ex := driverConn.(driver.ExecerContext)
1253+
nvc := driverConn.(driver.NamedValueChecker)
12591254

1260-
var multiArgs []driver.NamedValue
1255+
multiArgs := make([]driver.NamedValue, 0, nArgs)
12611256
multiQueryBuilder := strings.Builder{}
1262-
var rowDeltas []int64
1263-
12641257
for _, buildResult := range buildResults {
12651258
for _, arg := range buildResult.args {
12661259
nv := driver.NamedValue{Value: driver.Value(arg)}
@@ -1270,29 +1263,21 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
12701263

12711264
multiQueryBuilder.WriteString(buildResult.query)
12721265
multiQueryBuilder.WriteString(";\n")
1273-
1274-
rowDeltas = append(rowDeltas, buildResult.rowsDelta)
12751266
}
12761267

1277-
// this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs)
12781268
res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs)
12791269
if err != nil {
12801270
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
1281-
this.migrationContext.Log.Errorf("Error exec: %+v", err)
12821271
return err
12831272
}
12841273

1285-
mysqlRes, ok := res.(drivermysql.Result)
1286-
if !ok {
1287-
return fmt.Errorf("Could not cast %+v to mysql.Result", res)
1288-
}
1274+
mysqlRes := res.(drivermysql.Result)
12891275

12901276
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
12911277
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
12921278
for i, rowsAffected := range mysqlRes.AllRowsAffected() {
1293-
totalDelta += rowDeltas[i] * rowsAffected
1279+
totalDelta += buildResults[i].rowsDelta * rowsAffected
12941280
}
1295-
12961281
return nil
12971282
})
12981283

Diff for: go/mysql/connection.go

-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string {
132132
connectionParams := []string{
133133
"autocommit=true",
134134
"interpolateParams=true",
135-
"multiStatements=true",
136135
fmt.Sprintf("charset=%s", this.Charset),
137136
fmt.Sprintf("tls=%s", tlsOption),
138137
fmt.Sprintf("transaction_isolation=%q", this.TransactionIsolation),

Diff for: go/mysql/connection_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func TestGetDBUri(t *testing.T) {
8686
c.Charset = "utf8mb4,utf8,latin1"
8787

8888
uri := c.GetDBUri("test")
89-
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
89+
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
9090
}
9191

9292
func TestGetDBUriWithTLSSetup(t *testing.T) {
@@ -100,5 +100,5 @@ func TestGetDBUriWithTLSSetup(t *testing.T) {
100100
c.Charset = "utf8mb4_general_ci,utf8_general_ci,latin1"
101101

102102
uri := c.GetDBUri("test")
103-
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
103+
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
104104
}

0 commit comments

Comments
 (0)