diff --git a/server.go b/server.go index 0697b8d..50c365d 100644 --- a/server.go +++ b/server.go @@ -1,13 +1,35 @@ package ssh import ( + "errors" "fmt" "net" + "sync" "time" gossh "golang.org/x/crypto/ssh" ) +var ( + // ErrInvalidState will be returned from some functions when asked to do + // something but the server is either already running and shouldn't be or + // vice versa. + ErrInvalidState = errors.New("Invalid server state") + + // ErrDraining is returned from Serve in some cases when cleanly shutting + // down. Note that this will not always be returned if the server was asked + // to shut down. + ErrDraining = errors.New("Server was asked to shut down") +) + +type serverState int + +const ( + stateStopped serverState = iota + stateStarted + stateDraining +) + // Server defines parameters for running an SSH server. The zero value for // Server is a valid configuration. When both PasswordHandler and // PublicKeyHandler are nil, no client authentication is performed. @@ -21,6 +43,13 @@ type Server struct { PublicKeyHandler PublicKeyHandler // public key authentication handler PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil PermissionsCallback PermissionsCallback // optional callback for setting up permissions + + // Internal fields. Note that the zero value for these should be a state we + // can detect so the Server can still be instantiated using &Server{}. + stateLock sync.Mutex + stateChan chan struct{} + state serverState + listener net.Listener } func (srv *Server) makeConfig() (*gossh.ServerConfig, error) { @@ -72,17 +101,54 @@ func (srv *Server) makeConfig() (*gossh.ServerConfig, error) { } // Handle sets the Handler for the server. -func (srv *Server) Handle(fn Handler) { +func (srv *Server) Handle(fn Handler) error { + srv.stateLock.Lock() + defer srv.stateLock.Unlock() + + if srv.state != stateStopped { + return ErrInvalidState + } + srv.Handler = fn + + return nil } // Serve accepts incoming connections on the Listener l, creating a new // connection goroutine for each. The connection goroutines read requests and then -// calls srv.Handler to handle sessions. +// calls srv.Handler to handle sessions. Note that this connection will wait // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { - defer l.Close() + // Ensure we're just starting the server and set up any values which need to + // be set up. + srv.stateLock.Lock() + if srv.state != stateStopped { + l.Close() + srv.stateLock.Unlock() + return ErrInvalidState + } + srv.state = stateStarted + srv.stateChan = make(chan struct{}, 1) + srv.listener = l + srv.stateLock.Unlock() + + wg := &sync.WaitGroup{} + + defer func() { + srv.stateLock.Lock() + defer srv.stateLock.Unlock() + + srv.state = stateStopped + srv.stateChan = nil + + // If there's still a listener around, we need to close it + if srv.listener != nil { + srv.listener.Close() + } + srv.listener = nil + }() + config, err := srv.makeConfig() if err != nil { return err @@ -90,6 +156,9 @@ func (srv *Server) Serve(l net.Listener) error { if srv.Handler == nil { srv.Handler = DefaultHandler } + + defer wg.Wait() + var tempDelay time.Duration for { conn, e := l.Accept() @@ -106,10 +175,45 @@ func (srv *Server) Serve(l net.Listener) error { time.Sleep(tempDelay) continue } + return e } - go srv.handleConn(conn, config) + + // Add one to the wg and start up the connection + wg.Add(1) + go func() { + defer wg.Done() + srv.handleConn(conn, config) + }() + + // If there was a message left for us on the stateChan, we're draining + // and can safely return. + _, ok := <-srv.stateChan + if ok { + return ErrDraining + } + } +} + +// Drain will signal for the server to drain connections and shut down. +func (srv *Server) Drain() error { + srv.stateLock.Lock() + defer srv.stateLock.Unlock() + + if srv.state != stateStarted { + return ErrInvalidState } + + // Update the state to draining, close the listener and send notify Serve + // that we're shutting down. Calling Close will force Accept to return with + // an error which should be acceptable as long as we wait for the + // connections to exit. + srv.state = stateDraining + srv.listener.Close() + srv.listener = nil + srv.stateChan <- struct{}{} + + return nil } func (srv *Server) handleConn(conn net.Conn, conf *gossh.ServerConfig) { @@ -151,6 +255,7 @@ func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel) *session // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. // ListenAndServe always returns a non-nil error. func (srv *Server) ListenAndServe() error { + // NOTE: There's a complication here addr := srv.Addr if addr == "" { addr = ":22" @@ -173,5 +278,12 @@ func (srv *Server) AddHostKey(key Signer) { // SetOption runs a functional option against the server. func (srv *Server) SetOption(option Option) error { + srv.stateLock.Lock() + defer srv.stateLock.Unlock() + + if srv.state != stateStopped { + return ErrInvalidState + } + return option(srv) }