Skip to content

Commit

Permalink
refactor(plc4go/opcua): restructure secure channel ownership
Browse files Browse the repository at this point in the history
  • Loading branch information
sruehl committed Aug 2, 2023
1 parent 51589ed commit 4cf782b
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 174 deletions.
70 changes: 51 additions & 19 deletions plc4go/internal/opcua/Connection.go
Expand Up @@ -22,7 +22,7 @@ package opcua
import (
"context"
"runtime/debug"
"sync"
"time"

"github.com/apache/plc4x/plc4go/pkg/api"
apiModel "github.com/apache/plc4x/plc4go/pkg/api/model"
Expand All @@ -45,7 +45,12 @@ type Connection struct {
configuration Configuration `stringer:"true"`
driverContext DriverContext `stringer:"true"`

handlerWaitGroup sync.WaitGroup
channel *SecureChannel

connectEvent chan struct{}
connectTimeout time.Duration `stringer:"true"` // TODO: do we need to have that in general, where to get that from
disconnectEvent chan struct{}
disconnectTimeout time.Duration `stringer:"true"` // TODO: do we need to have that in general, where to get that from

connectionId string
tracer tracer.Tracer
Expand All @@ -57,12 +62,16 @@ type Connection struct {
func NewConnection(messageCodec *MessageCodec, configuration Configuration, driverContext DriverContext, tagHandler spi.PlcTagHandler, connectionOptions map[string][]string, _options ...options.WithOption) *Connection {
customLogger := options.ExtractCustomLoggerOrDefaultToGlobal(_options...)
connection := &Connection{
messageCodec: messageCodec,
configuration: configuration,
driverContext: driverContext,

log: customLogger,
_options: _options,
messageCodec: messageCodec,
configuration: configuration,
driverContext: driverContext,
channel: NewSecureChannel(customLogger, driverContext, configuration),
connectEvent: make(chan struct{}),
connectTimeout: 5 * time.Second,
disconnectEvent: make(chan struct{}),
disconnectTimeout: 5 * time.Second,
log: customLogger,
_options: _options,
}
if traceEnabledOption, ok := connectionOptions["traceEnabled"]; ok {
if len(traceEnabledOption) == 1 {
Expand Down Expand Up @@ -132,10 +141,15 @@ func (c *Connection) Close() <-chan plc4go.PlcConnectionCloseResult {
results := make(chan plc4go.PlcConnectionCloseResult, 1)
go func() {
result := <-c.DefaultConnection.Close()
c.log.Trace().Msg("Waiting for handlers to stop")
c.handlerWaitGroup.Wait()
c.log.Trace().Msg("handlers stopped, dispatching result")
results <- result
c.channel.onDisconnect(context.Background(), c)
disconnectTimeout := time.NewTimer(c.disconnectTimeout)
select {
case <-c.disconnectEvent:
c.log.Info().Msg("disconnected")
results <- result
case <-disconnectTimeout.C:
results <- _default.NewDefaultPlcConnectionCloseResult(c, errors.Errorf("timeout after %s", c.disconnectTimeout))
}
}()
return results
}
Expand All @@ -153,14 +167,14 @@ func (c *Connection) ReadRequestBuilder() apiModel.PlcReadRequestBuilder {
return spiModel.NewDefaultPlcReadRequestBuilder(
c.GetPlcTagHandler(),
NewReader(
c.messageCodec,
c,
append(c._options, options.WithCustomLogger(c.log))...,
),
)
}

func (c *Connection) WriteRequestBuilder() apiModel.PlcWriteRequestBuilder {
return spiModel.NewDefaultPlcWriteRequestBuilder(c.GetPlcTagHandler(), c.GetPlcValueHandler(), NewWriter(c.messageCodec))
return spiModel.NewDefaultPlcWriteRequestBuilder(c.GetPlcTagHandler(), c.GetPlcValueHandler(), NewWriter(c))
}

func (c *Connection) SubscriptionRequestBuilder() apiModel.PlcSubscriptionRequestBuilder {
Expand All @@ -169,7 +183,7 @@ func (c *Connection) SubscriptionRequestBuilder() apiModel.PlcSubscriptionReques
c.GetPlcValueHandler(),
NewSubscriber(
c.addSubscriber,
c.messageCodec,
c,
append(c._options, options.WithCustomLogger(c.log))...,
),
)
Expand All @@ -190,10 +204,25 @@ func (c *Connection) addSubscriber(subscriber *Subscriber) {
c.subscribers = append(c.subscribers, subscriber)
}

func (c *Connection) setupConnection(_ context.Context, ch chan plc4go.PlcConnectionConnectResult) {
c.log.Trace().Msg("Connection setup done")
c.fireConnected(ch)
c.log.Trace().Msg("Connect fired")
func (c *Connection) setupConnection(ctx context.Context, ch chan plc4go.PlcConnectionConnectResult) {
c.log.Trace().Msg("setup connection")

c.log.Debug().Msg("Opcua Driver running in ACTIVE mode.")
c.channel.onConnect(ctx, c, ch)

connectTimeout := time.NewTimer(c.connectTimeout)
select {
case <-c.connectEvent:
c.log.Info().Msg("connected")
c.fireConnected(ch)
c.log.Trace().Msg("Connect fired")
case <-connectTimeout.C:
c.fireConnectionError(errors.Errorf("timeout after %s", c.connectTimeout), ch)
c.log.Trace().Msg("connection error fired")
return
}

c.log.Trace().Msg("connection setup done")
}

func (c *Connection) fireConnectionError(err error, ch chan<- plc4go.PlcConnectionConnectResult) {
Expand All @@ -206,6 +235,8 @@ func (c *Connection) fireConnectionError(err error, ch chan<- plc4go.PlcConnecti
if err := c.messageCodec.Disconnect(); err != nil {
c.log.Debug().Err(err).Msg("Error disconnecting message codec on connection error")
}
c.SetConnected(false)
close(c.disconnectEvent)
}

func (c *Connection) fireConnected(ch chan<- plc4go.PlcConnectionConnectResult) {
Expand All @@ -216,4 +247,5 @@ func (c *Connection) fireConnected(ch chan<- plc4go.PlcConnectionConnectResult)
c.log.Info().Msg("Successfully connected")
}
c.SetConnected(true)
close(c.connectEvent)
}
25 changes: 25 additions & 0 deletions plc4go/internal/opcua/Connection_plc4xgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion plc4go/internal/opcua/Driver.go
Expand Up @@ -145,7 +145,6 @@ func (d *Driver) GetConnectionWithContext(ctx context.Context, transportUrl url.

codec := NewMessageCodec(
transportInstance,
NewSecureChannel(d.log, driverContext, configuration),
append(d._options, options.WithCustomLogger(d.log))...,
)
d.log.Debug().Stringer("codec", codec).Msg("working with codec")
Expand Down
62 changes: 4 additions & 58 deletions plc4go/internal/opcua/MessageCodec.go
Expand Up @@ -22,14 +22,12 @@ package opcua
import (
"context"
"encoding/binary"
"sync"
"time"

readWriteModel "github.com/apache/plc4x/plc4go/protocols/opcua/readwrite/model"
"github.com/apache/plc4x/plc4go/spi"
"github.com/apache/plc4x/plc4go/spi/default"
"github.com/apache/plc4x/plc4go/spi/options"
"github.com/apache/plc4x/plc4go/spi/transports"
"sync"

"github.com/pkg/errors"
"github.com/rs/zerolog"
Expand All @@ -39,30 +37,18 @@ import (
type MessageCodec struct {
_default.DefaultCodec

channel *SecureChannel

connectEvent chan struct{}
connectTimeout time.Duration `stringer:"true"` // TODO: do we need to have that in general, where to get that from
disconnectEvent chan struct{}
disconnectTimeout time.Duration `stringer:"true"` // TODO: do we need to have that in general, where to get that from

stateChange sync.Mutex

passLogToModel bool `ignore:"true"`
log zerolog.Logger `ignore:"true"`
}

func NewMessageCodec(transportInstance transports.TransportInstance, channel *SecureChannel, _options ...options.WithOption) *MessageCodec {
func NewMessageCodec(transportInstance transports.TransportInstance, _options ...options.WithOption) *MessageCodec {
passLoggerToModel, _ := options.ExtractPassLoggerToModel(_options...)
customLogger := options.ExtractCustomLoggerOrDefaultToGlobal(_options...)
codec := &MessageCodec{
channel: channel,
connectEvent: make(chan struct{}),
connectTimeout: 5 * time.Second,
disconnectEvent: make(chan struct{}),
disconnectTimeout: 5 * time.Second,
passLogToModel: passLoggerToModel,
log: customLogger,
passLogToModel: passLoggerToModel,
log: customLogger,
}
codec.DefaultCodec = _default.NewDefaultCodec(codec, transportInstance, _options...)
return codec
Expand All @@ -76,46 +62,6 @@ func (m *MessageCodec) Connect() error {
return m.ConnectWithContext(context.Background())
}

func (m *MessageCodec) ConnectWithContext(ctx context.Context) error {
m.log.Trace().Msg("connecting")
if err := m.DefaultCodec.ConnectWithContext(ctx); err != nil {
return errors.Wrap(err, "error connecting default codec")
}
m.log.Debug().Msg("Opcua Driver running in ACTIVE mode.")
m.channel.onConnect(ctx, m)

connectTimeout := time.NewTimer(m.connectTimeout)
select {
case <-m.connectEvent:
m.log.Info().Msg("connected")
case <-connectTimeout.C:
return errors.Errorf("timeout after %s", m.connectTimeout)
}
return nil
}

func (m *MessageCodec) fireConnected() {
m.log.Trace().Msg("fire connected event")
close(m.connectEvent)
}

func (m *MessageCodec) Disconnect() error {
m.log.Trace().Msg("disconnecting")
m.channel.onDisconnect(context.Background(), m)
disconnectTimeout := time.NewTimer(m.disconnectTimeout)
select {
case <-m.disconnectEvent:
m.log.Info().Msg("disconnected")
case <-disconnectTimeout.C:
return errors.Errorf("timeout after %s", m.disconnectTimeout)
}
return m.DefaultCodec.Disconnect()
}

func (m *MessageCodec) fireDisconnected() {
close(m.disconnectEvent)
}

func (m *MessageCodec) Send(message spi.Message) error {
m.log.Trace().Stringer("message", message).Msg("Sending message")
// Cast the message to the correct type of struct
Expand Down
25 changes: 0 additions & 25 deletions plc4go/internal/opcua/MessageCodec_plc4xgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions plc4go/internal/opcua/Reader.go
Expand Up @@ -35,15 +35,15 @@ import (
)

type Reader struct {
messageCodec *MessageCodec
connection *Connection

log zerolog.Logger
}

func NewReader(messageCodec *MessageCodec, _options ...options.WithOption) *Reader {
func NewReader(connection *Connection, _options ...options.WithOption) *Reader {
customLogger := options.ExtractCustomLoggerOrDefaultToGlobal(_options...)
return &Reader{
messageCodec: messageCodec,
connection: connection,

log: customLogger,
}
Expand All @@ -64,9 +64,9 @@ func (m *Reader) readSync(ctx context.Context, readRequest apiModel.PlcReadReque
}()

requestHeader := readWriteModel.NewRequestHeader(
m.messageCodec.channel.getAuthenticationToken(),
m.messageCodec.channel.getCurrentDateTime(),
m.messageCodec.channel.getRequestHandle(),
m.connection.channel.getAuthenticationToken(),
m.connection.channel.getCurrentDateTime(),
m.connection.channel.getRequestHandle(),
0,
NULL_STRING,
REQUEST_TIMEOUT_LONG,
Expand Down Expand Up @@ -149,5 +149,5 @@ func (m *Reader) readSync(ctx context.Context, readRequest apiModel.PlcReadReque
result <- spiModel.NewDefaultPlcReadRequestResult(readRequest, nil, err)
}

m.messageCodec.channel.submit(ctx, m.messageCodec, errorDispatcher, consumer, buffer)
m.connection.channel.submit(ctx, m.connection.messageCodec, errorDispatcher, consumer, buffer)
}

0 comments on commit 4cf782b

Please sign in to comment.