diff --git a/binlog_streamer.go b/binlog_streamer.go index fcdf3e0f..93b6cc6b 100644 --- a/binlog_streamer.go +++ b/binlog_streamer.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" sqlorig "database/sql" + "errors" "fmt" "time" @@ -281,6 +282,7 @@ func (s *BinlogStreamer) Run() { es.isEventPositionResumable = false es.isEventPositionValid = true + // if there is a handler associated with this eventType, call it eventTypeString := ev.Header.EventType.String() if handler, ok := s.eventHandlers[eventTypeString]; ok { query, err = handler(ev, query, &es) @@ -289,6 +291,7 @@ func (s *BinlogStreamer) Run() { s.ErrorHandler.Fatal("binlog_streamer", err) } } else { + // call the default event handler for everything else query, err = s.defaultEventHandler(ev, query, &es) } @@ -300,11 +303,22 @@ func (s *BinlogStreamer) Run() { } } -func (s *BinlogStreamer) AddBinlogEventHandler(ev string, eh func(*replication.BinlogEvent, []byte, *BinlogEventState) ([]byte, error)) { +// Attach an event handler to a replication BinLogEvent +// We only support attaching events to any of the events defined in +// https://github.com/go-mysql-org/go-mysql/blob/master/replication/const.go +func (s *BinlogStreamer) AddBinlogEventHandler(evType replication.EventType, eh func(*replication.BinlogEvent, []byte, *BinlogEventState) ([]byte, error)) error { + // verify that event-type is valid + // if eventTypeString is unrecognized, bail + eventTypeString := evType.String() + if eventTypeString == "UnknownEvent" { + return errors.New("Unknown event type") + } + if s.eventHandlers == nil { s.eventHandlers = make(map[string]func(*replication.BinlogEvent, []byte, *BinlogEventState) ([]byte, error)) } - s.eventHandlers[ev] = eh + s.eventHandlers[eventTypeString] = eh + return nil } func (s *BinlogStreamer) AddEventListener(listener func([]DMLEvent) error) { diff --git a/test/go/binlog_streamer_test.go b/test/go/binlog_streamer_test.go index e1a204f7..b16a4088 100644 --- a/test/go/binlog_streamer_test.go +++ b/test/go/binlog_streamer_test.go @@ -11,6 +11,7 @@ import ( "github.com/Shopify/ghostferry" "github.com/Shopify/ghostferry/testhelpers" + "github.com/go-mysql-org/go-mysql/replication" "github.com/stretchr/testify/suite" ) @@ -195,6 +196,20 @@ func (this *BinlogStreamerTestSuite) TestBinlogStreamerSetsQueryEventOnRowsEvent this.Require().True(eventAsserted) } +func (this *BinlogStreamerTestSuite) TestBinlogStreamerAddEventHandlerEventTypes() { + qe := func(ev *replication.BinlogEvent, query []byte, es *ghostferry.BinlogEventState) ([]byte, error) { + return query, nil + } + + // try attaching a handler to a valid event type + err := this.binlogStreamer.AddBinlogEventHandler(replication.TABLE_MAP_EVENT, qe) + this.Require().Nil(err) + + // try attaching a handler to an invalid event type + err = this.binlogStreamer.AddBinlogEventHandler(replication.EventType(byte(0)), qe) + this.Require().NotNil(err) +} + func TestBinlogStreamerTestSuite(t *testing.T) { testhelpers.SetupTest() suite.Run(t, &BinlogStreamerTestSuite{GhostferryUnitTestSuite: &testhelpers.GhostferryUnitTestSuite{}}) diff --git a/test/lib/go/ddl_ghostferry/main.go b/test/lib/go/ddl_ghostferry/main.go index 3dca411c..7b49f108 100644 --- a/test/lib/go/ddl_ghostferry/main.go +++ b/test/lib/go/ddl_ghostferry/main.go @@ -14,7 +14,7 @@ func queryEventHandler(ev *replication.BinlogEvent, query []byte, es *ghostferry } func AfterInitialize(f *tf.IntegrationFerry) error { - f.Ferry.BinlogStreamer.AddBinlogEventHandler("QueryEvent", queryEventHandler) + _ := f.Ferry.BinlogStreamer.AddBinlogEventHandler(replication.QUERY_EVENT, queryEventHandler) return nil }