Skip to content

Commit

Permalink
fix: enable server error handle middleware (cloudwego#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Nov 25, 2021
1 parent 6164cb8 commit 9621edd
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
19 changes: 19 additions & 0 deletions internal/mocks/serviceinfo.go
Expand Up @@ -18,6 +18,7 @@ package mocks

import (
"context"
"errors"
"fmt"

"github.com/apache/thrift/lib/go/thrift"
Expand All @@ -30,6 +31,7 @@ const (
MockServiceName = "MockService"
MockMethod string = "mock"
MockExceptionMethod string = "mockException"
MockErrorMethod string = "mockError"
MockOnewayMethod string = "mockOneway"
)

Expand All @@ -44,6 +46,7 @@ func newServiceInfo() *serviceinfo.ServiceInfo {
methods := map[string]serviceinfo.MethodInfo{
"mock": serviceinfo.NewMethodInfo(mockHandler, NewMockArgs, NewMockResult, false),
"mockException": serviceinfo.NewMethodInfo(mockExceptionHandler, NewMockArgs, newMockExceptionResult, false),
"mockError": serviceinfo.NewMethodInfo(mockErrorHandler, NewMockArgs, NewMockResult, false),
"mockOneway": serviceinfo.NewMethodInfo(mockOnewayHandler, NewMockArgs, nil, true),
}

Expand Down Expand Up @@ -91,6 +94,17 @@ func mockExceptionHandler(ctx context.Context, handler, args, result interface{}
return nil
}

func mockErrorHandler(ctx context.Context, handler, args, result interface{}) error {
a := args.(*myServiceMockArgs)
r := result.(*myServiceMockResult)
reply, err := handler.(MyService).MockError(ctx, a.Req)
if err != nil {
return err
}
r.Success = reply
return nil
}

func newMockExceptionResult() interface{} {
return &myServiceMockExceptionResult{}
}
Expand All @@ -108,6 +122,7 @@ func mockOnewayHandler(ctx context.Context, handler, args, result interface{}) e
type MyService interface {
Mock(ctx context.Context, req *MyRequest) (r *MyResponse, err error)
MockException(ctx context.Context, req *MyRequest) (r *MyResponse, err error)
MockError(ctx context.Context, req *MyRequest) (r *MyResponse, err error)
MockOneway(ctx context.Context, req *MyRequest) (err error)
}

Expand Down Expand Up @@ -192,6 +207,10 @@ func (h *myServiceHandler) MockException(ctx context.Context, req *MyRequest) (r
return &MyResponse{Name: MockExceptionMethod}, nil
}

func (h *myServiceHandler) MockError(ctx context.Context, req *MyRequest) (r *MyResponse, err error) {
return nil, errors.New(MockErrorMethod)
}

func (h *myServiceHandler) MockOneway(ctx context.Context, req *MyRequest) (err error) {
return nil
}
2 changes: 1 addition & 1 deletion pkg/remote/trans/default_server_handler.go
Expand Up @@ -264,7 +264,7 @@ func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) {
if ri == nil {
return "", rAddr
}
if rAddr.Network() == "unix" {
if rAddr != nil && rAddr.Network() == "unix" {
if ri.From().Address() != nil {
rAddr = ri.From().Address()
}
Expand Down
26 changes: 21 additions & 5 deletions server/invoke_test.go
Expand Up @@ -17,6 +17,8 @@
package server

import (
"strings"
"sync/atomic"
"testing"

"github.com/apache/thrift/lib/go/thrift"
Expand All @@ -28,9 +30,11 @@ import (
)

func TestInvokerCall(t *testing.T) {
var opts []Option
opts = append(opts, WithMetaHandler(noopMetahandler{}))
invoker := NewInvoker(opts...)
var gotErr atomic.Value
invoker := NewInvoker(WithMetaHandler(noopMetahandler{}), WithErrorHandler(func(err error) error {
gotErr.Store(err)
return err
}))

err := invoker.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
if err != nil {
Expand All @@ -40,13 +44,14 @@ func TestInvokerCall(t *testing.T) {
if err != nil {
t.Fatal(err)
}

args := mocks.NewMockArgs()
codec := utils.NewThriftMessageCodec()
b, _ := codec.Encode("mock", thrift.CALL, 0, args.(thrift.TStruct))

// call success
b, _ := codec.Encode("mock", thrift.CALL, 0, args.(thrift.TStruct))
msg := invoke.NewMessage(nil, nil)
msg.SetRequestBytes(b)

err = invoker.Call(msg)
if err != nil {
t.Fatal(err)
Expand All @@ -56,4 +61,15 @@ func TestInvokerCall(t *testing.T) {
t.Fatal(err)
}
test.Assert(t, len(b) > 0)
test.Assert(t, gotErr.Load() == nil)

// call fails
b, _ = codec.Encode("mockError", thrift.CALL, 0, args.(thrift.TStruct))
msg = invoke.NewMessage(nil, nil)
msg.SetRequestBytes(b)
err = invoker.Call(msg)
if err != nil {
t.Fatal(err)
}
test.Assert(t, strings.Contains(gotErr.Load().(error).Error(), "mockError"))
}
19 changes: 18 additions & 1 deletion server/server.go
Expand Up @@ -77,7 +77,11 @@ func (s *server) init() {
ctx := fillContext(s.opt)
s.mws = richMWsWithBuilder(ctx, s.opt.MWBs, s)
s.mws = append(s.mws, acl.NewACLMiddleware(s.opt.ACLRules))

if s.opt.ErrHandle != nil {
// errorHandleMW must be the last middleware,
// to ensure it only catches the server handler's error.
s.mws = append(s.mws, newErrorHandleMW(s.opt.ErrHandle))
}
if ds := s.opt.DebugService; ds != nil {
ds.RegisterProbeFunc(diagnosis.OptionsKey, diagnosis.WrapAsProbeFunc(s.opt.DebugInfo))
ds.RegisterProbeFunc(diagnosis.ChangeEventsKey, s.opt.Events.Dump)
Expand All @@ -100,6 +104,19 @@ func richMWsWithBuilder(ctx context.Context, mwBs []endpoint.MiddlewareBuilder,
return ks.mws
}

// newErrorHandleMW provides a hook point for server error handling.
func newErrorHandleMW(errHandle func(error) error) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) error {
err := next(ctx, request, response)
if err == nil {
return nil
}
return errHandle(err)
}
}
}

func (s *server) initRPCInfoFunc() func(context.Context, net.Addr) (rpcinfo.RPCInfo, context.Context) {
return func(ctx context.Context, rAddr net.Addr) (rpcinfo.RPCInfo, context.Context) {
if ctx == nil {
Expand Down

0 comments on commit 9621edd

Please sign in to comment.