Skip to content

Commit

Permalink
Prevent a race condition with jsonKey.precomputed and remove intermit…
Browse files Browse the repository at this point in the history
…tent curve variable
  • Loading branch information
Micah Parks committed Sep 17, 2021
1 parent 79eea7e commit f783be7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
13 changes: 8 additions & 5 deletions ecdsa.go
Expand Up @@ -33,12 +33,15 @@ const (
func (j *jsonKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {

// Check if the key has already been computed.
j.precomputedMux.RLock()
if j.precomputed != nil {
var ok bool
if publicKey, ok = j.precomputed.(*ecdsa.PublicKey); ok {
j.precomputedMux.RUnlock()
return publicKey, nil
}
}
j.precomputedMux.RUnlock()

// Confirm everything needed is present.
if j.X == "" || j.Y == "" || j.Curve == "" {
Expand All @@ -64,16 +67,14 @@ func (j *jsonKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
publicKey = &ecdsa.PublicKey{}

// Set the curve type.
var curve elliptic.Curve
switch j.Curve {
case p256:
curve = elliptic.P256()
publicKey.Curve = elliptic.P256()
case p384:
curve = elliptic.P384()
publicKey.Curve = elliptic.P384()
case p521:
curve = elliptic.P521()
publicKey.Curve = elliptic.P521()
}
publicKey.Curve = curve

// Turn the X coordinate into *big.Int.
//
Expand All @@ -85,7 +86,9 @@ func (j *jsonKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)

// Keep the public key so it won't have to be computed every time.
j.precomputedMux.Lock()
j.precomputed = publicKey
j.precomputedMux.Unlock()

return publicKey, nil
}
2 changes: 2 additions & 0 deletions hmac.go
Expand Up @@ -22,6 +22,8 @@ const (
func (j *jsonKey) HMAC() (key []byte, err error) {

// Confirm the key is already present as expected.
j.precomputedMux.RLock()
defer j.precomputedMux.RUnlock()
if j.precomputed != nil {
var ok bool
if key, ok = j.precomputed.([]byte); ok {
Expand Down
15 changes: 8 additions & 7 deletions jwks.go
Expand Up @@ -23,13 +23,14 @@ type ErrorHandler func(err error)

// jsonKey represents a raw key inside a JWKs.
type jsonKey struct {
Curve string `json:"crv"`
Exponent string `json:"e"`
ID string `json:"kid"`
Modulus string `json:"n"`
X string `json:"x"`
Y string `json:"y"`
precomputed interface{}
Curve string `json:"crv"`
Exponent string `json:"e"`
ID string `json:"kid"`
Modulus string `json:"n"`
X string `json:"x"`
Y string `json:"y"`
precomputed interface{}
precomputedMux sync.RWMutex
}

// JWKs represents a JSON Web Key Set.
Expand Down
2 changes: 2 additions & 0 deletions keyfunc.go
Expand Up @@ -49,6 +49,8 @@ func (j *JWKs) Keyfunc(token *jwt.Token) (interface{}, error) {
default:

// Assume there's a given key for a custom algorithm.
key.precomputedMux.RLock()
defer key.precomputedMux.RUnlock()
if key.precomputed != nil {
return key.precomputed, nil
}
Expand Down
5 changes: 5 additions & 0 deletions rsa.go
Expand Up @@ -32,12 +32,15 @@ const (
func (j *jsonKey) RSA() (publicKey *rsa.PublicKey, err error) {

// Check if the key has already been computed.
j.precomputedMux.RLock()
if j.precomputed != nil {
var ok bool
if publicKey, ok = j.precomputed.(*rsa.PublicKey); ok {
j.precomputedMux.RUnlock()
return publicKey, nil
}
}
j.precomputedMux.RUnlock()

// Confirm everything needed is present.
if j.Exponent == "" || j.Modulus == "" {
Expand Down Expand Up @@ -72,7 +75,9 @@ func (j *jsonKey) RSA() (publicKey *rsa.PublicKey, err error) {
publicKey.N = big.NewInt(0).SetBytes(modulus)

// Keep the public key so it won't have to be computed every time.
j.precomputedMux.Lock()
j.precomputed = publicKey
j.precomputedMux.Unlock()

return publicKey, nil
}

0 comments on commit f783be7

Please sign in to comment.