Skip to content

Commit

Permalink
fix(plc4go): ensure discoverer respect context cancel
Browse files Browse the repository at this point in the history
  • Loading branch information
sruehl committed Jan 25, 2024
1 parent b6fbfbc commit fae748c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 5 deletions.
30 changes: 30 additions & 0 deletions plc4go/internal/ads/Discoverer.go
Expand Up @@ -77,7 +77,13 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
deviceNames := options.FilterDiscoveryOptionsDeviceName(discoveryOptions)
if len(deviceNames) > 0 {
for _, curInterface := range allInterfaces {
if err := ctx.Err(); err != nil {
return err
}
for _, deviceNameOption := range deviceNames {
if err := ctx.Err(); err != nil {
return err
}
if curInterface.Name == deviceNameOption.GetDeviceName() {
interfaces = append(interfaces, curInterface)
break
Expand All @@ -91,6 +97,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
// Iterate over all selected network devices and filter out all the devices with IPv4 configured
var discoveryItems []*discovery
for _, interf := range interfaces {
if err := ctx.Err(); err != nil {
return err
}
addrs, err := interf.Addrs()
if err != nil {
return err
Expand All @@ -99,6 +108,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
// For ADS we're only interested in IPv4 addresses, as it doesn't
// seem to work with IPv6.
for _, addr := range addrs {
if err := ctx.Err(); err != nil {
return err
}
var ipv4Addr net.IP
switch addr.(type) {
// If the device is configured to communicate with a subnet
Expand Down Expand Up @@ -136,6 +148,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.

// Open a listening udp socket for each of the discoveryItems
for _, discoveryItem := range discoveryItems {
if err := ctx.Err(); err != nil {
return err
}
responseAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", discoveryItem.localAddress, model.AdsDiscoveryConstants_ADSDISCOVERYUDPDEFAULTPORT))
if err != nil {
return errors.Wrap(err, "error resolving udp")
Expand All @@ -158,6 +173,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}()
buf := make([]byte, 1024)
for {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(ctx.Err()).Msg("ending")
return
}
length, fromAddr, err := socket.ReadFromUDP(buf)
if length == 0 {
continue
Expand All @@ -181,6 +200,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
var versionBlock model.AdsDiscoveryBlockVersion
var fingerprintBlock model.AdsDiscoveryBlockFingerprint
for _, block := range discoveryResponse.GetBlocks() {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("ending")
return
}
switch block.GetBlockType() {
case model.AdsDiscoveryBlockType_HOST_NAME:
hostNameBlock = block.(model.AdsDiscoveryBlockHostName)
Expand Down Expand Up @@ -245,6 +268,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}
defer func() {
for _, discoveryItem := range discoveryItems {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("ending")
return
}
if discoveryItem.socket != nil {
if err := discoveryItem.socket.Close(); err != nil {
d.log.Debug().Err(err).Msg("errored")
Expand All @@ -258,6 +285,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.

// Iterate over all network devices of this system.
for _, discoveryItem := range discoveryItems {
if err := ctx.Err(); err != nil {
return err
}
// Prepare the discovery packet data
// Create the discovery request message for this device.
amsNetId := model.NewAmsNetId(discoveryItem.localAddress[0], discoveryItem.localAddress[1], discoveryItem.localAddress[2], discoveryItem.localAddress[3], uint8(1), uint8(1))
Expand Down
25 changes: 23 additions & 2 deletions plc4go/internal/bacnetip/Discoverer.go
Expand Up @@ -64,7 +64,7 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
return errors.Wrap(err, "error extracting protocol specific options")
}

communicationChannels, err := buildupCommunicationChannels(interfaces, specificOptions.bacNetPort)
communicationChannels, err := buildupCommunicationChannels(ctx, interfaces, specificOptions.bacNetPort)
if err != nil {
return errors.Wrap(err, "error building communication channels")
}
Expand All @@ -90,6 +90,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
func (d *Discoverer) broadcastAndDiscover(ctx context.Context, communicationChannels []communicationChannel, specificOptions *protocolSpecificOptions) (chan receivedBvlcMessage, error) {
incomingBVLCChannel := make(chan receivedBvlcMessage)
for _, communicationChannelInstance := range communicationChannels {
if err := ctx.Err(); err != nil {
return incomingBVLCChannel, err
}
// Prepare the discovery packet data
{
var lowLimit driverModel.BACnetContextTagUnsignedInteger
Expand Down Expand Up @@ -161,6 +164,10 @@ func (d *Discoverer) broadcastAndDiscover(ctx context.Context, communicationChan

go func(communicationChannelInstance communicationChannel) {
for {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("ending")
return
}
blockingReadChan := make(chan bool)
go func() {
buf := make([]byte, 4096)
Expand Down Expand Up @@ -197,6 +204,10 @@ func (d *Discoverer) broadcastAndDiscover(ctx context.Context, communicationChan

go func(communicationChannelInstance communicationChannel) {
for {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("ending")
return
}
blockingReadChan := make(chan bool)
go func() {
buf := make([]byte, 4096)
Expand Down Expand Up @@ -235,6 +246,10 @@ func (d *Discoverer) broadcastAndDiscover(ctx context.Context, communicationChan

func handleIncomingBVLCs(ctx context.Context, callback func(event apiModel.PlcDiscoveryItem), incomingBVLCChannel chan receivedBvlcMessage) {
for {
if err := ctx.Err(); err != nil {
// TODO: maybe we log something, but maybe it is fine
return
}
select {
case receivedBvlc := <-incomingBVLCChannel:
var npdu driverModel.NPDU
Expand Down Expand Up @@ -297,15 +312,21 @@ func handleIncomingBVLCs(ctx context.Context, callback func(event apiModel.PlcDi
}
}

func buildupCommunicationChannels(interfaces []net.Interface, bacNetPort int) (communicationChannels []communicationChannel, err error) {
func buildupCommunicationChannels(ctx context.Context, interfaces []net.Interface, bacNetPort int) (communicationChannels []communicationChannel, err error) {
// Iterate over all network devices of this system.
for _, networkInterface := range interfaces {
if err := ctx.Err(); err != nil {
return nil, err
}
unicastInterfaceAddress, err := networkInterface.Addrs()
if err != nil {
return nil, errors.Wrapf(err, "Error getting Addresses for %v", networkInterface)
}
// Iterate over all addresses the current interface has configured
for _, unicastAddress := range unicastInterfaceAddress {
if err := ctx.Err(); err != nil {
return nil, err
}
var ipAddr net.IP
switch addr := unicastAddress.(type) {
// If the device is configured to communicate with a subnet
Expand Down
17 changes: 16 additions & 1 deletion plc4go/internal/cbus/Discoverer.go
Expand Up @@ -84,6 +84,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
tcpTransport := tcp.NewTransport()
// Iterate over all network devices of this system.
for _, netInterface := range interfaces {
if err := ctx.Err(); err != nil {
return err
}
interfaceLog := d.log.With().Stringer("interface", netInterface).Logger()
interfaceLog.Debug().Msg("Scanning")
addrs, err := netInterface.Addrs()
Expand All @@ -105,6 +108,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
for _, addr := range addrs {
addressLogger := interfaceLog.With().Stringer("address", addr).Logger()
addressLogger.Debug().Msg("looking into")
if err := ctx.Err(); err != nil {
addressLogger.Debug().Err(err).Msg("ending")
return
}
var ipv4Addr net.IP
switch addr.(type) {
// If the device is configured to communicate with a subnet
Expand Down Expand Up @@ -141,6 +148,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}()
defer func() { wg.Done() }()
for ip := range addresses {
if err := ctx.Err(); err != nil {
addressLogger.Debug().Err(err).Msg("ending")
return
}
addressLogger.Trace().IPAddr("ip", ip).Msg("Handling found ip")
d.transportInstanceCreationQueue.Submit(
ctx,
Expand Down Expand Up @@ -177,12 +188,16 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}()
deviceScanWg := sync.WaitGroup{}
for transportInstance := range transportInstances {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("ending")
return
}
d.log.Debug().Stringer("transportInstance", transportInstance).Msg("submitting device scan")
completionFuture := d.deviceScanningQueue.Submit(ctx, d.deviceScanningWorkItemId.Add(1), d.createDeviceScanDispatcher(transportInstance.(*tcp.TransportInstance), callback))
deviceScanWg.Add(1)
go func() {
defer deviceScanWg.Done()
if err := completionFuture.AwaitCompletion(context.TODO()); err != nil {
if err := completionFuture.AwaitCompletion(ctx); err != nil {
d.log.Debug().Err(err).Msg("error waiting for completion")
}
}()
Expand Down
21 changes: 19 additions & 2 deletions plc4go/internal/knxnetip/Discoverer.go
Expand Up @@ -85,7 +85,13 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
deviceNames := options.FilterDiscoveryOptionsDeviceName(discoveryOptions)
if len(deviceNames) > 0 {
for _, curInterface := range allInterfaces {
if err := ctx.Err(); err != nil {
return err
}
for _, deviceNameOption := range deviceNames {
if err := ctx.Err(); err != nil {
return err
}
if curInterface.Name == deviceNameOption.GetDeviceName() {
interfaces = append(interfaces, curInterface)
break
Expand All @@ -100,6 +106,9 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
wg := &sync.WaitGroup{}
// Iterate over all network devices of this system.
for _, netInterface := range interfaces {
if err := ctx.Err(); err != nil {
return err
}
addrs, err := netInterface.Addrs()
if err != nil {
return err
Expand All @@ -119,6 +128,10 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
// For KNX we're only interested in IPv4 addresses, as it doesn't
// seem to work with IPv6.
for _, addr := range addrs {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("done")
return
}
var ipv4Addr net.IP
switch addr.(type) {
// If the device is configured to communicate with a subnet
Expand Down Expand Up @@ -156,7 +169,7 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}
}()
for transportInstance := range transportInstances {
d.deviceScanningQueue.Submit(ctx, d.deviceScanningWorkItemId.Add(1), d.createDeviceScanDispatcher(transportInstance.(*udp.TransportInstance), callback))
d.deviceScanningQueue.Submit(ctx, d.deviceScanningWorkItemId.Add(1), d.createDeviceScanDispatcher(ctx, transportInstance.(*udp.TransportInstance), callback))
}
}()
return nil
Expand Down Expand Up @@ -184,7 +197,7 @@ func (d *Discoverer) createTransportInstanceDispatcher(ctx context.Context, wg *
}
}

func (d *Discoverer) createDeviceScanDispatcher(udpTransportInstance *udp.TransportInstance, callback func(event apiModel.PlcDiscoveryItem)) pool.Runnable {
func (d *Discoverer) createDeviceScanDispatcher(ctx context.Context, udpTransportInstance *udp.TransportInstance, callback func(event apiModel.PlcDiscoveryItem)) pool.Runnable {
return func() {
d.log.Debug().Stringer("udpTransportInstance", udpTransportInstance).Msg("Scanning")
// Create a codec for sending and receiving messages.
Expand Down Expand Up @@ -216,6 +229,10 @@ func (d *Discoverer) createDeviceScanDispatcher(udpTransportInstance *udp.Transp
timeout := time.NewTimer(1 * time.Second)
timeout.Stop()
for start := time.Now(); time.Since(start) < time.Second*5; {
if err := ctx.Err(); err != nil {
d.log.Debug().Err(err).Msg("done")
return
}
timeout.Reset(1 * time.Second)
select {
case message := <-codec.GetDefaultIncomingMessageChannel():
Expand Down

0 comments on commit fae748c

Please sign in to comment.