Skip to content

Commit

Permalink
THRIFT-5495: close client when shutdown server in go lib
Browse files Browse the repository at this point in the history
Client: go
  • Loading branch information
buptubuntu committed Jan 23, 2022
1 parent 39d7278 commit 490a7f7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 6 deletions.
33 changes: 29 additions & 4 deletions lib/go/thrift/simple_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package thrift

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -47,16 +48,18 @@ var ErrAbandonRequest = errors.New("request abandoned")
//
// If it's changed to <=0, the feature will be disabled.
var ServerConnectivityCheckInterval = time.Millisecond * 5
var ServerCloseTimeout = time.Duration(0)

/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
closed int32
wg sync.WaitGroup
mu sync.Mutex
closed int32
wg sync.WaitGroup
mu sync.Mutex
stopChan chan struct{}

processorFactory TProcessorFactory
serverTransport TServerTransport
Expand Down Expand Up @@ -121,6 +124,7 @@ func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTranspor
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
stopChan: make(chan struct{}),
}
}

Expand Down Expand Up @@ -192,13 +196,27 @@ func (p *TSimpleServer) innerAccept() (int32, error) {
return 0, err
}
if client != nil {
p.wg.Add(1)
ctx, cancel := context.WithCancel(context.Background())
p.wg.Add(2)

go func() {
defer p.wg.Done()
defer cancel()
if err := p.processRequests(client); err != nil {
p.logger(fmt.Sprintf("error processing request: %v", err))
}
}()

go func() {
defer p.wg.Done()
select {
case <-ctx.Done():
// client exited, do nothing
case <-p.stopChan:
// TSimpleServer.Close called, close the client connection
client.Close()
}
}()
}
return 0, nil
}
Expand Down Expand Up @@ -234,7 +252,14 @@ func (p *TSimpleServer) Stop() error {
}
atomic.StoreInt32(&p.closed, 1)
p.serverTransport.Interrupt()

if ServerCloseTimeout > 0 {
<-time.After(ServerCloseTimeout)
close(p.stopChan)
}

p.wg.Wait()
p.stopChan = make(chan struct{})
return nil
}

Expand Down
50 changes: 49 additions & 1 deletion lib/go/thrift/simple_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
package thrift

import (
"testing"
"context"
"errors"
"net"
"runtime"
"testing"
"time"
)

type mockServerTransport struct {
Expand Down Expand Up @@ -154,3 +157,48 @@ func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) {
runtime.Gosched()
serv.Stop()
}

func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
ln, _ := net.Listen("tcp", "localhost:0")

proc := &mockProcessor{
ProcessFunc: func(in, out TProtocol) (bool, TException) {
in.ReadMessageBegin(context.Background())
return false, nil
},
}

trans := &mockServerTransport{
ListenFunc: func() error {
return nil
},
AcceptFunc: func() (TTransport, error) {
conn, err := ln.Accept()
if err != nil {
return nil, err
}
return NewTSocketFromConnConf(conn, nil), nil
},
CloseFunc: func() error {
return nil
},
InterruptFunc: func() error {
return ln.Close()
},
}

serv := NewTSimpleServer2(proc, trans)
go serv.Serve()
time.Sleep(10 * time.Millisecond)

netConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil || netConn == nil {
t.Fatal("error when dial server")
}
time.Sleep(10 * time.Millisecond)
ServerCloseTimeout = 10 * time.Millisecond
err = serv.Stop()
if err != nil {
t.Fatal("error when stop server")
}
}
2 changes: 1 addition & 1 deletion test/go/src/common/clientserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func doUnit(t *testing.T, unit *test_unit) {
t.Errorf("Unable to start server: %v", err)
return
}
go server.AcceptLoop()
go server.Serve()
defer server.Stop()
client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
if err != nil {
Expand Down

0 comments on commit 490a7f7

Please sign in to comment.