Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Inter-Service Timeout Propagation using context.Context #508

Merged
merged 1 commit into from Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 11 additions & 3 deletions tars/servant.go
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/TarsCloud/TarsGo/tars/protocol/res/requestf"
"github.com/TarsCloud/TarsGo/tars/util/current"
"github.com/TarsCloud/TarsGo/tars/util/endpoint"
"github.com/TarsCloud/TarsGo/tars/util/rtimer"
"github.com/TarsCloud/TarsGo/tars/util/tools"
)

Expand Down Expand Up @@ -159,17 +158,26 @@ func (s *ServantProxy) TarsInvoke(ctx context.Context, cType byte,
msg := &Message{Req: &req, Ser: s, Resp: resp}
msg.Init()

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, hashType, hashCode, isHash := current.GetClientHash(ctx); ok {
msg.isHash = isHash
msg.hashType = HashType(hashType)
msg.hashCode = hashCode
}

timeout := time.Duration(s.timeout) * time.Millisecond
if ok, to, isTimeout := current.GetClientTimeout(ctx); ok && isTimeout {
timeout = time.Duration(to) * time.Millisecond
req.ITimeout = int32(to)
}
// timeout delivery
if dl, ok := ctx.Deadline(); ok {
timeout = time.Until(dl)
req.ITimeout = int32(timeout / time.Millisecond)
} else {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

var err error
s.manager.preInvoke()
Expand Down Expand Up @@ -253,7 +261,7 @@ func (s *ServantProxy) doInvoke(ctx context.Context, msg *Message, timeout time.
return nil
}
select {
case <-rtimer.After(timeout):
case <-ctx.Done():
msg.Status = basef.TARSINVOKETIMEOUT
adp.failAdd()
msg.End()
Expand Down
99 changes: 56 additions & 43 deletions tars/tarsprotocol.go
Expand Up @@ -44,6 +44,21 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
is := codec.NewReader(req[4:])
reqPackage.ReadFrom(is)

recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx)
if !ok {
recvPkgTs = time.Now().UnixNano() / 1e6
}

// timeout delivery
now := time.Now().UnixNano() / 1e6
if reqPackage.ITimeout > 0 {
sub := now - recvPkgTs // coroutine scheduling time difference
timeout := int64(reqPackage.ITimeout) - sub
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
}

if reqPackage.HasMessageType(basef.TARSMESSAGETYPEDYED) {
if dyeingKey, ok := reqPackage.Status[current.StatusDyedKey]; ok {
if ok = current.SetDyeingKey(ctx, dyeingKey); !ok {
Expand All @@ -62,10 +77,6 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
}
}

recvPkgTs, ok := current.GetRecvPkgTsFromContext(ctx)
if !ok {
recvPkgTs = time.Now().UnixNano() / 1e6
}
if reqPackage.CPacketType == basef.TARSONEWAY {
defer func() {
endTime := time.Now().UnixNano() / 1e6
Expand All @@ -81,52 +92,54 @@ func (s *Protocol) Invoke(ctx context.Context, req []byte) (rsp []byte) {
rspPackage.IVersion = reqPackage.IVersion
rspPackage.IRequestId = reqPackage.IRequestId

// Improve server timeout handling
now := time.Now().UnixNano() / 1e6
if ok && reqPackage.ITimeout > 0 && now-recvPkgTs > int64(reqPackage.ITimeout) {
select {
case <-ctx.Done():
rspPackage.IRet = basef.TARSSERVERQUEUETIMEOUT
rspPackage.SResultDesc = "server invoke timeout"
ip, _ := current.GetClientIPFromContext(ctx)
port, _ := current.GetClientPortFromContext(ctx)
TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d",
reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId)
} else if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch
if s.withContext {
if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok {
TLOG.Error("Set request status in context fail!")
}
if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok {
TLOG.Error("Set request context in context fail!")
}
}
var err error
if s.app.allFilters.sf != nil {
err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else if sf := s.app.getMiddlewareServerFilter(); sf != nil {
err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else {
// execute pre server filters
for i, v := range s.app.allFilters.preSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err)
TLOG.Errorf("handle queue timeout, obj:%s, func:%s, recv time:%d, now:%d, timeout:%d, cost:%d, addr:(%s:%s), reqId:%d, err: %v",
reqPackage.SServantName, reqPackage.SFuncName, recvPkgTs, now, reqPackage.ITimeout, now-recvPkgTs, ip, port, reqPackage.IRequestId, ctx.Err())
default:
if reqPackage.SFuncName != "tars_ping" { // not tars_ping, normal business call branch
if s.withContext {
if ok = current.SetRequestStatus(ctx, reqPackage.Status); !ok {
TLOG.Error("Set request status in context fail!")
}
if ok = current.SetRequestContext(ctx, reqPackage.Context); !ok {
TLOG.Error("Set request context in context fail!")
}
}
// execute business server
err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext)
// execute post server filters
for i, v := range s.app.allFilters.postSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Post filter error, No.%v, err: %v", i, err)
var err error
if s.app.allFilters.sf != nil {
err = s.app.allFilters.sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else if sf := s.app.getMiddlewareServerFilter(); sf != nil {
err = sf(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
} else {
// execute pre server filters
for i, v := range s.app.allFilters.preSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Pre filter error, No.%v, err: %v", i, err)
}
}
// execute business server
err = s.dispatcher.Dispatch(ctx, s.serverImp, &reqPackage, &rspPackage, s.withContext)
// execute post server filters
for i, v := range s.app.allFilters.postSfs {
err = v(ctx, s.dispatcher.Dispatch, s.serverImp, &reqPackage, &rspPackage, s.withContext)
if err != nil {
TLOG.Errorf("Post filter error, No.%v, err: %v", i, err)
}
}
}
}
if err != nil {
TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err)
rspPackage.IRet = 1
rspPackage.SResultDesc = err.Error()
if tarsErr, ok := err.(*Error); ok {
rspPackage.IRet = tarsErr.Code
if err != nil {
TLOG.Errorf("RequestID:%d, Found err: %v", reqPackage.IRequestId, err)
rspPackage.IRet = 1
rspPackage.SResultDesc = err.Error()
if tarsErr, ok := err.(*Error); ok {
rspPackage.IRet = tarsErr.Code
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tars/transport/tcphandler.go
Expand Up @@ -71,9 +71,9 @@ func (t *tcpHandler) getConnContext(connSt *connInfo) context.Context {

func (t *tcpHandler) handleConn(connSt *connInfo, pkg []byte) {
// recvPkgTs are more accurate
ctx := t.getConnContext(connSt)
handler := func() {
defer atomic.AddInt32(&connSt.numInvoke, -1)
ctx := t.getConnContext(connSt)
rsp := t.server.invoke(ctx, pkg)

cPacketType, ok := current.GetPacketTypeFromContext(ctx)
Expand Down
2 changes: 1 addition & 1 deletion tars/transport/udphandler.go
Expand Up @@ -67,10 +67,10 @@ func (u *udpHandler) Handle() error {
}
pkg := make([]byte, n)
copy(pkg, buffer[0:n])
ctx := u.getConnContext(udpAddr)
go func() {
atomic.AddInt32(&u.server.numInvoke, 1)
defer atomic.AddInt32(&u.server.numInvoke, -1)
ctx := u.getConnContext(udpAddr)
rsp := u.server.invoke(ctx, pkg) // no need to check package

cPacketType, ok := current.GetPacketTypeFromContext(ctx)
Expand Down