Skip to content

Commit

Permalink
feat(plc4go/opcua): implement keepalive and connect event
Browse files Browse the repository at this point in the history
  • Loading branch information
sruehl committed Jul 28, 2023
1 parent 36673bd commit fb1a6d6
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 15 deletions.
31 changes: 30 additions & 1 deletion plc4go/internal/opcua/MessageCodec.go
Expand Up @@ -22,12 +22,14 @@ package opcua
import (
"context"
"encoding/binary"
"sync"
"time"

readWriteModel "github.com/apache/plc4x/plc4go/protocols/opcua/readwrite/model"
"github.com/apache/plc4x/plc4go/spi"
"github.com/apache/plc4x/plc4go/spi/default"
"github.com/apache/plc4x/plc4go/spi/options"
"github.com/apache/plc4x/plc4go/spi/transports"
"sync"

"github.com/pkg/errors"
"github.com/rs/zerolog"
Expand All @@ -39,6 +41,9 @@ type MessageCodec struct {

channel *SecureChannel

connectEvent chan struct{}
connectTimeout time.Duration // TODO: do we need to have that in general, where to get that from

stateChange sync.Mutex

passLogToModel bool `ignore:"true"`
Expand All @@ -50,6 +55,8 @@ func NewMessageCodec(transportInstance transports.TransportInstance, channel *Se
customLogger := options.ExtractCustomLoggerOrDefaultToGlobal(_options...)
codec := &MessageCodec{
channel: channel,
connectEvent: make(chan struct{}),
connectTimeout: 5 * time.Second,
passLogToModel: passLoggerToModel,
log: customLogger,
}
Expand All @@ -71,9 +78,31 @@ func (m *MessageCodec) ConnectWithContext(ctx context.Context) error {
}
m.log.Debug().Msg("Opcua Driver running in ACTIVE mode.")
m.channel.onConnect(ctx, m)

connectTimeout := time.NewTimer(m.connectTimeout)
select {
case <-m.connectEvent:
m.log.Info().Msg("connected")
case <-connectTimeout.C:
return errors.Errorf("timeout after %s", m.connectTimeout)
}
return nil
}

func (m *MessageCodec) fireConnected() {
close(m.connectEvent)
}

func (m *MessageCodec) Disconnect() error {
// TODO: implement me, e.g. wait group above or something
// TODO: on Disconecct
return m.DefaultCodec.Disconnect()
}

func (m *MessageCodec) fireDisconnected() {
// TODO: implement me, e.g. wait group above or something
}

func (m *MessageCodec) Send(message spi.Message) error {
m.log.Trace().Msgf("Sending message\n%s", message)
// Cast the message to the correct type of struct
Expand Down
182 changes: 168 additions & 14 deletions plc4go/internal/opcua/SecureChannel.go
Expand Up @@ -23,23 +23,25 @@ import (
"bytes"
"context"
"encoding/binary"
"github.com/apache/plc4x/plc4go/spi"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
"math"
"math/rand"
"net"
"net/url"
"regexp"
"strconv"
"sync"
"sync/atomic"
"time"

apiModel "github.com/apache/plc4x/plc4go/pkg/api/model"
readWriteModel "github.com/apache/plc4x/plc4go/protocols/opcua/readwrite/model"
"github.com/apache/plc4x/plc4go/spi"
"github.com/apache/plc4x/plc4go/spi/utils"

"github.com/dchest/uniuri"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
)

const (
Expand Down Expand Up @@ -83,7 +85,7 @@ var (
APPLICATION_URI = readWriteModel.NewPascalString("urn:apache:plc4x:client")
PRODUCT_URI = readWriteModel.NewPascalString("urn:apache:plc4x:client")
APPLICATION_TEXT = readWriteModel.NewPascalString("OPCUA client for the Apache PLC4X:PLC4J project")
DEFAULT_CONNECTION_LIFETIME = 36000000
DEFAULT_CONNECTION_LIFETIME = uint32(36000000)
)

type SecureChannel struct {
Expand Down Expand Up @@ -114,8 +116,10 @@ type SecureChannel struct {
authenticationToken readWriteModel.NodeIdTypeDefinition
codec *MessageCodec
channelTransactionManager *SecureChannelTransactionManager
lifetime int
keepAlive func()
lifetime uint32
keepAliveStateChange sync.Mutex
keepAliveIndicator atomic.Bool
keepAliveWg sync.WaitGroup
sendBufferSize int
maxMessageSize int
endpoints []string
Expand Down Expand Up @@ -226,8 +230,7 @@ func (s *SecureChannel) submit(ctx context.Context, codec *MessageCodec, errorDi
messageBuffer = opcuaResponse.GetMessage()
if !(s.senderSequenceNumber.Add(1) == (opcuaResponse.GetSequenceNumber())) {
s.log.Error().Msgf("Sequence number isn't as expected, we might have missed a packet. - %d != %d", s.senderSequenceNumber.Add(1), opcuaResponse.GetSequenceNumber())
// TODO: where to dispatch the disconnect too
// codec.fireDisconnected()
codec.fireDisconnected()
}
}
return true
Expand Down Expand Up @@ -654,8 +657,8 @@ func (s *SecureChannel) onConnectActivateSessionRequest(ctx context.Context, cod
}

// Send an event that connection setup is complete.
s.keepAlive = s.createKeepAlive()
// codec.fireConnected()// TODO: how to do that
s.keepAlive()
codec.fireConnected()
} else {
serviceFault := unknownExtensionObject.(readWriteModel.ServiceFault)
header := serviceFault.GetResponseHeader().(readWriteModel.ResponseHeader)
Expand All @@ -676,9 +679,160 @@ func (s *SecureChannel) getRequestHandle() uint32 {
return s.requestHandleGenerator.Add(1)
}

func (s *SecureChannel) createKeepAlive() func() {
//TODO big wip: look for keepalive method not sure how to implement that properly
return nil
func (s *SecureChannel) keepAlive() {
s.keepAliveStateChange.Lock()
defer s.keepAliveStateChange.Unlock()
if s.keepAliveIndicator.Load() {
s.log.Warn().Msg("keepalive already running")
return
}
s.keepAliveWg.Add(1)
go func() {
defer s.keepAliveWg.Done()
s.keepAliveIndicator.Store(true)
defer s.keepAliveIndicator.Store(false)
ctx := context.Background()
for s.codec == nil || s.codec.IsRunning() {
sleepTime := time.Duration(math.Ceil(float64(s.lifetime)*0.75)) * time.Millisecond
s.log.Trace().Dur("sleepTime", sleepTime).Msg("Sleeping")
time.Sleep(sleepTime)

transactionId := s.channelTransactionManager.getTransactionIdentifier()

requestHeader := readWriteModel.NewRequestHeader(readWriteModel.NewNodeId(s.authenticationToken),
s.getCurrentDateTime(),
0, //RequestHandle
0,
NULL_STRING,
REQUEST_TIMEOUT_LONG,
NULL_EXTENSION_OBJECT)

var openSecureChannelRequest readWriteModel.OpenSecureChannelRequest
if s.isEncrypted {
openSecureChannelRequest = readWriteModel.NewOpenSecureChannelRequest(
requestHeader,
VERSION,
readWriteModel.SecurityTokenRequestType_securityTokenRequestTypeIssue,
readWriteModel.MessageSecurityMode_messageSecurityModeSignAndEncrypt,
readWriteModel.NewPascalByteString(int32(len(s.clientNonce)), s.clientNonce),
uint32(s.lifetime))
} else {
openSecureChannelRequest = readWriteModel.NewOpenSecureChannelRequest(
requestHeader,
VERSION,
readWriteModel.SecurityTokenRequestType_securityTokenRequestTypeIssue,
readWriteModel.MessageSecurityMode_messageSecurityModeNone,
NULL_BYTE_STRING,
uint32(s.lifetime))
}
identifier, err := strconv.ParseUint(openSecureChannelRequest.GetIdentifier(), 10, 16)
if err != nil {
s.log.Error().Err(err).Msg("error parsing identifier")
return
}

expandedNodeId := readWriteModel.NewExpandedNodeId(false, //Namespace Uri Specified
false, //Server Index Specified
readWriteModel.NewNodeIdFourByte(0, uint16(identifier)),
nil,
nil)

extObject := readWriteModel.NewExtensionObject(
expandedNodeId,
nil,
openSecureChannelRequest,
false,
)

buffer := utils.NewWriteBufferByteBased(utils.WithByteOrderForByteBasedBuffer(binary.LittleEndian))
if err := extObject.SerializeWithWriteBuffer(ctx, buffer); err != nil {
s.log.Error().Err(err).Msg("error serializing")
return
}

openRequest := readWriteModel.NewOpcuaOpenRequest(
FINAL_CHUNK,
0,
readWriteModel.NewPascalString(s.securityPolicy),
s.publicCertificate,
s.thumbprint,
transactionId,
transactionId,
buffer.GetBytes(),
)

var apu readWriteModel.OpcuaAPU

if s.isEncrypted {
apu, err = readWriteModel.OpcuaAPUParse(ctx, s.encryptionHandler.encodeMessage(openRequest, buffer.GetBytes()), false)
if err != nil {
s.log.Error().Err(err).Msg("error parsing")
return
}
} else {
apu = readWriteModel.NewOpcuaAPU(openRequest, false)
}

requestConsumer := func(transactionId int32) {
if err := s.codec.SendRequest(
ctx,
apu,
func(message spi.Message) bool {
opcuaAPU, ok := message.(readWriteModel.OpcuaAPUExactly)
if !ok {
s.log.Debug().Type("type", message).Msg("Not relevant")
return false
}
messagePDU := opcuaAPU.GetMessage()
openResponse, ok := messagePDU.(readWriteModel.OpcuaOpenResponseExactly)
if !ok {
s.log.Debug().Type("type", messagePDU).Msg("Not relevant")
return false
}
return openResponse.GetRequestId() == transactionId
},
func(message spi.Message) error {
opcuaAPU := message.(readWriteModel.OpcuaAPU)
messagePDU := opcuaAPU.GetMessage()
opcuaOpenResponse := messagePDU.(readWriteModel.OpcuaOpenResponse)
readBuffer := utils.NewReadBufferByteBased(opcuaOpenResponse.GetMessage(), utils.WithByteOrderForReadBufferByteBased(binary.LittleEndian))
extensionObject, err := readWriteModel.ExtensionObjectParseWithBuffer(ctx, readBuffer, false)
if err != nil {
return errors.Wrap(err, "error parsing")
}

if fault, ok := extensionObject.GetBody().(readWriteModel.ServiceFaultExactly); ok {
statusCode := fault.GetResponseHeader().(readWriteModel.ResponseHeader).GetServiceResult().GetStatusCode()
statusCodeByValue, _ := readWriteModel.OpcuaStatusCodeByValue(statusCode)
s.log.Error().Msgf("Failed to connect to opc ua server for the following reason:- %v, %v",
statusCode,
statusCodeByValue)
} else {
s.log.Debug().Msg("Got Secure Response Connection Response")
openSecureChannelResponse := extensionObject.GetBody().(readWriteModel.OpenSecureChannelResponse)
token := openSecureChannelResponse.GetSecurityToken().(readWriteModel.ChannelSecurityToken)
s.tokenId.Store(int32(token.GetTokenId())) // TODO: strange that int32 and uint32 missmatch
s.channelId.Store(int32(token.GetChannelId()))
s.lifetime = token.GetRevisedLifetime()
}
return nil
},
func(err error) error {
s.log.Debug().Err(err).Msg("error submitting")
return nil
},
REQUEST_TIMEOUT,
); err != nil {
s.log.Debug().Err(err).Msg("a error")
}
}
s.log.Debug().Msgf("Submitting OpenSecureChannel with id of %d", transactionId)
if err := s.channelTransactionManager.submit(requestConsumer, transactionId); err != nil {
s.log.Debug().Err(err).Msg("error submitting")
}
}
}()
return
}

func (s *SecureChannel) selectEndpoint(sessionResponse readWriteModel.CreateSessionResponse) {
Expand Down

0 comments on commit fb1a6d6

Please sign in to comment.