diff --git a/lib/go/README.md b/lib/go/README.md index 75d7174912d..8c956048832 100644 --- a/lib/go/README.md +++ b/lib/go/README.md @@ -132,3 +132,27 @@ if this interval is set to a value too low (for example, 1ms), it might cause excessive cpu overhead. This feature is also only enabled on non-oneway endpoints. + +A note about server stop implementations +==== ==================================== + +[TSimpleServer.Stop](https://pkg.go.dev/github.com/apache/thrift/lib/go/thrift#TSimpleServer.Stop) will wait for all client connections to be closed after +the last received request to be handled, as the time spent by Stop + may sometimes be too long: +* When socket timeout is not set, server might be hanged before all active + clients to finish handling the last received or to be received request. +* When the socket timeout is too long (e.g one hour), server will + hang for that duration before all active clients to finish handling the + last received request. + +To prevent Stop from hanging for too long, you can set +thrift.ServerStopTimeout in your main or init function: + + thrift.ServerStopTimeout = + +If it's set to <=0, the feature will be disabled (by default), and server +will wait for all the client connections to be closed gracefully with +zero err time. Otherwise, the stop will wait for all the client +connections to be closed gracefully util thrift.ServerStopTimeout is +reached, and client connections that are not closed after thrift.ServerStopTimeout +will be closed abruptly which may cause some client errors. \ No newline at end of file diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 563cbfc694a..20b91198acc 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -20,6 +20,7 @@ package thrift import ( + "context" "errors" "fmt" "io" @@ -48,15 +49,26 @@ var ErrAbandonRequest = errors.New("request abandoned") // If it's changed to <=0, the feature will be disabled. var ServerConnectivityCheckInterval = time.Millisecond * 5 +// ServerStopTimeout defines max stop wait duration used by +// server stop to avoid hanging too long to wait for all client connections to be closed gracefully. +// +// It's defined as a variable instead of constant, so that thrift server +// implementations can change its value to control the behavior. +// +// If it's set to <=0, the feature will be disabled(by default), and the server will wait for +// for all the client connections to be closed gracefully. +var ServerStopTimeout = time.Duration(0) + /* * This is not a typical TSimpleServer as it is not blocked after accept a socket. * It is more like a TThreadedServer that can handle different connections in different goroutines. * This will work if golang user implements a conn-pool like thing in client side. */ type TSimpleServer struct { - closed int32 - wg sync.WaitGroup - mu sync.Mutex + closed int32 + wg sync.WaitGroup + mu sync.Mutex + stopChan chan struct{} processorFactory TProcessorFactory serverTransport TServerTransport @@ -121,6 +133,7 @@ func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTranspor outputTransportFactory: outputTransportFactory, inputProtocolFactory: inputProtocolFactory, outputProtocolFactory: outputProtocolFactory, + stopChan: make(chan struct{}), } } @@ -192,13 +205,27 @@ func (p *TSimpleServer) innerAccept() (int32, error) { return 0, err } if client != nil { - p.wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + p.wg.Add(2) + go func() { defer p.wg.Done() + defer cancel() if err := p.processRequests(client); err != nil { p.logger(fmt.Sprintf("error processing request: %v", err)) } }() + + go func() { + defer p.wg.Done() + select { + case <-ctx.Done(): + // client exited, do nothing + case <-p.stopChan: + // TSimpleServer.Close called, close the client connection + client.Close() + } + }() } return 0, nil } @@ -229,12 +256,31 @@ func (p *TSimpleServer) Serve() error { func (p *TSimpleServer) Stop() error { p.mu.Lock() defer p.mu.Unlock() + if atomic.LoadInt32(&p.closed) != 0 { return nil } atomic.StoreInt32(&p.closed, 1) p.serverTransport.Interrupt() - p.wg.Wait() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer cancel() + p.wg.Wait() + }() + + if ServerStopTimeout > 0 { + timer := time.NewTimer(ServerStopTimeout) + select { + case <-timer.C: + case <-ctx.Done(): + } + close(p.stopChan) + timer.Stop() + } + + <-ctx.Done() + p.stopChan = make(chan struct{}) return nil } diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go index 58149a8e66d..b92d50f01d4 100644 --- a/lib/go/thrift/simple_server_test.go +++ b/lib/go/thrift/simple_server_test.go @@ -20,11 +20,17 @@ package thrift import ( - "testing" + "context" "errors" + "net" "runtime" + "sync" + "testing" + "time" ) +const networkWaitDuration = 10 * time.Millisecond + type mockServerTransport struct { ListenFunc func() error AcceptFunc func() (TTransport, error) @@ -154,3 +160,130 @@ func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) { runtime.Gosched() serv.Stop() } + +func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + in.ReadMessageBegin(context.Background()) + return false, nil + }, + } + + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + + return NewTSocketFromConnConf(conn, nil), nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return ln.Close() + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + time.Sleep(networkWaitDuration) + + netConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil || netConn == nil { + t.Fatal("error when dial server") + } + time.Sleep(networkWaitDuration) + + serverStopTimeout := 50 * time.Millisecond + backupServerStopTimeout := ServerStopTimeout + t.Cleanup(func() { + ServerStopTimeout = backupServerStopTimeout + }) + ServerStopTimeout = serverStopTimeout + + st := time.Now() + err = serv.Stop() + if err != nil { + t.Errorf("error when stop server:%v", err) + } + + if elapsed := time.Since(st); elapsed < serverStopTimeout { + t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed) + } +} + +func TestStopTimeoutWithSocketTimeout(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + in.ReadMessageBegin(context.Background()) + return false, nil + }, + } + + conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond} + wg := &sync.WaitGroup{} + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + defer wg.Done() + return NewTSocketFromConnConf(conn, conf), nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return ln.Close() + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + time.Sleep(networkWaitDuration) + + wg.Add(1) + netConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil || netConn == nil { + t.Fatal("error when dial server") + } + wg.Wait() + + expectedStopTimeout := time.Second + backupServerStopTimeout := ServerStopTimeout + t.Cleanup(func() { + ServerStopTimeout = backupServerStopTimeout + }) + ServerStopTimeout = expectedStopTimeout + + st := time.Now() + err = serv.Stop() + if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 { + t.Errorf("stop cost more time than socket timeout, socket timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, ServerStopTimeout, elapsed) + } + + if err != nil { + t.Fatalf("error when stop server:%v", err) + } +} diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go index 609086bad81..64b326a816a 100644 --- a/test/go/src/common/clientserver_test.go +++ b/test/go/src/common/clientserver_test.go @@ -75,7 +75,7 @@ func doUnit(t *testing.T, unit *test_unit) { t.Errorf("Unable to start server: %v", err) return } - go server.AcceptLoop() + go server.Serve() defer server.Stop() client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl) if err != nil {