Skip to content

Commit

Permalink
Fix: decode net stream bytes as getty rule (#1820)
Browse files Browse the repository at this point in the history
* return decode value as getty standard

* fix fatal error: do not handle nil package

* fix unit test

Co-authored-by: jason <lvs.pjx@gmail.com>
Co-authored-by: Laurence <45508533+LaurenceLiZhixin@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 15, 2022
1 parent 65d3f38 commit 605764e
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 51 deletions.
49 changes: 32 additions & 17 deletions protocol/dubbo/dubbo_codec.go
Expand Up @@ -34,7 +34,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/logger"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/protocol/dubbo/impl"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
invct "dubbo.apache.org/dubbo-go/v3/protocol/invocation"
"dubbo.apache.org/dubbo-go/v3/remoting"
)

Expand Down Expand Up @@ -160,20 +160,30 @@ func (c *DubboCodec) EncodeResponse(response *remoting.Response) (*bytes.Buffer,
}

// Decode data, including request and response.
func (c *DubboCodec) Decode(data []byte) (remoting.DecodeResult, int, error) {
func (c *DubboCodec) Decode(data []byte) (*remoting.DecodeResult, int, error) {
dataLen := len(data)
if dataLen < impl.HEADER_LENGTH { // check whether header bytes is enough or not
return nil, 0, nil
}
if c.isRequest(data) {
req, len, err := c.decodeRequest(data)
req, length, err := c.decodeRequest(data)
if err != nil {
return remoting.DecodeResult{}, len, perrors.WithStack(err)
return nil, length, perrors.WithStack(err)
}
if req == ((*remoting.Request)(nil)) {
return nil, length, err
}
return remoting.DecodeResult{IsRequest: true, Result: req}, len, perrors.WithStack(err)
return &remoting.DecodeResult{IsRequest: true, Result: req}, length, perrors.WithStack(err)
}

resp, len, err := c.decodeResponse(data)
rsp, length, err := c.decodeResponse(data)
if err != nil {
return remoting.DecodeResult{}, len, perrors.WithStack(err)
return nil, length, perrors.WithStack(err)
}
if rsp == ((*remoting.Response)(nil)) {
return nil, length, err
}
return remoting.DecodeResult{IsRequest: false, Result: resp}, len, perrors.WithStack(err)
return &remoting.DecodeResult{IsRequest: false, Result: rsp}, length, perrors.WithStack(err)
}

func (c *DubboCodec) isRequest(data []byte) bool {
Expand All @@ -182,16 +192,18 @@ func (c *DubboCodec) isRequest(data []byte) bool {

// decode request
func (c *DubboCodec) decodeRequest(data []byte) (*remoting.Request, int, error) {
var request *remoting.Request = nil
var request *remoting.Request
buf := bytes.NewBuffer(data)
pkg := impl.NewDubboPackage(buf)
pkg.SetBody(make([]interface{}, 7))
err := pkg.Unmarshal()
if err != nil {
originErr := perrors.Cause(err)
if originErr == hessian.ErrHeaderNotEnough || originErr == hessian.ErrBodyNotEnough {
// FIXME
return nil, 0, originErr
if originErr == hessian.ErrHeaderNotEnough { // this is impossible, as dubbo_codec.go:DubboCodec::Decode() line 167
return nil, 0, nil
}
if originErr == hessian.ErrBodyNotEnough {
return nil, hessian.HEADER_LENGTH + pkg.GetBodyLen(), nil
}
logger.Errorf("pkg.Unmarshal(len(@data):%d) = error:%+v", buf.Len(), err)

Expand Down Expand Up @@ -223,8 +235,8 @@ func (c *DubboCodec) decodeRequest(data []byte) (*remoting.Request, int, error)
methodName = pkg.Service.Method
args = req[impl.ArgsKey].([]interface{})
attachments = req[impl.AttachmentsKey].(map[string]interface{})
invoc := invocation.NewRPCInvocationWithOptions(invocation.WithAttachments(attachments),
invocation.WithArguments(args), invocation.WithMethodName(methodName))
invoc := invct.NewRPCInvocationWithOptions(invct.WithAttachments(attachments),
invct.WithArguments(args), invct.WithMethodName(methodName))
request.Data = invoc

}
Expand All @@ -239,11 +251,14 @@ func (c *DubboCodec) decodeResponse(data []byte) (*remoting.Response, int, error
if err != nil {
originErr := perrors.Cause(err)
// if the data is very big, so the receive need much times.
if originErr == hessian.ErrHeaderNotEnough || originErr == hessian.ErrBodyNotEnough {
return nil, 0, originErr
if originErr == hessian.ErrHeaderNotEnough { // this is impossible, as dubbo_codec.go:DubboCodec::Decode() line 167
return nil, 0, nil
}
if originErr == hessian.ErrBodyNotEnough {
return nil, hessian.HEADER_LENGTH + pkg.GetBodyLen(), nil
}
logger.Errorf("pkg.Unmarshal(len(@data):%d) = error:%+v", buf.Len(), err)

logger.Warnf("pkg.Unmarshal(len(@data):%d) = error:%+v", buf.Len(), err)
return nil, 0, perrors.WithStack(err)
}
response := &remoting.Response{
Expand Down
3 changes: 3 additions & 0 deletions protocol/dubbo/impl/codec.go
Expand Up @@ -157,6 +157,9 @@ func (c *ProtocolCodec) Decode(p *DubboPackage) error {
return err
}
}
if c.reader.Size() < p.GetBodyLen()+HEADER_LENGTH {
return hessian.ErrBodyNotEnough
}
body, err := c.reader.Peek(p.GetBodyLen())
if err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions protocol/dubbo/impl/const.go
Expand Up @@ -223,6 +223,8 @@ var (

// Error part
var (
// the following errors has already existed in github.com/apache/dubbo-go-hessian2.
// Please do not use them.
ErrHeaderNotEnough = errors.New("header buffer too short")
ErrBodyNotEnough = errors.New("body buffer too short")
ErrJavaException = errors.New("got java exception")
Expand Down
2 changes: 1 addition & 1 deletion remoting/codec.go
Expand Up @@ -26,7 +26,7 @@ import (
type Codec interface {
EncodeRequest(request *Request) (*bytes.Buffer, error)
EncodeResponse(response *Response) (*bytes.Buffer, error)
Decode(data []byte) (DecodeResult, int, error)
Decode(data []byte) (*DecodeResult, int, error)
}

type DecodeResult struct {
Expand Down
10 changes: 5 additions & 5 deletions remoting/getty/dubbo_codec_for_test.go
Expand Up @@ -155,19 +155,19 @@ func (c *DubboTestCodec) EncodeResponse(response *remoting.Response) (*bytes.Buf
}

// Decode data, including request and response.
func (c *DubboTestCodec) Decode(data []byte) (remoting.DecodeResult, int, error) {
func (c *DubboTestCodec) Decode(data []byte) (*remoting.DecodeResult, int, error) {
if c.isRequest(data) {
req, len, err := c.decodeRequest(data)
if err != nil {
return remoting.DecodeResult{}, len, perrors.WithStack(err)
return &remoting.DecodeResult{}, len, perrors.WithStack(err)
}
return remoting.DecodeResult{IsRequest: true, Result: req}, len, perrors.WithStack(err)
return &remoting.DecodeResult{IsRequest: true, Result: req}, len, perrors.WithStack(err)
} else {
resp, len, err := c.decodeResponse(data)
if err != nil {
return remoting.DecodeResult{}, len, perrors.WithStack(err)
return &remoting.DecodeResult{}, len, perrors.WithStack(err)
}
return remoting.DecodeResult{IsRequest: false, Result: resp}, len, perrors.WithStack(err)
return &remoting.DecodeResult{IsRequest: false, Result: resp}, len, perrors.WithStack(err)
}
}

Expand Down
8 changes: 4 additions & 4 deletions remoting/getty/listener.go
Expand Up @@ -94,8 +94,8 @@ func (h *RpcClientHandler) OnClose(session getty.Session) {

// OnMessage get response from getty server, and update the session to the getty client session list
func (h *RpcClientHandler) OnMessage(session getty.Session, pkg interface{}) {
result, ok := pkg.(remoting.DecodeResult)
if !ok {
result, ok := pkg.(*remoting.DecodeResult)
if !ok || result == ((*remoting.DecodeResult)(nil)) {
logger.Errorf("[RpcClientHandler.OnMessage] getty client gets an unexpected rpc result: %#v", result)
return
}
Expand Down Expand Up @@ -232,8 +232,8 @@ func (h *RpcServerHandler) OnMessage(session getty.Session, pkg interface{}) {
}
h.rwlock.Unlock()

decodeResult, drOK := pkg.(remoting.DecodeResult)
if !drOK {
decodeResult, drOK := pkg.(*remoting.DecodeResult)
if !drOK || decodeResult == ((*remoting.DecodeResult)(nil)) {
logger.Errorf("illegal package{%#v}", pkg)
return
}
Expand Down
37 changes: 15 additions & 22 deletions remoting/getty/readwriter.go
Expand Up @@ -18,15 +18,12 @@
package getty

import (
"errors"
"reflect"
)

import (
"github.com/apache/dubbo-getty"

hessian "github.com/apache/dubbo-go-hessian2"

perrors "github.com/pkg/errors"
)

Expand All @@ -48,19 +45,17 @@ func NewRpcClientPackageHandler(client *Client) *RpcClientPackageHandler {
// Read data from server. if the package size from server is larger than 4096 byte, server will read 4096 byte
// and send to client each time. the Read can assemble it.
func (p *RpcClientPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) {
resp, length, err := (p.client.codec).Decode(data)
// err := pkg.Unmarshal(buf, p.client)
rsp, length, err := (p.client.codec).Decode(data)
if err != nil {
if errors.Is(err, hessian.ErrHeaderNotEnough) || errors.Is(err, hessian.ErrBodyNotEnough) {
return nil, 0, nil
}

logger.Errorf("pkg.Unmarshal(ss:%+v, len(@data):%d) = error:%+v", ss, len(data), err)

err = perrors.WithStack(err)
}
if rsp == ((*remoting.DecodeResult)(nil)) {
return nil, length, err
}

return resp, length, nil
if rsp.Result == ((*remoting.Response)(nil)) || rsp.Result == ((*remoting.Request)(nil)) {
return nil, length, err
}
return rsp, length, err
}

// Write send the data to server
Expand Down Expand Up @@ -102,17 +97,15 @@ func NewRpcServerPackageHandler(server *Server) *RpcServerPackageHandler {
// and send to client each time. the Read can assemble it.
func (p *RpcServerPackageHandler) Read(ss getty.Session, data []byte) (interface{}, int, error) {
req, length, err := (p.server.codec).Decode(data)
// resp,len, err := (*p.).DecodeResponse(buf)
if err != nil {
if errors.Is(err, hessian.ErrHeaderNotEnough) || errors.Is(err, hessian.ErrBodyNotEnough) {
return nil, 0, nil
}

logger.Errorf("pkg.Unmarshal(ss:%+v, len(@data):%d) = error:%+v", ss, len(data), err)

return nil, 0, err
err = perrors.WithStack(err)
}
if req == ((*remoting.DecodeResult)(nil)) {
return nil, length, err
}
if req.Result == ((*remoting.Request)(nil)) || req.Result == ((*remoting.Response)(nil)) {
return nil, length, err // as getty rule
}

return req, length, err
}

Expand Down
4 changes: 2 additions & 2 deletions remoting/getty/readwriter_test.go
Expand Up @@ -75,8 +75,8 @@ func testDecodeTCPPackage(t *testing.T, svr *Server, client *Client) {
assert.True(t, incompletePkgLen >= impl.HEADER_LENGTH, "header buffer too short")
incompletePkg := pkgBytes[0 : incompletePkgLen-1]
pkg, pkgLen, err := pkgReadHandler.Read(nil, incompletePkg)
assert.NoError(t, err)
assert.Equal(t, pkg, nil)
assert.Equal(t, err.Error(), "body buffer too short")
assert.Equal(t, pkg.(*remoting.DecodeResult).Result, nil)
assert.Equal(t, pkgLen, 0)
}

Expand Down

0 comments on commit 605764e

Please sign in to comment.