Skip to content

Commit

Permalink
optimize decode
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidCai1111 committed Dec 12, 2016
1 parent 9729749 commit b8ca1c0
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 31 deletions.
28 changes: 15 additions & 13 deletions decode.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
package jwt

import (
"bytes"
"encoding/base64"
"encoding/json"
"strings"
)

func decode(token string) (dh map[string]interface{}, dp map[string]interface{}, err error) {
splited := strings.Split(token, ".")
func decode(token []byte) (header map[string]interface{}, payload map[string]interface{}, err error) {
segments := bytes.Split(token, periodBytes)

if len(splited) != 3 {
if len(segments) != 3 {
return nil, nil, ErrInvalidToken
}

h, err := base64.StdEncoding.DecodeString(splited[0])

if err != nil {
if header, err = decodeSegment(segments[0]); err != nil {
return nil, nil, err
}

if err := json.Unmarshal(h, dh); err != nil {
if payload, err = decodeSegment(segments[1]); err != nil {
return nil, nil, err
}

p, err := base64.StdEncoding.DecodeString(splited[1])
return header, payload, nil
}

func decodeSegment(segment []byte) (m map[string]interface{}, err error) {
s, err := base64.StdEncoding.DecodeString(string(segment))

if err != nil {
return nil, nil, err
return nil, err
}

if err := json.Unmarshal(p, dp); err != nil {
return nil, nil, err
if err := json.Unmarshal(s, &m); err != nil {
return nil, err
}

return dh, dp, nil
return
}
49 changes: 49 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package jwt

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestDecode(t *testing.T) {
assert := assert.New(t)

t.Run("Should return ErrInvalidToken when token is invalid", func(t *testing.T) {
_, _, err := decode([]byte("a.b"))

assert.Equal(ErrInvalidToken, err)
})

t.Run("Should return origin header and payload", func(t *testing.T) {
custom := map[string]interface{}{
"test1k": "test1v",
"test2k": float64(234),
}

opt := &SignOption{
Algorithm: HS256,
Issuer: "testIssuer",
Subject: "tsetSubject",
Audience: "testAudience",
ExpiresIn: time.Minute,
}

signed, err := Sign(custom, "key", opt)

header, payload, err := decode(signed)

assert.Nil(err)
assert.Equal(2, len(header))
assert.Equal(6, len(payload))
assert.Equal(string(HS256), header["alg"])
assert.Equal("JWT", header["typ"])
assert.Equal(opt.Subject, payload["sub"])
assert.Equal(opt.Issuer, payload["iss"])
assert.Equal(opt.Audience, payload["aud"])
assert.Equal(float64(60), payload["exp"])
assert.Equal(custom["test1k"], payload["test1k"])
assert.Equal(custom["test2k"], payload["test2k"])
})
}
6 changes: 2 additions & 4 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ func (ha hmacAlgImp) sign(content []byte, secret interface{}) ([]byte, error) {

h := hmac.New(ha.hashFunc, s)

if _, err := h.Write(content); err != nil {
return nil, err
}
h.Write(content)

return h.Sum(nil), nil
}

func (ha hmacAlgImp) verify(signing []byte, key interface{}) error {
func (ha hmacAlgImp) verify(signing []byte, secret interface{}) error {
return nil
}
18 changes: 10 additions & 8 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@ import (
"crypto"
"crypto/rand"
"crypto/rsa"
"hash"
)

func init() {
algImpMap[RS256] = rsaAlgImp{ch: crypto.SHA256}
algImpMap[RS384] = rsaAlgImp{ch: crypto.SHA384}
algImpMap[RS512] = rsaAlgImp{ch: crypto.SHA512}
algImpMap[RS256] = rsaAlgImp{hash: crypto.SHA256}
algImpMap[RS384] = rsaAlgImp{hash: crypto.SHA384}
algImpMap[RS512] = rsaAlgImp{hash: crypto.SHA512}
}

type rsaAlgImp struct {
ch crypto.Hash
hh hash.Hash
hash crypto.Hash
}

func (ra rsaAlgImp) sign(content []byte, privateKey interface{}) ([]byte, error) {
pk, ok := privateKey.(*rsa.PrivateKey)
key, ok := privateKey.(*rsa.PrivateKey)

if !ok {
return nil, ErrInvalidKeyType
}

return rsa.SignPKCS1v15(rand.Reader, pk, ra.ch, ra.hh.Sum(content))
h := ra.hash.New()

h.Write(content)

return rsa.SignPKCS1v15(rand.Reader, key, ra.hash, h.Sum(nil))
}

func (ra rsaAlgImp) verify(signing []byte, key interface{}) error {
Expand Down
28 changes: 27 additions & 1 deletion sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package jwt

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"testing"
"time"
Expand Down Expand Up @@ -124,7 +126,7 @@ func TestSign(t *testing.T) {
assert.Equal(ErrEmptySecretOrPrivateKey, err)
})

t.Run("Should return with three parts", func(t *testing.T) {
t.Run("Should return with three parts and using HMAC", func(t *testing.T) {
custom := map[string]interface{}{
"test1k": "test1v",
"test2k": float64(234),
Expand All @@ -143,4 +145,28 @@ func TestSign(t *testing.T) {
assert.Nil(err)
assert.Equal(3, len(bytes.Split(signed, periodBytes)))
})

t.Run("Should return with three parts and using RSA", func(t *testing.T) {
custom := map[string]interface{}{
"test1k": "test1v",
"test2k": float64(234),
}

opt := &SignOption{
Algorithm: RS256,
Issuer: "testIssuer",
Subject: "tsetSubject",
Audience: "testAudience",
ExpiresIn: time.Minute,
}

key, err := rsa.GenerateKey(rand.Reader, 1024)

assert.Nil(err)

signed, err := Sign(custom, key, opt)

assert.Nil(err)
assert.Equal(3, len(bytes.Split(signed, periodBytes)))
})
}
10 changes: 5 additions & 5 deletions verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type VerifyOption struct {

// Verify decodes the given token and check whether the token is valid.
func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (map[string]interface{}, map[string]interface{}, error) {
header, payload, err := decode(string(token))
header, payload, err := decode(token)

if err != nil {
return nil, nil, err
Expand All @@ -28,6 +28,10 @@ func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (ma
algImp algorithmImplementation
)

if err := algImp.verify(token, secretOrPrivateKey); err != nil {
return nil, nil, ErrInvalidSignature
}

if typ, ok = header["typ"]; !ok {
return nil, nil, ErrInvalidHeaderType
}
Expand All @@ -44,9 +48,5 @@ func Verify(token []byte, secretOrPrivateKey interface{}, opt *VerifyOption) (ma
return nil, nil, ErrInvalidAlgorithm
}

if err := algImp.verify(token, secretOrPrivateKey); err != nil {
return nil, nil, err
}

return header, payload, nil
}

0 comments on commit b8ca1c0

Please sign in to comment.