Skip to content
Permalink
Browse files
Merge pull request #309 from wongoo/master
support non-strict mode to decode object to map when unregistered
  • Loading branch information
LaurenceLiZhixin committed Jan 18, 2022
2 parents a80067e + f3f3a2e commit ab8e18640f757dbfb706dc17d0cf4e5c4356732b
Showing 17 changed files with 465 additions and 121 deletions.
@@ -368,6 +368,20 @@ type Dog struct {
}
```

## Strict Mode

Default, hessian2 will decode an object to map if it's not being registered.
If you don't want that, change the decoder to strict mode as following,
and it will return error when meeting unregistered object.

```go
e := hessian.NewDecoder(bytes)
e.Strict = true // set to strict mode, default is false
// or
e := hessian.NewStrictDecoder(bytes)
```

## Tools

### tools/gen-go-enum
@@ -34,8 +34,23 @@ type Decoder struct {
refs []interface{}
// record type refs, both list and map need it
typeRefs *TypeRefs
classInfoList []*classInfo
classInfoList []*ClassInfo
isSkip bool

// In strict mode, a class data can be decoded only when the class is registered, otherwise error returned.
// In non-strict mode, a class data will be decoded to a map when the class is not registered.
// The default is non-strict mode, user can change it as required.
Strict bool
}

// FindClassInfo find ClassInfo for the given name in decoder class info list.
func (d *Decoder) FindClassInfo(javaName string) *ClassInfo {
for _, info := range d.classInfoList {
if info.javaName == javaName {
return info
}
}
return nil
}

// Error part
@@ -49,6 +64,16 @@ func NewDecoder(b []byte) *Decoder {
return &Decoder{reader: bufio.NewReader(bytes.NewReader(b)), typeRefs: &TypeRefs{records: map[string]bool{}}}
}

// NewStrictDecoder generates a strict mode decoder instance.
// In strict mode, all target class must be registered.
func NewStrictDecoder(b []byte) *Decoder {
return &Decoder{
reader: bufio.NewReader(bytes.NewReader(b)),
typeRefs: &TypeRefs{records: map[string]bool{}},
Strict: true,
}
}

// NewDecoderSize generate a decoder instance.
func NewDecoderSize(b []byte, size int) *Decoder {
return &Decoder{reader: bufio.NewReaderSize(bytes.NewReader(b), size), typeRefs: &TypeRefs{records: map[string]bool{}}}
@@ -145,6 +145,16 @@ func testDecodeFrameworkFunc(t *testing.T, method string, expected func(interfac
expected(r)
}

func mustDecodeObject(t *testing.T, b []byte) interface{} {
d := NewDecoder(b)
res, err := d.Decode()
if err != nil {
t.Error(err)
t.FailNow()
}
return res
}

func TestUserDefindeException(t *testing.T) {
expect := &UnknownException{
DetailMessage: "throw UserDefindException",
@@ -33,11 +33,21 @@ import (

// Encoder struct
type Encoder struct {
classInfoList []*classInfo
classInfoList []*ClassInfo
buffer []byte
refMap map[unsafe.Pointer]_refElem
}

// classIndex find the index of the given java name in encoder class info list.
func (e *Encoder) classIndex(javaName string) int {
for i := range e.classInfoList {
if javaName == e.classInfoList[i].javaName {
return i
}
}
return -1
}

// NewEncoder generate an encoder instance
func NewEncoder() *Encoder {
buffer := make([]byte, 64)
@@ -74,7 +74,7 @@ func NewHessianCodec(reader *bufio.Reader) *HessianCodec {
}
}

// NewHessianCodec generate a new hessian codec instance
// NewHessianCodecCustom generate a new hessian codec instance.
func NewHessianCodecCustom(pkgType PackageType, reader *bufio.Reader, bodyLen int) *HessianCodec {
return &HessianCodec{
pkgType: pkgType,
@@ -129,15 +129,15 @@ func (h *HessianCodec) ReadHeader(header *DubboHeader) error {
return perrors.Errorf("serialization ID:%v", header.SerialID)
}

flag := buf[2] & FLAG_EVENT
if flag != Zero {
headerFlag := buf[2] & FLAG_EVENT
if headerFlag != Zero {
header.Type |= PackageHeartbeat
}
flag = buf[2] & FLAG_REQUEST
if flag != Zero {
headerFlag = buf[2] & FLAG_REQUEST
if headerFlag != Zero {
header.Type |= PackageRequest
flag = buf[2] & FLAG_TWOWAY
if flag != Zero {
headerFlag = buf[2] & FLAG_TWOWAY
if headerFlag != Zero {
header.Type |= PackageRequest_TwoWay
}
} else {
@@ -197,7 +197,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error {
case PackageRequest | PackageHeartbeat, PackageResponse | PackageHeartbeat:
case PackageRequest:
if rspObj != nil {
if err = unpackRequestBody(NewDecoder(buf[:]), rspObj); err != nil {
if err = unpackRequestBody(NewStrictDecoder(buf[:]), rspObj); err != nil {
return perrors.WithStack(err)
}
}
@@ -212,7 +212,7 @@ func (h *HessianCodec) ReadBody(rspObj interface{}) error {
return nil
}

// ignore body, but only read attachments
// ReadAttachments ignore body, but only read attachments
func (h *HessianCodec) ReadAttachments() (map[string]string, error) {
if h.reader.Buffered() < h.bodyLen {
return nil, ErrBodyNotEnough
@@ -114,39 +114,52 @@ func doTestResponse(t *testing.T, packageType PackageType, responseStatus byte,
func TestResponse(t *testing.T) {
caseObj := Case{A: "a", B: 1}
decodedResponse := &Response{}
RegisterPOJO(&caseObj)

arr := []*Case{&caseObj}
var arrRes []interface{}
decodedResponse.RspObj = &arrRes
decodedResponse.RspObj = nil
doTestResponse(t, PackageResponse, Response_OK, arr, decodedResponse, func() {
arrRes, ok := decodedResponse.RspObj.([]*Case)
if !ok {
t.Errorf("expect []*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String())
return
}
assert.Equal(t, 1, len(arrRes))
assert.Equal(t, &caseObj, arrRes[0])
})

decodedResponse.RspObj = &Case{}
doTestResponse(t, PackageResponse, Response_OK, &Case{A: "a", B: 1}, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, &caseObj, decodedResponse, func() {
assert.Equal(t, &caseObj, decodedResponse.RspObj)
})

s := "ok!!!!!"
strObj := ""
decodedResponse.RspObj = &strObj
doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, s, decodedResponse, func() {
assert.Equal(t, s, decodedResponse.RspObj)
})

var intObj int64
decodedResponse.RspObj = &intObj
doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, int64(3), decodedResponse, func() {
assert.Equal(t, int64(3), decodedResponse.RspObj)
})

boolObj := false
decodedResponse.RspObj = &boolObj
doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, nil)
doTestResponse(t, PackageResponse, Response_OK, true, decodedResponse, func() {
assert.Equal(t, true, decodedResponse.RspObj)
})

strObj = ""
decodedResponse.RspObj = &strObj
doTestResponse(t, PackageResponse, Response_SERVER_ERROR, "error!!!!!", decodedResponse, nil)
errorMsg := "error!!!!!"
decodedResponse.RspObj = nil
doTestResponse(t, PackageResponse, Response_SERVER_ERROR, errorMsg, decodedResponse, func() {
assert.Equal(t, "java exception:error!!!!!", decodedResponse.Exception.Error())
})

decodedResponse.RspObj = nil
decodedResponse.Exception = nil
mapObj := map[string][]*Case{"key": {&caseObj}}
mapRes := map[interface{}]interface{}{}
decodedResponse.RspObj = &mapRes
doTestResponse(t, PackageResponse, Response_OK, mapObj, decodedResponse, func() {
mapRes, ok := decodedResponse.RspObj.(map[interface{}]interface{})
if !ok {
t.Errorf("expect map[string][]*Case, but get %s", reflect.TypeOf(decodedResponse.RspObj).String())
return
}
c, ok := mapRes["key"]
if !ok {
assert.FailNow(t, "no key in decoded response map")
@@ -211,7 +224,8 @@ func TestHessianCodec_ReadAttachments(t *testing.T) {
t.Log(h)

err = codecR1.ReadBody(body)
assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error())
assert.NoError(t, err)
// assert.Equal(t, "can not find go type name com.test.caseb in registry", err.Error())
attrs, err := codecR2.ReadAttachments()
assert.NoError(t, err)
assert.Equal(t, "2.6.4", attrs[DUBBO_VERSION_KEY])
2 int.go
@@ -110,7 +110,7 @@ func (d *Decoder) decInt32(flag int32) (int32, error) {
}
}

func (d *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) {
func (e *Encoder) encTypeInt32(b []byte, p interface{}) ([]byte, error) {
value := reflect.ValueOf(p)
if PackPtr(value).IsNil() {
return EncNull(b), nil
@@ -85,7 +85,7 @@ func (JavaCollectionSerializer) EncObject(e *Encoder, vv POJO) error {
return nil
}

func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) {
func (JavaCollectionSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) {
// for the java impl of hessian encode collections as list, which will not be decoded as object in go impl, this method should not be called
return nil, perrors.New("unexpected collection decode call")
}
@@ -49,10 +49,9 @@ type JavaSqlTimeSerializer struct{}
// nolint
func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
var (
i int
idx int
err error
clsDef *classInfo
clsDef *ClassInfo
className string
ptrV reflect.Value
)
@@ -78,13 +77,8 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
}

// write object definition
idx = -1
for i = range e.classInfoList {
if v.JavaClassName() == e.classInfoList[i].javaName {
idx = i
break
}
}
idx = e.classIndex(v.JavaClassName())

if idx == -1 {
idx, ok = checkPOJORegistry(vv)
if !ok {
@@ -114,7 +108,7 @@ func (JavaSqlTimeSerializer) EncObject(e *Encoder, vv POJO) error {
}

// nolint
func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *classInfo) (interface{}, error) {
func (JavaSqlTimeSerializer) DecObject(d *Decoder, typ reflect.Type, cls *ClassInfo) (interface{}, error) {
if typ.Kind() != reflect.Struct {
return nil, perrors.Errorf("wrong type expect Struct but get:%s", typ.String())
}
@@ -28,7 +28,7 @@ import (

var exceptionCheckMutex sync.Mutex

func checkAndGetException(cls *classInfo) (*structInfo, bool) {
func checkAndGetException(cls *ClassInfo) (*structInfo, bool) {
if len(cls.fieldNameList) < 4 {
return nil, false
}
@@ -26,7 +26,7 @@ import (
)

func TestCheckAndGetException(t *testing.T) {
clazzInfo1 := &classInfo{
clazzInfo1 := &ClassInfo{
javaName: "com.test.UserDefinedException",
fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "stackTrace", "cause"},
}
@@ -36,7 +36,7 @@ func TestCheckAndGetException(t *testing.T) {
assert.Equal(t, s.javaName, "com.test.UserDefinedException")
assert.Equal(t, s.goName, "github.com/apache/dubbo-go-hessian2/hessian.UnknownException")

clazzInfo2 := &classInfo{
clazzInfo2 := &ClassInfo{
javaName: "com.test.UserDefinedException",
fieldNameList: []string{"detailMessage", "code", "suppressedExceptions", "cause"},
}
9 map.go
@@ -20,9 +20,7 @@ package hessian
import (
"io"
"reflect"
)

import (
perrors "github.com/pkg/errors"
)

@@ -110,6 +108,13 @@ func (e *Encoder) encMap(m interface{}) error {
return nil
}

// check whether it should encode the map as class.
if mm, ok := m.(map[string]interface{}); ok {
if _, ok = mm[ClassKey]; ok {
return e.EncodeMapClass(mm)
}
}

value = UnpackPtrValue(value)
// check nil map
if value.Kind() == reflect.Ptr && !value.Elem().IsValid() {

0 comments on commit ab8e186

Please sign in to comment.