Skip to content

Commit

Permalink
Add support for cleanly shutting down a server
Browse files Browse the repository at this point in the history
This is related to gliderlabs#22 and gliderlabs#20
  • Loading branch information
belak committed Feb 6, 2017
1 parent f313115 commit 6cd286b
Showing 1 changed file with 115 additions and 4 deletions.
119 changes: 115 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
package ssh

import (
"errors"
"fmt"
"net"
"sync"
"time"

gossh "golang.org/x/crypto/ssh"
)

var (
// ErrInvalidState will be returned from some functions when asked to do
// something but the server is either already running and shouldn't be or
// vice versa.
ErrInvalidState = errors.New("Invalid server state")

// ErrDraining is returned from Serve in some cases when cleanly shutting
// down. Note that this will not always be returned if the server was asked
// to shut down.
ErrDraining = errors.New("Server was asked to shut down")
)

type serverState int

const (
stateStopped serverState = iota
stateStarted
stateDraining
)

// Server defines parameters for running an SSH server. The zero value for
// Server is a valid configuration. When both PasswordHandler and
// PublicKeyHandler are nil, no client authentication is performed.
Expand All @@ -21,6 +43,13 @@ type Server struct {
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
PermissionsCallback PermissionsCallback // optional callback for setting up permissions

// Internal fields. Note that the zero value for these should be a state we
// can detect so the Server can still be instantiated using &Server{}.
stateLock sync.Mutex
stateChan chan struct{}
state serverState
listener net.Listener
}

func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
Expand Down Expand Up @@ -72,24 +101,64 @@ func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
}

// Handle sets the Handler for the server.
func (srv *Server) Handle(fn Handler) {
func (srv *Server) Handle(fn Handler) error {
srv.stateLock.Lock()
defer srv.stateLock.Unlock()

if srv.state != stateStopped {
return ErrInvalidState
}

srv.Handler = fn

return nil
}

// Serve accepts incoming connections on the Listener l, creating a new
// connection goroutine for each. The connection goroutines read requests and then
// calls srv.Handler to handle sessions.
// calls srv.Handler to handle sessions. Note that this connection will wait
//
// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
// Ensure we're just starting the server and set up any values which need to
// be set up.
srv.stateLock.Lock()
if srv.state != stateStopped {
l.Close()
srv.stateLock.Unlock()
return ErrInvalidState
}
srv.state = stateStarted
srv.stateChan = make(chan struct{}, 1)
srv.listener = l
srv.stateLock.Unlock()

wg := &sync.WaitGroup{}

defer func() {
srv.stateLock.Lock()
defer srv.stateLock.Unlock()

srv.state = stateStopped
srv.stateChan = nil

// If there's still a listener around, we need to close it
if srv.listener != nil {
srv.listener.Close()
}
srv.listener = nil
}()

config, err := srv.makeConfig()
if err != nil {
return err
}
if srv.Handler == nil {
srv.Handler = DefaultHandler
}

defer wg.Wait()

var tempDelay time.Duration
for {
conn, e := l.Accept()
Expand All @@ -106,10 +175,45 @@ func (srv *Server) Serve(l net.Listener) error {
time.Sleep(tempDelay)
continue
}

return e
}
go srv.handleConn(conn, config)

// Add one to the wg and start up the connection
wg.Add(1)
go func() {
defer wg.Done()
srv.handleConn(conn, config)
}()

// If there was a message left for us on the stateChan, we're draining
// and can safely return.
_, ok := <-srv.stateChan
if ok {
return ErrDraining
}
}
}

// Drain will signal for the server to drain connections and shut down.
func (srv *Server) Drain() error {
srv.stateLock.Lock()
defer srv.stateLock.Unlock()

if srv.state != stateStarted {
return ErrInvalidState
}

// Update the state to draining, close the listener and send notify Serve
// that we're shutting down. Calling Close will force Accept to return with
// an error which should be acceptable as long as we wait for the
// connections to exit.
srv.state = stateDraining
srv.listener.Close()
srv.listener = nil
srv.stateChan <- struct{}{}

return nil
}

func (srv *Server) handleConn(conn net.Conn, conf *gossh.ServerConfig) {
Expand Down Expand Up @@ -173,5 +277,12 @@ func (srv *Server) AddHostKey(key Signer) {

// SetOption runs a functional option against the server.
func (srv *Server) SetOption(option Option) error {
srv.stateLock.Lock()
defer srv.stateLock.Unlock()

if srv.state != stateStopped {
return ErrInvalidState
}

return option(srv)
}

0 comments on commit 6cd286b

Please sign in to comment.