diff --git a/go.mod b/go.mod index 74f08a2..334e63d 100644 --- a/go.mod +++ b/go.mod @@ -3,31 +3,22 @@ module github.com/ruelala/arconn go 1.21 require ( - github.com/aws/aws-sdk-go v1.49.13 github.com/aws/aws-sdk-go-v2 v1.24.0 github.com/aws/aws-sdk-go-v2/config v1.26.2 github.com/aws/aws-sdk-go-v2/credentials v1.16.13 github.com/aws/aws-sdk-go-v2/service/ec2 v1.142.0 github.com/aws/aws-sdk-go-v2/service/ecs v1.35.6 github.com/aws/aws-sdk-go-v2/service/ssm v1.44.6 + github.com/aws/session-manager-plugin v1.2.536 github.com/aws/smithy-go v1.19.0 github.com/buger/jsonparser v1.1.1 - github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 - github.com/eiannone/keyboard v0.0.0-20220611211555-0d226195f203 - github.com/fsnotify/fsnotify v1.7.0 - github.com/gorilla/websocket v1.5.1 github.com/integrii/flaggy v1.5.2 github.com/manifoldco/promptui v0.9.0 - github.com/stretchr/testify v1.8.4 - github.com/twinj/uuid v1.0.0 - github.com/xtaci/smux v1.5.24 - golang.org/x/crypto v0.17.0 golang.org/x/mod v0.14.0 - golang.org/x/sync v0.5.0 - golang.org/x/sys v0.15.0 ) require ( + github.com/aws/aws-sdk-go v1.49.14 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect @@ -38,14 +29,26 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 // indirect github.com/chzyer/readline v1.5.1 // indirect + github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/eiannone/keyboard v0.0.0-20220611211555-0d226195f203 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gorilla/websocket v1.5.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/myesui/uuid v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.1 // indirect + github.com/stretchr/testify v1.8.4 // indirect + github.com/twinj/uuid v1.0.0 // indirect + github.com/xtaci/smux v1.5.24 // indirect + golang.org/x/crypto v0.17.0 // indirect golang.org/x/net v0.19.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect gopkg.in/stretchr/testify.v1 v1.2.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/aws/session-manager-plugin => github.com/RueLaLa/session-manager-plugin v1.3.1 diff --git a/go.sum b/go.sum index 482222d..bfc77a8 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ -github.com/aws/aws-sdk-go v1.49.13 h1:f4mGztsgnx2dR9r8FQYa9YW/RsKb+N7bgef4UGrOW1Y= -github.com/aws/aws-sdk-go v1.49.13/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/RueLaLa/session-manager-plugin v1.3.1 h1:O85kF/AZcwcjJeEbg9m0YIuDJqOmQA0AN1oleaZ68eY= +github.com/RueLaLa/session-manager-plugin v1.3.1/go.mod h1:kpIkvJv/BH88I5YN4Qh9Geq9FbCa1D9WtiRZsF37PuU= +github.com/aws/aws-sdk-go v1.49.14 h1:AZ7wfESxXuqQElXRnDCaohJSUSaf2s7c2uPB7g5js/w= +github.com/aws/aws-sdk-go v1.49.14/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v1.24.0 h1:890+mqQ+hTpNuw0gGP6/4akolQkSToDJgHfQE7AwGuk= github.com/aws/aws-sdk-go-v2 v1.24.0/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= github.com/aws/aws-sdk-go-v2/config v1.26.2 h1:+RWLEIWQIGgrz2pBPAUoGgNGs1TOyF4Hml7hCnYj2jc= diff --git a/pkg/awsClients/ssm/ssm.go b/pkg/awsClients/ssm/ssm.go index 4511d95..8754ee2 100644 --- a/pkg/awsClients/ssm/ssm.go +++ b/pkg/awsClients/ssm/ssm.go @@ -9,11 +9,11 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session" + _ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/portsession" + _ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/shellsession" "github.com/manifoldco/promptui" "github.com/ruelala/arconn/pkg/awsClients/AwsConfig" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - _ "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession" - _ "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession" "github.com/ruelala/arconn/pkg/utils" ) diff --git a/pkg/session-manager-plugin/.gitignore b/pkg/session-manager-plugin/.gitignore deleted file mode 100644 index d614dde..0000000 --- a/pkg/session-manager-plugin/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -.DS_Store -.idea/ -*.iml -/Tools/bin/ -/Tools/pkg/ -/Tools/src/github.com/ -/Tools/src/golang.org/ -bin/ -build -/vendor/bin/ -/vendor/pkg/ -.vscode/ \ No newline at end of file diff --git a/pkg/session-manager-plugin/LICENSE b/pkg/session-manager-plugin/LICENSE deleted file mode 100644 index b45f513..0000000 --- a/pkg/session-manager-plugin/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/pkg/session-manager-plugin/README.md b/pkg/session-manager-plugin/README.md deleted file mode 100644 index bb94e79..0000000 --- a/pkg/session-manager-plugin/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# session-manager-plugin - -This package is a clone of the AWS [session-manager-plugin](https://github.com/aws/session-manager-plugin) with some unnecessary code stripped out and some manual fixes applied. diff --git a/pkg/session-manager-plugin/communicator/mocks/IWebSocketChannel.go b/pkg/session-manager-plugin/communicator/mocks/IWebSocketChannel.go deleted file mode 100644 index 2d89744..0000000 --- a/pkg/session-manager-plugin/communicator/mocks/IWebSocketChannel.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -// Code generated by mockery v1.0.0. DO NOT EDIT. -package mocks - -import ( - log "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - mock "github.com/stretchr/testify/mock" - - time "time" -) - -// IWebSocketChannel is an autogenerated mock type for the IWebSocketChannel type -type IWebSocketChannel struct { - mock.Mock -} - -// Close provides a mock function with given fields: _a0 -func (_m *IWebSocketChannel) Close(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// GetChannelToken provides a mock function with given fields: -func (_m *IWebSocketChannel) GetChannelToken() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -// GetStreamUrl provides a mock function with given fields: -func (_m *IWebSocketChannel) GetStreamUrl() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -// Initialize provides a mock function with given fields: _a0, channelUrl, channelToken -func (_m *IWebSocketChannel) Initialize(_a0 log.T, channelUrl string, channelToken string) { - _m.Called(_a0, channelUrl, channelToken) -} - -// Open provides a mock function with given fields: _a0 -func (_m *IWebSocketChannel) Open(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendMessage provides a mock function with given fields: _a0, input, inputType -func (_m *IWebSocketChannel) SendMessage(_a0 log.T, input []byte, inputType int) error { - ret := _m.Called(_a0, input, inputType) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, []byte, int) error); ok { - r0 = rf(_a0, input, inputType) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SetChannelToken provides a mock function with given fields: _a0 -func (_m *IWebSocketChannel) SetChannelToken(_a0 string) { - _m.Called(_a0) -} - -// SetOnError provides a mock function with given fields: onErrorHandler -func (_m *IWebSocketChannel) SetOnError(onErrorHandler func(error)) { - _m.Called(onErrorHandler) -} - -// SetOnMessage provides a mock function with given fields: onMessageHandler -func (_m *IWebSocketChannel) SetOnMessage(onMessageHandler func([]byte)) { - _m.Called(onMessageHandler) -} - -// StartPings provides a mock function with given fields: _a0, pingInterval -func (_m *IWebSocketChannel) StartPings(_a0 log.T, pingInterval time.Duration) { - _m.Called(_a0, pingInterval) -} diff --git a/pkg/session-manager-plugin/communicator/websocketchannel.go b/pkg/session-manager-plugin/communicator/websocketchannel.go deleted file mode 100644 index 085d735..0000000 --- a/pkg/session-manager-plugin/communicator/websocketchannel.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// this package implement base communicator for network connections. -package communicator - -import ( - "errors" - "sync" - "time" - - "github.com/gorilla/websocket" - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/websocketutil" -) - -// IWebSocketChannel is the interface for DataChannel. -type IWebSocketChannel interface { - Initialize(log log.T, channelUrl string, channelToken string) - Open(log log.T) error - Close(log log.T) error - SendMessage(log log.T, input []byte, inputType int) error - StartPings(log log.T, pingInterval time.Duration) - GetChannelToken() string - GetStreamUrl() string - SetChannelToken(string) - SetOnError(onErrorHandler func(error)) - SetOnMessage(onMessageHandler func([]byte)) -} - -// WebSocketChannel parent class for DataChannel. -type WebSocketChannel struct { - IWebSocketChannel - Url string - OnMessage func([]byte) - OnError func(error) - IsOpen bool - writeLock *sync.Mutex - Connection *websocket.Conn - ChannelToken string -} - -// GetChannelToken gets the channel token -func (webSocketChannel *WebSocketChannel) GetChannelToken() string { - return webSocketChannel.ChannelToken -} - -// SetChannelToken sets the channel token -func (webSocketChannel *WebSocketChannel) SetChannelToken(channelToken string) { - webSocketChannel.ChannelToken = channelToken -} - -// GetStreamUrl gets stream url -func (webSocketChannel *WebSocketChannel) GetStreamUrl() string { - return webSocketChannel.Url -} - -// SetOnError sets OnError field of websocket channel -func (webSocketChannel *WebSocketChannel) SetOnError(onErrorHandler func(error)) { - webSocketChannel.OnError = onErrorHandler -} - -// SetOnMessage sets OnMessage field of websocket channel -func (webSocketChannel *WebSocketChannel) SetOnMessage(onMessageHandler func([]byte)) { - webSocketChannel.OnMessage = onMessageHandler -} - -// Initialize initializes websocket channel fields -func (webSocketChannel *WebSocketChannel) Initialize(log log.T, channelUrl string, channelToken string) { - webSocketChannel.ChannelToken = channelToken - webSocketChannel.Url = channelUrl -} - -// StartPings starts the pinging process to keep the websocket channel alive. -func (webSocketChannel *WebSocketChannel) StartPings(log log.T, pingInterval time.Duration) { - - go func() { - for { - if webSocketChannel.IsOpen == false { - return - } - - log.Debug("WebsocketChannel: Send ping. Message.") - webSocketChannel.writeLock.Lock() - err := webSocketChannel.Connection.WriteMessage(websocket.PingMessage, []byte("keepalive")) - webSocketChannel.writeLock.Unlock() - if err != nil { - log.Errorf("Error while sending websocket ping: %v", err) - return - } - time.Sleep(pingInterval) - } - }() -} - -// SendMessage sends a byte message through the websocket connection. -// Examples of message type are websocket.TextMessage or websocket.Binary -func (webSocketChannel *WebSocketChannel) SendMessage(log log.T, input []byte, inputType int) error { - if webSocketChannel.IsOpen == false { - return errors.New("Can't send message: Connection is closed.") - } - - if len(input) < 1 { - return errors.New("Can't send message: Empty input.") - } - - webSocketChannel.writeLock.Lock() - err := webSocketChannel.Connection.WriteMessage(inputType, input) - webSocketChannel.writeLock.Unlock() - return err -} - -// Close closes the corresponding connection. -func (webSocketChannel *WebSocketChannel) Close(log log.T) error { - - log.Info("Closing websocket channel connection to: " + webSocketChannel.Url) - if webSocketChannel.IsOpen == true { - // Send signal to stop receiving message - webSocketChannel.IsOpen = false - return websocketutil.NewWebsocketUtil(log, nil).CloseConnection(webSocketChannel.Connection) - } - - log.Info("Websocket channel connection to: " + webSocketChannel.Url + " is already Closed!") - return nil -} - -// Open upgrades the http connection to a websocket connection. -func (webSocketChannel *WebSocketChannel) Open(log log.T) error { - // initialize the write mutex - webSocketChannel.writeLock = &sync.Mutex{} - - ws, err := websocketutil.NewWebsocketUtil(log, nil).OpenConnection(webSocketChannel.Url) - if err != nil { - return err - } - webSocketChannel.Connection = ws - webSocketChannel.IsOpen = true - webSocketChannel.StartPings(log, config.PingTimeInterval) - - // spin up a different routine to listen to the incoming traffic - go func() { - defer func() { - if msg := recover(); msg != nil { - log.Errorf("WebsocketChannel listener run panic: %v", msg) - } - }() - - retryCount := 0 - for { - if webSocketChannel.IsOpen == false { - log.Debugf("Ending the channel listening routine since the channel is closed: %s", - webSocketChannel.Url) - break - } - - messageType, rawMessage, err := webSocketChannel.Connection.ReadMessage() - if err != nil { - retryCount++ - if retryCount >= config.RetryAttempt { - log.Errorf("Reach the retry limit %v for receive messages.", config.RetryAttempt) - webSocketChannel.OnError(err) - break - } - log.Debugf("An error happened when receiving the message. Retried times: %v, Error: %v, Messagetype: %v", - retryCount, - err.Error(), - messageType) - } else if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage { - // We only accept text messages which are interpreted as UTF-8 or binary encoded text. - log.Errorf("Invalid message type. We only accept UTF-8 or binary encoded text. Message type: %v", messageType) - - } else { - retryCount = 0 - webSocketChannel.OnMessage(rawMessage) - } - } - }() - return nil -} diff --git a/pkg/session-manager-plugin/config/config.go b/pkg/session-manager-plugin/config/config.go deleted file mode 100644 index 08a2454..0000000 --- a/pkg/session-manager-plugin/config/config.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// config package implement configuration retrieval for session manager apis -package config - -import "time" - -const ( - RolePublishSubscribe = "publish_subscribe" - MessageSchemaVersion = "1.0" - DefaultTransmissionTimeout = 200 * time.Millisecond - DefaultRoundTripTime = 100 * time.Millisecond - DefaultRoundTripTimeVariation = 0 - ResendSleepInterval = 100 * time.Millisecond - ResendMaxAttempt = 3000 // 5 minutes / ResendSleepInterval - StreamDataPayloadSize = 1024 - OutgoingMessageBufferCapacity = 10000 - IncomingMessageBufferCapacity = 10000 - RTTConstant = 1.0 / 8.0 // Round trip time constant - RTTVConstant = 1.0 / 4.0 // Round trip time variation constant - ClockGranularity = 10 * time.Millisecond - MaxTransmissionTimeout = 1 * time.Second - RetryBase = 2 - DataChannelNumMaxRetries = 5 - DataChannelRetryInitialDelayMillis = 100 - DataChannelRetryMaxIntervalMillis = 5000 - RetryAttempt = 5 - PingTimeInterval = 5 * time.Minute - - // Plugin names - ShellPluginName = "Standard_Stream" - PortPluginName = "Port" - InteractiveCommandsPluginName = "InteractiveCommands" - NonInteractiveCommandsPluginName = "NonInteractiveCommands" - - //Agent Versions - TerminateSessionFlagSupportedAfterThisAgentVersion = "2.3.722.0" - TCPMultiplexingSupportedAfterThisAgentVersion = "3.0.196.0" - TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion = "3.1.1511.0" -) diff --git a/pkg/session-manager-plugin/datachannel/mocks/IDataChannel.go b/pkg/session-manager-plugin/datachannel/mocks/IDataChannel.go deleted file mode 100644 index c9b14ea..0000000 --- a/pkg/session-manager-plugin/datachannel/mocks/IDataChannel.go +++ /dev/null @@ -1,339 +0,0 @@ -// Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -// Code generated by mockery 2.7.4. DO NOT EDIT. -package mocks - -import ( - list "container/list" - - communicator "github.com/ruelala/arconn/pkg/session-manager-plugin/communicator" - datachannel "github.com/ruelala/arconn/pkg/session-manager-plugin/datachannel" - log "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - message "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - mock "github.com/stretchr/testify/mock" -) - -// IDataChannel is an autogenerated mock type for the IDataChannel type -type IDataChannel struct { - mock.Mock -} - -// AddDataToIncomingMessageBuffer provides a mock function with given fields: streamMessage -func (_m *IDataChannel) AddDataToIncomingMessageBuffer(streamMessage datachannel.StreamingMessage) { - _m.Called(streamMessage) -} - -// AddDataToOutgoingMessageBuffer provides a mock function with given fields: streamMessage -func (_m *IDataChannel) AddDataToOutgoingMessageBuffer(streamMessage datachannel.StreamingMessage) { - _m.Called(streamMessage) -} - -// CalculateRetransmissionTimeout provides a mock function with given fields: _a0, streamingMessage -func (_m *IDataChannel) CalculateRetransmissionTimeout(_a0 log.T, streamingMessage datachannel.StreamingMessage) { - _m.Called(_a0, streamingMessage) -} - -// Close provides a mock function with given fields: _a0 -func (_m *IDataChannel) Close(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DeregisterOutputStreamHandler provides a mock function with given fields: handler -func (_m *IDataChannel) DeregisterOutputStreamHandler(handler datachannel.OutputStreamDataMessageHandler) { - _m.Called(handler) -} - -// FinalizeDataChannelHandshake provides a mock function with given fields: _a0, tokenValue -func (_m *IDataChannel) FinalizeDataChannelHandshake(_a0 log.T, tokenValue string) error { - ret := _m.Called(_a0, tokenValue) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, string) error); ok { - r0 = rf(_a0, tokenValue) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// GetAgentVersion provides a mock function with given fields: -func (_m *IDataChannel) GetAgentVersion() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -// GetSessionProperties provides a mock function with given fields: -func (_m *IDataChannel) GetSessionProperties() interface{} { - ret := _m.Called() - - var r0 interface{} - if rf, ok := ret.Get(0).(func() interface{}); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) - } - } - - return r0 -} - -// GetSessionType provides a mock function with given fields: -func (_m *IDataChannel) GetSessionType() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -// GetStreamDataSequenceNumber provides a mock function with given fields: -func (_m *IDataChannel) GetStreamDataSequenceNumber() int64 { - ret := _m.Called() - - var r0 int64 - if rf, ok := ret.Get(0).(func() int64); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(int64) - } - - return r0 -} - -// GetWsChannel provides a mock function with given fields: -func (_m *IDataChannel) GetWsChannel() communicator.IWebSocketChannel { - ret := _m.Called() - - var r0 communicator.IWebSocketChannel - if rf, ok := ret.Get(0).(func() communicator.IWebSocketChannel); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(communicator.IWebSocketChannel) - } - } - - return r0 -} - -// Initialize provides a mock function with given fields: _a0, clientId, sessionId, targetId, isAwsCliUpgradeNeeded -func (_m *IDataChannel) Initialize(_a0 log.T, clientId string, sessionId string, targetId string, isAwsCliUpgradeNeeded bool) { - _m.Called(_a0, clientId, sessionId, targetId, isAwsCliUpgradeNeeded) -} - -// IsSessionTypeSet provides a mock function with given fields: -func (_m *IDataChannel) IsSessionTypeSet() chan bool { - ret := _m.Called() - - var r0 chan bool - if rf, ok := ret.Get(0).(func() chan bool); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(chan bool) - } - } - - return r0 -} - -// IsStreamMessageResendTimeout checks if resending a streaming message reaches timeout -func (_m *IDataChannel) IsStreamMessageResendTimeout() chan bool { - ret := _m.Called() - - var r0 chan bool - if rf, ok := ret.Get(0).(func() chan bool); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(chan bool) - } - } - - return r0 -} - -// Open provides a mock function with given fields: _a0 -func (_m *IDataChannel) Open(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// OutputMessageHandler provides a mock function with given fields: _a0, stopHandler, sessionID, rawMessage -func (_m *IDataChannel) OutputMessageHandler(_a0 log.T, stopHandler datachannel.Stop, sessionID string, rawMessage []byte) error { - ret := _m.Called(_a0, stopHandler, sessionID, rawMessage) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, datachannel.Stop, string, []byte) error); ok { - r0 = rf(_a0, stopHandler, sessionID, rawMessage) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ProcessAcknowledgedMessage provides a mock function with given fields: _a0, acknowledgeMessageContent -func (_m *IDataChannel) ProcessAcknowledgedMessage(_a0 log.T, acknowledgeMessageContent message.AcknowledgeContent) error { - ret := _m.Called(_a0, acknowledgeMessageContent) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, message.AcknowledgeContent) error); ok { - r0 = rf(_a0, acknowledgeMessageContent) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Reconnect provides a mock function with given fields: _a0 -func (_m *IDataChannel) Reconnect(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// RegisterOutputStreamHandler provides a mock function with given fields: handler, isSessionSpecificHandler -func (_m *IDataChannel) RegisterOutputStreamHandler(handler datachannel.OutputStreamDataMessageHandler, isSessionSpecificHandler bool) { - _m.Called(handler, isSessionSpecificHandler) -} - -// RemoveDataFromIncomingMessageBuffer provides a mock function with given fields: sequenceNumber -func (_m *IDataChannel) RemoveDataFromIncomingMessageBuffer(sequenceNumber int64) { - _m.Called(sequenceNumber) -} - -// RemoveDataFromOutgoingMessageBuffer provides a mock function with given fields: streamMessageElement -func (_m *IDataChannel) RemoveDataFromOutgoingMessageBuffer(streamMessageElement *list.Element) { - _m.Called(streamMessageElement) -} - -// ResendStreamDataMessageScheduler provides a mock function with given fields: _a0 -func (_m *IDataChannel) ResendStreamDataMessageScheduler(_a0 log.T) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendAcknowledgeMessage provides a mock function with given fields: _a0, clientMessage -func (_m *IDataChannel) SendAcknowledgeMessage(_a0 log.T, clientMessage message.ClientMessage) error { - ret := _m.Called(_a0, clientMessage) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, message.ClientMessage) error); ok { - r0 = rf(_a0, clientMessage) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendFlag provides a mock function with given fields: _a0, flagType -func (_m *IDataChannel) SendFlag(_a0 log.T, flagType message.PayloadTypeFlag) error { - ret := _m.Called(_a0, flagType) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, message.PayloadTypeFlag) error); ok { - r0 = rf(_a0, flagType) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendInputDataMessage provides a mock function with given fields: _a0, payloadType, inputData -func (_m *IDataChannel) SendInputDataMessage(_a0 log.T, payloadType message.PayloadType, inputData []byte) error { - ret := _m.Called(_a0, payloadType, inputData) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, message.PayloadType, []byte) error); ok { - r0 = rf(_a0, payloadType, inputData) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendMessage provides a mock function with given fields: _a0, input, inputType -func (_m *IDataChannel) SendMessage(_a0 log.T, input []byte, inputType int) error { - ret := _m.Called(_a0, input, inputType) - - var r0 error - if rf, ok := ret.Get(0).(func(log.T, []byte, int) error); ok { - r0 = rf(_a0, input, inputType) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SetAgentVersion provides a mock function with given fields: agentVersion -func (_m *IDataChannel) SetAgentVersion(agentVersion string) { - _m.Called(agentVersion) -} - -// SetSessionType provides a mock function with given fields: sessionType -func (_m *IDataChannel) SetSessionType(sessionType string) { - _m.Called(sessionType) -} - -// SetWebsocket provides a mock function with given fields: _a0, streamUrl, tokenValue -func (_m *IDataChannel) SetWebsocket(_a0 log.T, streamUrl string, tokenValue string) { - _m.Called(_a0, streamUrl, tokenValue) -} - -// SetWsChannel provides a mock function with given fields: wsChannel -func (_m *IDataChannel) SetWsChannel(wsChannel communicator.IWebSocketChannel) { - _m.Called(wsChannel) -} diff --git a/pkg/session-manager-plugin/datachannel/streaming.go b/pkg/session-manager-plugin/datachannel/streaming.go deleted file mode 100644 index 88ad723..0000000 --- a/pkg/session-manager-plugin/datachannel/streaming.go +++ /dev/null @@ -1,955 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// datachannel package implement data channel for interactive sessions. -package datachannel - -import ( - "bytes" - "container/list" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "math" - "os" - "reflect" - "sync" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/gorilla/websocket" - "github.com/ruelala/arconn/pkg/session-manager-plugin/communicator" - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/encryption" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/service" - "github.com/ruelala/arconn/pkg/session-manager-plugin/version" - "github.com/twinj/uuid" -) - -type IDataChannel interface { - Initialize(log log.T, clientId string, sessionId string, targetId string, isAwsCliUpgradeNeeded bool) - SetWebsocket(log log.T, streamUrl string, tokenValue string) - Reconnect(log log.T) error - SendFlag(log log.T, flagType message.PayloadTypeFlag) error - Open(log log.T) error - Close(log log.T) error - FinalizeDataChannelHandshake(log log.T, tokenValue string) error - SendInputDataMessage(log log.T, payloadType message.PayloadType, inputData []byte) error - ResendStreamDataMessageScheduler(log log.T) error - ProcessAcknowledgedMessage(log log.T, acknowledgeMessageContent message.AcknowledgeContent) error - OutputMessageHandler(log log.T, stopHandler Stop, sessionID string, rawMessage []byte) error - SendAcknowledgeMessage(log log.T, clientMessage message.ClientMessage) error - AddDataToOutgoingMessageBuffer(streamMessage StreamingMessage) - RemoveDataFromOutgoingMessageBuffer(streamMessageElement *list.Element) - AddDataToIncomingMessageBuffer(streamMessage StreamingMessage) - RemoveDataFromIncomingMessageBuffer(sequenceNumber int64) - CalculateRetransmissionTimeout(log log.T, streamingMessage StreamingMessage) - SendMessage(log log.T, input []byte, inputType int) error - RegisterOutputStreamHandler(handler OutputStreamDataMessageHandler, isSessionSpecificHandler bool) - DeregisterOutputStreamHandler(handler OutputStreamDataMessageHandler) - IsSessionTypeSet() chan bool - EndSession() error - IsSessionEnded() bool - IsStreamMessageResendTimeout() chan bool - GetSessionType() string - SetSessionType(sessionType string) - GetSessionProperties() interface{} - GetWsChannel() communicator.IWebSocketChannel - SetWsChannel(wsChannel communicator.IWebSocketChannel) - GetStreamDataSequenceNumber() int64 - GetAgentVersion() string - SetAgentVersion(agentVersion string) -} - -// DataChannel used for communication between the mgs and the cli. -type DataChannel struct { - wsChannel communicator.IWebSocketChannel - Role string - ClientId string - SessionId string - TargetId string - IsAwsCliUpgradeNeeded bool - //records sequence number of last acknowledged message received over data channel - ExpectedSequenceNumber int64 - //records sequence number of last stream data message sent over data channel - StreamDataSequenceNumber int64 - //buffer to store outgoing stream messages until acknowledged - //using linked list for this buffer as access to oldest message is required and it support faster deletion from any position of list - OutgoingMessageBuffer ListMessageBuffer - //buffer to store incoming stream messages if received out of sequence - //using map for this buffer as incoming messages can be out of order and retrieval would be faster by sequenceId - IncomingMessageBuffer MapMessageBuffer - //round trip time of latest acknowledged message - RoundTripTime float64 - //round trip time variation of latest acknowledged message - RoundTripTimeVariation float64 - //timeout used for resending unacknowledged message - RetransmissionTimeout time.Duration - // Encrypter to encrypt/decrypt if agent requests encryption - encryption encryption.IEncrypter - encryptionEnabled bool - - // SessionType - sessionType string - isSessionTypeSet chan bool - sessionProperties interface{} - - isSessionEnded bool - - // Used to detect if resending a streaming message reaches timeout - isStreamMessageResendTimeout chan bool - - // Handles data on output stream. Output stream is data outputted by the SSM agent and received here. - outputStreamHandlers []OutputStreamDataMessageHandler - isSessionSpecificHandlerSet bool - - // AgentVersion received during handshake - agentVersion string -} - -type ListMessageBuffer struct { - Messages *list.List - Capacity int - Mutex *sync.Mutex -} - -type MapMessageBuffer struct { - Messages map[int64]StreamingMessage - Capacity int - Mutex *sync.Mutex -} - -type StreamingMessage struct { - Content []byte - SequenceNumber int64 - LastSentTime time.Time - ResendAttempt *int -} - -type OutputStreamDataMessageHandler func(log log.T, streamDataMessage message.ClientMessage) (bool, error) - -type Stop func() - -var SendAcknowledgeMessageCall = func(log log.T, dataChannel *DataChannel, streamDataMessage message.ClientMessage) error { - return dataChannel.SendAcknowledgeMessage(log, streamDataMessage) -} - -var ProcessAcknowledgedMessageCall = func(log log.T, dataChannel *DataChannel, acknowledgeMessage message.AcknowledgeContent) error { - return dataChannel.ProcessAcknowledgedMessage(log, acknowledgeMessage) -} - -var SendMessageCall = func(log log.T, dataChannel *DataChannel, input []byte, inputType int) error { - return dataChannel.SendMessage(log, input, inputType) -} - -var GetRoundTripTime = func(streamingMessage StreamingMessage) time.Duration { - return time.Since(streamingMessage.LastSentTime) -} - -var newEncrypter = func(log log.T, kmsKeyId string, encryptionConext map[string]*string, kmsService kmsiface.KMSAPI) (encryption.IEncrypter, error) { - return encryption.NewEncrypter(log, kmsKeyId, encryptionConext, kmsService) -} - -// Initialize populates the data channel object with the correct values. -func (dataChannel *DataChannel) Initialize(log log.T, clientId string, sessionId string, targetId string, isAwsCliUpgradeNeeded bool) { - //open data channel as publish_subscribe - log.Debugf("Calling Initialize Datachannel for role: %s", config.RolePublishSubscribe) - - dataChannel.Role = config.RolePublishSubscribe - dataChannel.ClientId = clientId - dataChannel.SessionId = sessionId - dataChannel.TargetId = targetId - dataChannel.ExpectedSequenceNumber = 0 - dataChannel.StreamDataSequenceNumber = 0 - dataChannel.OutgoingMessageBuffer = ListMessageBuffer{ - list.New(), - config.OutgoingMessageBufferCapacity, - &sync.Mutex{}, - } - dataChannel.IncomingMessageBuffer = MapMessageBuffer{ - make(map[int64]StreamingMessage), - config.IncomingMessageBufferCapacity, - &sync.Mutex{}, - } - dataChannel.RoundTripTime = float64(config.DefaultRoundTripTime) - dataChannel.RoundTripTimeVariation = config.DefaultRoundTripTimeVariation - dataChannel.RetransmissionTimeout = config.DefaultTransmissionTimeout - dataChannel.wsChannel = &communicator.WebSocketChannel{} - dataChannel.encryptionEnabled = false - dataChannel.isSessionTypeSet = make(chan bool, 1) - dataChannel.isSessionEnded = false - dataChannel.isStreamMessageResendTimeout = make(chan bool, 1) - dataChannel.sessionType = "" - dataChannel.IsAwsCliUpgradeNeeded = isAwsCliUpgradeNeeded -} - -// SetWebsocket function populates websocket channel object -func (dataChannel *DataChannel) SetWebsocket(log log.T, channelUrl string, channelToken string) { - dataChannel.wsChannel.Initialize(log, channelUrl, channelToken) -} - -// FinalizeHandshake sends the token for service to acknowledge the connection. -func (dataChannel *DataChannel) FinalizeDataChannelHandshake(log log.T, tokenValue string) (err error) { - uuid.SwitchFormat(uuid.FormatCanonical) - uid := uuid.NewV4().String() - - log.Infof("Sending token through data channel %s to acknowledge connection", dataChannel.wsChannel.GetStreamUrl()) - openDataChannelInput := service.OpenDataChannelInput{ - MessageSchemaVersion: aws.String(config.MessageSchemaVersion), - RequestId: aws.String(uid), - TokenValue: aws.String(tokenValue), - ClientId: aws.String(dataChannel.ClientId), - } - - var openDataChannelInputBytes []byte - - if openDataChannelInputBytes, err = json.Marshal(openDataChannelInput); err != nil { - log.Errorf("Error serializing openDataChannelInput: %s", err) - return - } - return dataChannel.SendMessage(log, openDataChannelInputBytes, websocket.TextMessage) -} - -// SendMessage sends a message to the service through datachannel -func (dataChannel *DataChannel) SendMessage(log log.T, input []byte, inputType int) error { - return dataChannel.wsChannel.SendMessage(log, input, inputType) -} - -// Open opens websocket connects and does final handshake to acknowledge connection -func (dataChannel *DataChannel) Open(log log.T) (err error) { - if err = dataChannel.wsChannel.Open(log); err != nil { - return fmt.Errorf("failed to open data channel with error: %v", err) - } - - if err = dataChannel.FinalizeDataChannelHandshake(log, dataChannel.wsChannel.GetChannelToken()); err != nil { - return fmt.Errorf("error sending token for handshake: %v", err) - } - return -} - -// Close closes datachannel - its web socket connection -func (dataChannel *DataChannel) Close(log log.T) error { - log.Infof("Closing datachannel with url %s", dataChannel.wsChannel.GetStreamUrl()) - return dataChannel.wsChannel.Close(log) -} - -// Reconnect calls ResumeSession API to reconnect datachannel when connection is lost -func (dataChannel *DataChannel) Reconnect(log log.T) (err error) { - - if err = dataChannel.Close(log); err != nil { - log.Debugf("Closing datachannel failed with error: %v", err) - } - - if err = dataChannel.Open(log); err != nil { - return fmt.Errorf("failed to reconnect data channel %s with error: %v", dataChannel.wsChannel.GetStreamUrl(), err) - } - - log.Infof("Successfully reconnected to data channel: %s", dataChannel.wsChannel.GetStreamUrl()) - return -} - -// SendFlag sends a data message with PayloadType as given flag. -func (dataChannel *DataChannel) SendFlag( - log log.T, - flagType message.PayloadTypeFlag) (err error) { - - flagBuf := new(bytes.Buffer) - binary.Write(flagBuf, binary.BigEndian, flagType) - return dataChannel.SendInputDataMessage(log, message.Flag, flagBuf.Bytes()) -} - -// SendInputDataMessage sends a data message in a form of ClientMessage. -func (dataChannel *DataChannel) SendInputDataMessage( - log log.T, - payloadType message.PayloadType, - inputData []byte) (err error) { - - var ( - flag uint64 = 0 - msg []byte - ) - - messageId := uuid.NewV4() - - // today 'enter' is taken as 'next line' in winpty shell. so hardcoding 'next line' byte to actual 'enter' byte - if bytes.Equal(inputData, []byte{10}) { - inputData = []byte{13} - } - - // Encrypt if encryption is enabled and payload type is Output - if dataChannel.encryptionEnabled && payloadType == message.Output { - inputData, err = dataChannel.encryption.Encrypt(log, inputData) - if err != nil { - return err - } - } - - clientMessage := message.ClientMessage{ - MessageType: message.InputStreamMessage, - SchemaVersion: 1, - CreatedDate: uint64(time.Now().UnixNano() / 1000000), - Flags: flag, - MessageId: messageId, - PayloadType: uint32(payloadType), - Payload: inputData, - SequenceNumber: dataChannel.StreamDataSequenceNumber, - } - - if msg, err = clientMessage.SerializeClientMessage(log); err != nil { - log.Errorf("Cannot serialize StreamData message with error: %v", err) - return - } - - log.Tracef("Sending message with seq number: %d", dataChannel.StreamDataSequenceNumber) - if err = SendMessageCall(log, dataChannel, msg, websocket.BinaryMessage); err != nil { - log.Errorf("Error sending stream data message %v", err) - return - } - - streamingMessage := StreamingMessage{ - msg, - dataChannel.StreamDataSequenceNumber, - time.Now(), - new(int), - } - dataChannel.AddDataToOutgoingMessageBuffer(streamingMessage) - dataChannel.StreamDataSequenceNumber = dataChannel.StreamDataSequenceNumber + 1 - - return -} - -// ResendStreamDataMessageScheduler spawns a separate go thread which keeps checking OutgoingMessageBuffer at fixed interval -// and resends first message if time elapsed since lastSentTime of the message is more than acknowledge wait time -func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err error) { - go func() { - for { - time.Sleep(config.ResendSleepInterval) - dataChannel.OutgoingMessageBuffer.Mutex.Lock() - streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front() - dataChannel.OutgoingMessageBuffer.Mutex.Unlock() - - if streamMessageElement == nil { - continue - } - - streamMessage := streamMessageElement.Value.(StreamingMessage) - if time.Since(streamMessage.LastSentTime) > dataChannel.RetransmissionTimeout { - log.Debugf("Resend stream data message %d for the %d attempt.", streamMessage.SequenceNumber, *streamMessage.ResendAttempt) - if *streamMessage.ResendAttempt >= config.ResendMaxAttempt { - log.Warnf("Message %d was resent over %d times.", streamMessage.SequenceNumber, config.ResendMaxAttempt) - dataChannel.isStreamMessageResendTimeout <- true - } - *streamMessage.ResendAttempt++ - if err = SendMessageCall(log, dataChannel, streamMessage.Content, websocket.BinaryMessage); err != nil { - log.Errorf("Unable to send stream data message: %s", err) - } - streamMessage.LastSentTime = time.Now() - } - } - }() - - return -} - -// ProcessAcknowledgedMessage processes acknowledge messages by deleting them from OutgoingMessageBuffer -func (dataChannel *DataChannel) ProcessAcknowledgedMessage(log log.T, acknowledgeMessageContent message.AcknowledgeContent) error { - acknowledgeSequenceNumber := acknowledgeMessageContent.SequenceNumber - for streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front(); streamMessageElement != nil; streamMessageElement = streamMessageElement.Next() { - streamMessage := streamMessageElement.Value.(StreamingMessage) - if streamMessage.SequenceNumber == acknowledgeSequenceNumber { - - //Calculate retransmission timeout based on latest round trip time of message - dataChannel.CalculateRetransmissionTimeout(log, streamMessage) - - dataChannel.RemoveDataFromOutgoingMessageBuffer(streamMessageElement) - break - } - } - return nil -} - -// SendAcknowledgeMessage sends acknowledge message for stream data over data channel -func (dataChannel *DataChannel) SendAcknowledgeMessage(log log.T, streamDataMessage message.ClientMessage) (err error) { - dataStreamAcknowledgeContent := message.AcknowledgeContent{ - MessageType: streamDataMessage.MessageType, - MessageId: streamDataMessage.MessageId.String(), - SequenceNumber: streamDataMessage.SequenceNumber, - IsSequentialMessage: true, - } - - var msg []byte - if msg, err = message.SerializeClientMessageWithAcknowledgeContent(log, dataStreamAcknowledgeContent); err != nil { - log.Errorf("Cannot serialize Acknowledge message err: %v", err) - return - } - - if err = SendMessageCall(log, dataChannel, msg, websocket.BinaryMessage); err != nil { - log.Errorf("Error sending acknowledge message %v", err) - return - } - return -} - -// OutputMessageHandler gets output on the data channel -func (dataChannel *DataChannel) OutputMessageHandler(log log.T, stopHandler Stop, sessionID string, rawMessage []byte) error { - outputMessage := &message.ClientMessage{} - err := outputMessage.DeserializeClientMessage(log, rawMessage) - if err != nil { - log.Errorf("Cannot deserialize raw message: %s, err: %v.", string(rawMessage), err) - return err - } - if err = outputMessage.Validate(); err != nil { - log.Errorf("Invalid outputMessage: %v, err: %v.", *outputMessage, err) - return err - } - - log.Tracef("Processing stream data message of type: %s", outputMessage.MessageType) - switch outputMessage.MessageType { - case message.OutputStreamMessage: - return dataChannel.HandleOutputMessage(log, *outputMessage, rawMessage) - case message.AcknowledgeMessage: - return dataChannel.HandleAcknowledgeMessage(log, *outputMessage) - case message.ChannelClosedMessage: - dataChannel.HandleChannelClosedMessage(log, stopHandler, sessionID, *outputMessage) - case message.StartPublicationMessage, message.PausePublicationMessage: - return nil - default: - log.Warn("Invalid message type received: %s", outputMessage.MessageType) - } - - return nil -} - -// handleHandshakeRequest is the handler for payloads of type HandshakeRequest -func (dataChannel *DataChannel) handleHandshakeRequest(log log.T, clientMessage message.ClientMessage) error { - - handshakeRequest, err := clientMessage.DeserializeHandshakeRequest(log) - if err != nil { - log.Errorf("Deserialize Handshake Request failed: %s", err) - return err - } - - dataChannel.agentVersion = handshakeRequest.AgentVersion - - var errorList []error - var handshakeResponse message.HandshakeResponsePayload - handshakeResponse.ClientVersion = version.Version - handshakeResponse.ProcessedClientActions = []message.ProcessedClientAction{} - for _, action := range handshakeRequest.RequestedClientActions { - processedAction := message.ProcessedClientAction{} - switch action.ActionType { - case message.KMSEncryption: - processedAction.ActionType = action.ActionType - err := dataChannel.ProcessKMSEncryptionHandshakeAction(log, action.ActionParameters) - if err != nil { - processedAction.ActionStatus = message.Failed - processedAction.Error = fmt.Sprintf("Failed to process action %s: %s", - message.KMSEncryption, err) - errorList = append(errorList, err) - } else { - processedAction.ActionStatus = message.Success - processedAction.ActionResult = message.KMSEncryptionResponse{ - KMSCipherTextKey: dataChannel.encryption.GetEncryptedDataKey(), - } - dataChannel.encryptionEnabled = true - } - case message.SessionType: - processedAction.ActionType = action.ActionType - err := dataChannel.ProcessSessionTypeHandshakeAction(action.ActionParameters) - if err != nil { - processedAction.ActionStatus = message.Failed - processedAction.Error = fmt.Sprintf("Failed to process action %s: %s", - message.SessionType, err) - errorList = append(errorList, err) - } else { - processedAction.ActionStatus = message.Success - } - - default: - processedAction.ActionType = action.ActionType - processedAction.ActionResult = message.Unsupported - processedAction.Error = fmt.Sprintf("Unsupported action %s", action.ActionType) - errorList = append(errorList, errors.New(processedAction.Error)) - } - handshakeResponse.ProcessedClientActions = append(handshakeResponse.ProcessedClientActions, processedAction) - } - for _, x := range errorList { - handshakeResponse.Errors = append(handshakeResponse.Errors, x.Error()) - } - err = dataChannel.sendHandshakeResponse(log, handshakeResponse) - return err -} - -// handleHandshakeComplete is the handler for when the payload type is HandshakeComplete. This will trigger -// the plugin to start. -func (dataChannel *DataChannel) handleHandshakeComplete(log log.T, clientMessage message.ClientMessage) error { - var err error - var handshakeComplete message.HandshakeCompletePayload - handshakeComplete, err = clientMessage.DeserializeHandshakeComplete(log) - if err != nil { - return err - } - - // SessionType would be set when handshake request is received - if dataChannel.sessionType != "" { - dataChannel.isSessionTypeSet <- true - } else { - dataChannel.isSessionTypeSet <- false - } - - log.Debugf("Handshake Complete. Handshake time to complete is: %s seconds", - handshakeComplete.HandshakeTimeToComplete.Seconds()) - - if handshakeComplete.CustomerMessage != "" { - fmt.Fprintln(os.Stdout, handshakeComplete.CustomerMessage) - } - - return err -} - -// handleEncryptionChallengeRequest receives EncryptionChallenge and responds. -func (dataChannel *DataChannel) handleEncryptionChallengeRequest(log log.T, clientMessage message.ClientMessage) error { - var err error - var encChallengeReq message.EncryptionChallengeRequest - err = json.Unmarshal(clientMessage.Payload, &encChallengeReq) - if err != nil { - return fmt.Errorf("Could not deserialize rawMessage, %s : %s", clientMessage.Payload, err) - } - challenge := encChallengeReq.Challenge - challenge, err = dataChannel.encryption.Decrypt(log, challenge) - if err != nil { - return err - } - challenge, err = dataChannel.encryption.Encrypt(log, challenge) - if err != nil { - return err - } - encChallengeResp := message.EncryptionChallengeResponse{ - Challenge: challenge, - } - - err = dataChannel.sendEncryptionChallengeResponse(log, encChallengeResp) - return err -} - -// sendEncryptionChallengeResponse sends EncryptionChallengeResponse -func (dataChannel *DataChannel) sendEncryptionChallengeResponse(log log.T, response message.EncryptionChallengeResponse) error { - var resultBytes, err = json.Marshal(response) - if err != nil { - return fmt.Errorf("Could not serialize EncChallengeResponse message: %v, err: %s", response, err) - } - - log.Tracef("Sending EncChallengeResponse message.") - if err := dataChannel.SendInputDataMessage(log, message.EncChallengeResponse, resultBytes); err != nil { - return err - } - return nil - -} - -// sendHandshakeResponse sends HandshakeResponse -func (dataChannel *DataChannel) sendHandshakeResponse(log log.T, response message.HandshakeResponsePayload) error { - - var resultBytes, err = json.Marshal(response) - if err != nil { - log.Errorf("Could not serialize HandshakeResponse message: %v, err: %s", response, err) - } - - log.Tracef("Sending HandshakeResponse message.") - if err := dataChannel.SendInputDataMessage(log, message.HandshakeResponsePayloadType, resultBytes); err != nil { - return err - } - return nil -} - -// RegisterOutputStreamHandler register a handler for messages of type OutputStream. This is usually called by the plugin. -func (dataChannel *DataChannel) RegisterOutputStreamHandler(handler OutputStreamDataMessageHandler, isSessionSpecificHandler bool) { - dataChannel.isSessionSpecificHandlerSet = isSessionSpecificHandler - dataChannel.outputStreamHandlers = append(dataChannel.outputStreamHandlers, handler) -} - -// DeregisterOutputStreamHandler deregisters a handler previously registered using RegisterOutputStreamHandler -func (dataChannel *DataChannel) DeregisterOutputStreamHandler(handler OutputStreamDataMessageHandler) { - // Find and remove "handler" - for i, v := range dataChannel.outputStreamHandlers { - if reflect.ValueOf(v).Pointer() == reflect.ValueOf(handler).Pointer() { - dataChannel.outputStreamHandlers = append(dataChannel.outputStreamHandlers[:i], dataChannel.outputStreamHandlers[i+1:]...) - break - } - } -} - -func (dataChannel *DataChannel) processOutputMessageWithHandlers(log log.T, message message.ClientMessage) (isHandlerReady bool, err error) { - // Return false if sessionType is known but session specific handler is not set - if dataChannel.sessionType != "" && !dataChannel.isSessionSpecificHandlerSet { - return false, nil - } - for _, handler := range dataChannel.outputStreamHandlers { - isHandlerReady, err = handler(log, message) - // Break the processing of message and return if session specific handler is not ready - if err != nil || !isHandlerReady { - break - } - } - return isHandlerReady, err -} - -// handleOutputMessage handles incoming stream data message by processing the payload and updating expectedSequenceNumber -func (dataChannel *DataChannel) HandleOutputMessage( - log log.T, - outputMessage message.ClientMessage, - rawMessage []byte) (err error) { - - // On receiving expected stream data message, send acknowledgement, process it and increment expected sequence number by 1. - // Further process messages from IncomingMessageBuffer - if outputMessage.SequenceNumber == dataChannel.ExpectedSequenceNumber { - - switch message.PayloadType(outputMessage.PayloadType) { - case message.HandshakeRequestPayloadType: - { - if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { - return err - } - - // PayloadType is HandshakeRequest so we call our own handler instead of the provided handler - log.Debugf("Processing HandshakeRequest message %s", outputMessage) - if err = dataChannel.handleHandshakeRequest(log, outputMessage); err != nil { - log.Errorf("Unable to process incoming data payload, MessageType %s, "+ - "PayloadType HandshakeRequestPayloadType, err: %s.", outputMessage.MessageType, err) - return err - } - } - case message.HandshakeCompletePayloadType: - { - if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { - return err - } - - if err = dataChannel.handleHandshakeComplete(log, outputMessage); err != nil { - log.Errorf("Unable to process incoming data payload, MessageType %s, "+ - "PayloadType HandshakeCompletePayloadType, err: %s.", outputMessage.MessageType, err) - return err - } - } - case message.EncChallengeRequest: - { - if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { - return err - } - - if err = dataChannel.handleEncryptionChallengeRequest(log, outputMessage); err != nil { - log.Errorf("Unable to process incoming data payload, MessageType %s, "+ - "PayloadType EncChallengeRequest, err: %s.", outputMessage.MessageType, err) - return err - } - } - default: - - log.Tracef("Process new incoming stream data message. Sequence Number: %d", outputMessage.SequenceNumber) - - // Decrypt if encryption is enabled and payload type is output - if dataChannel.encryptionEnabled && - (outputMessage.PayloadType == uint32(message.Output) || - outputMessage.PayloadType == uint32(message.StdErr) || - outputMessage.PayloadType == uint32(message.ExitCode)) { - outputMessage.Payload, err = dataChannel.encryption.Decrypt(log, outputMessage.Payload) - if err != nil { - log.Errorf("Unable to decrypt incoming data payload, MessageType %s, "+ - "PayloadType %d, err: %s.", outputMessage.MessageType, outputMessage.PayloadType, err) - return err - } - } - - isHandlerReady, err := dataChannel.processOutputMessageWithHandlers(log, outputMessage) - if err != nil { - log.Error("Failed to process stream data message: %s", err.Error()) - return err - } - if !isHandlerReady { - log.Warnf("Stream data message with sequence number %d is not processed as session handler is not ready.", outputMessage.SequenceNumber) - return nil - } else { - // Acknowledge outputMessage only if session specific handler is ready - if err := SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { - return err - } - } - } - dataChannel.ExpectedSequenceNumber = dataChannel.ExpectedSequenceNumber + 1 - return dataChannel.ProcessIncomingMessageBufferItems(log, outputMessage) - } else { - log.Debugf("Unexpected sequence message received. Received Sequence Number: %d. Expected Sequence Number: %d", - outputMessage.SequenceNumber, dataChannel.ExpectedSequenceNumber) - - // If incoming message sequence number is greater then expected sequence number and IncomingMessageBuffer has capacity, - // add message to IncomingMessageBuffer and send acknowledgement - if outputMessage.SequenceNumber > dataChannel.ExpectedSequenceNumber { - log.Debugf("Received Sequence Number %d is higher than Expected Sequence Number %d, adding to IncomingMessageBuffer", - outputMessage.SequenceNumber, dataChannel.ExpectedSequenceNumber) - if len(dataChannel.IncomingMessageBuffer.Messages) < dataChannel.IncomingMessageBuffer.Capacity { - if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { - return err - } - - streamingMessage := StreamingMessage{ - rawMessage, - outputMessage.SequenceNumber, - time.Now(), - new(int), - } - - //Add message to buffer for future processing - dataChannel.AddDataToIncomingMessageBuffer(streamingMessage) - } - } - } - return nil -} - -// processIncomingMessageBufferItems check if new expected sequence stream data is present in IncomingMessageBuffer. -// If so process it and increment expected sequence number. -// Repeat until expected sequence stream data is not found in IncomingMessageBuffer. -func (dataChannel *DataChannel) ProcessIncomingMessageBufferItems(log log.T, - outputMessage message.ClientMessage) (err error) { - - for { - bufferedStreamMessage := dataChannel.IncomingMessageBuffer.Messages[dataChannel.ExpectedSequenceNumber] - if bufferedStreamMessage.Content != nil { - log.Debugf("Process stream data message from IncomingMessageBuffer. "+ - "Sequence Number: %d", bufferedStreamMessage.SequenceNumber) - - if err := outputMessage.DeserializeClientMessage(log, bufferedStreamMessage.Content); err != nil { - log.Errorf("Cannot deserialize raw message with err: %v.", err) - return err - } - - // Decrypt if encryption is enabled and payload type is output - if dataChannel.encryptionEnabled && - (outputMessage.PayloadType == uint32(message.Output) || - outputMessage.PayloadType == uint32(message.StdErr) || - outputMessage.PayloadType == uint32(message.ExitCode)) { - outputMessage.Payload, err = dataChannel.encryption.Decrypt(log, outputMessage.Payload) - if err != nil { - log.Errorf("Unable to decrypt buffered message data payload, MessageType %s, "+ - "PayloadType %d, err: %s.", outputMessage.MessageType, outputMessage.PayloadType, err) - return err - } - } - - dataChannel.processOutputMessageWithHandlers(log, outputMessage) - - dataChannel.ExpectedSequenceNumber = dataChannel.ExpectedSequenceNumber + 1 - dataChannel.RemoveDataFromIncomingMessageBuffer(bufferedStreamMessage.SequenceNumber) - } else { - break - } - } - return -} - -// handleAcknowledgeMessage deserialize acknowledge content and process it -func (dataChannel *DataChannel) HandleAcknowledgeMessage( - log log.T, - outputMessage message.ClientMessage) (err error) { - - var acknowledgeMessage message.AcknowledgeContent - if acknowledgeMessage, err = outputMessage.DeserializeDataStreamAcknowledgeContent(log); err != nil { - log.Errorf("Cannot deserialize payload to AcknowledgeMessage with error: %v.", err) - return err - } - - err = ProcessAcknowledgedMessageCall(log, dataChannel, acknowledgeMessage) - return err -} - -// handleChannelClosedMessage exits the shell -func (dataChannel *DataChannel) HandleChannelClosedMessage(log log.T, stopHandler Stop, sessionId string, outputMessage message.ClientMessage) { - var ( - channelClosedMessage message.ChannelClosed - err error - ) - if channelClosedMessage, err = outputMessage.DeserializeChannelClosedMessage(log); err != nil { - log.Errorf("Cannot deserialize payload to ChannelClosedMessage: %v.", err) - } - - log.Infof("Exiting session with sessionId: %s with output: %s", sessionId, channelClosedMessage.Output) - if channelClosedMessage.Output == "" { - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", sessionId) - } else { - fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output) - } - dataChannel.EndSession() - dataChannel.Close(log) - - stopHandler() -} - -// AddDataToOutgoingMessageBuffer removes first message from OutgoingMessageBuffer if capacity is full and adds given message at the end -func (dataChannel *DataChannel) AddDataToOutgoingMessageBuffer(streamMessage StreamingMessage) { - if dataChannel.OutgoingMessageBuffer.Messages.Len() == dataChannel.OutgoingMessageBuffer.Capacity { - dataChannel.RemoveDataFromOutgoingMessageBuffer(dataChannel.OutgoingMessageBuffer.Messages.Front()) - } - dataChannel.OutgoingMessageBuffer.Mutex.Lock() - dataChannel.OutgoingMessageBuffer.Messages.PushBack(streamMessage) - dataChannel.OutgoingMessageBuffer.Mutex.Unlock() -} - -// RemoveDataFromOutgoingMessageBuffer removes given element from OutgoingMessageBuffer -func (dataChannel *DataChannel) RemoveDataFromOutgoingMessageBuffer(streamMessageElement *list.Element) { - dataChannel.OutgoingMessageBuffer.Mutex.Lock() - dataChannel.OutgoingMessageBuffer.Messages.Remove(streamMessageElement) - dataChannel.OutgoingMessageBuffer.Mutex.Unlock() -} - -// AddDataToIncomingMessageBuffer adds given message to IncomingMessageBuffer if it has capacity -func (dataChannel *DataChannel) AddDataToIncomingMessageBuffer(streamMessage StreamingMessage) { - if len(dataChannel.IncomingMessageBuffer.Messages) == dataChannel.IncomingMessageBuffer.Capacity { - return - } - dataChannel.IncomingMessageBuffer.Mutex.Lock() - dataChannel.IncomingMessageBuffer.Messages[streamMessage.SequenceNumber] = streamMessage - dataChannel.IncomingMessageBuffer.Mutex.Unlock() -} - -// RemoveDataFromIncomingMessageBuffer removes given sequence number message from IncomingMessageBuffer -func (dataChannel *DataChannel) RemoveDataFromIncomingMessageBuffer(sequenceNumber int64) { - dataChannel.IncomingMessageBuffer.Mutex.Lock() - delete(dataChannel.IncomingMessageBuffer.Messages, sequenceNumber) - dataChannel.IncomingMessageBuffer.Mutex.Unlock() -} - -// CalculateRetransmissionTimeout calculates message retransmission timeout value based on round trip time on given message -func (dataChannel *DataChannel) CalculateRetransmissionTimeout(log log.T, streamingMessage StreamingMessage) { - newRoundTripTime := float64(GetRoundTripTime(streamingMessage)) - - dataChannel.RoundTripTimeVariation = ((1 - config.RTTVConstant) * dataChannel.RoundTripTimeVariation) + - (config.RTTVConstant * math.Abs(dataChannel.RoundTripTime-newRoundTripTime)) - - dataChannel.RoundTripTime = ((1 - config.RTTConstant) * dataChannel.RoundTripTime) + - (config.RTTConstant * newRoundTripTime) - - dataChannel.RetransmissionTimeout = time.Duration(dataChannel.RoundTripTime + - math.Max(float64(config.ClockGranularity), float64(4*dataChannel.RoundTripTimeVariation))) - - // Ensure RetransmissionTimeout do not exceed maximum timeout defined - if dataChannel.RetransmissionTimeout > config.MaxTransmissionTimeout { - dataChannel.RetransmissionTimeout = config.MaxTransmissionTimeout - } -} - -// ProcessKMSEncryptionHandshakeAction sets up the encrypter and calls KMS to generate a new data key. This is triggered -// when encryption is specified in HandshakeRequest -func (dataChannel *DataChannel) ProcessKMSEncryptionHandshakeAction(log log.T, actionParams json.RawMessage) (err error) { - - if dataChannel.IsAwsCliUpgradeNeeded { - return errors.New("installed version of CLI does not support Session Manager encryption feature. Please upgrade to the latest version of your CLI (e.g., AWS CLI)") - } - kmsEncRequest := message.KMSEncryptionRequest{} - json.Unmarshal(actionParams, &kmsEncRequest) - log.Info(actionParams) - kmsKeyId := kmsEncRequest.KMSKeyID - - kmsService, err := encryption.NewKMSService(log) - if err != nil { - return fmt.Errorf("error while creating new KMS service, %v", err) - } - - encryptionContext := map[string]*string{"aws:ssm:SessionId": &dataChannel.SessionId, "aws:ssm:TargetId": &dataChannel.TargetId} - dataChannel.encryption, err = newEncrypter(log, kmsKeyId, encryptionContext, kmsService) - return -} - -// ProcessSessionTypeHandshakeAction processes session type action in HandshakeRequest. This sets the session type in the datachannel. -func (dataChannel *DataChannel) ProcessSessionTypeHandshakeAction(actionParams json.RawMessage) (err error) { - sessTypeReq := message.SessionTypeRequest{} - json.Unmarshal(actionParams, &sessTypeReq) - switch sessTypeReq.SessionType { - // This switch-case is just so that we can fail early if an unknown session type is passed in. - case config.ShellPluginName, config.InteractiveCommandsPluginName, config.NonInteractiveCommandsPluginName: - dataChannel.sessionType = config.ShellPluginName - dataChannel.sessionProperties = sessTypeReq.Properties - return nil - case config.PortPluginName: - dataChannel.sessionType = sessTypeReq.SessionType - dataChannel.sessionProperties = sessTypeReq.Properties - return nil - default: - return fmt.Errorf("Unknown session type %s", sessTypeReq.SessionType) - } -} - -// IsSessionTypeSet check has data channel sessionType been set -func (dataChannel *DataChannel) IsSessionTypeSet() chan bool { - return dataChannel.isSessionTypeSet -} - -// IsSessionEnded check if session has ended -func (dataChannel *DataChannel) IsSessionEnded() bool { - return dataChannel.isSessionEnded -} - -// IsSessionEnded check if session has ended -func (dataChannel *DataChannel) EndSession() error { - dataChannel.isSessionEnded = true - return nil -} - -// IsStreamMessageResendTimeout checks if resending a streaming message reaches timeout -func (dataChannel *DataChannel) IsStreamMessageResendTimeout() chan bool { - return dataChannel.isStreamMessageResendTimeout -} - -// SetSessionType set session type -func (dataChannel *DataChannel) SetSessionType(sessionType string) { - dataChannel.sessionType = sessionType - dataChannel.isSessionTypeSet <- true -} - -// GetSessionType returns SessionType of the dataChannel -func (dataChannel *DataChannel) GetSessionType() string { - return dataChannel.sessionType -} - -// GetSessionProperties returns SessionProperties of the dataChannel -func (dataChannel *DataChannel) GetSessionProperties() interface{} { - return dataChannel.sessionProperties -} - -// GetWsChannel returns WsChannel of the dataChannel -func (dataChannel *DataChannel) GetWsChannel() communicator.IWebSocketChannel { - return dataChannel.wsChannel -} - -// SetWsChannel set WsChannel of the dataChannel -func (dataChannel *DataChannel) SetWsChannel(wsChannel communicator.IWebSocketChannel) { - dataChannel.wsChannel = wsChannel -} - -// GetStreamDataSequenceNumber returns StreamDataSequenceNumber of the dataChannel -func (dataChannel *DataChannel) GetStreamDataSequenceNumber() int64 { - return dataChannel.StreamDataSequenceNumber -} - -// GetAgentVersion returns agent version of the target instance -func (dataChannel *DataChannel) GetAgentVersion() string { - return dataChannel.agentVersion -} - -// SetAgentVersion set agent version of the target instance -func (dataChannel *DataChannel) SetAgentVersion(agentVersion string) { - dataChannel.agentVersion = agentVersion -} diff --git a/pkg/session-manager-plugin/encryption/encrypter.go b/pkg/session-manager-plugin/encryption/encrypter.go deleted file mode 100644 index a34e0b8..0000000 --- a/pkg/session-manager-plugin/encryption/encrypter.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package encryption - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "fmt" - "io" - - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" -) - -const ( - nonceSize = 12 -) - -type KMSKeyProvider interface { - GenerateDataKey() -} - -type IEncrypter interface { - Encrypt(log log.T, plainText []byte) (cipherText []byte, err error) - Decrypt(log log.T, cipherText []byte) (plainText []byte, err error) - GetEncryptedDataKey() (ciptherTextBlob []byte) -} - -type Encrypter struct { - KMSService kmsiface.KMSAPI - - kmsKeyId string - cipherTextKey []byte - encryptionKey []byte - decryptionKey []byte -} - -var NewEncrypter = func(log log.T, kmsKeyId string, context map[string]*string, KMSService kmsiface.KMSAPI) (*Encrypter, error) { - encrypter := Encrypter{kmsKeyId: kmsKeyId, KMSService: KMSService} - err := encrypter.generateEncryptionKey(log, kmsKeyId, context) - return &encrypter, err -} - -// generateEncryptionKey calls KMS to generate a new encryption key -func (encrypter *Encrypter) generateEncryptionKey(log log.T, kmsKeyId string, context map[string]*string) error { - cipherTextKey, plainTextKey, err := KMSGenerateDataKey(kmsKeyId, encrypter.KMSService, context) - if err != nil { - log.Errorf("Error generating data key from KMS: %s,", err) - return err - } - keySize := len(plainTextKey) / 2 - encrypter.decryptionKey = plainTextKey[:keySize] - encrypter.encryptionKey = plainTextKey[keySize:] - encrypter.cipherTextKey = cipherTextKey - return nil -} - -// GetEncryptedDataKey returns the cipherText that was pulled from KMS -func (encrypter *Encrypter) GetEncryptedDataKey() (ciptherTextBlob []byte) { - return encrypter.cipherTextKey -} - -// GetKMSKeyId gets the KMS key id that is used to generate the encryption key -func (encrypter *Encrypter) GetKMSKeyId() (kmsKey string) { - return encrypter.kmsKeyId -} - -// getAEAD gets AEAD which is a GCM cipher mode providing authenticated encryption with associated data -func getAEAD(plainTextKey []byte) (aesgcm cipher.AEAD, err error) { - var block cipher.Block - if block, err = aes.NewCipher(plainTextKey); err != nil { - return nil, fmt.Errorf("error creating NewCipher, %v", err) - } - - if aesgcm, err = cipher.NewGCM(block); err != nil { - return nil, fmt.Errorf("error creating NewGCM, %v", err) - } - - return aesgcm, nil -} - -// Encrypt encrypts a byte slice and returns the encrypted slice -func (encrypter *Encrypter) Encrypt(log log.T, plainText []byte) (cipherText []byte, err error) { - var aesgcm cipher.AEAD - - if aesgcm, err = getAEAD(encrypter.encryptionKey); err != nil { - err = fmt.Errorf("%v", err) - return - } - - cipherText = make([]byte, nonceSize+len(plainText)) - nonce := make([]byte, nonceSize) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - err = fmt.Errorf("error when generating nonce for encryption, %v", err) - return - } - - // Encrypt plain text using given key and newly generated nonce - cipherTextWithoutNonce := aesgcm.Seal(nil, nonce, plainText, nil) - - // Append nonce to the beginning of the cipher text to be used while decrypting - cipherText = append(cipherText[:nonceSize], nonce...) - cipherText = append(cipherText[nonceSize:], cipherTextWithoutNonce...) - return cipherText, nil -} - -// Decrypt decrypts a byte slice and returns the decrypted slice -func (encrypter *Encrypter) Decrypt(log log.T, cipherText []byte) (plainText []byte, err error) { - var aesgcm cipher.AEAD - if aesgcm, err = getAEAD(encrypter.decryptionKey); err != nil { - err = fmt.Errorf("%v", err) - return - } - - // Pull the nonce out of the cipherText - nonce := cipherText[:nonceSize] - cipherTextWithoutNonce := cipherText[nonceSize:] - - // Decrypt just the actual cipherText using nonce extracted above - if plainText, err = aesgcm.Open(nil, nonce, cipherTextWithoutNonce, nil); err != nil { - err = fmt.Errorf("error decrypting encrypted test, %v", err) - return - } - return plainText, nil -} diff --git a/pkg/session-manager-plugin/encryption/kmsservice.go b/pkg/session-manager-plugin/encryption/kmsservice.go deleted file mode 100644 index 133c299..0000000 --- a/pkg/session-manager-plugin/encryption/kmsservice.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package encryption - -import ( - "fmt" - - sdkSession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sdkutil" -) - -// KMSKeySizeInBytes is the key size that is fetched from KMS. 64 bytes key is split into two halves. -// First half 32 bytes key is used by agent for encryption and second half 32 bytes by clients like cli/console -const KMSKeySizeInBytes int64 = 64 - -func NewKMSService(log log.T) (kmsService *kms.KMS, err error) { - var session *sdkSession.Session - if session, err = sdkutil.GetDefaultSession(); err != nil { - return nil, err - } - - kmsService = kms.New(session) - return kmsService, nil -} - -func KMSDecrypt(log log.T, svc kmsiface.KMSAPI, ciptherTextBlob []byte, encryptionContext map[string]*string) (plainText []byte, err error) { - output, err := svc.Decrypt(&kms.DecryptInput{ - CiphertextBlob: ciptherTextBlob, - EncryptionContext: encryptionContext}) - if err != nil { - log.Error("Error when decrypting data key", err) - return nil, err - } - return output.Plaintext, nil -} - -// GenerateDataKey gets cipher text and plain text keys from KMS service -func KMSGenerateDataKey(kmsKeyId string, svc kmsiface.KMSAPI, context map[string]*string) (cipherTextKey []byte, plainTextKey []byte, err error) { - kmsKeySize := KMSKeySizeInBytes - generateDataKeyInput := kms.GenerateDataKeyInput{ - KeyId: &kmsKeyId, - NumberOfBytes: &kmsKeySize, - EncryptionContext: context, - } - - var generateDataKeyOutput *kms.GenerateDataKeyOutput - if generateDataKeyOutput, err = svc.GenerateDataKey(&generateDataKeyInput); err != nil { - return nil, nil, fmt.Errorf("Error calling KMS GenerateDataKey API: %s", err) - } - - return generateDataKeyOutput.CiphertextBlob, generateDataKeyOutput.Plaintext, nil -} diff --git a/pkg/session-manager-plugin/encryption/mocks/IEncrypter.go b/pkg/session-manager-plugin/encryption/mocks/IEncrypter.go deleted file mode 100644 index daee0e8..0000000 --- a/pkg/session-manager-plugin/encryption/mocks/IEncrypter.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. - -// Code generated by mockery v1.0.0. DO NOT EDIT. - -package mocks - -import ( - log "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - mock "github.com/stretchr/testify/mock" -) - -// IEncrypter is an autogenerated mock type for the IEncrypter type -type IEncrypter struct { - mock.Mock -} - -// Decrypt provides a mock function with given fields: _a0, cipherText -func (_m *IEncrypter) Decrypt(_a0 log.T, cipherText []byte) ([]byte, error) { - ret := _m.Called(_a0, cipherText) - - var r0 []byte - if rf, ok := ret.Get(0).(func(log.T, []byte) []byte); ok { - r0 = rf(_a0, cipherText) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(log.T, []byte) error); ok { - r1 = rf(_a0, cipherText) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Encrypt provides a mock function with given fields: _a0, plainText -func (_m *IEncrypter) Encrypt(_a0 log.T, plainText []byte) ([]byte, error) { - ret := _m.Called(_a0, plainText) - - var r0 []byte - if rf, ok := ret.Get(0).(func(log.T, []byte) []byte); ok { - r0 = rf(_a0, plainText) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(log.T, []byte) error); ok { - r1 = rf(_a0, plainText) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetEncryptedDataKey provides a mock function with given fields: -func (_m *IEncrypter) GetEncryptedDataKey() []byte { - ret := _m.Called() - - var r0 []byte - if rf, ok := ret.Get(0).(func() []byte); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - return r0 -} diff --git a/pkg/session-manager-plugin/jsonutil/ioutil_deps.go b/pkg/session-manager-plugin/jsonutil/ioutil_deps.go deleted file mode 100644 index a5e815b..0000000 --- a/pkg/session-manager-plugin/jsonutil/ioutil_deps.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package jsonutil contains various utilities for dealing with json data. -package jsonutil - -import "io/ioutil" - -// dependency -var ioUtil ioUtility = ioU{} - -type ioUtility interface { - ReadFile(filename string) ([]byte, error) -} - -type ioU struct{} - -// ioU implements io/ioutil. -func (ioU) ReadFile(filename string) ([]byte, error) { return ioutil.ReadFile(filename) } diff --git a/pkg/session-manager-plugin/jsonutil/jsonutil.go b/pkg/session-manager-plugin/jsonutil/jsonutil.go deleted file mode 100644 index 9b29f0c..0000000 --- a/pkg/session-manager-plugin/jsonutil/jsonutil.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package jsonutil contains various utilities for dealing with json data. -package jsonutil - -import ( - "bytes" - "encoding/json" -) - -// jsonFormat json formatIndent -const jsonFormat = " " - -// Indent indents a json string. -func Indent(jsonStr string) string { - var dst bytes.Buffer - json.Indent(&dst, []byte(jsonStr), "", jsonFormat) - return string(dst.Bytes()) -} - -// Remarshal marshals an object to Json then parses it back to another object. -// This is useful for example when we want to go from map[string]interface{} -// to a more specific struct type or if we want a deep copy of the object. -func Remarshal(obj interface{}, remarshalledObj interface{}) (err error) { - b, err := json.Marshal(obj) - if err != nil { - return - } - err = json.Unmarshal(b, remarshalledObj) - if err != nil { - return - } - return nil -} - -// Marshal marshals an object to a json string. -// Returns empty string if marshal fails. -func Marshal(obj interface{}) (result string, err error) { - var resultB []byte - resultB, err = json.Marshal(obj) - if err != nil { - return - } - result = string(resultB) - return -} - -// UnmarshalFile reads the content of a file then Unmarshals the content to an object. -func UnmarshalFile(filePath string, dest interface{}) (err error) { - content, err := ioUtil.ReadFile(filePath) - if err != nil { - return - } - err = json.Unmarshal(content, dest) - return -} - -// Unmarshal unmarshals the content in string format to an object. -func Unmarshal(jsonContent string, dest interface{}) (err error) { - content := []byte(jsonContent) - err = json.Unmarshal(content, dest) - return -} - -// MarshalIndent is like Marshal but applies Indent to format the output. -// Returns empty string if marshal fails -func MarshalIndent(obj interface{}) (result string, err error) { - var resultsByte []byte - // Make sure the output file keeps formal json format - resultsByte, err = json.MarshalIndent(obj, "", jsonFormat) - if err != nil { - return - } - result = string(resultsByte) - return -} diff --git a/pkg/session-manager-plugin/log/config_watcher.go b/pkg/session-manager-plugin/log/config_watcher.go deleted file mode 100644 index bc11ed2..0000000 --- a/pkg/session-manager-plugin/log/config_watcher.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package log is used to initialize logger -package log - -import ( - "path/filepath" - - "github.com/fsnotify/fsnotify" -) - -// IFileWatcher interface for FileWatcher with functions to initialize, start and stop the watcher -type IFileWatcher interface { - Init(log T, configFilePath string, replaceLogger func()) - Start() - Stop() -} - -// FileWatcher implements the IFileWatcher by using fileChangeWatcher and fileExistsWatcher -type FileWatcher struct { - configFilePath string - replaceLogger func() - log T - watcher *fsnotify.Watcher -} - -// Init initializes the data and channels for the filewatcher -func (fileWatcher *FileWatcher) Init(log T, configFilePath string, replaceLogger func()) { - fileWatcher.replaceLogger = replaceLogger - fileWatcher.configFilePath = configFilePath - fileWatcher.log = log -} - -// Start creates and starts the go routines for filewatcher -func (fileWatcher *FileWatcher) Start() { - - fileWatcher.log.Debugf("Start File Watcher On: %v", fileWatcher.configFilePath) - - // Since the filewatcher fails if the file does not exist, need to watch the parent directory for any changes - dirPath := filepath.Dir(fileWatcher.configFilePath) - fileWatcher.log.Debugf("Start Watcher on directory: %v", dirPath) - - // Creating Watcher - watcher, err := fsnotify.NewWatcher() - if err != nil { - // Error initializing the watcher - fileWatcher.log.Errorf("Error initializing the watcher: %v", err) - return - } - - fileWatcher.watcher = watcher - - // Starting the goroutine for event handler - go fileWatcher.fileEventHandler() - - // Add the directory to watcher - err = fileWatcher.watcher.Add(dirPath) - if err != nil { - // Error adding the file to watcher - fileWatcher.log.Errorf("Error adding the directory to watcher: %v", err) - return - } -} - -// fileEventHandler implements handling of the events triggered by the OS -func (fileWatcher *FileWatcher) fileEventHandler() { - - // Waiting on signals from OS - for event := range fileWatcher.watcher.Events { - // Event signalled by OS on file - fileWatcher.log.Debugf("Event on file %v : %v", event.Name, event) - if event.Name == fileWatcher.configFilePath { - // Event on the file being watched - if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create || event.Op&fsnotify.Rename == fsnotify.Rename { - // One of Write or Create or Rename Event - fileWatcher.log.Debugf("File Watcher Triggers Function Execution: %v", fileWatcher.configFilePath) - // Execute the function - fileWatcher.replaceLogger() - } - } - } -} - -// Stop stops the filewatcher -func (fileWatcher *FileWatcher) Stop() { - fileWatcher.log.Infof("Stop the filewatcher on :%v", fileWatcher.configFilePath) - // Check if watcher instance is set - if fileWatcher.watcher != nil { - err := fileWatcher.watcher.Close() - if err != nil { - // Error closing the filewatcher. Logging the error - fileWatcher.log.Debugf("Error Closing the filewatcher :%v", err) - } - } -} diff --git a/pkg/session-manager-plugin/log/defaultconfig.go b/pkg/session-manager-plugin/log/defaultconfig.go deleted file mode 100644 index 73ec7d9..0000000 --- a/pkg/session-manager-plugin/log/defaultconfig.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package log is used to initialize the logger. -package log - -import ( - "path/filepath" -) - -func DefaultConfig() []byte { - return LoadLog(DefaultLogDir, ApplicationLogFile, ErrorLogFile) -} - -func LoadLog(defaultLogDir string, logFile string, errorFile string) []byte { - var logFilePath, errorFilePath string - - logFilePath = filepath.Join(defaultLogDir, logFile) - errorFilePath = filepath.Join(defaultLogDir, errorFile) - - logConfig := ` - - - - - - ` - logConfig += `` - logConfig += ` - - ` - logConfig += `` - logConfig += ` - - - - - - - - -` - return []byte(logConfig) -} diff --git a/pkg/session-manager-plugin/log/interface.go b/pkg/session-manager-plugin/log/interface.go deleted file mode 100644 index 1346986..0000000 --- a/pkg/session-manager-plugin/log/interface.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package log is used to initialize the logger. -package log - -// BasicT represents structs capable of logging messages. -// This interface matches seelog.LoggerInterface. -type BasicT interface { - // Tracef formats message according to format specifier - // and writes to log with level Trace. - Tracef(format string, params ...interface{}) - - // Debugf formats message according to format specifier - // and writes to log with level Debug. - Debugf(format string, params ...interface{}) - - // Infof formats message according to format specifier - // and writes to log with level Info. - Infof(format string, params ...interface{}) - - // Warnf formats message according to format specifier - // and writes to log with level Warn. - Warnf(format string, params ...interface{}) error - - // Errorf formats message according to format specifier - // and writes to log with level Error. - Errorf(format string, params ...interface{}) error - - // Criticalf formats message according to format specifier - // and writes to log with level Critical. - Criticalf(format string, params ...interface{}) error - - // Trace formats message using the default formats for its operands - // and writes to log with level Trace. - Trace(v ...interface{}) - - // Debug formats message using the default formats for its operands - // and writes to log with level Debug. - Debug(v ...interface{}) - - // Info formats message using the default formats for its operands - // and writes to log with level Info. - Info(v ...interface{}) - - // Warn formats message using the default formats for its operands - // and writes to log with level Warn. - Warn(v ...interface{}) error - - // Error formats message using the default formats for its operands - // and writes to log with level Error. - Error(v ...interface{}) error - - // Critical formats message using the default formats for its operands - // and writes to log with level Critical. - Critical(v ...interface{}) error - - // Flush flushes all the messages in the logger. - Flush() - - // Close flushes all the messages in the logger and closes it. The logger cannot be used after this operation. - Close() -} - -// T represents structs capable of logging messages, and context management. -type T interface { - BasicT - WithContext(context ...string) (contextLogger T) -} diff --git a/pkg/session-manager-plugin/log/log.go b/pkg/session-manager-plugin/log/log.go deleted file mode 100644 index 99aab12..0000000 --- a/pkg/session-manager-plugin/log/log.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package log is used to initialize the logger. -package log - -import ( - "sync" - - "github.com/cihub/seelog" -) - -const ( - LogFileExtension = ".log" - SeelogConfigFileName = "seelog.xml" - ErrorLogFileSuffix = "errors" -) - -var ( - err error - DefaultSeelogConfigFilePath string // DefaultSeelogConfigFilePath specifies the default seelog location - DefaultLogDir string // DefaultLogDir specifies default log location - ApplicationLogFile string // ApplicationLogFile specifies name of application log file - ErrorLogFile string // ErrorLogFile specifies name of error log file - loadedLogger *T - lock sync.RWMutex -) - -// pkgMutex is the lock used to serialize calls to the logger. -var pkgMutex = new(sync.Mutex) - -// loggerInstance is the delegate logger in the wrapper -var loggerInstance = &DelegateLogger{} - -// ContextFormatFilter is a filter that can add a context to the parameters of a log message. -type ContextFormatFilter struct { - Context []string -} - -type LogConfig struct { - ClientName string -} - -// Filter adds the context at the beginning of the parameter slice. -func (f ContextFormatFilter) Filter(params ...interface{}) (newParams []interface{}) { - newParams = make([]interface{}, len(f.Context)+len(params)) - for i, param := range f.Context { - newParams[i] = param + " " - } - ctxLen := len(f.Context) - for i, param := range params { - newParams[ctxLen+i] = param - } - return newParams -} - -// Filterf adds the context in from of the format string. -func (f ContextFormatFilter) Filterf(format string, params ...interface{}) (newFormat string, newParams []interface{}) { - newFormat = "" - for _, param := range f.Context { - newFormat += param + " " - } - newFormat += format - newParams = params - return -} - -// Logger is the starting point to initialize with client name. -func Logger(useWatcher bool, clientName string) T { - logConfig := LogConfig{ - ClientName: clientName, - } - if !isLoaded() { - logger := logConfig.InitLogger(useWatcher) - cache(logger) - } - return getCached() -} - -// initLogger initializes a new logger based on current configurations and starts file watcher on the configurations file -func (config *LogConfig) InitLogger(useWatcher bool) (logger T) { - // Read the current configurations or get the default configurations - logConfigBytes := config.GetLogConfigBytes() - // Initialize the base seelog logger - baseLogger, _ := initBaseLoggerFromBytes(logConfigBytes) - // Create the wrapper logger - logger = withContext(baseLogger) - if useWatcher { - // Start the config file watcher - config.startWatcher(logger) - } - return -} - -// check if a logger has be loaded -func isLoaded() bool { - lock.RLock() - defer lock.RUnlock() - return loadedLogger != nil -} - -// cache the loaded logger -func cache(logger T) { - lock.Lock() - defer lock.Unlock() - loadedLogger = &logger -} - -// return the cached logger -func getCached() T { - lock.RLock() - defer lock.RUnlock() - return *loadedLogger -} - -// startWatcher starts the file watcher on the seelog configurations file path -func (config *LogConfig) startWatcher(logger T) { - defer func() { - // In case the creation of watcher panics, let the current logger continue - if msg := recover(); msg != nil { - logger.Errorf("Seelog File Watcher Initilization Failed. Any updates on config file will be ignored unless agent is restarted: %v", msg) - } - }() - fileWatcher := &FileWatcher{} - fileWatcher.Init(logger, DefaultSeelogConfigFilePath, config.replaceLogger) - // Start the file watcher - fileWatcher.Start() -} - -// ReplaceLogger replaces the current logger with a new logger initialized from the current configurations file -func (config *LogConfig) replaceLogger() { - - // Get the current logger - logger := getCached() - - //Create new logger - logConfigBytes := config.GetLogConfigBytes() - baseLogger, err := initBaseLoggerFromBytes(logConfigBytes) - - // If err in creating logger, do not replace logger - if err != nil { - logger.Error("New logger creation failed") - return - } - - setStackDepth(baseLogger) - baseLogger.Debug("New Logger Successfully Created") - - // Safe conversion to *Wrapper - wrapper, ok := logger.(*Wrapper) - if !ok { - logger.Errorf("Logger replace failed. The logger is not a wrapper") - return - } - - // Replace the underlying base logger in wrapper - wrapper.ReplaceDelegate(baseLogger) -} - -func (config *LogConfig) GetLogConfigBytes() []byte { - return getLogConfigBytes(config.ClientName) -} - -// initBaseLoggerFromBytes initializes the base logger using the specified configuration as bytes. -func initBaseLoggerFromBytes(seelogConfig []byte) (seelogger seelog.LoggerInterface, err error) { - seelogger, err = seelog.LoggerFromConfigAsBytes(seelogConfig) - if err != nil { - // Create logger with default config - seelogger, _ = seelog.LoggerFromConfigAsBytes(DefaultConfig()) - } - return -} - -// withContext creates a wrapper logger on the base logger passed with context is passed -func withContext(logger seelog.LoggerInterface, context ...string) (contextLogger T) { - loggerInstance.BaseLoggerInstance = logger - formatFilter := &ContextFormatFilter{Context: context} - contextLogger = &Wrapper{Format: formatFilter, M: pkgMutex, Delegate: loggerInstance} - - setStackDepth(logger) - return contextLogger -} - -// setStackDepth sets the stack depth of the logger passed -func setStackDepth(logger seelog.LoggerInterface) { - // additional stack depth so that we print the calling function correctly - // stack depth 0 would print the function in the wrapper (e.g. wrapper.Debug) - // stack depth 1 prints the function calling the logger (wrapper), which is what we want. - logger.SetAdditionalStackDepth(1) -} diff --git a/pkg/session-manager-plugin/log/log_unix.go b/pkg/session-manager-plugin/log/log_unix.go deleted file mode 100644 index d1aa7ef..0000000 --- a/pkg/session-manager-plugin/log/log_unix.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build darwin || freebsd || linux || netbsd || openbsd -// +build darwin freebsd linux netbsd openbsd - -// Package log is used to initialize logger -package log - -import ( - "fmt" - "io/ioutil" - "path/filepath" -) - -const ( - LogsDirectory = "logs" - DefaultInstallLocationPrefix = "/tmp" -) - -func getApplicationName(clientName string) string { - var applicationName string - if clientName == "ssmcli" { - applicationName = "SSMCLI" - } else if clientName == "session-manager-plugin" { - applicationName = "sessionmanagerplugin" - } - - return applicationName -} - -// getLogConfigBytes reads and returns the seelog configs from the config file path if present -// otherwise returns the seelog default configurations -// Linux uses seelog.xml file as configuration by default. -func getLogConfigBytes(clientName string) (logConfigBytes []byte) { - - applicationName := getApplicationName(clientName) - DefaultSeelogConfigFilePath = filepath.Join(DefaultInstallLocationPrefix, applicationName, SeelogConfigFileName) - DefaultLogDir = filepath.Join(DefaultInstallLocationPrefix, applicationName, LogsDirectory) - ApplicationLogFile = fmt.Sprintf("%s%s", clientName, LogFileExtension) - ErrorLogFile = fmt.Sprintf("%s%s", ErrorLogFileSuffix, LogFileExtension) - if logConfigBytes, err = ioutil.ReadFile(DefaultSeelogConfigFilePath); err != nil { - logConfigBytes = DefaultConfig() - } - return -} diff --git a/pkg/session-manager-plugin/log/log_windows.go b/pkg/session-manager-plugin/log/log_windows.go deleted file mode 100644 index 3fd882b..0000000 --- a/pkg/session-manager-plugin/log/log_windows.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build windows -// +build windows - -// Package log is used to initialize logger -package log - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" -) - -const ( - // ApplicationFolder is the path under local app data. - ApplicationFolderPrefix = "Amazon\\" - LogsDirectory = "Logs" -) - -var EnvProgramFiles = os.Getenv("ProgramFiles") // Windows environment variable %ProgramFiles% - -func getApplicationName(clientName string) string { - var applicationName string - if clientName == "ssmcli" { - applicationName = "SSMCLI" - } else if clientName == "session-manager-plugin" { - applicationName = "SessionManagerPlugin" - } - - return applicationName -} - -// getLogConfigBytes reads and returns the seelog configs from the config file path if present -// otherwise returns the seelog default configurations -// Windows uses default log configuration if there is no seelog.xml override provided. -func getLogConfigBytes(clientName string) (logConfigBytes []byte) { - DefaultProgramFolder := filepath.Join( - EnvProgramFiles, - ApplicationFolderPrefix, - getApplicationName(clientName)) - DefaultSeelogConfigFilePath = filepath.Join(DefaultProgramFolder, SeelogConfigFileName) - DefaultLogDir = filepath.Join( - DefaultProgramFolder, - LogsDirectory) - ApplicationLogFile = fmt.Sprintf("%s%s", clientName, LogFileExtension) - ErrorLogFile = fmt.Sprintf("%s%s", ErrorLogFileSuffix, LogFileExtension) - - if logConfigBytes, err = ioutil.ReadFile(DefaultSeelogConfigFilePath); err != nil { - logConfigBytes = DefaultConfig() - } - return -} diff --git a/pkg/session-manager-plugin/log/wrapper.go b/pkg/session-manager-plugin/log/wrapper.go deleted file mode 100644 index f5595e8..0000000 --- a/pkg/session-manager-plugin/log/wrapper.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package log is used to initialize the logger. -package log - -import ( - "sync" -) - -// DelegateLogger holds the base logger for logging -type DelegateLogger struct { - BaseLoggerInstance BasicT -} - -// Wrapper is a logger that can modify the format of a log message before delegating to another logger. -type Wrapper struct { - Format FormatFilter - M *sync.Mutex - Delegate *DelegateLogger -} - -// FormatFilter can modify the format and or parameters to be passed to a logger. -type FormatFilter interface { - - // Filter modifies parameters that will be passed to log.Debug, log.Info, etc. - Filter(params ...interface{}) (newParams []interface{}) - - // Filter modifies format and/or parameter strings that will be passed to log.Debugf, log.Infof, etc. - Filterf(format string, params ...interface{}) (newFormat string, newParams []interface{}) -} - -// WithContext creates a wrapper logger with context -func (w *Wrapper) WithContext(context ...string) (contextLogger T) { - formatFilter := &ContextFormatFilter{Context: context} - contextLogger = &Wrapper{Format: formatFilter, M: w.M, Delegate: w.Delegate} - return contextLogger -} - -// Tracef formats message according to format specifier -// and writes to log with level = Trace. -func (w *Wrapper) Tracef(format string, params ...interface{}) { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Tracef(format, params...) -} - -// Debugf formats message according to format specifier -// and writes to log with level = Debug. -func (w *Wrapper) Debugf(format string, params ...interface{}) { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Debugf(format, params...) -} - -// Infof formats message according to format specifier -// and writes to log with level = Info. -func (w *Wrapper) Infof(format string, params ...interface{}) { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Infof(format, params...) -} - -// Warnf formats message according to format specifier -// and writes to log with level = Warn. -func (w *Wrapper) Warnf(format string, params ...interface{}) error { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Warnf(format, params...) -} - -// Errorf formats message according to format specifier -// and writes to log with level = Error. -func (w *Wrapper) Errorf(format string, params ...interface{}) error { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Errorf(format, params...) -} - -// Criticalf formats message according to format specifier -// and writes to log with level = Critical. -func (w *Wrapper) Criticalf(format string, params ...interface{}) error { - format, params = w.Format.Filterf(format, params...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Criticalf(format, params...) -} - -// Trace formats message using the default formats for its operands -// and writes to log with level = Trace -func (w *Wrapper) Trace(v ...interface{}) { - v = w.Format.Filter(v...) - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Trace(v...) -} - -// Debug formats message using the default formats for its operands -// and writes to log with level = Debug -func (w *Wrapper) Debug(v ...interface{}) { - v = w.Format.Filter(v...) - - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Debug(v...) -} - -// Info formats message using the default formats for its operands -// and writes to log with level = Info -func (w *Wrapper) Info(v ...interface{}) { - v = w.Format.Filter(v...) - - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Info(v...) -} - -// Warn formats message using the default formats for its operands -// and writes to log with level = Warn -func (w *Wrapper) Warn(v ...interface{}) error { - v = w.Format.Filter(v...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Warn(v...) -} - -// Error formats message using the default formats for its operands -// and writes to log with level = Error -func (w *Wrapper) Error(v ...interface{}) error { - v = w.Format.Filter(v...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Error(v...) -} - -// Critical formats message using the default formats for its operands -// and writes to log with level = Critical -func (w *Wrapper) Critical(v ...interface{}) error { - v = w.Format.Filter(v...) - - w.M.Lock() - defer w.M.Unlock() - return w.Delegate.BaseLoggerInstance.Critical(v...) -} - -// Flush flushes all the messages in the logger. -func (w *Wrapper) Flush() { - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Flush() -} - -// Close flushes all the messages in the logger and closes it. It cannot be used after this operation. -func (w *Wrapper) Close() { - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Close() -} - -// ReplaceDelegate replaces the delegate logger with a new logger -func (w *Wrapper) ReplaceDelegate(newLogger BasicT) { - w.M.Lock() - defer w.M.Unlock() - w.Delegate.BaseLoggerInstance.Flush() - w.Delegate.BaseLoggerInstance = newLogger - w.Delegate.BaseLoggerInstance.Info("Logger Replaced. New Logger Used to log the message") -} diff --git a/pkg/session-manager-plugin/message/clientmessage.go b/pkg/session-manager-plugin/message/clientmessage.go deleted file mode 100644 index de10944..0000000 --- a/pkg/session-manager-plugin/message/clientmessage.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// message package defines data channel messages structure. -package message - -import ( - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/twinj/uuid" -) - -const ( - // InputStreamMessage represents message type for input data - InputStreamMessage = "input_stream_data" - - // OutputStreamMessage represents message type for output data - OutputStreamMessage = "output_stream_data" - - // AcknowledgeMessage represents message type for acknowledge - AcknowledgeMessage = "acknowledge" - - // ChannelClosedMessage represents message type for ChannelClosed - ChannelClosedMessage = "channel_closed" - - // StartPublicationMessage represents the message type that notifies the CLI to start sending stream messages - StartPublicationMessage = "start_publication" - - // PausePublicationMessage represents the message type that notifies the CLI to pause sending stream messages - // as the remote data channel is inactive - PausePublicationMessage = "pause_publication" -) - -// AcknowledgeContent is used to inform the sender of an acknowledge message that the message has been received. -// * MessageType is a 32 byte UTF-8 string containing the message type. -// * MessageId is a 40 byte UTF-8 string containing the UUID identifying this message being acknowledged. -// * SequenceNumber is an 8 byte integer containing the message sequence number for serialized message. -// * IsSequentialMessage is a boolean field representing whether the acknowledged message is part of a sequence -type AcknowledgeContent struct { - MessageType string `json:"AcknowledgedMessageType"` - MessageId string `json:"AcknowledgedMessageId"` - SequenceNumber int64 `json:"AcknowledgedMessageSequenceNumber"` - IsSequentialMessage bool `json:"IsSequentialMessage"` -} - -// ChannelClosed is used to inform the client to close the channel -// * MessageId is a 40 byte UTF-8 string containing the UUID identifying this message. -// * CreatedDate is a string field containing the message create epoch millis in UTC. -// * DestinationId is a string field containing the session target. -// * SessionId is a string field representing which session to close. -// * MessageType is a 32 byte UTF-8 string containing the message type. -// * SchemaVersion is a 4 byte integer containing the message schema version number. -// * Output is a string field containing the error message for channel close. -type ChannelClosed struct { - MessageId string `json:"MessageId"` - CreatedDate string `json:"CreatedDate"` - DestinationId string `json:"DestinationId"` - SessionId string `json:"SessionId"` - MessageType string `json:"MessageType"` - SchemaVersion int `json:"SchemaVersion"` - Output string `json:"Output"` -} - -type PayloadType uint32 - -const ( - Output PayloadType = 1 - Error PayloadType = 2 - Size PayloadType = 3 - Parameter PayloadType = 4 - HandshakeRequestPayloadType PayloadType = 5 - HandshakeResponsePayloadType PayloadType = 6 - HandshakeCompletePayloadType PayloadType = 7 - EncChallengeRequest PayloadType = 8 - EncChallengeResponse PayloadType = 9 - Flag PayloadType = 10 - StdErr PayloadType = 11 - ExitCode PayloadType = 12 -) - -type PayloadTypeFlag uint32 - -const ( - DisconnectToPort PayloadTypeFlag = 1 - TerminateSession PayloadTypeFlag = 2 - ConnectToPortError PayloadTypeFlag = 3 -) - -type SizeData struct { - Cols uint32 `json:"cols"` - Rows uint32 `json:"rows"` -} - -type IClientMessage interface { - Validate() error - DeserializeClientMessage(log log.T, input []byte) (err error) - SerializeClientMessage(log log.T) (result []byte, err error) - DeserializeDataStreamAcknowledgeContent(log log.T) (dataStreamAcknowledge AcknowledgeContent, err error) - DeserializeChannelClosedMessage(log log.T) (channelClosed ChannelClosed, err error) - DeserializeHandshakeRequest(log log.T) (handshakeRequest HandshakeRequestPayload, err error) - DeserializeHandshakeComplete(log log.T) (handshakeComplete HandshakeCompletePayload, err error) -} - -// ClientMessage represents a message for client to send/receive. ClientMessage Message in MGS is equivalent to MDS' InstanceMessage. -// All client messages are sent in this form to the MGS service. -type ClientMessage struct { - HeaderLength uint32 - MessageType string - SchemaVersion uint32 - CreatedDate uint64 - SequenceNumber int64 - Flags uint64 - MessageId uuid.UUID - PayloadDigest []byte - PayloadType uint32 - PayloadLength uint32 - Payload []byte -} - -// * HL - HeaderLength is a 4 byte integer that represents the header length. -// * MessageType is a 32 byte UTF-8 string containing the message type. -// * SchemaVersion is a 4 byte integer containing the message schema version number. -// * CreatedDate is an 8 byte integer containing the message create epoch millis in UTC. -// * SequenceNumber is an 8 byte integer containing the message sequence number for serialized message streams. -// * Flags is an 8 byte unsigned integer containing a packed array of control flags: -// * Bit 0 is SYN - SYN is set (1) when the recipient should consider Seq to be the first message number in the stream -// * Bit 1 is FIN - FIN is set (1) when this message is the final message in the sequence. -// * MessageId is a 40 byte UTF-8 string containing a random UUID identifying this message. -// * Payload digest is a 32 byte containing the SHA-256 hash of the payload. -// * Payload length is an 4 byte unsigned integer containing the byte length of data in the Payload field. -// * Payload is a variable length byte data. -// -// * | HL| MessageType |Ver| CD | Seq | Flags | -// * | MessageId | Digest | PayType | PayLen| -// * | Payload | - -const ( - ClientMessage_HLLength = 4 - ClientMessage_MessageTypeLength = 32 - ClientMessage_SchemaVersionLength = 4 - ClientMessage_CreatedDateLength = 8 - ClientMessage_SequenceNumberLength = 8 - ClientMessage_FlagsLength = 8 - ClientMessage_MessageIdLength = 16 - ClientMessage_PayloadDigestLength = 32 - ClientMessage_PayloadTypeLength = 4 - ClientMessage_PayloadLengthLength = 4 -) - -const ( - ClientMessage_HLOffset = 0 - ClientMessage_MessageTypeOffset = ClientMessage_HLOffset + ClientMessage_HLLength - ClientMessage_SchemaVersionOffset = ClientMessage_MessageTypeOffset + ClientMessage_MessageTypeLength - ClientMessage_CreatedDateOffset = ClientMessage_SchemaVersionOffset + ClientMessage_SchemaVersionLength - ClientMessage_SequenceNumberOffset = ClientMessage_CreatedDateOffset + ClientMessage_CreatedDateLength - ClientMessage_FlagsOffset = ClientMessage_SequenceNumberOffset + ClientMessage_SequenceNumberLength - ClientMessage_MessageIdOffset = ClientMessage_FlagsOffset + ClientMessage_FlagsLength - ClientMessage_PayloadDigestOffset = ClientMessage_MessageIdOffset + ClientMessage_MessageIdLength - ClientMessage_PayloadTypeOffset = ClientMessage_PayloadDigestOffset + ClientMessage_PayloadDigestLength - ClientMessage_PayloadLengthOffset = ClientMessage_PayloadTypeOffset + ClientMessage_PayloadTypeLength - ClientMessage_PayloadOffset = ClientMessage_PayloadLengthOffset + ClientMessage_PayloadLengthLength -) diff --git a/pkg/session-manager-plugin/message/handshakemessage.go b/pkg/session-manager-plugin/message/handshakemessage.go deleted file mode 100644 index c8abcf2..0000000 --- a/pkg/session-manager-plugin/message/handshakemessage.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// message package defines data channel messages structure. -package message - -import ( - "encoding/json" - "time" -) - -// ActionType used in Handshake to determine action requested by the agent -type ActionType string - -const ( - KMSEncryption ActionType = "KMSEncryption" - SessionType ActionType = "SessionType" -) - -type ActionStatus int - -const ( - Success ActionStatus = 1 - Failed ActionStatus = 2 - Unsupported ActionStatus = 3 -) - -// This is sent by the agent to initialize KMS encryption -type KMSEncryptionRequest struct { - KMSKeyID string `json:"KMSKeyId"` -} - -// This is received by the agent to set up KMS encryption -type KMSEncryptionResponse struct { - KMSCipherTextKey []byte `json:"KMSCipherTextKey"` - KMSCipherTextHash []byte `json:"KMSCipherTextHash"` -} - -// SessionType request contains type of the session that needs to be launched and properties for plugin -type SessionTypeRequest struct { - SessionType string `json:"SessionType"` - Properties interface{} `json:"Properties"` -} - -// Handshake payload sent by the agent to the session manager plugin -type HandshakeRequestPayload struct { - AgentVersion string `json:"AgentVersion"` - RequestedClientActions []RequestedClientAction `json:"RequestedClientActions"` -} - -// An action requested by the agent to the plugin -type RequestedClientAction struct { - ActionType ActionType `json:"ActionType"` - ActionParameters json.RawMessage `json:"ActionParameters"` -} - -// The result of processing the action by the plugin -type ProcessedClientAction struct { - ActionType ActionType `json:"ActionType"` - ActionStatus ActionStatus `json:"ActionStatus"` - ActionResult interface{} `json:"ActionResult"` - Error string `json:"Error"` -} - -// Handshake Response sent by the plugin in response to the handshake request -type HandshakeResponsePayload struct { - ClientVersion string `json:"ClientVersion"` - ProcessedClientActions []ProcessedClientAction `json:"ProcessedClientActions"` - Errors []string `json:"Errors"` -} - -// This is sent by the agent as a challenge to the client. The challenge field -// is some data that was encrypted by the agent. The client must be able to decrypt -// this and in turn encrypt it with its own key. -type EncryptionChallengeRequest struct { - Challenge []byte `json:"Challenge"` -} - -// This is received by the agent from the client. The challenge field contains -// some data received, decrypted and then encrypted by the client. Agent must -// be able to decrypt this and verify it matches the original plaintext challenge. -type EncryptionChallengeResponse struct { - Challenge []byte `json:"Challenge"` -} - -// Handshake Complete indicates to client that handshake is complete. -// This signals the client to start the plugin and display a customer message where appropriate. -type HandshakeCompletePayload struct { - HandshakeTimeToComplete time.Duration `json:"HandshakeTimeToComplete"` - CustomerMessage string `json:"CustomerMessage"` -} diff --git a/pkg/session-manager-plugin/message/messageparser.go b/pkg/session-manager-plugin/message/messageparser.go deleted file mode 100644 index 1c9371e..0000000 --- a/pkg/session-manager-plugin/message/messageparser.go +++ /dev/null @@ -1,571 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// message package defines data channel messages structure. -package message - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/twinj/uuid" -) - -// DeserializeClientMessage deserializes the byte array into an ClientMessage message. -// * Payload is a variable length byte data. -// * | HL| MessageType |Ver| CD | Seq | Flags | -// * | MessageId | Digest | PayType | PayLen| -// * | Payload | -func (clientMessage *ClientMessage) DeserializeClientMessage(log log.T, input []byte) (err error) { - clientMessage.MessageType, err = getString(log, input, ClientMessage_MessageTypeOffset, ClientMessage_MessageTypeLength) - if err != nil { - log.Errorf("Could not deserialize field MessageType with error: %v", err) - return err - } - clientMessage.SchemaVersion, err = getUInteger(log, input, ClientMessage_SchemaVersionOffset) - if err != nil { - log.Errorf("Could not deserialize field SchemaVersion with error: %v", err) - return err - } - clientMessage.CreatedDate, err = getULong(log, input, ClientMessage_CreatedDateOffset) - if err != nil { - log.Errorf("Could not deserialize field CreatedDate with error: %v", err) - return err - } - clientMessage.SequenceNumber, err = getLong(log, input, ClientMessage_SequenceNumberOffset) - if err != nil { - log.Errorf("Could not deserialize field SequenceNumber with error: %v", err) - return err - } - clientMessage.Flags, err = getULong(log, input, ClientMessage_FlagsOffset) - if err != nil { - log.Errorf("Could not deserialize field Flags with error: %v", err) - return err - } - clientMessage.MessageId, err = getUuid(log, input, ClientMessage_MessageIdOffset) - if err != nil { - log.Errorf("Could not deserialize field MessageId with error: %v", err) - return err - } - clientMessage.PayloadDigest, err = getBytes(log, input, ClientMessage_PayloadDigestOffset, ClientMessage_PayloadDigestLength) - if err != nil { - log.Errorf("Could not deserialize field PayloadDigest with error: %v", err) - return err - } - clientMessage.PayloadType, err = getUInteger(log, input, ClientMessage_PayloadTypeOffset) - if err != nil { - log.Errorf("Could not deserialize field PayloadType with error: %v", err) - return err - } - clientMessage.PayloadLength, err = getUInteger(log, input, ClientMessage_PayloadLengthOffset) - - headerLength, herr := getUInteger(log, input, ClientMessage_HLOffset) - if herr != nil { - log.Errorf("Could not deserialize field HeaderLength with error: %v", err) - return err - } - - clientMessage.HeaderLength = headerLength - clientMessage.Payload = input[headerLength+ClientMessage_PayloadLengthLength:] - - return err -} - -// getString get a string value from the byte array starting from the specified offset to the defined length. -func getString(log log.T, byteArray []byte, offset int, stringLength int) (result string, err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+stringLength-1 > byteArrayLength-1 || offset < 0 { - log.Error("getString failed: Offset is invalid.") - return "", errors.New("Offset is outside the byte array.") - } - - //remove nulls from the bytes array - b := bytes.Trim(byteArray[offset:offset+stringLength], "\x00") - - return strings.TrimSpace(string(b)), nil -} - -// getUInteger gets an unsigned integer -func getUInteger(log log.T, byteArray []byte, offset int) (result uint32, err error) { - var temp int32 - temp, err = getInteger(log, byteArray, offset) - return uint32(temp), err -} - -// getInteger gets an integer value from a byte array starting from the specified offset. -func getInteger(log log.T, byteArray []byte, offset int) (result int32, err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+4 > byteArrayLength || offset < 0 { - log.Error("getInteger failed: Offset is invalid.") - return 0, errors.New("Offset is bigger than the byte array.") - } - return bytesToInteger(log, byteArray[offset:offset+4]) -} - -// bytesToInteger gets an integer from a byte array. -func bytesToInteger(log log.T, input []byte) (result int32, err error) { - var res int32 - inputLength := len(input) - if inputLength != 4 { - log.Error("bytesToInteger failed: input array size is not equal to 4.") - return 0, errors.New("Input array size is not equal to 4.") - } - buf := bytes.NewBuffer(input) - binary.Read(buf, binary.BigEndian, &res) - return res, nil -} - -// getULong gets an unsigned long integer -func getULong(log log.T, byteArray []byte, offset int) (result uint64, err error) { - var temp int64 - temp, err = getLong(log, byteArray, offset) - return uint64(temp), err -} - -// getLong gets a long integer value from a byte array starting from the specified offset. 64 bit. -func getLong(log log.T, byteArray []byte, offset int) (result int64, err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+8 > byteArrayLength || offset < 0 { - log.Error("getLong failed: Offset is invalid.") - return 0, errors.New("Offset is outside the byte array.") - } - return bytesToLong(log, byteArray[offset:offset+8]) -} - -// bytesToLong gets a Long integer from a byte array. -func bytesToLong(log log.T, input []byte) (result int64, err error) { - var res int64 - inputLength := len(input) - if inputLength != 8 { - log.Error("bytesToLong failed: input array size is not equal to 8.") - return 0, errors.New("Input array size is not equal to 8.") - } - buf := bytes.NewBuffer(input) - binary.Read(buf, binary.BigEndian, &res) - return res, nil -} - -// getUuid gets the 128bit uuid from an array of bytes starting from the offset. -func getUuid(log log.T, byteArray []byte, offset int) (result uuid.UUID, err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+16-1 > byteArrayLength-1 || offset < 0 { - log.Error("getUuid failed: Offset is invalid.") - return uuid.Nil.UUID(), errors.New("Offset is outside the byte array.") - } - - leastSignificantLong, err := getLong(log, byteArray, offset) - if err != nil { - log.Error("getUuid failed: failed to get uuid LSBs Long value.") - return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs long value.") - } - - leastSignificantBytes, err := longToBytes(log, leastSignificantLong) - if err != nil { - log.Error("getUuid failed: failed to get uuid LSBs bytes value.") - return uuid.Nil.UUID(), errors.New("Failed to get uuid LSBs bytes value.") - } - - mostSignificantLong, err := getLong(log, byteArray, offset+8) - if err != nil { - log.Error("getUuid failed: failed to get uuid MSBs Long value.") - return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs long value.") - } - - mostSignificantBytes, err := longToBytes(log, mostSignificantLong) - if err != nil { - log.Error("getUuid failed: failed to get uuid MSBs bytes value.") - return uuid.Nil.UUID(), errors.New("Failed to get uuid MSBs bytes value.") - } - - uuidBytes := append(mostSignificantBytes, leastSignificantBytes...) - - return uuid.New(uuidBytes), nil -} - -// longToBytes gets bytes array from a long integer. -func longToBytes(log log.T, input int64) (result []byte, err error) { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, input) - if buf.Len() != 8 { - log.Error("longToBytes failed: buffer output length is not equal to 8.") - return make([]byte, 8), errors.New("Input array size is not equal to 8.") - } - - return buf.Bytes(), nil -} - -// getBytes gets an array of bytes starting from the offset. -func getBytes(log log.T, byteArray []byte, offset int, byteLength int) (result []byte, err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+byteLength-1 > byteArrayLength-1 || offset < 0 { - log.Error("getBytes failed: Offset is invalid.") - return make([]byte, byteLength), errors.New("Offset is outside the byte array.") - } - return byteArray[offset : offset+byteLength], nil -} - -// Validate returns error if the message is invalid -func (clientMessage *ClientMessage) Validate() error { - if StartPublicationMessage == clientMessage.MessageType || - PausePublicationMessage == clientMessage.MessageType { - return nil - } - if clientMessage.HeaderLength == 0 { - return errors.New("HeaderLength cannot be zero") - } - if clientMessage.MessageType == "" { - return errors.New("MessageType is missing") - } - if clientMessage.CreatedDate == 0 { - return errors.New("CreatedDate is missing") - } - if clientMessage.PayloadLength != 0 { - hasher := sha256.New() - hasher.Write(clientMessage.Payload) - if !bytes.Equal(hasher.Sum(nil), clientMessage.PayloadDigest) { - return errors.New("payload Hash is not valid") - } - } - return nil -} - -// SerializeClientMessage serializes ClientMessage message into a byte array. -// * Payload is a variable length byte data. -// * | HL| MessageType |Ver| CD | Seq | Flags | -// * | MessageId | Digest |PayType| PayLen| -// * | Payload | -func (clientMessage *ClientMessage) SerializeClientMessage(log log.T) (result []byte, err error) { - payloadLength := uint32(len(clientMessage.Payload)) - headerLength := uint32(ClientMessage_PayloadLengthOffset) - // Set payload length - clientMessage.PayloadLength = payloadLength - - totalMessageLength := headerLength + ClientMessage_PayloadLengthLength + payloadLength - result = make([]byte, totalMessageLength) - - err = putUInteger(log, result, ClientMessage_HLOffset, headerLength) - if err != nil { - log.Errorf("Could not serialize HeaderLength with error: %v", err) - return make([]byte, 1), err - } - - startPosition := ClientMessage_MessageTypeOffset - endPosition := ClientMessage_MessageTypeOffset + ClientMessage_MessageTypeLength - 1 - err = putString(log, result, startPosition, endPosition, clientMessage.MessageType) - if err != nil { - log.Errorf("Could not serialize MessageType with error: %v", err) - return make([]byte, 1), err - } - - err = putUInteger(log, result, ClientMessage_SchemaVersionOffset, clientMessage.SchemaVersion) - if err != nil { - log.Errorf("Could not serialize SchemaVersion with error: %v", err) - return make([]byte, 1), err - } - - err = putULong(log, result, ClientMessage_CreatedDateOffset, clientMessage.CreatedDate) - if err != nil { - log.Errorf("Could not serialize CreatedDate with error: %v", err) - return make([]byte, 1), err - } - - err = putLong(log, result, ClientMessage_SequenceNumberOffset, clientMessage.SequenceNumber) - if err != nil { - log.Errorf("Could not serialize SequenceNumber with error: %v", err) - return make([]byte, 1), err - } - - err = putULong(log, result, ClientMessage_FlagsOffset, clientMessage.Flags) - if err != nil { - log.Errorf("Could not serialize Flags with error: %v", err) - return make([]byte, 1), err - } - - err = putUuid(log, result, ClientMessage_MessageIdOffset, clientMessage.MessageId) - if err != nil { - log.Errorf("Could not serialize MessageId with error: %v", err) - return make([]byte, 1), err - } - - hasher := sha256.New() - hasher.Write(clientMessage.Payload) - - startPosition = ClientMessage_PayloadDigestOffset - endPosition = ClientMessage_PayloadDigestOffset + ClientMessage_PayloadDigestLength - 1 - err = putBytes(log, result, startPosition, endPosition, hasher.Sum(nil)) - if err != nil { - log.Errorf("Could not serialize PayloadDigest with error: %v", err) - return make([]byte, 1), err - } - - err = putUInteger(log, result, ClientMessage_PayloadTypeOffset, clientMessage.PayloadType) - if err != nil { - log.Errorf("Could not serialize PayloadType with error: %v", err) - return make([]byte, 1), err - } - - err = putUInteger(log, result, ClientMessage_PayloadLengthOffset, clientMessage.PayloadLength) - if err != nil { - log.Errorf("Could not serialize PayloadLength with error: %v", err) - return make([]byte, 1), err - } - - startPosition = ClientMessage_PayloadOffset - endPosition = ClientMessage_PayloadOffset + int(payloadLength) - 1 - err = putBytes(log, result, startPosition, endPosition, clientMessage.Payload) - if err != nil { - log.Errorf("Could not serialize Payload with error: %v", err) - return make([]byte, 1), err - } - - return result, nil -} - -// putUInteger puts an unsigned integer -func putUInteger(log log.T, byteArray []byte, offset int, value uint32) (err error) { - return putInteger(log, byteArray, offset, int32(value)) -} - -// putInteger puts an integer value to a byte array starting from the specified offset. -func putInteger(log log.T, byteArray []byte, offset int, value int32) (err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+4 > byteArrayLength || offset < 0 { - log.Error("putInteger failed: Offset is invalid.") - return errors.New("Offset is outside the byte array.") - } - - bytes, err := integerToBytes(log, value) - if err != nil { - log.Error("putInteger failed: getBytesFromInteger Failed.") - return err - } - - copy(byteArray[offset:offset+4], bytes) - return nil -} - -// integerToBytes gets bytes array from an integer. -func integerToBytes(log log.T, input int32) (result []byte, err error) { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, input) - if buf.Len() != 4 { - log.Error("integerToBytes failed: buffer output length is not equal to 4.") - return make([]byte, 4), errors.New("Input array size is not equal to 4.") - } - - return buf.Bytes(), nil -} - -// putString puts a string value to a byte array starting from the specified offset. -func putString(log log.T, byteArray []byte, offsetStart int, offsetEnd int, inputString string) (err error) { - byteArrayLength := len(byteArray) - if offsetStart > byteArrayLength-1 || offsetEnd > byteArrayLength-1 || offsetStart > offsetEnd || offsetStart < 0 { - log.Error("putString failed: Offset is invalid.") - return errors.New("Offset is outside the byte array.") - } - - if offsetEnd-offsetStart+1 < len(inputString) { - log.Error("putString failed: Not enough space to save the string.") - return errors.New("Not enough space to save the string.") - } - - // wipe out the array location first and then insert the new value. - for i := offsetStart; i <= offsetEnd; i++ { - byteArray[i] = ' ' - } - - copy(byteArray[offsetStart:offsetEnd+1], inputString) - return nil -} - -// putBytes puts bytes into the array at the correct offset. -func putBytes(log log.T, byteArray []byte, offsetStart int, offsetEnd int, inputBytes []byte) (err error) { - byteArrayLength := len(byteArray) - if offsetStart > byteArrayLength-1 || offsetEnd > byteArrayLength-1 || offsetStart > offsetEnd || offsetStart < 0 { - log.Error("putBytes failed: Offset is invalid.") - return errors.New("Offset is outside the byte array.") - } - - if offsetEnd-offsetStart+1 != len(inputBytes) { - log.Error("putBytes failed: Not enough space to save the bytes.") - return errors.New("Not enough space to save the bytes.") - } - - copy(byteArray[offsetStart:offsetEnd+1], inputBytes) - return nil -} - -// putUuid puts the 128 bit uuid to an array of bytes starting from the offset. -func putUuid(log log.T, byteArray []byte, offset int, input uuid.UUID) (err error) { - if uuid.IsNil(input) { - log.Error("putUuid failed: input is null.") - return errors.New("putUuid failed: input is null.") - } - - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+16-1 > byteArrayLength-1 || offset < 0 { - log.Error("putUuid failed: Offset is invalid.") - return errors.New("Offset is outside the byte array.") - } - - leastSignificantLong, err := bytesToLong(log, input.Bytes()[8:16]) - if err != nil { - log.Error("putUuid failed: Failed to get leastSignificant Long value.") - return errors.New("Failed to get leastSignificant Long value.") - } - - mostSignificantLong, err := bytesToLong(log, input.Bytes()[0:8]) - if err != nil { - log.Error("putUuid failed: Failed to get mostSignificantLong Long value.") - return errors.New("Failed to get mostSignificantLong Long value.") - } - - err = putLong(log, byteArray, offset, leastSignificantLong) - if err != nil { - log.Error("putUuid failed: Failed to put leastSignificantLong Long value.") - return errors.New("Failed to put leastSignificantLong Long value.") - } - - err = putLong(log, byteArray, offset+8, mostSignificantLong) - if err != nil { - log.Error("putUuid failed: Failed to put mostSignificantLong Long value.") - return errors.New("Failed to put mostSignificantLong Long value.") - } - - return nil -} - -// putLong puts a long integer value to a byte array starting from the specified offset. -func putLong(log log.T, byteArray []byte, offset int, value int64) (err error) { - byteArrayLength := len(byteArray) - if offset > byteArrayLength-1 || offset+8 > byteArrayLength || offset < 0 { - log.Error("putInteger failed: Offset is invalid.") - return errors.New("Offset is outside the byte array.") - } - - mbytes, err := longToBytes(log, value) - if err != nil { - log.Error("putInteger failed: getBytesFromInteger Failed.") - return err - } - - copy(byteArray[offset:offset+8], mbytes) - return nil -} - -// putULong puts an unsigned long integer. -func putULong(log log.T, byteArray []byte, offset int, value uint64) (err error) { - return putLong(log, byteArray, offset, int64(value)) -} - -// SerializeClientMessagePayload marshals payloads for all session specific messages into bytes. -func SerializeClientMessagePayload(log log.T, obj interface{}) (reply []byte, err error) { - reply, err = json.Marshal(obj) - if err != nil { - log.Errorf("Could not serialize message with err: %s", err) - } - return -} - -// SerializeClientMessageWithAcknowledgeContent marshals client message with payloads of acknowledge contents into bytes. -func SerializeClientMessageWithAcknowledgeContent(log log.T, acknowledgeContent AcknowledgeContent) (reply []byte, err error) { - - acknowledgeContentBytes, err := SerializeClientMessagePayload(log, acknowledgeContent) - if err != nil { - // should not happen - log.Errorf("Cannot marshal acknowledge content to json string: %v", acknowledgeContentBytes) - return - } - - uuid.SwitchFormat(uuid.FormatCanonical) - messageId := uuid.NewV4() - clientMessage := ClientMessage{ - MessageType: AcknowledgeMessage, - SchemaVersion: 1, - CreatedDate: uint64(time.Now().UnixNano() / 1000000), - SequenceNumber: 0, - Flags: 3, - MessageId: messageId, - Payload: acknowledgeContentBytes, - } - - reply, err = clientMessage.SerializeClientMessage(log) - if err != nil { - log.Errorf("Error serializing client message with acknowledge content err: %v", err) - } - - return -} - -// DeserializeDataStreamAcknowledgeContent parses acknowledge content from payload of ClientMessage. -func (clientMessage *ClientMessage) DeserializeDataStreamAcknowledgeContent(log log.T) (dataStreamAcknowledge AcknowledgeContent, err error) { - if clientMessage.MessageType != AcknowledgeMessage { - err = fmt.Errorf("ClientMessage is not of type AcknowledgeMessage. Found message type: %s", clientMessage.MessageType) - return - } - - err = json.Unmarshal(clientMessage.Payload, &dataStreamAcknowledge) - if err != nil { - log.Errorf("Could not deserialize rawMessage: %s", err) - } - return -} - -// DeserializeChannelClosedMessage parses channelClosed message from payload of ClientMessage. -func (clientMessage *ClientMessage) DeserializeChannelClosedMessage(log log.T) (channelClosed ChannelClosed, err error) { - if clientMessage.MessageType != ChannelClosedMessage { - err = fmt.Errorf("ClientMessage is not of type ChannelClosed. Found message type: %s", clientMessage.MessageType) - return - } - - err = json.Unmarshal(clientMessage.Payload, &channelClosed) - if err != nil { - log.Errorf("Could not deserialize rawMessage: %s", err) - } - return -} - -func (clientMessage *ClientMessage) DeserializeHandshakeRequest(log log.T) (handshakeRequest HandshakeRequestPayload, err error) { - if clientMessage.PayloadType != uint32(HandshakeRequestPayloadType) { - err = log.Errorf("ClientMessage PayloadType is not of type HandshakeRequestPayloadType. Found payload type: %d", - clientMessage.PayloadType) - return - } - - err = json.Unmarshal(clientMessage.Payload, &handshakeRequest) - if err != nil { - log.Errorf("Could not deserialize rawMessage: %s", err) - } - return -} - -func (clientMessage *ClientMessage) DeserializeHandshakeComplete(log log.T) (handshakeComplete HandshakeCompletePayload, err error) { - if clientMessage.PayloadType != uint32(HandshakeCompletePayloadType) { - err = log.Errorf("ClientMessage PayloadType is not of type HandshakeCompletePayloadType. Found payload type: %d", - clientMessage.PayloadType) - return - } - - err = json.Unmarshal(clientMessage.Payload, &handshakeComplete) - if err != nil { - log.Errorf("Could not deserialize rawMessage, %s : %s", clientMessage.Payload, err) - } - return -} diff --git a/pkg/session-manager-plugin/retry/retry.go b/pkg/session-manager-plugin/retry/retry.go deleted file mode 100644 index 000377b..0000000 --- a/pkg/session-manager-plugin/retry/retry.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// retry implements back off retry strategy for reconnect web socket connection. -package retry - -import ( - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" -) - -const sleepConstant = 2 - -// Retry implements back off retry strategy for reconnect web socket connection. -func Retry(log log.T, attempts int, sleep time.Duration, fn func() error) (err error) { - - log.Info("Retrying connection to channel") - for attempts > 0 { - attempts-- - if err = fn(); err != nil { - time.Sleep(sleep) - sleep = sleep * sleepConstant - log.Debugf("%v attempts to connect web socket connection.", attempts) - continue - } - return nil - } - return err -} diff --git a/pkg/session-manager-plugin/retry/retryer.go b/pkg/session-manager-plugin/retry/retryer.go deleted file mode 100644 index 245e402..0000000 --- a/pkg/session-manager-plugin/retry/retryer.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// retry implements back off retry strategy for reconnect web socket connection. -package retry - -import ( - "math" - "time" -) - -type Retryer interface { - Call() error - NextSleepTime(int32) time.Duration -} - -type RepeatableExponentialRetryer struct { - CallableFunc func() error - GeometricRatio float64 - InitialDelayInMilli int - MaxDelayInMilli int - MaxAttempts int -} - -// NextSleepTime calculates the next delay of retry. -func (retryer *RepeatableExponentialRetryer) NextSleepTime(attempt int) time.Duration { - return time.Duration(float64(retryer.InitialDelayInMilli)*math.Pow(retryer.GeometricRatio, float64(attempt))) * time.Millisecond -} - -// Call calls the operation and does exponential retry if error happens. -func (retryer *RepeatableExponentialRetryer) Call() (err error) { - attempt := 0 - failedAttemptsSoFar := 0 - for { - err := retryer.CallableFunc() - if err == nil || failedAttemptsSoFar == retryer.MaxAttempts { - return err - } - sleep := retryer.NextSleepTime(attempt) - if int(sleep/time.Millisecond) > retryer.MaxDelayInMilli { - attempt = 0 - sleep = retryer.NextSleepTime(attempt) - } - time.Sleep(sleep) - attempt++ - failedAttemptsSoFar++ - } -} diff --git a/pkg/session-manager-plugin/sdkutil/awsconfig.go b/pkg/session-manager-plugin/sdkutil/awsconfig.go deleted file mode 100644 index 3ddf038..0000000 --- a/pkg/session-manager-plugin/sdkutil/awsconfig.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package sdkutil provides utilities used to call awssdk. -package sdkutil - -import ( - "fmt" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sdkutil/retryer" -) - -var defaultRegion string -var defaultProfile string - -// GetNewSessionWithEndpoint creates aws sdk session with given profile, region and endpoint -func GetNewSessionWithEndpoint(endpoint string) (sess *session.Session, err error) { - if sess, err = session.NewSessionWithOptions(session.Options{ - Config: aws.Config{ - Retryer: newRetryer(), - SleepDelay: sleepDelay, - Region: aws.String(defaultRegion), - Endpoint: aws.String(endpoint), - }, - SharedConfigState: session.SharedConfigEnable, - Profile: defaultProfile, - }); err != nil { - return nil, fmt.Errorf("Error creating new aws sdk session %s", err) - } - return sess, nil -} - -// GetDefaultSession creates aws sdk session with given profile and region -func GetDefaultSession() (sess *session.Session, err error) { - return GetNewSessionWithEndpoint("") -} - -// Sets the region and profile for default aws sessions -func SetRegionAndProfile(region string, profile string) { - defaultRegion = region - defaultProfile = profile -} - -var newRetryer = func() aws.RequestRetryer { - r := retryer.SsmCliRetryer{} - r.NumMaxRetries = 3 - return r -} - -var sleepDelay = func(d time.Duration) { - time.Sleep(d) -} diff --git a/pkg/session-manager-plugin/sdkutil/retryer/retryer.go b/pkg/session-manager-plugin/sdkutil/retryer/retryer.go deleted file mode 100644 index 73aa5bd..0000000 --- a/pkg/session-manager-plugin/sdkutil/retryer/retryer.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package retryer overrides the default aws sdk retryer delay logic to better suit the mds needs -package retryer - -import ( - "math" - "math/rand" - "strings" - "time" - - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/request" -) - -type SsmCliRetryer struct { - client.DefaultRetryer -} - -// RetryRules returns the delay duration before retrying this request again -func (s SsmCliRetryer) RetryRules(r *request.Request) time.Duration { - // Handle GetMessages Client.Timeout error - if r.Operation.Name == "GetMessages" && r.Error != nil && strings.Contains(r.Error.Error(), "Client.Timeout") { - // expected error. we will retry with a short 100 ms delay - return time.Duration(100 * time.Millisecond) - } - - // retry after a > 1 sec timeout, increasing exponentially with each retry - rand.Seed(time.Now().UnixNano()) - delay := int(math.Pow(2, float64(r.RetryCount))) * (rand.Intn(500) + 1000) - return time.Duration(delay) * time.Millisecond -} diff --git a/pkg/session-manager-plugin/service/service.go b/pkg/session-manager-plugin/service/service.go deleted file mode 100644 index 2457311..0000000 --- a/pkg/session-manager-plugin/service/service.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package service is a wrapper for the new Service -package service - -// OpenDataChannelInput -type OpenDataChannelInput struct { - _ struct{} `type:"structure"` - - // MessageSchemaVersion is a required field - MessageSchemaVersion *string `json:"MessageSchemaVersion" min:"1" type:"string" required:"true"` - - // RequestId is a required field - RequestId *string `json:"RequestId" min:"16" type:"string" required:"true"` - - // TokenValue is a required field - TokenValue *string `json:"TokenValue" min:"1" type:"string" required:"true"` - - // ClientId is a required field - ClientId *string `json:"ClientId" min:"1" type:"string" required:"true"` -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin-main/main.go b/pkg/session-manager-plugin/sessionmanagerplugin-main/main.go deleted file mode 100644 index 6ba5928..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin-main/main.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package main represents the entry point to session manager plugin. -package main - -import ( - "os" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - _ "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession" - _ "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession" -) - -func main() { - session.ValidateInputAndStartSession(os.Args, os.Stdout) -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/basicportforwarding.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/basicportforwarding.go deleted file mode 100644 index 39ab73a..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/basicportforwarding.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package portsession starts port session. -package portsession - -import ( - "fmt" - "net" - "os" - "os/signal" - "strconv" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil" - "github.com/ruelala/arconn/pkg/session-manager-plugin/version" -) - -// BasicPortForwarding is type of port session -// accepts one client connection at a time -type BasicPortForwarding struct { - port IPortSession - stream net.Conn - listener net.Listener - sessionId string - portParameters PortParameters - session session.Session -} - -// IsStreamNotSet checks if stream is not set -func (p *BasicPortForwarding) IsStreamNotSet() (status bool) { - return p.stream == nil -} - -// Stop closes the stream -func (p *BasicPortForwarding) Stop() { - p.listener.Close() - if p.stream != nil { - p.stream.Close() - } - return -} - -// InitializeStreams establishes connection and initializes the stream -func (p *BasicPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { - p.handleControlSignals(log) - if err = p.startLocalConn(log); err != nil { - return - } - return -} - -// ReadStream reads data from the stream -func (p *BasicPortForwarding) ReadStream(log log.T) (err error) { - msg := make([]byte, config.StreamDataPayloadSize) - for { - numBytes, err := p.stream.Read(msg) - if err != nil { - log.Debugf("Reading from port %s failed with error: %v. Close this connection, listen and accept new one.", - p.portParameters.PortNumber, err) - - // Send DisconnectToPort flag to agent when client tcp connection drops to ensure agent closes tcp connection too with server port - if err = p.session.DataChannel.SendFlag(log, message.DisconnectToPort); err != nil { - log.Errorf("Failed to send packet: %v", err) - return err - } - - if err = p.reconnect(log); err != nil { - return err - } - - // continue to read from connection as it has been re-established - continue - } - - log.Tracef("Received message of size %d from stdin.", numBytes) - if err = p.session.DataChannel.SendInputDataMessage(log, message.Output, msg[:numBytes]); err != nil { - log.Errorf("Failed to send packet: %v", err) - return err - } - // Sleep to process more data - time.Sleep(time.Millisecond) - } -} - -// WriteStream writes data to stream -func (p *BasicPortForwarding) WriteStream(outputMessage message.ClientMessage) error { - _, err := p.stream.Write(outputMessage.Payload) - return err -} - -// startLocalConn establishes a new local connection to forward remote server packets to -func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) { - // When localPortNumber is not specified, set port number to 0 to let net.conn choose an open port at random - localPortNumber := p.portParameters.LocalPortNumber - if p.portParameters.LocalPortNumber == "" { - localPortNumber = "0" - } - - if err = p.startLocalListener(log, localPortNumber); err != nil { - log.Errorf("Unable to open tcp connection to port. %v", err) - return err - } - - if p.stream, err = p.listener.Accept(); err != nil { - if p.session.DataChannel.IsSessionEnded() == false { - log.Errorf("Failed to accept connection with error. %v", err) - return err - } - } - if p.session.DataChannel.IsSessionEnded() == false { - log.Infof("Connection accepted for session %s.", p.sessionId) - fmt.Printf("Connection accepted for session %s.\n", p.sessionId) - } - - return -} - -// startLocalListener starts a local listener to given address -func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (err error) { - var displayMessage string - switch p.portParameters.LocalConnectionType { - case "unix": - if p.listener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { - return - } - displayMessage = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId) - default: - if p.listener, err = net.Listen("tcp", "localhost:"+portNumber); err != nil { - return - } - // get port number the TCP listener opened - p.portParameters.LocalPortNumber = strconv.Itoa(p.listener.Addr().(*net.TCPAddr).Port) - displayMessage = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId) - } - - log.Info(displayMessage) - fmt.Println(displayMessage) - return -} - -// handleControlSignals handles terminate signals -func (p *BasicPortForwarding) handleControlSignals(log log.T) { - c := make(chan os.Signal, 1) - signal.Notify(c, sessionutil.ControlSignals...) - go func() { - <-c - fmt.Println("Terminate signal received, exiting.") - - p.session.DataChannel.EndSession() - if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) { - if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { - log.Errorf("Failed to send TerminateSession flag: %v", err) - } - fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId) - } else { - p.session.TerminateSession(log) - } - p.Stop() - }() -} - -// reconnect closes existing connection, listens to new connection and accept it -func (p *BasicPortForwarding) reconnect(log log.T) (err error) { - // close existing connection as it is in a state from which data cannot be read - p.stream.Close() - - // wait for new connection on listener and accept it - if p.stream, err = p.listener.Accept(); err != nil { - if p.session.DataChannel.IsSessionEnded() == false { - log.Errorf("Failed to accept connection with error. %v", err) - return err - } - } - - return -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/muxportforwarding.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/muxportforwarding.go deleted file mode 100644 index 7b3734f..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/muxportforwarding.go +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package portsession starts port session. -package portsession - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "hash/fnv" - "io" - "net" - "os" - "os/signal" - "path/filepath" - "strconv" - "sync" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil" - "github.com/ruelala/arconn/pkg/session-manager-plugin/version" - "github.com/xtaci/smux" - "golang.org/x/sync/errgroup" -) - -// MuxClient contains smux client session and corresponding network connection -type MuxClient struct { - conn net.Conn - localListener net.Listener - session *smux.Session -} - -// MgsConn contains local server and corresponding connection to smux client -type MgsConn struct { - listener net.Listener - conn net.Conn -} - -// MuxPortForwarding is type of port session -// accepts multiple client connections through multiplexing -type MuxPortForwarding struct { - port IPortSession - sessionId string - socketFile string - portParameters PortParameters - session session.Session - muxClient *MuxClient - mgsConn *MgsConn -} - -func (c *MgsConn) close() { - c.listener.Close() - c.conn.Close() -} - -func (c *MuxClient) close() { - c.session.Close() - c.conn.Close() - c.localListener.Close() -} - -// IsStreamNotSet checks if stream is not set -func (p *MuxPortForwarding) IsStreamNotSet() (status bool) { - return p.muxClient.conn == nil -} - -// Stop closes all open stream -func (p *MuxPortForwarding) Stop() { - if p.mgsConn != nil { - p.mgsConn.close() - } - if p.muxClient != nil { - p.muxClient.close() - } - p.cleanUp() - return -} - -// InitializeStreams initializes i/o streams -func (p *MuxPortForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { - - p.handleControlSignals(log) - p.socketFile = getUnixSocketPath(p.sessionId, os.TempDir(), "session_manager_plugin_mux.sock") - - if err = p.initialize(log, agentVersion); err != nil { - p.cleanUp() - } - return -} - -// ReadStream reads data from different connections -func (p *MuxPortForwarding) ReadStream(log log.T) (err error) { - g, ctx := errgroup.WithContext(context.Background()) - - // reads data from smux client and transfers to server over datachannel - g.Go(func() error { - return p.transferDataToServer(log, ctx) - }) - - // set up network listener on SSM port and handle client connections - g.Go(func() error { - return p.handleClientConnections(log, ctx) - }) - - g.Go(func() error { - for { - time.Sleep(50 * time.Millisecond) - if p.session.DataChannel.IsSessionEnded() == true { - p.Stop() - return nil - } - } - }) - - return g.Wait() -} - -// WriteStream writes data to stream -func (p *MuxPortForwarding) WriteStream(outputMessage message.ClientMessage) error { - switch message.PayloadType(outputMessage.PayloadType) { - case message.Output: - _, err := p.mgsConn.conn.Write(outputMessage.Payload) - return err - case message.Flag: - var flag message.PayloadTypeFlag - buf := bytes.NewBuffer(outputMessage.Payload) - binary.Read(buf, binary.BigEndian, &flag) - - if message.ConnectToPortError == flag { - fmt.Printf("\nConnection to destination port failed, check SSM Agent logs.\n") - } - } - return nil -} - -// cleanUp deletes unix socket file -func (p *MuxPortForwarding) cleanUp() { - os.Remove(p.socketFile) -} - -// initialize opens a network connection that acts as smux client -func (p *MuxPortForwarding) initialize(log log.T, agentVersion string) (err error) { - - // open a network listener - var listener net.Listener - if listener, err = sessionutil.NewListener(log, p.socketFile); err != nil { - return - } - - var g errgroup.Group - // start a go routine to accept connections on the network listener - g.Go(func() error { - if conn, err := listener.Accept(); err != nil { - return err - } else { - p.mgsConn = &MgsConn{listener, conn} - } - return nil - }) - - // start a connection to the local network listener and set up client side of mux - g.Go(func() error { - if muxConn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()); err != nil { - return err - } else { - smuxConfig := smux.DefaultConfig() - if version.DoesAgentSupportDisableSmuxKeepAlive(log, agentVersion) { - // Disable smux KeepAlive or else it breaks Session Manager idle timeout. - smuxConfig.KeepAliveDisabled = false - } - if muxSession, err := smux.Client(muxConn, smuxConfig); err != nil { - return err - } else { - var localListener net.Listener - p.muxClient = &MuxClient{muxConn, localListener, muxSession} - } - } - return nil - }) - - return g.Wait() -} - -// handleControlSignals handles terminate signals -func (p *MuxPortForwarding) handleControlSignals(log log.T) { - c := make(chan os.Signal, 1) - signal.Notify(c, sessionutil.ControlSignals...) - go func() { - <-c - fmt.Println("Terminate signal received, exiting.") - - if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil { - log.Errorf("Failed to send TerminateSession flag: %v", err) - } - p.Stop() - }() -} - -// transferDataToServer reads from smux client connection and sends on data channel -func (p *MuxPortForwarding) transferDataToServer(log log.T, ctx context.Context) (err error) { - msg := make([]byte, config.StreamDataPayloadSize) - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var numBytes int - if numBytes, err = p.mgsConn.conn.Read(msg); err != nil { - log.Debugf("Reading from port failed with error: %v.", err) - return - } - - log.Tracef("Received message of size %d from mux client.", numBytes) - if err = p.session.DataChannel.SendInputDataMessage(log, message.Output, msg[:numBytes]); err != nil { - log.Errorf("Failed to send packet on data channel: %v", err) - return - } - // sleep to process more data - time.Sleep(time.Millisecond) - } - } -} - -// handleClientConnections sets up network server on local ssm port to accept connections from clients (browser/terminal) -func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Context) (err error) { - var ( - displayMsg string - ) - - if p.portParameters.LocalConnectionType == "unix" { - if p.muxClient.localListener, err = net.Listen(p.portParameters.LocalConnectionType, p.portParameters.LocalUnixSocket); err != nil { - return err - } - displayMsg = fmt.Sprintf("Unix socket %s opened for sessionId %s.", p.portParameters.LocalUnixSocket, p.sessionId) - } else { - localPortNumber := p.portParameters.LocalPortNumber - if p.portParameters.LocalPortNumber == "" { - localPortNumber = "0" - } - if p.muxClient.localListener, err = net.Listen("tcp", "localhost:"+localPortNumber); err != nil { - return err - } - p.portParameters.LocalPortNumber = strconv.Itoa(p.muxClient.localListener.Addr().(*net.TCPAddr).Port) - displayMsg = fmt.Sprintf("Port %s opened for sessionId %s.", p.portParameters.LocalPortNumber, p.sessionId) - } - - defer p.muxClient.localListener.Close() - - log.Infof(displayMsg) - fmt.Printf(displayMsg) - - log.Infof("Waiting for connections...\n") - fmt.Printf("\nWaiting for connections...\n") - - var once sync.Once - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - if conn, err := p.muxClient.localListener.Accept(); err != nil { - log.Errorf("Error while accepting connection: %v", err) - } else { - log.Infof("Connection accepted from %s\n for session [%s]", conn.RemoteAddr(), p.sessionId) - - once.Do(func() { - fmt.Printf("\nConnection accepted for session [%s]\n", p.sessionId) - }) - - stream, err := p.muxClient.session.OpenStream() - if err != nil { - continue - } - log.Debugf("Client stream opened %d\n", stream.ID()) - go handleDataTransfer(stream, conn) - } - } - } -} - -// handleDataTransfer launches routines to transfer data between source and destination -func handleDataTransfer(dst io.ReadWriteCloser, src io.ReadWriteCloser) { - var wait sync.WaitGroup - wait.Add(2) - - go func() { - io.Copy(dst, src) - dst.Close() - wait.Done() - }() - - go func() { - io.Copy(src, dst) - src.Close() - wait.Done() - }() - - wait.Wait() -} - -// getUnixSocketPath generates the unix socket file name based on sessionId and returns the path. -func getUnixSocketPath(sessionId string, dir string, suffix string) string { - hash := fnv.New32a() - hash.Write([]byte(sessionId)) - return filepath.Join(dir, fmt.Sprintf("%d_%s", hash.Sum32(), suffix)) -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/portsession.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/portsession.go deleted file mode 100644 index bb970ea..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/portsession.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package portsession starts port session. -package portsession - -import ( - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/jsonutil" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/version" -) - -const ( - LocalPortForwardingType = "LocalPortForwarding" -) - -type PortSession struct { - session.Session - portParameters PortParameters - portSessionType IPortSession -} - -type IPortSession interface { - IsStreamNotSet() (status bool) - InitializeStreams(log log.T, agentVersion string) (err error) - ReadStream(log log.T) (err error) - WriteStream(outputMessage message.ClientMessage) (err error) - Stop() -} - -type PortParameters struct { - PortNumber string `json:"portNumber"` - LocalPortNumber string `json:"localPortNumber"` - LocalUnixSocket string `json:"localUnixSocket"` - LocalConnectionType string `json:"localConnectionType"` - Type string `json:"type"` -} - -func init() { - session.Register(&PortSession{}) -} - -// Name is the session name used inputStream the plugin -func (PortSession) Name() string { - return config.PortPluginName -} - -func (s *PortSession) Initialize(log log.T, sessionVar *session.Session) { - s.Session = *sessionVar - if err := jsonutil.Remarshal(s.SessionProperties, &s.portParameters); err != nil { - log.Errorf("Invalid format: %v", err) - } - - if s.portParameters.Type == LocalPortForwardingType { - if version.DoesAgentSupportTCPMultiplexing(log, s.DataChannel.GetAgentVersion()) { - s.portSessionType = &MuxPortForwarding{ - sessionId: s.SessionId, - portParameters: s.portParameters, - session: s.Session, - } - } else { - s.portSessionType = &BasicPortForwarding{ - sessionId: s.SessionId, - portParameters: s.portParameters, - session: s.Session, - } - } - } else { - s.portSessionType = &StandardStreamForwarding{ - portParameters: s.portParameters, - session: s.Session, - } - } - - s.DataChannel.RegisterOutputStreamHandler(s.ProcessStreamMessagePayload, true) - s.DataChannel.GetWsChannel().SetOnMessage(func(input []byte) { - if s.portSessionType.IsStreamNotSet() { - outputMessage := &message.ClientMessage{} - if err := outputMessage.DeserializeClientMessage(log, input); err != nil { - log.Debugf("Ignore message deserialize error while stream connection had not set.") - return - } - if outputMessage.MessageType == message.OutputStreamMessage { - log.Debugf("Waiting for user to establish connection before processing incoming messages.") - return - } else { - log.Infof("Received %s message while establishing connection", outputMessage.MessageType) - } - } - s.DataChannel.OutputMessageHandler(log, s.Stop, s.SessionId, input) - }) - log.Infof("Connected to instance[%s] on port: %s", sessionVar.TargetId, s.portParameters.PortNumber) -} - -func (s *PortSession) Stop() { - s.portSessionType.Stop() -} - -// StartSession redirects inputStream/outputStream data to datachannel. -func (s *PortSession) SetSessionHandlers(log log.T) (err error) { - if err = s.portSessionType.InitializeStreams(log, s.DataChannel.GetAgentVersion()); err != nil { - return err - } - - if err = s.portSessionType.ReadStream(log); err != nil { - return err - } - return -} - -// ProcessStreamMessagePayload writes messages received on datachannel to stdout -func (s *PortSession) ProcessStreamMessagePayload(log log.T, outputMessage message.ClientMessage) (isHandlerReady bool, err error) { - if s.portSessionType.IsStreamNotSet() { - log.Debugf("Waiting for streams to be established before processing incoming messages.") - return false, nil - } - log.Tracef("Received payload of size %d from datachannel.", outputMessage.PayloadLength) - err = s.portSessionType.WriteStream(outputMessage) - return true, err -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/standardstreamforwarding.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/standardstreamforwarding.go deleted file mode 100644 index ce2fe54..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/portsession/standardstreamforwarding.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package portsession starts port session. -package portsession - -import ( - "fmt" - "io" - "os" - "os/signal" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil" -) - -type StandardStreamForwarding struct { - port IPortSession - inputStream *os.File - outputStream *os.File - portParameters PortParameters - session session.Session -} - -// IsStreamNotSet checks if streams are not set -func (p *StandardStreamForwarding) IsStreamNotSet() (status bool) { - return p.inputStream == nil || p.outputStream == nil -} - -// Stop closes the streams -func (p *StandardStreamForwarding) Stop() { - p.inputStream.Close() - p.outputStream.Close() - return -} - -// InitializeStreams initializes the streams with its file descriptors -func (p *StandardStreamForwarding) InitializeStreams(log log.T, agentVersion string) (err error) { - p.handleControlSignals(log) - p.inputStream = os.Stdin - p.outputStream = os.Stdout - return -} - -// handleControlSignals handles terminate signals -func (p *StandardStreamForwarding) handleControlSignals(log log.T) { - c := make(chan os.Signal, 1) - signal.Notify(c, sessionutil.ControlSignals...) - go func() { - <-c - fmt.Println("Terminate signal received, exiting.") - - p.session.DataChannel.EndSession() - p.Stop() - }() -} - -// ReadStream reads data from the input stream -func (p *StandardStreamForwarding) ReadStream(log log.T) (err error) { - msg := make([]byte, config.StreamDataPayloadSize) - for { - numBytes, err := p.inputStream.Read(msg) - if err != nil { - return p.handleReadError(log, err) - } - - log.Tracef("Received message of size %d from stdin.", numBytes) - if err = p.session.DataChannel.SendInputDataMessage(log, message.Output, msg[:numBytes]); err != nil { - log.Errorf("Failed to send packet: %v", err) - return err - } - // Sleep to process more data - time.Sleep(time.Millisecond) - } -} - -// WriteStream writes data to output stream -func (p *StandardStreamForwarding) WriteStream(outputMessage message.ClientMessage) error { - _, err := p.outputStream.Write(outputMessage.Payload) - return err -} - -// handleReadError handles read error -func (p *StandardStreamForwarding) handleReadError(log log.T, err error) error { - if err == io.EOF { - log.Infof("Session to instance[%s] on port[%s] was closed.", p.session.TargetId, p.portParameters.PortNumber) - return nil - } else { - log.Errorf("Reading input failed with error: %v", err) - return err - } -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/session.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/session.go deleted file mode 100644 index 7a324e8..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/session.go +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package session starts the session. -package session - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "os" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/ruelala/arconn/pkg/session-manager-plugin/datachannel" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/retry" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sdkutil" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil" - "github.com/ruelala/arconn/pkg/session-manager-plugin/version" - "github.com/twinj/uuid" -) - -const ( - LegacyArgumentLength = 4 - ArgumentLength = 7 - StartSessionOperation = "StartSession" - VersionFile = "VERSION" -) - -var SessionRegistry = map[string]ISessionPlugin{} - -type ISessionPlugin interface { - SetSessionHandlers(log.T) error - ProcessStreamMessagePayload(log log.T, streamDataMessage message.ClientMessage) (isHandlerReady bool, err error) - Initialize(log log.T, sessionVar *Session) - Stop() - Name() string -} - -type ISession interface { - Execute(log.T) error - OpenDataChannel(log.T) error - ProcessFirstMessage(log log.T, outputMessage message.ClientMessage) (isHandlerReady bool, err error) - Stop() - GetResumeSessionParams(log.T) (string, error) - ResumeSessionHandler(log.T) error - TerminateSession(log.T) error -} - -func init() { - SessionRegistry = make(map[string]ISessionPlugin) -} - -func Register(session ISessionPlugin) { - SessionRegistry[session.Name()] = session -} - -type Session struct { - DataChannel datachannel.IDataChannel - SessionId string - StreamUrl string - TokenValue string - IsAwsCliUpgradeNeeded bool - Endpoint string - ClientId string - TargetId string - sdk *ssm.SSM - retryParams retry.RepeatableExponentialRetryer - SessionType string - SessionProperties interface{} - DisplayMode sessionutil.DisplayMode -} - -// startSession create the datachannel for session -var startSession = func(session *Session, log log.T) error { - return session.Execute(log) -} - -// setSessionHandlersWithSessionType set session handlers based on session subtype -var setSessionHandlersWithSessionType = func(session *Session, log log.T) error { - // SessionType is set inside DataChannel - sessionSubType := SessionRegistry[session.SessionType] - sessionSubType.Initialize(log, session) - return sessionSubType.SetSessionHandlers(log) -} - -// Set up a scheduler to listen on stream data resend timeout event -var handleStreamMessageResendTimeout = func(session *Session, log log.T) { - log.Tracef("Setting up scheduler to listen on IsStreamMessageResendTimeout event.") - go func() { - for { - // Repeat this loop for every 200ms - time.Sleep(config.ResendSleepInterval) - if <-session.DataChannel.IsStreamMessageResendTimeout() { - log.Errorf("Terminating session %s as the stream data was not processed before timeout.", session.SessionId) - if err := session.TerminateSession(log); err != nil { - log.Errorf("Unable to terminate session upon stream data timeout. %v", err) - } - return - } - } - }() -} - -// ValidateInputAndStartSession validates input sent from AWS CLI and starts a session if validation is successful. -// AWS CLI sends input in the order of -// args[0] will be path of executable (ignored) -// args[1] is session response -// args[2] is client region -// args[3] is operation name -// args[4] is profile name from aws credentials/config files -// args[5] is parameters input to aws cli for StartSession api -// args[6] is endpoint for ssm service -func ValidateInputAndStartSession(args []string, out io.Writer) { - var ( - err error - session Session - startSessionOutput ssm.StartSessionOutput - response []byte - region string - operationName string - profile string - ssmEndpoint string - target string - ) - log := log.Logger(true, "session-manager-plugin") - uuid.SwitchFormat(uuid.FormatCanonical) - - if len(args) == 1 { - fmt.Fprint(out, "\nThe Session Manager plugin was installed successfully. "+ - "Use the AWS CLI to start a session.\n\n") - return - } else if len(args) == 2 && args[1] == "--version" { - fmt.Fprintf(out, "%s\n", string(version.Version)) - return - } else if len(args) >= 2 && len(args) < LegacyArgumentLength { - fmt.Fprintf(out, "\nUnknown operation %s. \nUse "+ - "session-manager-plugin --version to check the version.\n\n", string(args[1])) - return - - } else if len(args) == LegacyArgumentLength { - // If arguments do not have Profile passed from AWS CLI to Session-Manager-Plugin then - // should be upgraded to use Session Manager encryption feature - session.IsAwsCliUpgradeNeeded = true - } - - for argsIndex := 1; argsIndex < len(args); argsIndex++ { - switch argsIndex { - case 1: - response = []byte(args[1]) - case 2: - region = args[2] - case 3: - operationName = args[3] - case 4: - profile = args[4] - case 5: - // args[5] is parameters input to aws cli for StartSession api call - startSessionRequest := make(map[string]interface{}) - json.Unmarshal([]byte(args[5]), &startSessionRequest) - target = startSessionRequest["Target"].(string) - case 6: - ssmEndpoint = args[6] - } - } - sdkutil.SetRegionAndProfile(region, profile) - clientId := uuid.NewV4().String() - - switch operationName { - case StartSessionOperation: - if err = json.Unmarshal(response, &startSessionOutput); err != nil { - log.Errorf("Cannot perform start session: %v", err) - fmt.Fprintf(out, "Cannot perform start session: %v\n", err) - return - } - - session.SessionId = *startSessionOutput.SessionId - session.StreamUrl = *startSessionOutput.StreamUrl - session.TokenValue = *startSessionOutput.TokenValue - session.Endpoint = ssmEndpoint - session.ClientId = clientId - session.TargetId = target - session.DataChannel = &datachannel.DataChannel{} - - default: - fmt.Fprint(out, "Invalid Operation") - return - } - - if err = startSession(&session, log); err != nil { - if session.DataChannel.IsSessionEnded() == false { - log.Errorf("Cannot perform start session: %v", err) - fmt.Fprintf(out, "Cannot perform start session: %v\n", err) - } - return - } -} - -// Execute create data channel and start the session -func (s *Session) Execute(log log.T) (err error) { - fmt.Fprintf(os.Stdout, "\nStarting session with SessionId: %s\n", s.SessionId) - - // sets the display mode - s.DisplayMode = sessionutil.NewDisplayMode(log) - - if err = s.OpenDataChannel(log); err != nil { - log.Errorf("Error in Opening data channel: %v", err) - return - } - - handleStreamMessageResendTimeout(s, log) - - // The session type is set either by handshake or the first packet received. - if !<-s.DataChannel.IsSessionTypeSet() { - log.Errorf("unable to set SessionType for session %s", s.SessionId) - return errors.New("unable to determine SessionType") - } else { - s.SessionType = s.DataChannel.GetSessionType() - s.SessionProperties = s.DataChannel.GetSessionProperties() - if err = setSessionHandlersWithSessionType(s, log); err != nil { - if s.DataChannel.IsSessionEnded() == false { - log.Errorf("Session ending with error: %v", err) - } - return - } - } - - return -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionhandler.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionhandler.go deleted file mode 100644 index 23d038a..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionhandler.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package session starts the session. -package session - -import ( - "fmt" - "math/rand" - "os" - - sdkSession "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/retry" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sdkutil" -) - -// OpenDataChannel initializes datachannel -func (s *Session) OpenDataChannel(log log.T) (err error) { - s.retryParams = retry.RepeatableExponentialRetryer{ - GeometricRatio: config.RetryBase, - InitialDelayInMilli: rand.Intn(config.DataChannelRetryInitialDelayMillis) + config.DataChannelRetryInitialDelayMillis, - MaxDelayInMilli: config.DataChannelRetryMaxIntervalMillis, - MaxAttempts: config.DataChannelNumMaxRetries, - } - - s.DataChannel.Initialize(log, s.ClientId, s.SessionId, s.TargetId, s.IsAwsCliUpgradeNeeded) - s.DataChannel.SetWebsocket(log, s.StreamUrl, s.TokenValue) - s.DataChannel.GetWsChannel().SetOnMessage( - func(input []byte) { - s.DataChannel.OutputMessageHandler(log, s.Stop, s.SessionId, input) - }) - s.DataChannel.RegisterOutputStreamHandler(s.ProcessFirstMessage, false) - - if err = s.DataChannel.Open(log); err != nil { - log.Errorf("Retrying connection for data channel id: %s failed with error: %s", s.SessionId, err) - s.retryParams.CallableFunc = func() (err error) { return s.DataChannel.Reconnect(log) } - if err = s.retryParams.Call(); err != nil { - log.Error(err) - } - } - - s.DataChannel.GetWsChannel().SetOnError( - func(err error) { - log.Errorf("Trying to reconnect the session: %v with seq num: %d", s.StreamUrl, s.DataChannel.GetStreamDataSequenceNumber()) - s.retryParams.CallableFunc = func() (err error) { return s.ResumeSessionHandler(log) } - if err = s.retryParams.Call(); err != nil { - log.Error(err) - } - }) - - // Scheduler for resending of data - s.DataChannel.ResendStreamDataMessageScheduler(log) - - return nil -} - -// ProcessFirstMessage only processes messages with PayloadType Output to determine the -// sessionType of the session to be launched. This is a fallback for agent versions that do not support handshake, they -// immediately start sending shell output. -func (s *Session) ProcessFirstMessage(log log.T, outputMessage message.ClientMessage) (isHandlerReady bool, err error) { - // Immediately deregister self so that this handler is only called once, for the first message - s.DataChannel.DeregisterOutputStreamHandler(s.ProcessFirstMessage) - // Only set session type if the session type has not already been set. Usually session type will be set - // by handshake protocol which would be the first message but older agents may not perform handshake - if s.SessionType == "" { - if outputMessage.PayloadType == uint32(message.Output) { - log.Warn("Setting session type to shell based on PayloadType!") - s.DataChannel.SetSessionType(config.ShellPluginName) - s.DisplayMode.DisplayMessage(log, outputMessage) - } - } - return true, nil -} - -// Stop will end the session -func (s *Session) Stop() {} - -// GetResumeSessionParams calls ResumeSession API and gets tokenvalue for reconnecting -func (s *Session) GetResumeSessionParams(log log.T) (string, error) { - var ( - resumeSessionOutput *ssm.ResumeSessionOutput - err error - sdkSession *sdkSession.Session - ) - - if sdkSession, err = sdkutil.GetNewSessionWithEndpoint(s.Endpoint); err != nil { - return "", err - } - s.sdk = ssm.New(sdkSession) - - resumeSessionInput := ssm.ResumeSessionInput{ - SessionId: &s.SessionId, - } - - log.Debugf("Resume Session input parameters: %v", resumeSessionInput) - if resumeSessionOutput, err = s.sdk.ResumeSession(&resumeSessionInput); err != nil { - log.Errorf("Resume Session failed: %v", err) - return "", err - } - - if resumeSessionOutput.TokenValue == nil { - return "", nil - } - - return *resumeSessionOutput.TokenValue, nil -} - -// ResumeSessionHandler gets token value and tries to Reconnect to datachannel -func (s *Session) ResumeSessionHandler(log log.T) (err error) { - s.TokenValue, err = s.GetResumeSessionParams(log) - if err != nil { - log.Errorf("Failed to get token: %v", err) - return - } else if s.TokenValue == "" { - log.Debugf("Session: %s timed out", s.SessionId) - fmt.Fprintf(os.Stdout, "Session: %s timed out.\n", s.SessionId) - return - } - s.DataChannel.GetWsChannel().SetChannelToken(s.TokenValue) - err = s.DataChannel.Reconnect(log) - return -} - -// TerminateSession calls TerminateSession API -func (s *Session) TerminateSession(log log.T) error { - var ( - err error - newSession *sdkSession.Session - ) - - if newSession, err = sdkutil.GetNewSessionWithEndpoint(s.Endpoint); err != nil { - log.Errorf("Terminate Session failed: %v", err) - return err - } - s.sdk = ssm.New(newSession) - - terminateSessionInput := ssm.TerminateSessionInput{ - SessionId: &s.SessionId, - } - - log.Debugf("Terminate Session input parameters: %v", terminateSessionInput) - if _, err = s.sdk.TerminateSession(&terminateSessionInput); err != nil { - log.Errorf("Terminate Session failed: %v", err) - return err - } - return nil -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_unix.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_unix.go deleted file mode 100644 index 71fd02c..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_unix.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build darwin || freebsd || linux || netbsd || openbsd -// +build darwin freebsd linux netbsd openbsd - -// Package sessionutil contains utility methods required to start session. -package sessionutil - -import ( - "os" - "syscall" -) - -// All the signals to handles interrupt -// SIGINT captures Ctrl+C -// SIGQUIT captures Ctrl+\ -// SIGTSTP captures Ctrl+Z -var SignalsByteMap = map[os.Signal]byte{ - syscall.SIGINT: '\003', - syscall.SIGQUIT: '\x1c', - syscall.SIGTSTP: '\032', -} - -var ControlSignals = []os.Signal{syscall.SIGINT, syscall.SIGTSTP, syscall.SIGQUIT} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_windows.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_windows.go deleted file mode 100644 index 9cf4a85..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/control_signals_windows.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build windows -// +build windows - -// Package sessionutil contains utility methods required to start session. -package sessionutil - -import ( - "os" - "syscall" -) - -// All the signals to handles interrupt -// SIGINT captures Ctrl+C -// SIGQUIT captures Ctrl+Z -var SignalsByteMap = map[os.Signal]byte{ - syscall.SIGINT: '\003', - syscall.SIGQUIT: '\x1c', -} - -var ControlSignals = []os.Signal{syscall.SIGINT, syscall.SIGQUIT} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil.go deleted file mode 100644 index e9054d3..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package sessionutil provides utility for sessions. -package sessionutil - -import "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - -func NewDisplayMode(log log.T) DisplayMode { - displayMode := DisplayMode{} - displayMode.InitDisplayMode(log) - return displayMode -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go deleted file mode 100644 index 9aa6cc3..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_unix.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build darwin || freebsd || linux || netbsd || openbsd -// +build darwin freebsd linux netbsd openbsd - -// Package sessionutil provides utility for sessions. -package sessionutil - -import ( - "fmt" - "io" - "net" - "os" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" -) - -type DisplayMode struct { -} - -func (d *DisplayMode) InitDisplayMode(log log.T) { -} - -// DisplayMessage function displays the output on the screen -func (d *DisplayMode) DisplayMessage(log log.T, message message.ClientMessage) { - var out io.Writer = os.Stdout - fmt.Fprint(out, string(message.Payload)) -} - -// NewListener starts a new socket listener on the address. -func NewListener(log log.T, address string) (net.Listener, error) { - return net.Listen("unix", address) -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go deleted file mode 100644 index f33a0d9..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil/sessionutil_windows.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build windows -// +build windows - -// Package sessionutil provides utility for sessions. -package sessionutil - -import ( - "fmt" - "net" - "os" - "syscall" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "golang.org/x/sys/windows" -) - -var EnvProgramFiles = os.Getenv("ProgramFiles") - -type DisplayMode struct { - handle windows.Handle -} - -func (d *DisplayMode) InitDisplayMode(log log.T) { - var ( - state uint32 - fileDescriptor int - err error - ) - - // gets handler for Stdout - fileDescriptor = int(syscall.Stdout) - d.handle = windows.Handle(fileDescriptor) - - // gets current console mode i.e. current console settings - if err = windows.GetConsoleMode(d.handle, &state); err != nil { - log.Errorf("error getting console mode: %v", err) - } - - // this flag is set in order to support control character sequences - // that control cursor movement, color/font mode - // refer - https://docs.microsoft.com/en-us/windows/console/setconsolemode - state |= windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING - // sets the console with new flag - if err = windows.SetConsoleMode(d.handle, state); err != nil { - log.Errorf("error setting console mode: %v", err) - } -} - -// DisplayMessage function displays the output on the screen -func (d *DisplayMode) DisplayMessage(log log.T, message message.ClientMessage) { - var ( - done *uint32 - err error - ) - - // writes data to the specified file or input/output (I/O) device - // refer - https://docs.microsoft.com/en-us/windows/desktop/api/fileapi/nf-fileapi-writefile - if err = windows.WriteFile(d.handle, message.Payload, done, nil); err != nil { - log.Errorf("error occurred while writing to file: %v", err) - fmt.Fprintf(os.Stdout, "\nError getting the output. %s\n", err.Error()) - return - } -} - -// NewListener starts a new socket listener on the address. -// unix sockets are not supported in older windows versions, start tcp loopback server in such cases -func NewListener(log log.T, address string) (net.Listener, error) { - if listener, err := net.Listen("unix", address); err != nil { - log.Infof("Failed to open unix socket listener, %v. Starting TCP listener.", err) - return net.Listen("tcp", "localhost:0") - } else { - return listener, err - } -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession.go deleted file mode 100644 index 990ac8f..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package shellsession starts shell session. -package shellsession - -import ( - "bytes" - "encoding/json" - "os" - "os/signal" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session" - "github.com/ruelala/arconn/pkg/session-manager-plugin/sessionmanagerplugin/session/sessionutil" - "golang.org/x/crypto/ssh/terminal" -) - -const ( - ResizeSleepInterval = time.Millisecond * 500 - StdinBufferLimit = 1024 -) - -type ShellSession struct { - session.Session - - // SizeData is used to store size data at session level to compare with new size. - SizeData message.SizeData - originalSttyState bytes.Buffer -} - -var GetTerminalSizeCall = func(fd int) (width int, height int, err error) { - return terminal.GetSize(fd) -} - -func init() { - session.Register(&ShellSession{}) -} - -// Name is the session name used in the plugin -func (ShellSession) Name() string { - return config.ShellPluginName -} - -func (s *ShellSession) Initialize(log log.T, sessionVar *session.Session) { - s.Session = *sessionVar - s.DataChannel.RegisterOutputStreamHandler(s.ProcessStreamMessagePayload, true) - s.DataChannel.GetWsChannel().SetOnMessage( - func(input []byte) { - s.DataChannel.OutputMessageHandler(log, s.Stop, s.SessionId, input) - }) -} - -// StartSession takes input and write it to data channel -func (s *ShellSession) SetSessionHandlers(log log.T) (err error) { - - // handle re-size - s.handleTerminalResize(log) - - // handle control signals - s.handleControlSignals(log) - - //handles keyboard input - err = s.handleKeyboardInput(log) - - return -} - -// handleControlSignals handles control signals when given by user -func (s *ShellSession) handleControlSignals(log log.T) { - go func() { - signals := make(chan os.Signal, 1) - signal.Notify(signals, sessionutil.ControlSignals...) - for { - sig := <-signals - if b, ok := sessionutil.SignalsByteMap[sig]; ok { - if err := s.DataChannel.SendInputDataMessage(log, message.Output, []byte{b}); err != nil { - log.Errorf("Failed to send control signals: %v", err) - } - } - } - }() -} - -// handleTerminalResize checks size of terminal every 500ms and sends size data. -func (s *ShellSession) handleTerminalResize(log log.T) { - var ( - width int - height int - inputSizeData []byte - err error - ) - go func() { - for { - // If running from IDE GetTerminalSizeCall will not work. Supply a fixed width and height value. - if width, height, err = GetTerminalSizeCall(int(os.Stdout.Fd())); err != nil { - width = 300 - height = 100 - log.Errorf("Could not get size of the terminal: %s, using width %d height %d", err, width, height) - } - - if s.SizeData.Rows != uint32(height) || s.SizeData.Cols != uint32(width) { - sizeData := message.SizeData{ - Cols: uint32(width), - Rows: uint32(height), - } - s.SizeData = sizeData - - if inputSizeData, err = json.Marshal(sizeData); err != nil { - log.Errorf("Cannot marshall size data: %v", err) - } - log.Debugf("Sending input size data: %s", inputSizeData) - if err = s.DataChannel.SendInputDataMessage(log, message.Size, inputSizeData); err != nil { - log.Errorf("Failed to Send size data: %v", err) - } - } - // repeating this loop for every 500ms - time.Sleep(ResizeSleepInterval) - } - }() -} - -// ProcessStreamMessagePayload prints payload received on datachannel to console -func (s ShellSession) ProcessStreamMessagePayload(log log.T, outputMessage message.ClientMessage) (isHandlerReady bool, err error) { - s.DisplayMode.DisplayMessage(log, outputMessage) - return true, nil -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_unix.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_unix.go deleted file mode 100644 index 8563a73..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_unix.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build darwin || freebsd || linux || netbsd || openbsd -// +build darwin freebsd linux netbsd openbsd - -// Package shellsession starts shell session. -package shellsession - -import ( - "bufio" - "bytes" - "os" - "os/exec" - "time" - - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" -) - -// disableEchoAndInputBuffering disables echo to avoid double echo and disable input buffering -func (s *ShellSession) disableEchoAndInputBuffering() { - getState(&s.originalSttyState) - setState(bytes.NewBufferString("cbreak")) - setState(bytes.NewBufferString("-echo")) -} - -// getState gets current state of terminal -func getState(state *bytes.Buffer) error { - cmd := exec.Command("stty", "-g") - cmd.Stdin = os.Stdin - cmd.Stdout = state - return cmd.Run() -} - -// setState sets the new settings to terminal -func setState(state *bytes.Buffer) error { - cmd := exec.Command("stty", state.String()) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - return cmd.Run() -} - -// stop restores the terminal settings and exits -func (s *ShellSession) Stop() { - setState(&s.originalSttyState) - setState(bytes.NewBufferString("echo")) // for linux and ubuntu -} - -// handleKeyboardInput handles input entered by customer on terminal -func (s *ShellSession) handleKeyboardInput(log log.T) (err error) { - var ( - stdinBytesLen int - ) - - s.disableEchoAndInputBuffering() - ch := make(chan []byte) - go func(ch chan []byte) { - reader := bufio.NewReader(os.Stdin) - for { - stdinBytes := make([]byte, StdinBufferLimit) - stdinBytesLen, _ = reader.Read(stdinBytes) - ch <- stdinBytes - } - }(ch) - - for { - select { - case <-time.After(time.Second): - if s.Session.DataChannel.IsSessionEnded() { - return - } - case stdinBytes := <-ch: - if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, stdinBytes[:stdinBytesLen]); err != nil { - return - } - } - } -} diff --git a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_windows.go b/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_windows.go deleted file mode 100644 index 0644848..0000000 --- a/pkg/session-manager-plugin/sessionmanagerplugin/session/shellsession/shellsession_windows.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:build windows -// +build windows - -// Package shellsession starts shell session. -package shellsession - -import ( - "time" - - "github.com/eiannone/keyboard" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" - "github.com/ruelala/arconn/pkg/session-manager-plugin/message" -) - -// Byte array for key inputs -// Note: F11 cannot be converted to byte array -var specialKeysInputMap = map[keyboard.Key][]byte{ - keyboard.KeyEsc: {27}, - keyboard.KeyArrowUp: {27, 79, 65}, - keyboard.KeyArrowDown: {27, 79, 66}, - keyboard.KeyArrowRight: {27, 79, 67}, - keyboard.KeyArrowLeft: {27, 79, 68}, - keyboard.KeyF1: {27, 79, 80}, - keyboard.KeyF2: {27, 79, 81}, - keyboard.KeyF3: {27, 79, 82}, - keyboard.KeyF4: {27, 79, 83}, - keyboard.KeyF5: {27, 91, 49, 53, 126}, - keyboard.KeyF6: {27, 91, 49, 55, 126}, - keyboard.KeyF7: {27, 91, 49, 56, 126}, - keyboard.KeyF8: {27, 91, 49, 57, 126}, - keyboard.KeyF9: {27, 91, 50, 48, 126}, - keyboard.KeyF10: {27, 91, 50, 49, 126}, - keyboard.KeyF12: {27, 91, 50, 52, 126}, - keyboard.KeyHome: {27, 91, 72}, - keyboard.KeyEnd: {27, 91, 70}, - keyboard.KeyInsert: {27, 91, 50, 126}, - keyboard.KeyDelete: {27, 91, 51, 126}, - keyboard.KeyPgup: {27, 91, 53, 126}, - keyboard.KeyPgdn: {27, 91, 54, 126}, -} - -// stop restores the terminal settings and exits -func (s *ShellSession) Stop() { - keyboard.Close() -} - -// handleKeyboardInput handles input entered by customer on terminal -func (s *ShellSession) handleKeyboardInput(log log.T) (err error) { - var ( - character rune //character input from keyboard - key keyboard.Key //special keys like arrows and function keys - ) - - charCH := make(chan rune) - keyCH := make(chan keyboard.Key) - go func(charCH chan rune, keyCH chan keyboard.Key) { - if err = keyboard.Open(); err != nil { - log.Errorf("Failed to load Keyboard: %v", err) - return - } - for { - if character, key, err = keyboard.GetKey(); err != nil { - log.Errorf("Failed to get the key stroke: %v", err) - return - } - if character != 0 { - charCH <- character - } else if key != 0 { - keyCH <- key - } - } - }(charCH, keyCH) - - for { - select { - case <-time.After(time.Second): - if s.Session.DataChannel.IsSessionEnded() == true { - s.Stop() - return - } - case charStr := <-charCH: - charBytes := []byte(string(charStr)) - if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, charBytes); err != nil { - log.Errorf("Failed to send UTF8 char: %v", err) - return - } - case keyStr := <-keyCH: - keyBytes := []byte(string(keyStr)) - if byteValue, ok := specialKeysInputMap[key]; ok { - keyBytes = byteValue - } - if err = s.Session.DataChannel.SendInputDataMessage(log, message.Output, keyBytes); err != nil { - log.Errorf("Failed to send UTF8 char: %v", err) - return - } - } - } - return -} diff --git a/pkg/session-manager-plugin/version/version.go b/pkg/session-manager-plugin/version/version.go deleted file mode 100644 index 601c241..0000000 --- a/pkg/session-manager-plugin/version/version.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -// This is an autogenerated file. -// Changes made to this file will be overwritten during the build process. - -// Package version contains CLI version constant and utilities. - -package version - -// Version is the version of the CLI -const Version = "1.2.0.0" diff --git a/pkg/session-manager-plugin/version/versiongenerator/version-gen.go b/pkg/session-manager-plugin/version/versiongenerator/version-gen.go deleted file mode 100644 index 25fe748..0000000 --- a/pkg/session-manager-plugin/version/versiongenerator/version-gen.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package main represents the entry point to generate version. -package main - -import ( - "fmt" - "io/ioutil" - "log" - "os" - "path/filepath" - "text/template" -) - -const ( - ReadWriteAccess = 0600 - LicenseString = "// Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n// SPDX-License-Identifier: Apache-2.0\n\n" - VersiongoTemplate = `// This is an autogenerated file. -// Changes made to this file will be overwritten during the build process. - -// Package version contains CLI version constant and utilities. - -package version - -// Version is the version of the CLI -const Version = "{{.Version}}" -` -) - -// version-gen is a simple program that generates the plugin's version file, -// containing information about the plugin's version -func main() { - - versionContent, err := ioutil.ReadFile(filepath.Join("VERSION")) - if err != nil { - log.Fatalf("Error reading VERSION file. %v", err) - } - versionStr := string(versionContent) - - fmt.Printf("Session Manager Plugin Version: %v\n", versionStr) - - if err := ioutil.WriteFile(filepath.Join("VERSION"), []byte(versionStr), ReadWriteAccess); err != nil { - log.Fatalf("Error writing to VERSION file. %v", err) - } - - versionFilePath := filepath.Join("src", "version", "version.go") - - // Generate version.go - type versionInfo struct { - Version string - } - info := versionInfo{ - Version: versionStr, - } - t := template.Must(template.New("version").Parse(string(LicenseString) + VersiongoTemplate)) - outFile, err := os.Create(versionFilePath) - if err != nil { - log.Fatalf("Unable to create output version file: %v", err) - } - defer outFile.Close() - - err = t.Execute(outFile, info) - if err != nil { - log.Fatalf("Error applying template: %v", err) - } -} diff --git a/pkg/session-manager-plugin/version/versionutil.go b/pkg/session-manager-plugin/version/versionutil.go deleted file mode 100644 index 938bcc3..0000000 --- a/pkg/session-manager-plugin/version/versionutil.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package version contains CLI version constant and utilities. -package version - -import ( - "fmt" - "strconv" - "strings" -) - -type version struct { - version []string -} - -// NewVersion initializes version struct by splitting given version string into string list using separator "." -func NewVersion(versionString string) (version, error) { - if versionString == "" { - return version{}, fmt.Errorf("invalid version %s", versionString) - } - - return version{ - strings.Split(versionString, "."), - }, nil -} - -// compare returns 0 if thisVersion is equal to otherVersion, 1 if thisVersion is greater than otherVersion, -1 otherwise -func (thisVersion version) compare(otherVersion version) (int, error) { - if len(thisVersion.version) != len(otherVersion.version) { - return -1, fmt.Errorf("length mismatch for versions %s and %s", thisVersion.version, otherVersion.version) - } - - var ( - thisVersionSlice int - otherVersionSlice int - err error - ) - for i := range thisVersion.version { - if thisVersionSlice, err = strconv.Atoi(thisVersion.version[i]); err != nil { - return -1, err - } - if otherVersionSlice, err = strconv.Atoi(otherVersion.version[i]); err != nil { - return -1, err - } - - if thisVersionSlice > otherVersionSlice { - return 1, nil - } else if thisVersionSlice < otherVersionSlice { - return -1, nil - } - } - return 0, nil -} diff --git a/pkg/session-manager-plugin/version/versionvalidator.go b/pkg/session-manager-plugin/version/versionvalidator.go deleted file mode 100644 index a975ba2..0000000 --- a/pkg/session-manager-plugin/version/versionvalidator.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package version contains version constants and utilities. -package version - -import ( - "github.com/ruelala/arconn/pkg/session-manager-plugin/config" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" -) - -// DoesAgentSupportTCPMultiplexing returns true if given agentVersion supports TCP multiplexing in port plugin, false otherwise -func DoesAgentSupportTCPMultiplexing(log log.T, agentVersion string) (supported bool) { - return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingSupportedAfterThisAgentVersion) -} - -// DoesAgentSupportDisableSmuxKeepAlive returns true if given agentVersion disables smux KeepAlive in TCP multiplexing in port plugin, false otherwise -func DoesAgentSupportDisableSmuxKeepAlive(log log.T, agentVersion string) (supported bool) { - return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TCPMultiplexingWithSmuxKeepAliveDisabledAfterThisAgentVersion) -} - -// DoesAgentSupportTerminateSessionFlag returns true if given agentVersion supports TerminateSession flag, false otherwise -func DoesAgentSupportTerminateSessionFlag(log log.T, agentVersion string) (supported bool) { - return isAgentVersionGreaterThanSupportedVersion(log, agentVersion, config.TerminateSessionFlagSupportedAfterThisAgentVersion) -} - -// isAgentVersionGreaterThanSupportedVersion returns true if agentVersion is greater than supportedVersion, -// false in case of any error and agentVersion is equalTo or less than supportedVersion -func isAgentVersionGreaterThanSupportedVersion(log log.T, agentVersionString string, supportedVersionString string) (supported bool) { - var ( - supportedVersion version - agentVersion version - compareResult int - err error - ) - if supportedVersion, err = NewVersion(supportedVersionString); err != nil { - log.Debugf("supportedVersion initialization failed, %v", err) - return - } - - if agentVersion, err = NewVersion(agentVersionString); err != nil { - log.Debugf("agentVersion initialization failed, %v", err) - return - } - - if compareResult, err = agentVersion.compare(supportedVersion); err != nil { - log.Debugf("version comparison failed, %v", err) - return - } - - if compareResult == 1 { - supported = true - } - return -} diff --git a/pkg/session-manager-plugin/websocketutil/websocketutil.go b/pkg/session-manager-plugin/websocketutil/websocketutil.go deleted file mode 100644 index 49dad36..0000000 --- a/pkg/session-manager-plugin/websocketutil/websocketutil.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, -// either express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package websocketutil contains methods for interacting with websocket connections. -package websocketutil - -import ( - "errors" - - "github.com/gorilla/websocket" - "github.com/ruelala/arconn/pkg/session-manager-plugin/log" -) - -// IWebsocketUtil is the interface for the websocketutil. -type IWebsocketUtil interface { - OpenConnection(url string) (*websocket.Conn, error) - CloseConnection(ws websocket.Conn) error -} - -// WebsocketUtil struct provides functionality around creating and maintaining websockets. -type WebsocketUtil struct { - dialer *websocket.Dialer - log log.T -} - -// NewWebsocketUtil is the factory function for websocketutil. -func NewWebsocketUtil(logger log.T, dialerInput *websocket.Dialer) *WebsocketUtil { - - var websocketUtil *WebsocketUtil - - if dialerInput == nil { - websocketUtil = &WebsocketUtil{ - dialer: websocket.DefaultDialer, - log: logger, - } - } else { - websocketUtil = &WebsocketUtil{ - dialer: dialerInput, - log: logger, - } - } - - return websocketUtil -} - -// OpenConnection opens a websocket connection provided an input url. -func (u *WebsocketUtil) OpenConnection(url string) (*websocket.Conn, error) { - - u.log.Infof("Opening websocket connection to: ", url) - - conn, _, err := u.dialer.Dial(url, nil) - if err != nil { - u.log.Errorf("Failed to dial websocket: %s", err.Error()) - return nil, err - } - - u.log.Infof("Successfully opened websocket connection to: ", url) - - return conn, err -} - -// CloseConnection closes a websocket connection given the Conn object as input. -func (u *WebsocketUtil) CloseConnection(ws *websocket.Conn) error { - - if ws == nil { - return errors.New("websocket conn object is nil") - } - - u.log.Debugf("Closing websocket connection to:", ws.RemoteAddr().String()) - - err := ws.Close() - if err != nil { - u.log.Errorf("Failed to close websocket: %s", err.Error()) - return err - } - - u.log.Debugf("Successfully closed websocket connection to:", ws.RemoteAddr().String()) - - return nil -}