From 3db0e28e3a8cf1003b67a05f414ff803a3cbc4be Mon Sep 17 00:00:00 2001 From: Samuel Corsi-House Date: Fri, 10 May 2024 07:52:47 -0400 Subject: [PATCH] ssh/tailssh: add support for unix sockets Updates #6232 Signed-off-by: Samuel Corsi-House --- ssh/tailssh/tailssh.go | 44 +++- tailcfg/tailcfg.go | 11 +- tailcfg/tailcfg_clone.go | 2 + tailcfg/tailcfg_view.go | 4 + tempfork/gliderlabs/ssh/options_test.go | 2 +- tempfork/gliderlabs/ssh/server.go | 2 + tempfork/gliderlabs/ssh/server_test.go | 4 +- tempfork/gliderlabs/ssh/session_test.go | 19 +- tempfork/gliderlabs/ssh/ssh.go | 8 + tempfork/gliderlabs/ssh/streamlocal.go | 230 ++++++++++++++++++++ tempfork/gliderlabs/ssh/streamlocal_test.go | 205 +++++++++++++++++ tempfork/gliderlabs/ssh/tcpip.go | 69 ++++-- tempfork/gliderlabs/ssh/tcpip_test.go | 98 ++++++++- 13 files changed, 663 insertions(+), 35 deletions(-) create mode 100644 tempfork/gliderlabs/ssh/streamlocal.go create mode 100644 tempfork/gliderlabs/ssh/streamlocal_test.go diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 2bfb645f38161..00b7e7ba381cd 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -440,7 +440,8 @@ func (srv *server) newConn() (*conn, error) { c := &conn{srv: srv} now := srv.now() c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5)) - fwdHandler := &ssh.ForwardedTCPHandler{} + fwdHandlerTCP := &ssh.ForwardedTCPHandler{} + fwdHandlerUnix := &ssh.ForwardedUnixHandler{} c.Server = &ssh.Server{ Version: "Tailscale", ServerConfigCallback: c.ServerConfig, @@ -452,6 +453,8 @@ func (srv *server) newConn() (*conn, error) { Handler: c.handleSessionPostSSHAuth, LocalPortForwardingCallback: c.mayForwardLocalPortTo, ReversePortForwardingCallback: c.mayReversePortForwardTo, + LocalUnixForwardingCallback: c.mayForwardLocalUnixTo, + ReverseUnixForwardingCallback: c.mayReverseUnixForwardTo, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": c.handleSessionPostSSHAuth, }, @@ -459,11 +462,14 @@ func (srv *server) newConn() (*conn, error) { // only adds support for forwarding ports from the local machine. // TODO(maisem/bradfitz): add remote port forwarding support. ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": ssh.DirectStreamLocalHandler, }, RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": fwdHandler.HandleSSHRequest, - "cancel-tcpip-forward": fwdHandler.HandleSSHRequest, + "tcpip-forward": fwdHandlerTCP.HandleSSHRequest, + "cancel-tcpip-forward": fwdHandlerTCP.HandleSSHRequest, + "streamlocal-forward@openssh.com": fwdHandlerUnix.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": fwdHandlerUnix.HandleSSHRequest, }, } ss := c.Server @@ -514,6 +520,34 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de return false } +// mayReverseUnixForwardTo reports whether the ctx should be allowed to unix forward +// to the specified host. +func (c *conn) mayReverseUnixForwardTo(ctx ssh.Context, socketPath string) bool { + if sshDisableForwarding() { + return false + } + if c.finalAction != nil && c.finalAction.AllowRemoteUnixForwarding { + metricRemoteUnixForward.Add(1) + return true + } + // TODO(Xenfo): undo + return true +} + +// mayForwardLocalUnixTo reports whether the ctx should be allowed to unix forward +// to the specified host. +func (c *conn) mayForwardLocalUnixTo(ctx ssh.Context, socketPath string) bool { + if sshDisableForwarding() { + return false + } + if c.finalAction != nil && c.finalAction.AllowLocalUnixForwarding { + metricLocalUnixForward.Add(1) + return true + } + // TODO(Xenfo): undo + return true +} + // havePubKeyPolicy reports whether any policy rule may provide access by means // of a ssh.PublicKey. func (c *conn) havePubKeyPolicy() bool { @@ -1928,6 +1962,8 @@ var ( metricSFTP = clientmetric.NewCounter("ssh_sftp_sessions") metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests") metricRemotePortForward = clientmetric.NewCounter("ssh_remote_port_forward_requests") + metricLocalUnixForward = clientmetric.NewCounter("ssh_local_unix_forward_requests") + metricRemoteUnixForward = clientmetric.NewCounter("ssh_remote_unix_forward_requests") ) // userVisibleError is a wrapper around an error that implements diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index c842b88d3876c..732c66b44676d 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -136,7 +136,8 @@ type CapabilityVersion int // - 93: 2024-05-06: added support for stateful firewalling. // - 94: 2024-05-06: Client understands Node.IsJailed. // - 95: 2024-05-06: Client uses NodeAttrUserDialUseRoutes to change DNS dialing behavior. -const CurrentCapabilityVersion CapabilityVersion = 95 +// - 96: 2023-06-08: Client understands SSHAction.AllowLocalUnixForwarding and SSHAction.AllowRemoteUnixForwarding. +const CurrentCapabilityVersion CapabilityVersion = 96 type StableID string @@ -2456,6 +2457,14 @@ type SSHAction struct { // to use remote port forwarding if requested. AllowRemotePortForwarding bool `json:"allowRemotePortForwarding,omitempty"` + // AllowLocalUnixForwarding, if true, allows accepted connections + // to use local unix forwarding if requested. + AllowLocalUnixForwarding bool `json:"allowLocalUnixForwarding,omitempty"` + + // AllowRemoteUnixForwarding, if true, allows accepted connections + // to use remote unix forwarding if requested. + AllowRemoteUnixForwarding bool `json:"allowRemoteUnixForwarding,omitempty"` + // Recorders defines the destinations of the SSH session recorders. // The recording will be uploaded to http://addr:port/record. Recorders []netip.AddrPort `json:"recorders,omitempty"` diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 823fe681000f1..9c60c8443c23a 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -517,6 +517,8 @@ var _SSHActionCloneNeedsRegeneration = SSHAction(struct { HoldAndDelegate string AllowLocalPortForwarding bool AllowRemotePortForwarding bool + AllowLocalUnixForwarding bool + AllowRemoteUnixForwarding bool Recorders []netip.AddrPort OnRecordingFailure *SSHRecorderFailureAction }{}) diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index b5e1c9e802476..8c39c5ecf6441 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -1190,6 +1190,8 @@ func (v SSHActionView) AllowAgentForwarding() bool { return v.ж.All func (v SSHActionView) HoldAndDelegate() string { return v.ж.HoldAndDelegate } func (v SSHActionView) AllowLocalPortForwarding() bool { return v.ж.AllowLocalPortForwarding } func (v SSHActionView) AllowRemotePortForwarding() bool { return v.ж.AllowRemotePortForwarding } +func (v SSHActionView) AllowLocalUnixForwarding() bool { return v.ж.AllowLocalUnixForwarding } +func (v SSHActionView) AllowRemoteUnixForwarding() bool { return v.ж.AllowRemoteUnixForwarding } func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Recorders) } func (v SSHActionView) OnRecordingFailure() *SSHRecorderFailureAction { if v.ж.OnRecordingFailure == nil { @@ -1209,6 +1211,8 @@ var _SSHActionViewNeedsRegeneration = SSHAction(struct { HoldAndDelegate string AllowLocalPortForwarding bool AllowRemotePortForwarding bool + AllowLocalUnixForwarding bool + AllowRemoteUnixForwarding bool Recorders []netip.AddrPort OnRecordingFailure *SSHRecorderFailureAction }{}) diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 7cf6f376c6a88..50ea827d8e5f7 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -51,7 +51,7 @@ func TestPasswordAuth(t *testing.T) { func TestPasswordAuthBadPass(t *testing.T) { t.Parallel() - l := newLocalListener() + l := newLocalTCPListener() srv := &Server{Handler: func(s Session) {}} srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { return false diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 1086a72caf0e5..285651f142717 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -45,7 +45,9 @@ type Server struct { PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions diff --git a/tempfork/gliderlabs/ssh/server_test.go b/tempfork/gliderlabs/ssh/server_test.go index 177c071170c4e..5d29c22a5e393 100644 --- a/tempfork/gliderlabs/ssh/server_test.go +++ b/tempfork/gliderlabs/ssh/server_test.go @@ -31,7 +31,7 @@ func TestAddHostKey(t *testing.T) { } func TestServerShutdown(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() testBytes := []byte("Hello world\n") s := &Server{ Handler: func(s Session) { @@ -82,7 +82,7 @@ func TestServerShutdown(t *testing.T) { } func TestServerClose(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ Handler: func(s Session) { time.Sleep(5 * time.Second) diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index a60be5ec12d4e..42fe5f4dec0e7 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -22,14 +22,25 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + "direct-streamlocal@openssh.com": DirectStreamLocalHandler, } + + forwardedTCPHandler := &ForwardedTCPHandler{} + forwardedUnixHandler := &ForwardedUnixHandler{} + srv.RequestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + } + srv.HandleConn(conn) return nil } -func newLocalListener() net.Listener { +func newLocalTCPListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { @@ -66,7 +77,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g } func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() + l := newLocalTCPListener() go srv.serveOnce(l) return newClientSession(t, l.Addr().String(), cfg) } diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 644cb257d9afa..bd2467b833db4 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -62,9 +62,17 @@ type ConnCallback func(ctx Context, conn net.Conn) net.Conn // LocalPortForwardingCallback is a hook for allowing port forwarding type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool +// LocalUnixForwardingCallback is a hook for allowing unix forwarding +// (direct-streamlocal@openssh.com) +type LocalUnixForwardingCallback func(ctx Context, socketPath string) bool + // ReversePortForwardingCallback is a hook for allowing reverse port forwarding type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool +// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding +// (streamlocal-forward@openssh.com). +type ReverseUnixForwardingCallback func(ctx Context, socketPath string) bool + // ServerConfigCallback is a hook for creating custom default server configs type ServerConfigCallback func(ctx Context) *gossh.ServerConfig diff --git a/tempfork/gliderlabs/ssh/streamlocal.go b/tempfork/gliderlabs/ssh/streamlocal.go new file mode 100644 index 0000000000000..be5f2cba5ee5c --- /dev/null +++ b/tempfork/gliderlabs/ssh/streamlocal.go @@ -0,0 +1,230 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net" + "os" + "path/filepath" + "sync" + "syscall" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + forwardedUnixChannelType = "forwarded-streamlocal@openssh.com" +) + +// directStreamLocalChannelData data struct as specified in OpenSSH's protocol +// extensions document, Section 2.4. +// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD +type directStreamLocalChannelData struct { + SocketPath string + + Reserved1 string + Reserved2 uint32 +} + +// DirectStreamLocalHandler provides Unix forwarding from client -> server. It +// can be enabled by adding it to the server's ChannelHandlers under +// `direct-streamlocal@openssh.com`. +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + var d directStreamLocalChannelData + err := gossh.Unmarshal(newChan.ExtraData(), &d) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error()) + return + } + + if srv.LocalUnixForwardingCallback == nil || !srv.LocalUnixForwardingCallback(ctx, d.SocketPath) { + newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "unix", d.SocketPath) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error())) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + bicopy(ctx, ch, dconn) +} + +// remoteUnixForwardRequest describes the extra data sent in a +// streamlocal-forward@openssh.com containing the socket path to bind to. +type remoteUnixForwardRequest struct { + SocketPath string +} + +// remoteUnixForwardChannelData describes the data sent as the payload in the new +// channel request when a Unix connection is accepted by the listener. +type remoteUnixForwardChannelData struct { + SocketPath string + Reserved uint32 +} + +// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// `streamlocal-forward@openssh.com` and +// `cancel-streamlocal-forward@openssh.com` +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +type ForwardedUnixHandler struct { + sync.Mutex + forwards map[string]net.Listener +} + +func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + + switch req.Type { + case "streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + + if srv.ReverseUnixForwardingCallback == nil || !srv.ReverseUnixForwardingCallback(ctx, reqPayload.SocketPath) { + return false, []byte("unix forwarding is disabled") + } + + addr := reqPayload.SocketPath + h.Lock() + _, ok := h.forwards[addr] + h.Unlock() + if ok { + // TODO: log failure + return false, nil + } + + // Create socket parent dir if not exists. + parentDir := filepath.Dir(addr) + err = os.MkdirAll(parentDir, 0700) + if err != nil { + // TODO: log mkdir failure + return false, nil + } + + // Remove existing socket if it exists. We do not use os.Remove() here + // so that directories are kept. Note that it's possible that we will + // overwrite a regular file here. Both of these behaviors match OpenSSH, + // however, which is why we unlink. + err = unlink(addr) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + // TODO: log + return false, nil + } + + ln, err := net.Listen("unix", addr) + if err != nil { + // TODO: log unix listen failure + return false, nil + } + + if err := os.Chmod(addr, os.FileMode(0777)); err != nil { + // TODO: log permission change failure + return false, nil + } + + // The listener needs to successfully start before it can be added to + // the map, so we don't have to worry about checking for an existing + // listener as you can't listen on the same socket twice. + // + // This is also what the TCP version of this code does. + h.Lock() + h.forwards[addr] = ln + h.Unlock() + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + _ = ln.Close() + }() + go func() { + defer cancel() + + for { + c, err := ln.Accept() + if err != nil { + // closed below + break + } + payload := gossh.Marshal(&remoteUnixForwardChannelData{ + SocketPath: addr, + }) + + go func() { + ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload) + if err != nil { + _ = c.Close() + return + } + go gossh.DiscardRequests(reqs) + bicopy(ctx, ch, c) + }() + } + + h.Lock() + ln2, ok := h.forwards[addr] + if ok && ln2 == ln { + delete(h.forwards, addr) + } + h.Unlock() + _ = ln.Close() + }() + + return true, nil + + case "cancel-streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + h.Lock() + ln, ok := h.forwards[reqPayload.SocketPath] + h.Unlock() + if ok { + _ = ln.Close() + } + return true, nil + + default: + return false, nil + } +} + +// unlink removes files and unlike os.Remove, directories are kept. +func unlink(path string) error { + // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go + // for more details. + for { + err := syscall.Unlink(path) + if !errors.Is(err, syscall.EINTR) { + return err + } + } +} diff --git a/tempfork/gliderlabs/ssh/streamlocal_test.go b/tempfork/gliderlabs/ssh/streamlocal_test.go new file mode 100644 index 0000000000000..6e3c357a27322 --- /dev/null +++ b/tempfork/gliderlabs/ssh/streamlocal_test.go @@ -0,0 +1,205 @@ +//go:build glidertests + +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets. +// +// On all platforms other than darwin this just returns t.TempDir(). On darwin +// we manually make a temporary directory in /tmp because t.TempDir() returns a +// very long directory name, and the path length limit for Unix sockets on +// darwin is 104 characters. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("gliderlabs-ssh-test-%s-", testName)) + if err != nil { + t.Fatalf("create temp dir for test: %v", err) + } + + t.Cleanup(func() { + err := os.RemoveAll(dir) + if err != nil { + t.Errorf("remove temp dir %s: %v", dir, err) + } + }) + return dir + } + + return t.TempDir() +} + +func newLocalUnixListener(t *testing.T) net.Listener { + path := filepath.Join(tempDirUnixSocket(t), "socket.sock") + l, err := net.Listen("unix", path) + if err != nil { + t.Fatalf("failed to listen on a unix socket %q: %v", path, err) + } + return l +} + +func sampleUnixSocketServer(t *testing.T) net.Listener { + l := newLocalUnixListener(t) + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleUnixSocketServer(t) + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalUnixForwardingCallback: func(ctx Context, socketPath string) bool { + if socketPath != l.Addr().String() { + panic("unexpected socket path: " + socketPath) + } + return forwardingEnabled + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalUnixForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("unix", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, false) + defer cleanup() + + _, err := client.Dial("unix", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "unix forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} + +func TestReverseUnixForwardingWorks(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool { + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return true + }, + }, nil) + defer cleanup() + + l, err := client.ListenUnix(remoteSocketPath) + if err != nil { + t.Fatalf("failed to listen on a unix socket over SSH %q: %v", remoteSocketPath, err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("unix", remoteSocketPath) + if err != nil { + t.Fatalf("Error connecting to %v: %v", remoteSocketPath, err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the Unix socket is gone. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + _, err = os.Stat(remoteSocketPath) + if err == nil && !os.IsNotExist(err) { + t.Fatalf("expected remote socket to be gone but it still exists: %v", err) + } +} + +func TestReverseUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool { + atomic.AddInt64(&called, 1) + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return false + }, + }, nil) + defer cleanup() + + _, err := client.ListenUnix(remoteSocketPath) + if err == nil { + t.Fatalf("Expected error listening on %q but it succeeded", remoteSocketPath) + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index 056a0c7343daf..8ebe467cfb135 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "io" "log" "net" @@ -53,16 +54,7 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() + bicopy(ctx, ch, dconn) } type remoteForwardRequest struct { @@ -117,8 +109,14 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log listen failure return false, []byte{} } + + // If the bind port was port 0, we need to use the actual port in the + // listener map. _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if reqPayload.BindPort == 0 { + addr = net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(destPort)) + } h.Lock() h.forwards[addr] = ln h.Unlock() @@ -155,16 +153,7 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() + bicopy(ctx, ch, c) }() } h.Lock() @@ -191,3 +180,43 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return false, nil } } + +// bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index 118b5d53ac4a1..7a57f5e07a861 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -4,19 +4,22 @@ package ssh import ( "bytes" + "context" "io" "net" "strconv" "strings" + "sync/atomic" "testing" + "time" gossh "github.com/tailscale/golang-x-crypto/ssh" ) var sampleServerResponse = []byte("Hello world") -func sampleSocketServer() net.Listener { - l := newLocalListener() +func sampleTCPSocketServer() net.Listener { + l := newLocalTCPListener() go func() { conn, err := l.Accept() @@ -31,7 +34,7 @@ func sampleSocketServer() net.Listener { } func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() + l := sampleTCPSocketServer() _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, @@ -83,3 +86,92 @@ func TestLocalPortForwardingRespectsCallback(t *testing.T) { t.Fatalf("Expected permission error but got %#v", err) } } + +func TestReverseTCPForwardingWorks(t *testing.T) { + t.Parallel() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return true + }, + }, nil) + defer cleanup() + + l, err := client.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on a random TCP port over SSH: %v", err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the port is no longer in use. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + var d net.Dialer + _, err = d.DialContext(ctx, "tcp", l.Addr().String()) + if err == nil { + t.Fatalf("expected error connecting to %v but it succeeded", l.Addr().String()) + } +} + +func TestReverseTCPForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + atomic.AddInt64(&called, 1) + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return false + }, + }, nil) + defer cleanup() + + _, err := client.Listen("tcp", "127.0.0.1:0") + if err == nil { + t.Fatalf("Expected error listening on random port but it succeeded") + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +}