From 329dd986e62e2987e58ce5547f52f69f01a2c636 Mon Sep 17 00:00:00 2001 From: Fabian Holler Date: Tue, 31 Oct 2023 18:12:43 +0100 Subject: [PATCH] grpc: Wait until resources finish cleaning up in Stop() and GracefulStop() (#6489) --- internal/transport/http2_server.go | 27 ++++--- internal/transport/transport_test.go | 2 +- server.go | 99 +++++++++++------------- test/gracefulstop_test.go | 111 ++++++++++++++++++++++++--- 4 files changed, 159 insertions(+), 80 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 8bc506875a2f..680c9eba0b17 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -68,15 +68,15 @@ var serverConnectionCounter uint64 // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { - lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. - done chan struct{} - conn net.Conn - loopy *loopyWriter - readerDone chan struct{} // sync point to enable testing. - writerDone chan struct{} // sync point to enable testing. - peer peer.Peer - inTapHandle tap.ServerInHandle - framer *framer + lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. + done chan struct{} + conn net.Conn + loopy *loopyWriter + readerDone chan struct{} // sync point to enable testing. + loopyWriterDone chan struct{} + peer peer.Peer + inTapHandle tap.ServerInHandle + framer *framer // The max number of concurrent streams. maxStreams uint32 // controlBuf delivers all the control related tasks (e.g., window @@ -251,7 +251,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, peer: peer, framer: framer, readerDone: make(chan struct{}), - writerDone: make(chan struct{}), + loopyWriterDone: make(chan struct{}), maxStreams: config.MaxStreams, inTapHandle: config.InTapHandle, fc: &trInFlow{limit: uint32(icwz)}, @@ -323,7 +323,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler t.loopy.run() - close(t.writerDone) + close(t.loopyWriterDone) }() go t.keepalive() return t, nil @@ -608,7 +608,10 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade // typically run in a separate goroutine. // traceCtx attaches trace to ctx and returns the new context. func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) { - defer close(t.readerDone) + defer func() { + <-t.loopyWriterDone + close(t.readerDone) + }() for { t.controlBuf.throttle() frame, err := t.framer.fr.ReadFrame() diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 20292bd9f0f8..21aff27db1df 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1685,7 +1685,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) client.Close(errors.New("closed manually by test")) st.Close(errors.New("closed manually by test")) <-st.readerDone - <-st.writerDone + <-st.loopyWriterDone <-client.readerDone <-client.writerDone for _, cstream := range clientStreams { diff --git a/server.go b/server.go index 547a0cd7f343..6ac97a6b0d49 100644 --- a/server.go +++ b/server.go @@ -996,15 +996,11 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, } }() - var wg sync.WaitGroup streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) st.HandleStreams(ctx, func(stream *transport.Stream) { - wg.Add(1) - streamQuota.acquire() f := func() { defer streamQuota.release() - defer wg.Done() s.handleStream(st, stream) } @@ -1018,7 +1014,6 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, } go f() }) - wg.Wait() } var _ http.Handler = (*Server)(nil) @@ -1859,62 +1854,64 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream // pending RPCs on the client side will get notified by connection // errors. func (s *Server) Stop() { - s.quit.Fire() + s.stop(false) +} - defer func() { - s.serveWG.Wait() - s.done.Fire() - }() +// GracefulStop stops the gRPC server gracefully. It stops the server from +// accepting new connections and RPCs and blocks until all the pending RPCs are +// finished. +func (s *Server) GracefulStop() { + s.stop(true) +} + +func (s *Server) stop(graceful bool) { + s.quit.Fire() + defer s.done.Fire() s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) s.mu.Lock() - listeners := s.lis - s.lis = nil - conns := s.conns - s.conns = nil - // interrupt GracefulStop if Stop and GracefulStop are called concurrently. - s.cv.Broadcast() + s.closeListenersLocked() + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. s.mu.Unlock() + s.serveWG.Wait() - for lis := range listeners { - lis.Close() - } - for _, cs := range conns { - for st := range cs { - st.Close(errors.New("Server.Stop called")) - } + s.mu.Lock() + defer s.mu.Unlock() + + if graceful { + s.drainAllServerTransportsLocked() + } else { + s.closeServerTransportsLocked() } + if s.opts.numServerWorkers > 0 { s.stopServerWorkers() } - s.mu.Lock() + for len(s.conns) != 0 { + s.cv.Wait() + } + s.conns = nil + if s.events != nil { s.events.Finish() s.events = nil } - s.mu.Unlock() } -// GracefulStop stops the gRPC server gracefully. It stops the server from -// accepting new connections and RPCs and blocks until all the pending RPCs are -// finished. -func (s *Server) GracefulStop() { - s.quit.Fire() - defer s.done.Fire() - - s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) - s.mu.Lock() - if s.conns == nil { - s.mu.Unlock() - return +// s.mu must be held by the caller. +func (s *Server) closeServerTransportsLocked() { + for _, conns := range s.conns { + for st := range conns { + st.Close(errors.New("Server.Stop called")) + } } +} - for lis := range s.lis { - lis.Close() - } - s.lis = nil +// s.mu must be held by the caller. +func (s *Server) drainAllServerTransportsLocked() { if !s.drain { for _, conns := range s.conns { for st := range conns { @@ -1923,22 +1920,14 @@ func (s *Server) GracefulStop() { } s.drain = true } +} - // Wait for serving threads to be ready to exit. Only then can we be sure no - // new conns will be created. - s.mu.Unlock() - s.serveWG.Wait() - s.mu.Lock() - - for len(s.conns) != 0 { - s.cv.Wait() - } - s.conns = nil - if s.events != nil { - s.events.Finish() - s.events = nil +// s.mu must be held by the caller. +func (s *Server) closeListenersLocked() { + for lis := range s.lis { + lis.Close() } - s.mu.Unlock() + s.lis = nil } // contentSubtype must be lowercase diff --git a/test/gracefulstop_test.go b/test/gracefulstop_test.go index f0697e7e328b..ecf07d984359 100644 --- a/test/gracefulstop_test.go +++ b/test/gracefulstop_test.go @@ -24,6 +24,7 @@ import ( "net" "sync" "testing" + "time" "golang.org/x/net/http2" "google.golang.org/grpc" @@ -88,17 +89,17 @@ func (d *delayListener) Dial(ctx context.Context) (net.Conn, error) { return (&net.Dialer{}).DialContext(ctx, "tcp", d.Listener.Addr().String()) } +// TestGracefulStop ensures GracefulStop causes new connections to fail. +// +// Steps of this test: +// 1. Start Server +// 2. GracefulStop() Server after listener's Accept is called, but don't +// allow Accept() to exit when Close() is called on it. +// 3. Create a new connection to the server after listener.Close() is called. +// Server should close this connection immediately, before handshaking. +// 4. Send an RPC on the new connection. Should see Unavailable error +// because the ClientConn is in transient failure. func (s) TestGracefulStop(t *testing.T) { - // This test ensures GracefulStop causes new connections to fail. - // - // Steps of this test: - // 1. Start Server - // 2. GracefulStop() Server after listener's Accept is called, but don't - // allow Accept() to exit when Close() is called on it. - // 3. Create a new connection to the server after listener.Close() is called. - // Server should close this connection immediately, before handshaking. - // 4. Send an RPC on the new connection. Should see Unavailable error - // because the ClientConn is in transient failure. lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Error listenening: %v", err) @@ -167,9 +168,10 @@ func (s) TestGracefulStop(t *testing.T) { wg.Wait() } +// TestGracefulStopClosesConnAfterLastStream ensures that a server closes the +// connections to its clients when the final stream has completed after +// a GOAWAY. func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) { - // This test ensures that a server closes the connections to its clients - // when the final stream has completed after a GOAWAY. handlerCalled := make(chan struct{}) gracefulStopCalled := make(chan struct{}) @@ -216,3 +218,88 @@ func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) { <-gracefulStopDone // Wait for GracefulStop to return. }) } + +// TestGracefulStopBlocksUntilGRPCConnectionsTerminate ensures that +// GracefulStop() blocks until all ongoing RPCs finished. +func (s) TestGracefulStopBlocksUntilGRPCConnectionsTerminate(t *testing.T) { + unblockGRPCCall := make(chan struct{}) + grpcCallExecuting := make(chan struct{}) + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + close(grpcCallExecuting) + <-unblockGRPCCall + return &testpb.SimpleResponse{}, nil + }, + } + + err := ss.Start(nil) + if err != nil { + t.Fatalf("StubServer.start failed: %s", err) + } + t.Cleanup(ss.Stop) + + grpcClientCallReturned := make(chan struct{}) + go func() { + clt := ss.Client + _, err := clt.UnaryCall(context.Background(), &testpb.SimpleRequest{}) + if err != nil { + t.Errorf("rpc failed with error: %s", err) + } + close(grpcClientCallReturned) + }() + + gracefulStopReturned := make(chan struct{}) + <-grpcCallExecuting + go func() { + ss.S.GracefulStop() + close(gracefulStopReturned) + }() + + select { + case <-gracefulStopReturned: + t.Error("GracefulStop returned before rpc method call ended") + case <-time.After(defaultTestShortTimeout): + } + + unblockGRPCCall <- struct{}{} + <-grpcClientCallReturned + <-gracefulStopReturned +} + +// TestStopAbortsBlockingGRPCCall ensures that when Stop() is called while an ongoing RPC +// is blocking that: +// - Stop() returns +// - and the RPC fails with an connection closed error on the client-side +func (s) TestStopAbortsBlockingGRPCCall(t *testing.T) { + unblockGRPCCall := make(chan struct{}) + grpcCallExecuting := make(chan struct{}) + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + close(grpcCallExecuting) + <-unblockGRPCCall + return &testpb.SimpleResponse{}, nil + }, + } + + err := ss.Start(nil) + if err != nil { + t.Fatalf("StubServer.start failed: %s", err) + } + t.Cleanup(ss.Stop) + + grpcClientCallReturned := make(chan struct{}) + go func() { + clt := ss.Client + _, err := clt.UnaryCall(context.Background(), &testpb.SimpleRequest{}) + if err == nil || !isConnClosedErr(err) { + t.Errorf("expected rpc to fail with connection closed error, got: %v", err) + } + close(grpcClientCallReturned) + }() + + <-grpcCallExecuting + ss.S.Stop() + + unblockGRPCCall <- struct{}{} + <-grpcClientCallReturned +}