Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible address params #53

Merged
merged 19 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions pkg/device/genericcli/genericcli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func TestQuestionWithoutAnswer(t *testing.T) {
require.Empty(t, cmdRes)
require.NoError(t, err)
require.NoError(t, serverErr)
require.NoError(t, err)
}

func TestQuestionCmdOverlap(t *testing.T) {
Expand Down Expand Up @@ -215,7 +214,6 @@ func TestEscTermInEcho(t *testing.T) {
require.Equal(t, cmdRes, []cmd.CmdRes{cmd.NewCmdRes([]byte("olo")), cmd.NewCmdRes(nil)})
require.NoError(t, err)
require.NoError(t, serverErr)
require.NoError(t, err)
}

func TestEscTermInEchoEmptyCmd(t *testing.T) {
Expand Down Expand Up @@ -245,5 +243,4 @@ func TestEscTermInEchoEmptyCmd(t *testing.T) {
require.Equal(t, cmdRes, []cmd.CmdRes{cmd.NewCmdRes(nil), cmd.NewCmdRes(nil)})
require.NoError(t, err)
require.NoError(t, serverErr)
require.NoError(t, err)
}
130 changes: 102 additions & 28 deletions pkg/streamer/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"

Expand All @@ -35,6 +34,14 @@ import (
"github.com/annetutil/gnetcli/pkg/trace"
)

type Network string

const (
TCP Network = "tcp"
TCPv4 Network = "tcp4"
TCPv6 Network = "tcp6"
)

const (
defaultPort = 22
defaultReadTimeout = 20 * time.Second
Expand Down Expand Up @@ -83,9 +90,32 @@ type terminalParams struct {
h int
}

type Endpoint struct {
Host string
Port int
Network Network
}

func (endpoint Endpoint) String() string {
return fmt.Sprintf("{host: %s, port: %d, network: %s}", endpoint.Host, endpoint.Port, endpoint.Network)
}

func (endpoint *Endpoint) Addr() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}

func NewEndpoint(host string, port int, network Network) Endpoint {
res := Endpoint{
Host: host,
Port: port,
Network: network,
}
return res
}

type Streamer struct {
host string
port int
endpoint Endpoint
additionalEndpoints []Endpoint
credentials credentials.Credentials
logger *zap.Logger
conn *ssh.Client
Expand Down Expand Up @@ -136,8 +166,8 @@ func (m *Streamer) SetTerminalSize(w, h int) {

func NewStreamer(host string, credentials credentials.Credentials, opts ...StreamerOption) *Streamer {
h := &Streamer{
host: host,
port: defaultPort,
endpoint: NewEndpoint(host, defaultPort, TCP),
additionalEndpoints: []Endpoint{},
credentials: credentials,
logger: nil,
conn: nil,
Expand Down Expand Up @@ -340,9 +370,17 @@ func WithLogger(log *zap.Logger) StreamerOption {
}
}

// WithPort sets port for default endpoint
func WithPort(port int) StreamerOption {
return func(h *Streamer) {
h.port = port
h.endpoint.Port = port
}
}

// WithNetwork sets network for default endpoint
func WithNetwork(network Network) StreamerOption {
return func(h *Streamer) {
h.endpoint.Network = network
}
}

Expand All @@ -365,6 +403,14 @@ func WithEnv(key, value string) StreamerOption {
}
}

// WithAdditionalEndpoints adds slice of endpoints that Streamer will sequentially try to connect to untill success of dial,
// if original host dial fails
func WithAdditionalEndpoints(endpoints []Endpoint) StreamerOption {
return func(h *Streamer) {
h.additionalEndpoints = endpoints
}
}

func (m *Streamer) Close() {
m.forwardAgent = nil
if m.session != nil && m.session.session != nil {
Expand Down Expand Up @@ -541,31 +587,43 @@ func (m *Streamer) openConnect(ctx context.Context) (*ssh.Client, error) {
if err != nil {
return nil, err
}
remote := m.host + ":" + strconv.Itoa(m.port)
m.logger.Debug("open connection", zap.String("remote", remote))
var conn *ssh.Client
if m.tunnel != nil {
if !m.tunnel.IsConnected() {
err := m.tunnel.CreateConnect(ctx)
if err != nil {
return nil, err
}
}
conn, err = m.dialTunnel(ctx, conf)
} else {
conn, err = DialCtx(ctx, m.endpoint, m.additionalEndpoints, conf, m.logger)
}

tunConn, err := m.tunnel.StartForward(remote)
if err != nil {
return nil, err
}
return conn, err
}

conn, err = DialConnCtx(ctx, tunConn, remote, conf)
func (m *Streamer) dialTunnel(ctx context.Context, conf *ssh.ClientConfig) (*ssh.Client, error) {
if !m.tunnel.IsConnected() {
err := m.tunnel.CreateConnect(ctx)
if err != nil {
return nil, err
}
} else {
conn, err = DialCtx(ctx, "tcp", remote, conf)
}

return conn, err
var tunConn net.Conn
var err error
var connectedEndpoint Endpoint
endpoints := append([]Endpoint{m.endpoint}, m.additionalEndpoints...)
for _, endpoint := range endpoints {
connectedEndpoint = endpoint
tunConn, err = m.tunnel.StartForward(string(endpoint.Network), endpoint.Addr())
if err == nil {
break
}
m.logger.Debug("failed to open tunnel for endpoint", zap.String("address", endpoint.String()))
}
if err != nil {
return nil, fmt.Errorf("failed to open tunnel for any of given hosts: %v, last error: %w", m.endpoint, err)
}
res, err := DialConnCtx(ctx, tunConn, connectedEndpoint.Addr(), conf)
if err != nil {
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err)
}
return res, err
}

func (m *Streamer) onSessionOpen(sess *ssh.Session) error {
Expand Down Expand Up @@ -677,7 +735,7 @@ func (m *Streamer) Init(ctx context.Context) error {
return fmt.Errorf("already inited")
}
m.inited = true
m.logger.Debug("open connection", zap.String("host", m.host))
m.logger.Debug("open connection", zap.Stringer("endpoint", m.endpoint), zap.Stringers("additional endpoints", m.additionalEndpoints))

conn, err := m.openConnect(ctx)
if err != nil {
Expand Down Expand Up @@ -985,12 +1043,28 @@ func (m *Streamer) uploadSftp(filePaths map[string]streamer.File, useSudo bool)
}

// DialCtx ssh.Dial version with context arg
func DialCtx(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := streamer.TCPDialCtx(ctx, network, addr)
func DialCtx(ctx context.Context, endpoint Endpoint, additionalEndpoints []Endpoint, config *ssh.ClientConfig, logger *zap.Logger) (*ssh.Client, error) {
var err error
var conn net.Conn
var connectedEndpoint Endpoint
endpoints := append([]Endpoint{endpoint}, additionalEndpoints...)
for _, endpoint := range endpoints {
connectedEndpoint = endpoint
conn, err = streamer.TCPDialCtx(ctx, string(endpoint.Network), endpoint.Addr())
if err == nil {
break
}
// always continue attempts to connect in case of dial failure
logger.Debug("dial failed for endpoint", zap.String("endpoint", endpoint.String()), zap.Error(err))
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to dial any of given endpoints: %v, last error: %w", endpoint, err)
}
res, err := DialConnCtx(ctx, conn, connectedEndpoint.Addr(), config)
if err != nil {
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err)
}
return DialConnCtx(ctx, conn, addr, config)
return res, err
}

func DialConnCtx(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
Expand Down
50 changes: 6 additions & 44 deletions pkg/streamer/ssh/ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@ package ssh
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"syscall"

Expand All @@ -23,11 +20,11 @@ type Tunnel interface {
Close()
IsConnected() bool
CreateConnect(context.Context) error
StartForward(addr string) (net.Conn, error)
StartForward(network, addr string) (net.Conn, error)
}

type SSHTunnel struct {
Server *Endpoint
Server Endpoint
Config *ssh.ClientConfig
svrConn *ssh.Client
isOpen bool
Expand All @@ -36,12 +33,7 @@ type SSHTunnel struct {
mu sync.Mutex
}

func NewSSHTunnel(tunnel string, credentials credentials.Credentials, opts ...SSHTunnelOption) *SSHTunnel {
server := NewEndpoint(tunnel)
if server.Port == 0 {
server.Port = defaultPort
}

func NewSSHTunnel(server Endpoint, credentials credentials.Credentials, opts ...SSHTunnelOption) *SSHTunnel {
h := &SSHTunnel{
Server: server,
Config: nil,
Expand Down Expand Up @@ -78,7 +70,7 @@ func (m *SSHTunnel) CreateConnect(ctx context.Context) error {

m.Config = conf

serverConn, err := DialCtx(ctx, "tcp", m.Server.String(), m.Config)
serverConn, err := DialCtx(ctx, m.Server, nil, m.Config, m.logger)
if err != nil {
if !errors.Is(err, context.Canceled) {
m.logger.Error(err.Error())
Expand All @@ -91,15 +83,15 @@ func (m *SSHTunnel) CreateConnect(ctx context.Context) error {
return nil
}

func (m *SSHTunnel) StartForward(remoteAddr string) (net.Conn, error) {
func (m *SSHTunnel) StartForward(network, remoteAddr string) (net.Conn, error) {
if !m.isOpen {
return nil, errors.New("connection is closed")
}
lconn, rconn, err := m.makeSocketFromSocketPair()
if err != nil {
return nil, err
}
remoteConn, err := m.svrConn.Dial("tcp", remoteAddr)
remoteConn, err := m.svrConn.Dial(network, remoteAddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -175,33 +167,3 @@ func (m *SSHTunnel) makeSocketFromSocketPair() (net.Conn, net.Conn, error) {

return c0, c1, nil
}

type Endpoint struct {
Host string
Port int
User string
}

func NewEndpoint(s string) *Endpoint {
endpoint := &Endpoint{
Host: s,
Port: 0,
User: "",
}

if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
endpoint.User = parts[0]
endpoint.Host = parts[1]
}

if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
}

return endpoint
}

func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
14 changes: 10 additions & 4 deletions pkg/testutils/mock/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,17 @@ func NewMockSSHServer(dialog []Action, opts ...MockSSHServerOption) (*MockSSHSer
}

func (m *MockSSHServer) GetAddress() (string, int) {
address := strings.Split(m.listener.Addr().String(), ":")

portNum, _ := strconv.Atoi(address[1])
// ipv6 case
if v6EndIdx := strings.Index(m.listener.Addr().String(), "]"); v6EndIdx != -1 {
address := m.listener.Addr().String()[0 : v6EndIdx+1]
portNum, _ := strconv.Atoi(m.listener.Addr().String()[v6EndIdx+2:])
return address, portNum
}

return address[0], portNum
parts := strings.Split(m.listener.Addr().String(), ":")
address := parts[0]
portNum, _ := strconv.Atoi(parts[1])
return address, portNum
}

func (m *MockSSHServer) Run(ctx context.Context) error {
Expand Down