Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSIUE #75] support shutdown mehtod #97

Merged
merged 9 commits into from
Jul 9, 2019
29 changes: 22 additions & 7 deletions internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const (
// Pulling topic information interval from the named server
_PullNameServerInterval = 30 * time.Second

// Pulling topic information interval from the named server
// Sending heart beat interval to all broker
_HeartbeatBrokerInterval = 30 * time.Second

// Offset persistent interval for consumer
Expand All @@ -54,7 +54,7 @@ const (
)

var (
ErrServiceState = errors.New("service state is not running, please check")
ErrServiceState = errors.New("service close is not running, please check")

_VIPChannelEnable = false
)
Expand Down Expand Up @@ -129,6 +129,7 @@ type RMQClient struct {

remoteClient *remote.RemotingClient
hbMutex sync.Mutex
close bool
}

var clientMap sync.Map
Expand All @@ -150,6 +151,9 @@ func GetOrNewRocketMQClient(option ClientOptions) *RMQClient {
}

func (c *RMQClient) Start() {
//ctx, cancel := context.WithCancel(context.Background())
//c.cancel = cancel
c.close = false
c.once.Do(func() {
// TODO fetchNameServerAddr
go func() {}()
Expand All @@ -158,15 +162,15 @@ func (c *RMQClient) Start() {
go func() {
// delay
time.Sleep(50 * time.Millisecond)
for {
for !c.close{
c.UpdateTopicRouteInfo()
time.Sleep(_PullNameServerInterval)
}
}()

// TODO cleanOfflineBroker & sendHeartbeatToAllBrokerWithLock
go func() {
for {
for !c.close{
cleanOfflineBroker()
c.SendHeartbeatToAllBrokerWithLock()
time.Sleep(_HeartbeatBrokerInterval)
Expand All @@ -176,7 +180,7 @@ func (c *RMQClient) Start() {
// schedule persist offset
go func() {
//time.Sleep(10 * time.Second)
for {
for !c.close{
c.consumerMap.Range(func(key, value interface{}) bool {
consumer := value.(InnerConsumer)
consumer.PersistConsumerOffset()
Expand All @@ -187,7 +191,7 @@ func (c *RMQClient) Start() {
}()

go func() {
for {
for !c.close{
c.RebalanceImmediately()
time.Sleep(_RebalanceInterval)
}
Expand All @@ -196,7 +200,8 @@ func (c *RMQClient) Start() {
}

func (c *RMQClient) Shutdown() {
// TODO
c.remoteClient.ShutDown()
c.close = true
}

func (c *RMQClient) ClientID() string {
Expand All @@ -209,18 +214,28 @@ func (c *RMQClient) ClientID() string {

func (c *RMQClient) InvokeSync(addr string, request *remote.RemotingCommand,
timeoutMillis time.Duration) (*remote.RemotingCommand, error) {
if c.close {
return nil, ErrServiceState
}
return c.remoteClient.InvokeSync(addr, request, timeoutMillis)
}

func (c *RMQClient) InvokeAsync(addr string, request *remote.RemotingCommand,
timeoutMillis time.Duration, f func(*remote.RemotingCommand, error)) error {
if c.close {
return ErrServiceState
}
return c.remoteClient.InvokeAsync(addr, request, timeoutMillis, func(future *remote.ResponseFuture) {
f(future.ResponseCommand, future.Err)
})

}

func (c *RMQClient) InvokeOneWay(addr string, request *remote.RemotingCommand,
timeoutMillis time.Duration) error {
if c.close {
return ErrServiceState
}
return c.remoteClient.InvokeOneWay(addr, request, timeoutMillis)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/remote/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewRemotingCommand(code int16, header CustomHeader, body []byte) *RemotingC
}

func (command *RemotingCommand) String() string {
return fmt.Sprintf("Code: %d, Opaque: %d, Remark: %s, ExtFields: %v",
return fmt.Sprintf("Code: %d, opaque: %d, Remark: %s, ExtFields: %v",
command.Code, command.Opaque, command.Remark, command.ExtFields)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/remote/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func Test_decode(t *testing.T) {
t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, decodedRc.Version)
}
if rc.Opaque != decodedRc.Opaque {
t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
}
if rc.Remark != decodedRc.Remark {
t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, decodedRc.Remark)
Expand Down Expand Up @@ -167,7 +167,7 @@ func Test_jsonCodec_decodeHeader(t *testing.T) {
t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, decodedRc.Version)
}
if rc.Opaque != decodedRc.Opaque {
t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
}
if rc.Remark != decodedRc.Remark {
t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, decodedRc.Remark)
Expand Down Expand Up @@ -237,7 +237,7 @@ func Test_rmqCodec_decodeHeader(t *testing.T) {
t.Fatalf("wrong Version. want=%d, got=%d", rc.Version, decodedRc.Version)
}
if rc.Opaque != decodedRc.Opaque {
t.Fatalf("wrong Opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
t.Fatalf("wrong opaque. want=%d, got=%d", rc.Opaque, decodedRc.Opaque)
}
if rc.Remark != decodedRc.Remark {
t.Fatalf("wrong remark. want=%s, got=%s", rc.Remark, decodedRc.Remark)
Expand Down
16 changes: 8 additions & 8 deletions internal/remote/future.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ type ResponseFuture struct {
SendRequestOK bool
Err error
Opaque int32
TimeoutMillis time.Duration
Timeout time.Duration
callback func(*ResponseFuture)
BeginTimestamp int64
BeginTimestamp time.Duration
Done chan bool
callbackOnce sync.Once
}

// NewResponseFuture create ResponseFuture with opaque, timeout and callback
func NewResponseFuture(opaque int32, timeoutMillis time.Duration, callback func(*ResponseFuture)) *ResponseFuture {
func NewResponseFuture(opaque int32, timeout time.Duration, callback func(*ResponseFuture)) *ResponseFuture {
return &ResponseFuture{
Opaque: opaque,
Done: make(chan bool),
TimeoutMillis: timeoutMillis,
Timeout: timeout,
callback: callback,
BeginTimestamp: time.Now().Unix() * 1000,
BeginTimestamp: time.Duration(time.Now().Unix()) * time.Second,
}
}

Expand All @@ -55,16 +55,16 @@ func (r *ResponseFuture) executeInvokeCallback() {
}

func (r *ResponseFuture) isTimeout() bool {
diff := time.Now().Unix()*1000 - r.BeginTimestamp
return diff > int64(r.TimeoutMillis)
elapse := time.Duration(time.Now().Unix())*time.Second - r.BeginTimestamp
return elapse > r.Timeout
}

func (r *ResponseFuture) waitResponse() (*RemotingCommand, error) {
var (
cmd *RemotingCommand
err error
)
timer := time.NewTimer(r.TimeoutMillis * time.Millisecond)
timer := time.NewTimer(r.Timeout)
for {
select {
case <-r.Done:
Expand Down
57 changes: 26 additions & 31 deletions internal/remote/remote_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
var (
//ErrRequestTimeout for request timeout error
ErrRequestTimeout = errors.New("request timeout")
connectionLocker sync.Mutex
)

type ClientRequestFunc func(*RemotingCommand) *RemotingCommand
Expand All @@ -42,10 +41,11 @@ type TcpOption struct {
}

type RemotingClient struct {
responseTable sync.Map
connectionTable sync.Map
option TcpOption
processors map[int16]ClientRequestFunc
responseTable sync.Map
connectionTable sync.Map
option TcpOption
processors map[int16]ClientRequestFunc
connectionLocker sync.Mutex
}

func NewRemotingClient() *RemotingClient {
Expand All @@ -59,29 +59,29 @@ func (c *RemotingClient) RegisterRequestFunc(code int16, f ClientRequestFunc) {
}

// TODO: merge sync and async model. sync should run on async model by blocking on chan
func (c *RemotingClient) InvokeSync(addr string, request *RemotingCommand, timeoutMillis time.Duration) (*RemotingCommand, error) {
func (c *RemotingClient) InvokeSync(addr string, request *RemotingCommand, timeout time.Duration) (*RemotingCommand, error) {
conn, err := c.connect(addr)
if err != nil {
return nil, err
}
resp := NewResponseFuture(request.Opaque, timeoutMillis, nil)
resp := NewResponseFuture(request.Opaque, timeout, nil)
c.responseTable.Store(resp.Opaque, resp)
err = c.sendRequest(conn, request)
defer c.responseTable.Delete(request.Opaque)
err = c.sendRequest(conn, request)
if err != nil {
return nil, err
}
resp.SendRequestOK = true
return resp.waitResponse()
}

// InvokeAsync send request witout blocking, just return immediately.
func (c *RemotingClient) InvokeAsync(addr string, request *RemotingCommand, timeoutMillis time.Duration, callback func(*ResponseFuture)) error {
// InvokeAsync send request without blocking, just return immediately.
func (c *RemotingClient) InvokeAsync(addr string, request *RemotingCommand, timeout time.Duration, callback func(*ResponseFuture)) error {
conn, err := c.connect(addr)
if err != nil {
return err
}
resp := NewResponseFuture(request.Opaque, timeoutMillis, callback)
resp := NewResponseFuture(request.Opaque, timeout, callback)
c.responseTable.Store(resp.Opaque, resp)
err = c.sendRequest(conn, request)
if err != nil {
Expand All @@ -107,27 +107,10 @@ func (c *RemotingClient) InvokeOneWay(addr string, request *RemotingCommand, tim
return c.sendRequest(conn, request)
}

func (c *RemotingClient) ScanResponseTable() {
rfs := make([]*ResponseFuture, 0)
c.responseTable.Range(func(key, value interface{}) bool {
if resp, ok := value.(*ResponseFuture); ok {
if (resp.BeginTimestamp + int64(resp.TimeoutMillis) + 1000) <= time.Now().Unix()*1000 {
rfs = append(rfs, resp)
c.responseTable.Delete(key)
}
}
return true
})
for _, rf := range rfs {
rf.Err = ErrRequestTimeout
rf.executeInvokeCallback()
}
}

func (c *RemotingClient) connect(addr string) (net.Conn, error) {
//it needs additional locker.
connectionLocker.Lock()
defer connectionLocker.Unlock()
c.connectionLocker.Lock()
defer c.connectionLocker.Unlock()
conn, ok := c.connectionTable.Load(addr)
if ok {
return conn.(net.Conn), nil
Expand Down Expand Up @@ -181,7 +164,7 @@ func (c *RemotingClient) receiveResponse(r net.Conn) {
}
}
if scanner.Err() != nil {
rlog.Errorf("net: %s scanner exit, err: %s.", r.RemoteAddr().String(), scanner.Err())
rlog.Errorf("net: %s scanner exit, Err: %s.", r.RemoteAddr().String(), scanner.Err())
} else {
rlog.Infof("net: %s scanner exit.", r.RemoteAddr().String())
}
Expand Down Expand Up @@ -237,3 +220,15 @@ func (c *RemotingClient) closeConnection(toCloseConn net.Conn) {
}
})
}

func (c *RemotingClient) ShutDown() {
c.responseTable.Range(func(key, value interface{}) bool {
c.responseTable.Delete(key)
return true
})
c.connectionTable.Range(func(key, value interface{}) bool {
conn := value.(net.Conn)
conn.Close()
return true
})
}
Loading