From 490a7f7e1540b76fefc432be2857893d988f3b71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E6=A1=90?= Date: Tue, 4 Jan 2022 18:20:24 +0800 Subject: [PATCH] THRIFT-5495: close client when shutdown server in go lib Client: go --- lib/go/thrift/simple_server.go | 33 ++++++++++++++-- lib/go/thrift/simple_server_test.go | 50 ++++++++++++++++++++++++- test/go/src/common/clientserver_test.go | 2 +- 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 563cbfc694a..0fc73a1c350 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -20,6 +20,7 @@ package thrift import ( + "context" "errors" "fmt" "io" @@ -47,6 +48,7 @@ var ErrAbandonRequest = errors.New("request abandoned") // // If it's changed to <=0, the feature will be disabled. var ServerConnectivityCheckInterval = time.Millisecond * 5 +var ServerCloseTimeout = time.Duration(0) /* * This is not a typical TSimpleServer as it is not blocked after accept a socket. @@ -54,9 +56,10 @@ var ServerConnectivityCheckInterval = time.Millisecond * 5 * This will work if golang user implements a conn-pool like thing in client side. */ type TSimpleServer struct { - closed int32 - wg sync.WaitGroup - mu sync.Mutex + closed int32 + wg sync.WaitGroup + mu sync.Mutex + stopChan chan struct{} processorFactory TProcessorFactory serverTransport TServerTransport @@ -121,6 +124,7 @@ func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTranspor outputTransportFactory: outputTransportFactory, inputProtocolFactory: inputProtocolFactory, outputProtocolFactory: outputProtocolFactory, + stopChan: make(chan struct{}), } } @@ -192,13 +196,27 @@ func (p *TSimpleServer) innerAccept() (int32, error) { return 0, err } if client != nil { - p.wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + p.wg.Add(2) + go func() { defer p.wg.Done() + defer cancel() if err := p.processRequests(client); err != nil { p.logger(fmt.Sprintf("error processing request: %v", err)) } }() + + go func() { + defer p.wg.Done() + select { + case <-ctx.Done(): + // client exited, do nothing + case <-p.stopChan: + // TSimpleServer.Close called, close the client connection + client.Close() + } + }() } return 0, nil } @@ -234,7 +252,14 @@ func (p *TSimpleServer) Stop() error { } atomic.StoreInt32(&p.closed, 1) p.serverTransport.Interrupt() + + if ServerCloseTimeout > 0 { + <-time.After(ServerCloseTimeout) + close(p.stopChan) + } + p.wg.Wait() + p.stopChan = make(chan struct{}) return nil } diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go index 58149a8e66d..cb57a9f33f6 100644 --- a/lib/go/thrift/simple_server_test.go +++ b/lib/go/thrift/simple_server_test.go @@ -20,9 +20,12 @@ package thrift import ( - "testing" + "context" "errors" + "net" "runtime" + "testing" + "time" ) type mockServerTransport struct { @@ -154,3 +157,48 @@ func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) { runtime.Gosched() serv.Stop() } + +func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) { + ln, _ := net.Listen("tcp", "localhost:0") + + proc := &mockProcessor{ + ProcessFunc: func(in, out TProtocol) (bool, TException) { + in.ReadMessageBegin(context.Background()) + return false, nil + }, + } + + trans := &mockServerTransport{ + ListenFunc: func() error { + return nil + }, + AcceptFunc: func() (TTransport, error) { + conn, err := ln.Accept() + if err != nil { + return nil, err + } + return NewTSocketFromConnConf(conn, nil), nil + }, + CloseFunc: func() error { + return nil + }, + InterruptFunc: func() error { + return ln.Close() + }, + } + + serv := NewTSimpleServer2(proc, trans) + go serv.Serve() + time.Sleep(10 * time.Millisecond) + + netConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil || netConn == nil { + t.Fatal("error when dial server") + } + time.Sleep(10 * time.Millisecond) + ServerCloseTimeout = 10 * time.Millisecond + err = serv.Stop() + if err != nil { + t.Fatal("error when stop server") + } +} diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go index 609086bad81..64b326a816a 100644 --- a/test/go/src/common/clientserver_test.go +++ b/test/go/src/common/clientserver_test.go @@ -75,7 +75,7 @@ func doUnit(t *testing.T, unit *test_unit) { t.Errorf("Unable to start server: %v", err) return } - go server.AcceptLoop() + go server.Serve() defer server.Stop() client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl) if err != nil {