Skip to content

Commit

Permalink
grpc: Wait until resources finish cleaning up in Stop() and GracefulS…
Browse files Browse the repository at this point in the history
…top() (grpc#6489)
  • Loading branch information
fho authored and arvindbr8 committed Nov 7, 2023
1 parent eaa6d08 commit 329dd98
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 80 deletions.
27 changes: 15 additions & 12 deletions internal/transport/http2_server.go
Expand Up @@ -68,15 +68,15 @@ var serverConnectionCounter uint64

// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
done chan struct{}
conn net.Conn
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
peer peer.Peer
inTapHandle tap.ServerInHandle
framer *framer
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
done chan struct{}
conn net.Conn
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
loopyWriterDone chan struct{}
peer peer.Peer
inTapHandle tap.ServerInHandle
framer *framer
// The max number of concurrent streams.
maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window
Expand Down Expand Up @@ -251,7 +251,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
peer: peer,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
loopyWriterDone: make(chan struct{}),
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
Expand Down Expand Up @@ -323,7 +323,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy.run()
close(t.writerDone)
close(t.loopyWriterDone)
}()
go t.keepalive()
return t, nil
Expand Down Expand Up @@ -608,7 +608,10 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
defer close(t.readerDone)
defer func() {
<-t.loopyWriterDone
close(t.readerDone)
}()
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/transport_test.go
Expand Up @@ -1685,7 +1685,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
client.Close(errors.New("closed manually by test"))
st.Close(errors.New("closed manually by test"))
<-st.readerDone
<-st.writerDone
<-st.loopyWriterDone
<-client.readerDone
<-client.writerDone
for _, cstream := range clientStreams {
Expand Down
99 changes: 44 additions & 55 deletions server.go
Expand Up @@ -996,15 +996,11 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport,
}
}()

var wg sync.WaitGroup
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(ctx, func(stream *transport.Stream) {
wg.Add(1)

streamQuota.acquire()
f := func() {
defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream)
}

Expand All @@ -1018,7 +1014,6 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport,
}
go f()
})
wg.Wait()
}

var _ http.Handler = (*Server)(nil)
Expand Down Expand Up @@ -1859,62 +1854,64 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream
// pending RPCs on the client side will get notified by connection
// errors.
func (s *Server) Stop() {
s.quit.Fire()
s.stop(false)
}

defer func() {
s.serveWG.Wait()
s.done.Fire()
}()
// GracefulStop stops the gRPC server gracefully. It stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func (s *Server) GracefulStop() {
s.stop(true)
}

func (s *Server) stop(graceful bool) {
s.quit.Fire()
defer s.done.Fire()

s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })

s.mu.Lock()
listeners := s.lis
s.lis = nil
conns := s.conns
s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Broadcast()
s.closeListenersLocked()
// Wait for serving threads to be ready to exit. Only then can we be sure no
// new conns will be created.
s.mu.Unlock()
s.serveWG.Wait()

for lis := range listeners {
lis.Close()
}
for _, cs := range conns {
for st := range cs {
st.Close(errors.New("Server.Stop called"))
}
s.mu.Lock()
defer s.mu.Unlock()

if graceful {
s.drainAllServerTransportsLocked()
} else {
s.closeServerTransportsLocked()
}

if s.opts.numServerWorkers > 0 {
s.stopServerWorkers()
}

s.mu.Lock()
for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil

if s.events != nil {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
}

// GracefulStop stops the gRPC server gracefully. It stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func (s *Server) GracefulStop() {
s.quit.Fire()
defer s.done.Fire()

s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) })
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
return
// s.mu must be held by the caller.
func (s *Server) closeServerTransportsLocked() {
for _, conns := range s.conns {
for st := range conns {
st.Close(errors.New("Server.Stop called"))
}
}
}

for lis := range s.lis {
lis.Close()
}
s.lis = nil
// s.mu must be held by the caller.
func (s *Server) drainAllServerTransportsLocked() {
if !s.drain {
for _, conns := range s.conns {
for st := range conns {
Expand All @@ -1923,22 +1920,14 @@ func (s *Server) GracefulStop() {
}
s.drain = true
}
}

// Wait for serving threads to be ready to exit. Only then can we be sure no
// new conns will be created.
s.mu.Unlock()
s.serveWG.Wait()
s.mu.Lock()

for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
// s.mu must be held by the caller.
func (s *Server) closeListenersLocked() {
for lis := range s.lis {
lis.Close()
}
s.mu.Unlock()
s.lis = nil
}

// contentSubtype must be lowercase
Expand Down
111 changes: 99 additions & 12 deletions test/gracefulstop_test.go
Expand Up @@ -24,6 +24,7 @@ import (
"net"
"sync"
"testing"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc"
Expand Down Expand Up @@ -88,17 +89,17 @@ func (d *delayListener) Dial(ctx context.Context) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "tcp", d.Listener.Addr().String())
}

// TestGracefulStop ensures GracefulStop causes new connections to fail.
//
// Steps of this test:
// 1. Start Server
// 2. GracefulStop() Server after listener's Accept is called, but don't
// allow Accept() to exit when Close() is called on it.
// 3. Create a new connection to the server after listener.Close() is called.
// Server should close this connection immediately, before handshaking.
// 4. Send an RPC on the new connection. Should see Unavailable error
// because the ClientConn is in transient failure.
func (s) TestGracefulStop(t *testing.T) {
// This test ensures GracefulStop causes new connections to fail.
//
// Steps of this test:
// 1. Start Server
// 2. GracefulStop() Server after listener's Accept is called, but don't
// allow Accept() to exit when Close() is called on it.
// 3. Create a new connection to the server after listener.Close() is called.
// Server should close this connection immediately, before handshaking.
// 4. Send an RPC on the new connection. Should see Unavailable error
// because the ClientConn is in transient failure.
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error listenening: %v", err)
Expand Down Expand Up @@ -167,9 +168,10 @@ func (s) TestGracefulStop(t *testing.T) {
wg.Wait()
}

// TestGracefulStopClosesConnAfterLastStream ensures that a server closes the
// connections to its clients when the final stream has completed after
// a GOAWAY.
func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
// This test ensures that a server closes the connections to its clients
// when the final stream has completed after a GOAWAY.

handlerCalled := make(chan struct{})
gracefulStopCalled := make(chan struct{})
Expand Down Expand Up @@ -216,3 +218,88 @@ func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
<-gracefulStopDone // Wait for GracefulStop to return.
})
}

// TestGracefulStopBlocksUntilGRPCConnectionsTerminate ensures that
// GracefulStop() blocks until all ongoing RPCs finished.
func (s) TestGracefulStopBlocksUntilGRPCConnectionsTerminate(t *testing.T) {
unblockGRPCCall := make(chan struct{})
grpcCallExecuting := make(chan struct{})
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
close(grpcCallExecuting)
<-unblockGRPCCall
return &testpb.SimpleResponse{}, nil
},
}

err := ss.Start(nil)
if err != nil {
t.Fatalf("StubServer.start failed: %s", err)
}
t.Cleanup(ss.Stop)

grpcClientCallReturned := make(chan struct{})
go func() {
clt := ss.Client
_, err := clt.UnaryCall(context.Background(), &testpb.SimpleRequest{})
if err != nil {
t.Errorf("rpc failed with error: %s", err)
}
close(grpcClientCallReturned)
}()

gracefulStopReturned := make(chan struct{})
<-grpcCallExecuting
go func() {
ss.S.GracefulStop()
close(gracefulStopReturned)
}()

select {
case <-gracefulStopReturned:
t.Error("GracefulStop returned before rpc method call ended")
case <-time.After(defaultTestShortTimeout):
}

unblockGRPCCall <- struct{}{}
<-grpcClientCallReturned
<-gracefulStopReturned
}

// TestStopAbortsBlockingGRPCCall ensures that when Stop() is called while an ongoing RPC
// is blocking that:
// - Stop() returns
// - and the RPC fails with an connection closed error on the client-side
func (s) TestStopAbortsBlockingGRPCCall(t *testing.T) {
unblockGRPCCall := make(chan struct{})
grpcCallExecuting := make(chan struct{})
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
close(grpcCallExecuting)
<-unblockGRPCCall
return &testpb.SimpleResponse{}, nil
},
}

err := ss.Start(nil)
if err != nil {
t.Fatalf("StubServer.start failed: %s", err)
}
t.Cleanup(ss.Stop)

grpcClientCallReturned := make(chan struct{})
go func() {
clt := ss.Client
_, err := clt.UnaryCall(context.Background(), &testpb.SimpleRequest{})
if err == nil || !isConnClosedErr(err) {
t.Errorf("expected rpc to fail with connection closed error, got: %v", err)
}
close(grpcClientCallReturned)
}()

<-grpcCallExecuting
ss.S.Stop()

unblockGRPCCall <- struct{}{}
<-grpcClientCallReturned
}

0 comments on commit 329dd98

Please sign in to comment.