From b8ca1c0e69c7be6f93fa4a31d4f0f2f1cc2d19ec Mon Sep 17 00:00:00 2001 From: DavidCai Date: Mon, 12 Dec 2016 13:26:47 +0800 Subject: [PATCH] optimize decode --- decode.go | 28 +++++++++++++++------------- decode_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ hmac.go | 6 ++---- rsa.go | 18 ++++++++++-------- sign_test.go | 28 +++++++++++++++++++++++++++- verify.go | 10 +++++----- 6 files changed, 108 insertions(+), 31 deletions(-) create mode 100644 decode_test.go diff --git a/decode.go b/decode.go index 8a74975..4a15a84 100644 --- a/decode.go +++ b/decode.go @@ -1,37 +1,39 @@ package jwt import ( + "bytes" "encoding/base64" "encoding/json" - "strings" ) -func decode(token string) (dh map[string]interface{}, dp map[string]interface{}, err error) { - splited := strings.Split(token, ".") +func decode(token []byte) (header map[string]interface{}, payload map[string]interface{}, err error) { + segments := bytes.Split(token, periodBytes) - if len(splited) != 3 { + if len(segments) != 3 { return nil, nil, ErrInvalidToken } - h, err := base64.StdEncoding.DecodeString(splited[0]) - - if err != nil { + if header, err = decodeSegment(segments[0]); err != nil { return nil, nil, err } - if err := json.Unmarshal(h, dh); err != nil { + if payload, err = decodeSegment(segments[1]); err != nil { return nil, nil, err } - p, err := base64.StdEncoding.DecodeString(splited[1]) + return header, payload, nil +} + +func decodeSegment(segment []byte) (m map[string]interface{}, err error) { + s, err := base64.StdEncoding.DecodeString(string(segment)) if err != nil { - return nil, nil, err + return nil, err } - if err := json.Unmarshal(p, dp); err != nil { - return nil, nil, err + if err := json.Unmarshal(s, &m); err != nil { + return nil, err } - return dh, dp, nil + return } diff --git a/decode_test.go b/decode_test.go new file mode 100644 index 0000000..432211c --- /dev/null +++ b/decode_test.go @@ -0,0 +1,49 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDecode(t *testing.T) { + assert := assert.New(t) + + t.Run("Should return ErrInvalidToken when token is invalid", func(t *testing.T) { + _, _, err := decode([]byte("a.b")) + + assert.Equal(ErrInvalidToken, err) + }) + + t.Run("Should return origin header and payload", func(t *testing.T) { + custom := map[string]interface{}{ + "test1k": "test1v", + "test2k": float64(234), + } + + opt := &SignOption{ + Algorithm: HS256, + Issuer: "testIssuer", + Subject: "tsetSubject", + Audience: "testAudience", + ExpiresIn: time.Minute, + } + + signed, err := Sign(custom, "key", opt) + + header, payload, err := decode(signed) + + assert.Nil(err) + assert.Equal(2, len(header)) + assert.Equal(6, len(payload)) + assert.Equal(string(HS256), header["alg"]) + assert.Equal("JWT", header["typ"]) + assert.Equal(opt.Subject, payload["sub"]) + assert.Equal(opt.Issuer, payload["iss"]) + assert.Equal(opt.Audience, payload["aud"]) + assert.Equal(float64(60), payload["exp"]) + assert.Equal(custom["test1k"], payload["test1k"]) + assert.Equal(custom["test2k"], payload["test2k"]) + }) +} diff --git a/hmac.go b/hmac.go index 0ab7293..f27be7c 100644 --- a/hmac.go +++ b/hmac.go @@ -30,13 +30,11 @@ func (ha hmacAlgImp) sign(content []byte, secret interface{}) ([]byte, error) { h := hmac.New(ha.hashFunc, s) - if _, err := h.Write(content); err != nil { - return nil, err - } + h.Write(content) return h.Sum(nil), nil } -func (ha hmacAlgImp) verify(signing []byte, key interface{}) error { +func (ha hmacAlgImp) verify(signing []byte, secret interface{}) error { return nil } diff --git a/rsa.go b/rsa.go index ba65f4f..ceaed97 100644 --- a/rsa.go +++ b/rsa.go @@ -4,28 +4,30 @@ import ( "crypto" "crypto/rand" "crypto/rsa" - "hash" ) func init() { - algImpMap[RS256] = rsaAlgImp{ch: crypto.SHA256} - algImpMap[RS384] = rsaAlgImp{ch: crypto.SHA384} - algImpMap[RS512] = rsaAlgImp{ch: crypto.SHA512} + algImpMap[RS256] = rsaAlgImp{hash: crypto.SHA256} + algImpMap[RS384] = rsaAlgImp{hash: crypto.SHA384} + algImpMap[RS512] = rsaAlgImp{hash: crypto.SHA512} } type rsaAlgImp struct { - ch crypto.Hash - hh hash.Hash + hash crypto.Hash } func (ra rsaAlgImp) sign(content []byte, privateKey interface{}) ([]byte, error) { - pk, ok := privateKey.(*rsa.PrivateKey) + key, ok := privateKey.(*rsa.PrivateKey) if !ok { return nil, ErrInvalidKeyType } - return rsa.SignPKCS1v15(rand.Reader, pk, ra.ch, ra.hh.Sum(content)) + h := ra.hash.New() + + h.Write(content) + + return rsa.SignPKCS1v15(rand.Reader, key, ra.hash, h.Sum(nil)) } func (ra rsaAlgImp) verify(signing []byte, key interface{}) error { diff --git a/sign_test.go b/sign_test.go index fce1524..74ea2ed 100644 --- a/sign_test.go +++ b/sign_test.go @@ -2,6 +2,8 @@ package jwt import ( "bytes" + "crypto/rand" + "crypto/rsa" "encoding/json" "testing" "time" @@ -124,7 +126,7 @@ func TestSign(t *testing.T) { assert.Equal(ErrEmptySecretOrPrivateKey, err) }) - t.Run("Should return with three parts", func(t *testing.T) { + t.Run("Should return with three parts and using HMAC", func(t *testing.T) { custom := map[string]interface{}{ "test1k": "test1v", "test2k": float64(234), @@ -143,4 +145,28 @@ func TestSign(t *testing.T) { assert.Nil(err) assert.Equal(3, len(bytes.Split(signed, periodBytes))) }) + + t.Run("Should return with three parts and using RSA", func(t *testing.T) { + custom := map[string]interface{}{ + "test1k": "test1v", + "test2k": float64(234), + } + + opt := &SignOption{ + Algorithm: RS256, + Issuer: "testIssuer", + Subject: "tsetSubject", + Audience: "testAudience", + ExpiresIn: time.Minute, + } + + key, err := rsa.GenerateKey(rand.Reader, 1024) + + assert.Nil(err) + + signed, err := Sign(custom, key, opt) + + assert.Nil(err) + assert.Equal(3, len(bytes.Split(signed, periodBytes))) + }) } diff --git a/verify.go b/verify.go index e1ada57..d154beb 100644 --- a/verify.go +++ b/verify.go @@ -15,7 +15,7 @@ type VerifyOption struct { // Verify decodes the given token and check whether the token is valid. func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (map[string]interface{}, map[string]interface{}, error) { - header, payload, err := decode(string(token)) + header, payload, err := decode(token) if err != nil { return nil, nil, err @@ -28,6 +28,10 @@ func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (ma algImp algorithmImplementation ) + if err := algImp.verify(token, secretOrPrivateKey); err != nil { + return nil, nil, ErrInvalidSignature + } + if typ, ok = header["typ"]; !ok { return nil, nil, ErrInvalidHeaderType } @@ -44,9 +48,5 @@ func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (ma return nil, nil, ErrInvalidAlgorithm } - if err := algImp.verify(token, secretOrPrivateKey); err != nil { - return nil, nil, err - } - return header, payload, nil }