From 0d662800d60677eee0ad08f7f863a3e3b2ad034d Mon Sep 17 00:00:00 2001 From: lbbniu Date: Fri, 12 Jan 2024 13:45:34 +0800 Subject: [PATCH] feat(tars): implementing inter-service timeout propagation using context.Context --- tars/servant.go | 14 +++-- tars/tarsprotocol.go | 99 ++++++++++++++++++++---------------- tars/transport/tcphandler.go | 2 +- tars/transport/udphandler.go | 2 +- 4 files changed, 69 insertions(+), 48 deletions(-) diff --git a/tars/servant.go b/tars/servant.go index 006acf90..8911db4a 100755 --- a/tars/servant.go +++ b/tars/servant.go @@ -14,7 +14,6 @@ import ( "github.com/TarsCloud/TarsGo/tars/protocol/res/requestf" "github.com/TarsCloud/TarsGo/tars/util/current" "github.com/TarsCloud/TarsGo/tars/util/endpoint" - "github.com/TarsCloud/TarsGo/tars/util/rtimer" "github.com/TarsCloud/TarsGo/tars/util/tools" ) @@ -159,17 +158,26 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte, msg := &Message{Req: &req, Ser: s, Resp: resp} msg.Init() - timeout := time.Duration(s.timeout) * time.Millisecond if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok { msg.isHash = isHash msg.hashType = HashType(hashType) msg.hashCode = hashCode } + timeout := time.Duration(s.timeout) * time.Millisecond if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout { timeout = time.Duration(to) * time.Millisecond req.ITimeout = int32(to) } + // timeout delivery + if dl, ok := ctx.Deadline(); ok { + timeout = time.Until(dl) + req.ITimeout = int32(timeout / time.Millisecond) + } else { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } var err error s.manager.preInvoke() @@ -253,7 +261,7 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time. return nil } select { - case <-rtimer.After(timeout): + case <-ctx.Done(): msg.Status = basef.TARSINVOKETIMEOUT adp.failAdd() msg.End() diff --git a/tars/tarsprotocol.go b/tars/tarsprotocol.go index abb55859..76f9d8ee 100755 --- a/tars/tarsprotocol.go +++ b/tars/tarsprotocol.go @@ -44,6 +44,21 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) { is := codec.NewReader(req[4:]) reqPackage.ReadFrom(is) + recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx) + if !ok { + recvPkgTs = time.Now().UnixNano() / 1e6 + } + + // timeout delivery + now := time.Now().UnixNano() / 1e6 + if reqPackage.ITimeout > 0 { + sub := now - recvPkgTs // coroutine scheduling time difference + timeout := int64(reqPackage.ITimeout) - sub + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond) + defer cancel() + } + if reqPackage.HasMessageType(basef.TARSMESSAGETYPEDYED) { if dyeingKey, ok := reqPackage.Status[current.StatusDyedKey]; ok { if ok = current.SetDyeingKey(ctx, dyeingKey); !ok { @@ -62,10 +77,6 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) { } } - recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx) - if !ok { - recvPkgTs = time.Now().UnixNano() / 1e6 - } if reqPackage.CPacketType == basef.TARSONEWAY { defer func() { endTime := time.Now().UnixNano() / 1e6 @@ -81,52 +92,54 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) { rspPackage.IVersion = reqPackage.IVersion rspPackage.IRequestId = reqPackage.IRequestId - // Improve server timeout handling - now := time.Now().UnixNano() / 1e6 - if ok && reqPackage.ITimeout > 0 && now-recvPkgTs > int64(reqPackage.ITimeout) { + select { + case <-ctx.Done(): rspPackage.IRet = basef.TARSSERVERQUEUETIMEOUT + rspPackage.SResultDesc = "server invoke timeout" ip, _ := current.GetClientIPFromContext(ctx) port, _ := current.GetClientPortFromContext(ctx) - TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d", - reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId) - } else if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch - if s.withContext { - if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok { - TLOG.Error("Set request status in context fail!") - } - if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok { - TLOG.Error("Set request context in context fail!") - } - } - var err error - if s.app.allFilters.sf != nil { - err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) - } else if sf := s.app.getMiddlewareServerFilter(); sf != nil { - err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) - } else { - // execute pre server filters - for i, v := range s.app.allFilters.preSfs { - err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) - if err != nil { - TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err) + TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d, err: %v", + reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId, ctx.Err()) + default: + if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch + if s.withContext { + if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok { + TLOG.Error("Set request status in context fail!") + } + if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok { + TLOG.Error("Set request context in context fail!") } } - // execute business server - err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext) - // execute post server filters - for i, v := range s.app.allFilters.postSfs { - err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) - if err != nil { - TLOG.Errorf("Post filter error, No.%v, err: %v", i, err) + var err error + if s.app.allFilters.sf != nil { + err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) + } else if sf := s.app.getMiddlewareServerFilter(); sf != nil { + err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) + } else { + // execute pre server filters + for i, v := range s.app.allFilters.preSfs { + err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) + if err != nil { + TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err) + } + } + // execute business server + err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext) + // execute post server filters + for i, v := range s.app.allFilters.postSfs { + err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext) + if err != nil { + TLOG.Errorf("Post filter error, No.%v, err: %v", i, err) + } } } - } - if err != nil { - TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err) - rspPackage.IRet = 1 - rspPackage.SResultDesc = err.Error() - if tarsErr, ok := err.(*Error); ok { - rspPackage.IRet = tarsErr.Code + if err != nil { + TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err) + rspPackage.IRet = 1 + rspPackage.SResultDesc = err.Error() + if tarsErr, ok := err.(*Error); ok { + rspPackage.IRet = tarsErr.Code + } } } } diff --git a/tars/transport/tcphandler.go b/tars/transport/tcphandler.go index 343e89bc..11f8f873 100755 --- a/tars/transport/tcphandler.go +++ b/tars/transport/tcphandler.go @@ -71,9 +71,9 @@ func (t *tcpHandler) getConnContext(connSt *connInfo) context.Context { func (t *tcpHandler) handleConn(connSt *connInfo, pkg []byte) { // recvPkgTs are more accurate + ctx := t.getConnContext(connSt) handler := func() { defer atomic.AddInt32(&connSt.numInvoke, -1) - ctx := t.getConnContext(connSt) rsp := t.server.invoke(ctx, pkg) cPacketType, ok := current.GetPacketTypeFromContext(ctx) diff --git a/tars/transport/udphandler.go b/tars/transport/udphandler.go index e4d5ecb0..a14ee2a6 100755 --- a/tars/transport/udphandler.go +++ b/tars/transport/udphandler.go @@ -67,10 +67,10 @@ func (u *udpHandler) Handle() error { } pkg := make([]byte, n) copy(pkg, buffer[0:n]) + ctx := u.getConnContext(udpAddr) go func() { atomic.AddInt32(&u.server.numInvoke, 1) defer atomic.AddInt32(&u.server.numInvoke, -1) - ctx := u.getConnContext(udpAddr) rsp := u.server.invoke(ctx, pkg) // no need to check package cPacketType, ok := current.GetPacketTypeFromContext(ctx)