Skip to content

Commit

Permalink
Modifications to accomodate server-side changes for version compatibi…
Browse files Browse the repository at this point in the history
…lity check between client and server.
  • Loading branch information
aaa3schavan committed Mar 21, 2019
1 parent f4001a6 commit 13ad36d
Show file tree
Hide file tree
Showing 11 changed files with 535 additions and 24 deletions.
23 changes: 17 additions & 6 deletions client/goAPI/channel/AbstractChannel.go
Expand Up @@ -78,22 +78,23 @@ var ConnectionsToChannel int32
type AbstractChannel struct {
AuthToken int64
ChannelProperties *utils.SortedProperties
ChannelUrl *LinkUrl
ClientId string
ConnectionIndex int
InboxAddress string
NeedsPing bool
NumOfConnections int32
LastActiveTime time.Time
LinkState types.LinkState
ChannelUrl *LinkUrl
PrimaryUrl *LinkUrl
Reader *ChannelReader
RequestId int64
Responses map[int64]types.TGChannelResponse
SessionId int64
Reader *ChannelReader
sendLock sync.Mutex // reentrant-lock for synchronizing sending/receiving messages over the wire
exceptionLock sync.Mutex // reentrant-lock for synchronizing sending/receiving messages over the wire
exceptionCond *sync.Cond // Condition for lock
sendLock sync.Mutex // reentrant-lock for synchronizing sending/receiving messages over the wire
//tracer types.Tracer // Used for tracing the information flow during the execution
}

func DefaultAbstractChannel() *AbstractChannel {
Expand Down Expand Up @@ -124,6 +125,17 @@ func NewAbstractChannel(linkUrl *LinkUrl, props *utils.SortedProperties) *Abstra
newChannel.ChannelUrl = linkUrl
newChannel.PrimaryUrl = linkUrl
newChannel.ChannelProperties = props
//enableTraceFlag := newChannel.ChannelProperties.GetPropertyAsBoolean(utils.GetConfigFromKey(utils.EnableConnectionTrace))
//if enableTraceFlag {
// traceDir := newChannel.ChannelProperties.GetProperty(utils.GetConfigFromKey(utils.ConnectionTraceDir), ".")
// clientId := newChannel.ChannelProperties.GetProperty(utils.GetConfigFromKey(utils.ChannelClientId), "")
// tracer, err := NewChannelTracer(clientId, traceDir)
// if err != nil {
// enableTraceFlag = false
// } else {
// newChannel.tracer = tracer
// }
//}
return newChannel
}

Expand Down Expand Up @@ -523,9 +535,8 @@ func channelSendRequest(obj types.TGChannel, msg types.TGMessage, channelRespons
errMsg := fmt.Sprintf("AbstractChannel:channelSendRequest - Channel is closed")
return nil, exception.GetErrorByType(types.TGErrorGeneralException, types.TGDB_CHANNEL_ERROR, errMsg, "")
}
//if ! channelResponse.IsBlocking() {
// logger.Error(fmt.Sprint("ERROR: Returning AbstractChannel:channelSendRequest as channel response is NOT blocking"))
// return nil, nil
//if obj.tracer != nil {
// obj.tracer.Trace(msg)
//}
obj.ChannelLock()
logger.Log(fmt.Sprintf("Inside AbstractChannel:channelSendRequest about to set channel response '%+v' in map '%+v'", channelResponse, obj.GetResponses()))
Expand Down
37 changes: 35 additions & 2 deletions client/goAPI/channel/SslChannel.go
Expand Up @@ -279,8 +279,16 @@ func (obj *SSLChannel) performHandshake(sslMode bool) types.TGError {
return exception.NewTGGeneralException(types.TGDB_HNDSHKRESP_ERROR, types.TGErrorGeneralException, errMsg, "")
}

challenge := response.GetChallenge()
challenge = challenge * 2 / 3
// Validate the version specific information on the response object
serverVersion := response.GetChallenge()
clientVersion := utils.GetClientVersion()
err = obj.validateHandshakeResponseVersion(serverVersion, clientVersion)
if err != nil {
logger.Error(fmt.Sprintf("ERROR: Returning TCPChannel::performHandshake validateHandshakeResponseVersion failed w/ '%+v'", err.Error()))
return err
}

challenge := clientVersion.GetVersionAsLong()

// Ignore Error Handling
_ = msgRequest.(*pdu.HandShakeRequestMessage).UpdateSequenceAndTimeStamp(-1)
Expand Down Expand Up @@ -380,6 +388,23 @@ func (obj *SSLChannel) tryRead() (types.TGMessage, types.TGError) {
return obj.ReadWireMsg()
}

func (obj *SSLChannel) validateHandshakeResponseVersion(sVersion int64, cVersion *utils.TGClientVersion) types.TGError {
serverVersion := utils.NewTGServerVersion(sVersion)
sStrVer := serverVersion.GetVersionString()

cStrVer := cVersion.GetVersionString()

if serverVersion.GetMajor() == cVersion.GetMajor() &&
serverVersion.GetMajor() == cVersion.GetMajor() &&
serverVersion.GetMajor() == cVersion.GetMajor() {
return nil
}

errMsg := fmt.Sprintf("======> Inside SSLChannel:validateHandshakeResponseVersion - Version mismatch between client(%s) & server(%s)", cStrVer, sStrVer)
logger.Log(errMsg)
return exception.GetErrorByType(types.TGErrorVersionMismatchException, "", errMsg, "")
}

func (obj *SSLChannel) writeLoop(done chan bool) {
logger.Log(fmt.Sprint("======> Entering SSLChannel:writeLoop"))
for {
Expand Down Expand Up @@ -822,6 +847,14 @@ func (obj *SSLChannel) ReadWireMsg() (types.TGMessage, types.TGError) {
errMsg := msg.(*pdu.ExceptionMessage).GetExceptionMsg()
return nil, exception.GetErrorByType(types.TGErrorGeneralException, "", errMsg, "")
}

//if msg.GetVerbId() == pdu.VerbHandShakeResponse {
// if msg.GetResponseStatus() == pdu.ResponseChallengeFailed {
// errMsg := msg.GetErrorMessage()
// logger.Error(fmt.Sprintf("ERROR: Returning TCPChannel::ReadWireMsg msg.GetVerbId() == pdu.VerbHandShakeResponse w/ '%+v'", errMsg))
// return nil, exception.GetErrorByType(types.TGErrorVersionMismatchException, "", errMsg, "")
// }
//}
logger.Log(fmt.Sprintf("======> Returning SSLChannel:ReadWireMsg w/ Message as '%+v'", msg.String()))
return msg, nil
}
Expand Down
37 changes: 35 additions & 2 deletions client/goAPI/channel/TcpChannel.go
Expand Up @@ -145,8 +145,16 @@ func (obj *TCPChannel) performHandshake(sslMode bool) types.TGError {
return exception.NewTGGeneralException(types.TGDB_HNDSHKRESP_ERROR, types.TGErrorGeneralException, errMsg, "")
}

challenge := response.GetChallenge()
challenge = challenge * 2 / 3
// Validate the version specific information on the response object
serverVersion := response.GetChallenge()
clientVersion := utils.GetClientVersion()
err = obj.validateHandshakeResponseVersion(serverVersion, clientVersion)
if err != nil {
logger.Error(fmt.Sprintf("ERROR: Returning TCPChannel::performHandshake validateHandshakeResponseVersion failed w/ '%+v'", err.Error()))
return err
}

challenge := clientVersion.GetVersionAsLong()

// Ignore Error Handling
_ = msgRequest.(*pdu.HandShakeRequestMessage).UpdateSequenceAndTimeStamp(-1)
Expand Down Expand Up @@ -253,6 +261,23 @@ func (obj *TCPChannel) tryRead() (types.TGMessage, types.TGError) {
return obj.ReadWireMsg()
}

func (obj *TCPChannel) validateHandshakeResponseVersion(sVersion int64, cVersion *utils.TGClientVersion) types.TGError {
serverVersion := utils.NewTGServerVersion(sVersion)
sStrVer := serverVersion.GetVersionString()

cStrVer := cVersion.GetVersionString()

if serverVersion.GetMajor() == cVersion.GetMajor() &&
serverVersion.GetMinor() == cVersion.GetMinor() &&
serverVersion.GetUpdate() == cVersion.GetUpdate() {
return nil
}

errMsg := fmt.Sprintf("======> Inside SSLChannel:validateHandshakeResponseVersion - Version mismatch between client(%s) & server(%s)", cStrVer, sStrVer)
logger.Log(errMsg)
return exception.GetErrorByType(types.TGErrorVersionMismatchException, "", errMsg, "")
}

func (obj *TCPChannel) writeLoop(done chan bool) {
logger.Log(fmt.Sprintf("======> Entering TCPChannel:writeLoop"))
for {
Expand Down Expand Up @@ -676,6 +701,14 @@ func (obj *TCPChannel) ReadWireMsg() (types.TGMessage, types.TGError) {
errMsg := msg.(*pdu.ExceptionMessage).GetExceptionMsg()
return nil, exception.GetErrorByType(types.TGErrorGeneralException, "", errMsg, "")
}

//if msg.GetVerbId() == pdu.VerbHandShakeResponse {
// if msg.GetResponseStatus() == pdu.ResponseChallengeFailed {
// errMsg := msg.GetErrorMessage()
// logger.Error(fmt.Sprintf("ERROR: Returning TCPChannel::ReadWireMsg msg.GetVerbId() == pdu.VerbHandShakeResponse w/ '%+v'", errMsg))
// return nil, exception.GetErrorByType(types.TGErrorVersionMismatchException, "", errMsg, "")
// }
//}
logger.Log(fmt.Sprintf("======> Returning TCPChannel:ReadWireMsg w/ Socket '%+v' and Message as '%+v'", obj.socket, msg.String()))
return msg, nil
}
Expand Down
4 changes: 4 additions & 0 deletions client/goAPI/exception/ExceptionFactory.go
Expand Up @@ -64,6 +64,8 @@ func CreateExceptionByType(excpTypeId int) types.TGError {
return DefaultTGTypeCoercionNotSupported()
case types.TGErrorTypeNotSupported:
return DefaultTGTypeNotSupported()
case types.TGErrorVersionMismatchException:
return DefaultTGVersionMismatchException()

case types.TGErrorInvalidErrorCode:
fallthrough
Expand Down Expand Up @@ -116,6 +118,8 @@ func GetErrorByType(excpTypeId int, errorCode, errorMsg, errorDetails string) ty
return NewTGTypeCoercionNotSupported(errorCode, excpTypeId, errorMsg, errorDetails)
case types.TGErrorTypeNotSupported:
return NewTGTypeNotSupported(errorCode, excpTypeId, errorMsg, errorDetails)
case types.TGErrorVersionMismatchException:
return NewTGVersionMismatchException(errorCode, excpTypeId, errorMsg, errorDetails)

case types.TGErrorInvalidErrorCode:
fallthrough
Expand Down
91 changes: 91 additions & 0 deletions client/goAPI/exception/TGVersionMismatchException.go
@@ -0,0 +1,91 @@
package exception

import (
"fmt"
"github.com/TIBCOSoftware/tgdb-client/client/goAPI/types"
)

/**
* Copyright 2018-19 TIBCO Software Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); You may not use this file except
* in compliance with the License.
* A copy of the License is included in the distribution package with this file.
* You also may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF DirectionAny KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* File name: TGErrorVersionMismatchException.go
* Created on: Feb 20, 2019
* Created by: achavan
* SVN id: $id: $
*
*/

type VersionMismatchException struct {
*types.TGDBError
}

// Create New VersionMismatchException Instance
func DefaultTGVersionMismatchException() *VersionMismatchException {
newException := VersionMismatchException{
TGDBError: types.DefaultTGDBError(),
}
newException.ErrorType = types.TGErrorVersionMismatchException
return &newException
}

func NewTGVersionMismatchException(eCode string, eType int, eMsg, eDetails string) *VersionMismatchException {
newException := DefaultTGVersionMismatchException()
newException.ErrorCode = eCode
newException.ErrorType = eType
newException.ErrorMsg = eMsg
newException.ErrorDetails = eDetails
return newException
}

func NewTGVersionMismatchExceptionAttr(attrTypeName string) *VersionMismatchException {
newException := DefaultTGVersionMismatchException()
newException.ErrorMsg = fmt.Sprintf("Attribute descriptor: '%s' not supported", attrTypeName)
return newException
}

func NewTGVersionMismatchExceptionWithMsg(msg string) *VersionMismatchException {
newException := DefaultTGVersionMismatchException()
newException.ErrorMsg = msg
return newException
}

/////////////////////////////////////////////////////////////////
// Implement functions from Interface ==> TGError
/////////////////////////////////////////////////////////////////

func (e *VersionMismatchException) GetErrorCode() string {
return e.ErrorCode
}

func (e *VersionMismatchException) GetErrorType() int {
return e.ErrorType
}

func (e *VersionMismatchException) GetErrorMsg() string {
return e.ErrorMsg
}

func (e *VersionMismatchException) GetErrorDetails() string {
return e.ErrorDetails
}

/////////////////////////////////////////////////////////////////
// Implement functions from Interface ==> error
/////////////////////////////////////////////////////////////////

func (e *VersionMismatchException) Error() string {
errMsg := fmt.Sprintf("ErrorCode: %s, ErrorType: %d, ErrorMessage: %s, ErrorDetails: %s", e.ErrorCode, e.ErrorType, e.ErrorMsg, e.ErrorDetails)
return errMsg
}

25 changes: 18 additions & 7 deletions client/goAPI/pdu/HandshakeRequest.go
Expand Up @@ -43,8 +43,9 @@ const (
type HandShakeRequestMessage struct {
*AbstractProtocolMessage
sslMode bool
challenge int
challenge int64
handshakeType int
version int64
}

func DefaultHandShakeRequestMessage() *HandShakeRequestMessage {
Expand All @@ -60,6 +61,7 @@ func DefaultHandShakeRequestMessage() *HandShakeRequestMessage {
newMsg.VerbId = VerbHandShakeRequest
newMsg.sslMode = false
newMsg.challenge = 0
newMsg.version = 0
newMsg.handshakeType = InvalidRequest
newMsg.BufLength = int(reflect.TypeOf(newMsg).Size())
return &newMsg
Expand All @@ -82,26 +84,34 @@ func (msg *HandShakeRequestMessage) GetSslMode() bool {
return msg.sslMode
}

func (msg *HandShakeRequestMessage) GetChallenge() int {
func (msg *HandShakeRequestMessage) GetChallenge() int64 {
return msg.challenge
}

func (msg *HandShakeRequestMessage) GetRequestType() int {
return msg.handshakeType
}

func (msg *HandShakeRequestMessage) GetVersion() int64 {
return msg.version
}

func (msg *HandShakeRequestMessage) SetSslMode(mode bool) {
msg.sslMode = mode
}

func (msg *HandShakeRequestMessage) SetChallenge(challenge int) {
func (msg *HandShakeRequestMessage) SetChallenge(challenge int64) {
msg.challenge = challenge
}

func (msg *HandShakeRequestMessage) SetRequestType(rType int) {
msg.handshakeType = rType
}

func (msg *HandShakeRequestMessage) SetVersion(version int64) {
msg.version = version
}

/////////////////////////////////////////////////////////////////
// Implement functions from Interface ==> TGMessage
/////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -240,6 +250,7 @@ func (msg *HandShakeRequestMessage) String() string {
buffer.WriteString(fmt.Sprintf("SslMode: %+v", msg.sslMode))
buffer.WriteString(fmt.Sprintf(", Challenge: %d", msg.challenge))
buffer.WriteString(fmt.Sprintf(", HandshakeType: %d", msg.handshakeType))
buffer.WriteString(fmt.Sprintf(", Version: %d", msg.version))
buffer.WriteString(fmt.Sprintf(", BufLength: %d", msg.BufLength))
strArray := []string{buffer.String(), msg.messageToString()+"}"}
msgStr := strings.Join(strArray, ", ")
Expand Down Expand Up @@ -281,7 +292,7 @@ func (msg *HandShakeRequestMessage) ReadPayload(is types.TGInputStream) types.TG
}
logger.Log(fmt.Sprintf("Inside HandShakeRequestMessage:ReadPayload read mode as '%+v'", mode))

challenge, err := is.(*iostream.ProtocolDataInputStream).ReadInt()
challenge, err := is.(*iostream.ProtocolDataInputStream).ReadLong()
if err != nil {
logger.Error(fmt.Sprint("ERROR: Returning HandShakeRequestMessage:ReadPayload w/ Error in reading challenge from message buffer"))
return err
Expand All @@ -301,7 +312,7 @@ func (msg *HandShakeRequestMessage) WritePayload(os types.TGOutputStream) types.
logger.Log(fmt.Sprintf("Entering HandShakeRequestMessage:WritePayload at output buffer position: '%d'", startPos))
os.(*iostream.ProtocolDataOutputStream).WriteByte(msg.GetRequestType())
os.(*iostream.ProtocolDataOutputStream).WriteBoolean(msg.GetSslMode())
os.(*iostream.ProtocolDataOutputStream).WriteInt(msg.GetChallenge())
os.(*iostream.ProtocolDataOutputStream).WriteLong(msg.GetChallenge())
currPos := os.GetPosition()
length := currPos - startPos
logger.Log(fmt.Sprintf("Returning HandShakeRequestMessage::WritePayload at output buffer position at: %d after writing %d payload bytes", currPos, length))
Expand All @@ -316,7 +327,7 @@ func (msg *HandShakeRequestMessage) MarshalBinary() ([]byte, error) {
// A simple encoding: plain text.
var b bytes.Buffer
_, err := fmt.Fprintln(&b, msg.BufLength, msg.VerbId, msg.SequenceNo, msg.Timestamp,
msg.RequestId, msg.DataOffset, msg.AuthToken, msg.SessionId, msg.IsUpdatable, msg.sslMode, msg.challenge, msg.handshakeType)
msg.RequestId, msg.DataOffset, msg.AuthToken, msg.SessionId, msg.IsUpdatable, msg.sslMode, msg.challenge, msg.handshakeType, msg.version)
if err != nil {
logger.Error(fmt.Sprintf("ERROR: Returning HandShakeRequestMessage:MarshalBinary w/ Error: '%+v'", err.Error()))
return nil, err
Expand All @@ -334,7 +345,7 @@ func (msg *HandShakeRequestMessage) UnmarshalBinary(data []byte) error {
b := bytes.NewBuffer(data)
_, err := fmt.Fscanln(b, &msg.BufLength, &msg.VerbId, &msg.SequenceNo,
&msg.Timestamp, &msg.RequestId, &msg.DataOffset, &msg.AuthToken, &msg.SessionId, &msg.IsUpdatable,
&msg.sslMode, &msg.challenge, &msg.handshakeType)
&msg.sslMode, &msg.challenge, &msg.handshakeType, &msg.version)
if err != nil {
logger.Error(fmt.Sprintf("ERROR: Returning HandShakeRequestMessage:UnmarshalBinary w/ Error: '%+v'", err.Error()))
return err
Expand Down

0 comments on commit 13ad36d

Please sign in to comment.