diff --git a/proto/httproto/httproto.go b/proto/httproto/httproto.go index 15ff26d2..4b333cb7 100644 --- a/proto/httproto/httproto.go +++ b/proto/httproto/httproto.go @@ -26,12 +26,13 @@ import ( "strings" "sync" + "github.com/henrylee2cn/goutil" + "github.com/henrylee2cn/erpc/v6" "github.com/henrylee2cn/erpc/v6/codec" "github.com/henrylee2cn/erpc/v6/utils" "github.com/henrylee2cn/erpc/v6/xfer" "github.com/henrylee2cn/erpc/v6/xfer/gzip" - "github.com/henrylee2cn/goutil" ) var ( @@ -394,7 +395,7 @@ func (h *httproto) unpack(m erpc.Message, bb *utils.ByteBuffer) (size int, msg [ } m.Meta().SetBytesKV(a[0], a[1]) } - if bodySize == 0 { + if bodySize <= 0 { return size, msg, nil } bb.ChangeLen(bodySize) diff --git a/socket/message_test.go b/socket/message_test.go index ecfa85d7..d3762d91 100644 --- a/socket/message_test.go +++ b/socket/message_test.go @@ -3,6 +3,8 @@ package socket import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/henrylee2cn/erpc/v6/xfer/gzip" ) @@ -26,3 +28,19 @@ func TestMessageString(t *testing.T) { t.Logf("%%#v:%#v", m) t.Logf("%%+v:%+v", m) } + +func TestUint32Minus(t *testing.T) { + a := 1 + a, err := minus(a, 4) + assert.EqualError(t, err, "raw proto: bad package") + assert.Equal(t, int(1), a) + a, err = minus(a, 0) + assert.NoError(t, err) + assert.Equal(t, int(1), a) + a, err = minus(a, 1) + assert.NoError(t, err) + assert.Equal(t, int(0), a) + a, err = minus(a, 1) + assert.EqualError(t, err, "raw proto: bad package") + assert.Equal(t, int(0), a) +} diff --git a/socket/protocol.go b/socket/protocol.go index 33395a30..f5d6fa08 100644 --- a/socket/protocol.go +++ b/socket/protocol.go @@ -22,8 +22,9 @@ import ( "strconv" "sync" - "github.com/henrylee2cn/erpc/v6/utils" "github.com/henrylee2cn/goutil" + + "github.com/henrylee2cn/erpc/v6/utils" ) type ( @@ -229,12 +230,16 @@ func (r *rawProto) readMessage(bb *utils.ByteBuffer, m Message) error { if err != nil { return err } - lastSize := binary.BigEndian.Uint32(bb.B) - if err = m.SetSize(lastSize); err != nil { + _lastSize := binary.BigEndian.Uint32(bb.B) + if err = m.SetSize(_lastSize); err != nil { return err } - lastSize -= 4 - bb.ChangeLen(int(lastSize)) + lastSize := int(_lastSize) + lastSize, err = minus(lastSize, 4) + if err != nil { + return err + } + bb.ChangeLen(lastSize) // transfer pipe _, err = io.ReadFull(r.r, bb.B[:1]) @@ -252,14 +257,24 @@ func (r *rawProto) readMessage(bb *utils.ByteBuffer, m Message) error { return err } } - lastSize -= (1 + uint32(xferLen)) - + lastSize, err = minus(lastSize, 1+int(xferLen)) + if err != nil { + return err + } // read last all - bb.ChangeLen(int(lastSize)) + bb.ChangeLen(lastSize) _, err = io.ReadFull(r.r, bb.B) return err } +func minus(a int, b int) (int, error) { + r := a - b + if r < 0 || b < 0 { + return a, errors.New("raw proto: bad package") + } + return r, nil +} + func (r *rawProto) readHeader(data []byte, m Message) ([]byte, error) { // seq seqLen := data[0] diff --git a/utils/bytebuffer.go b/utils/bytebuffer.go index f5ff7abb..d3d8b332 100644 --- a/utils/bytebuffer.go +++ b/utils/bytebuffer.go @@ -113,7 +113,7 @@ func (b *ByteBuffer) Reset() { func (b *ByteBuffer) ChangeLen(newLen int) { if cap(b.B) < newLen { b.B = make([]byte, newLen) - } else { + } else if newLen >= 0 { b.B = b.B[:newLen] } }