Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Commit

Permalink
Merge pull request #5 from XenitAB/key-rotation
Browse files Browse the repository at this point in the history
add and test key rotation
  • Loading branch information
simongottschlag committed Jul 15, 2021
2 parents 0f41c48 + 30ac72e commit 6ccfc7b
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 83 deletions.
98 changes: 84 additions & 14 deletions key/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,133 @@ import (
"crypto/elliptic"
"crypto/rand"
"fmt"
"sync"

"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/xenitab/dispans/models"
)

type handler struct {
privateKey models.PrivateKey
publicKey models.PublicKey
sync.RWMutex
privateKeys []models.PrivateKey
publicKeys []models.PublicKey
}

func NewHandler() (*handler, error) {
h := &handler{
privateKeys: []models.PrivateKey{},
publicKeys: []models.PublicKey{},
}

err := h.AddNewKey()
if err != nil {
return nil, err
}

return h, nil
}

func (h *handler) AddNewKey() error {
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
fmt.Printf("failed to generate new ECDSA privatre key: %s\n", err)
return nil, err
return err
}

key, err := jwk.New(ecdsaKey)
if err != nil {
return nil, err
return err
}

if _, ok := key.(jwk.ECDSAPrivateKey); !ok {
return nil, fmt.Errorf("expected jwk.ECDSAPrivateKey, got %T", key)
return fmt.Errorf("expected jwk.ECDSAPrivateKey, got %T", key)
}

thumbprint, err := key.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
return err
}

keyID := fmt.Sprintf("%x", thumbprint)
key.Set(jwk.KeyIDKey, keyID)

pubKey, err := jwk.New(ecdsaKey.PublicKey)
if err != nil {
return nil, err
return err
}

if _, ok := pubKey.(jwk.ECDSAPublicKey); !ok {
return nil, fmt.Errorf("expected jwk.ECDSAPublicKey, got %T", key)
return fmt.Errorf("expected jwk.ECDSAPublicKey, got %T", key)
}

pubKey.Set(jwk.KeyIDKey, keyID)
pubKey.Set(jwk.AlgorithmKey, jwa.ES384)

return &handler{
privateKey: key,
publicKey: pubKey,
}, nil
h.Lock()

h.privateKeys = append(h.privateKeys, key)
h.publicKeys = append(h.publicKeys, pubKey)

h.Unlock()

return nil
}

func (h *handler) RemoveOldestKey() error {
h.RLock()
privKeysLen := len(h.privateKeys)
pubKeysLen := len(h.publicKeys)
h.RUnlock()

if privKeysLen != pubKeysLen {
return fmt.Errorf("Private keys length (%d) isn't equal private keys length (%d).", privKeysLen, pubKeysLen)
}

if privKeysLen <= 1 {
return fmt.Errorf("Keys length smaller or equal 1: %d", privKeysLen)
}

h.Lock()
h.privateKeys = h.privateKeys[1:]
h.publicKeys = h.publicKeys[1:]
h.Unlock()

return nil
}

func (h *handler) GetPrivateKey() models.PrivateKey {
return h.privateKey
h.RLock()

lastKeyIndex := len(h.privateKeys) - 1
privKey := h.privateKeys[lastKeyIndex]

h.RUnlock()

return privKey
}

func (h *handler) GetPublicKey() models.PublicKey {
return h.publicKey
h.RLock()

lastKeyIndex := len(h.publicKeys) - 1
pubKey := h.publicKeys[lastKeyIndex]

h.RUnlock()

return pubKey
}

func (h *handler) GetPublicKeySet() models.PublicKeySet {
keySet := jwk.NewSet()

h.RLock()

for _, pubKey := range h.publicKeys {
keySet.Add(pubKey)
}

h.RUnlock()

return keySet
}
119 changes: 119 additions & 0 deletions key/key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package key

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewHandler(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

require.Equal(t, 1, len(keyHandler.privateKeys))
require.Equal(t, 1, len(keyHandler.publicKeys))
}

func TestAddNewKey(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

err = keyHandler.AddNewKey()
require.NoError(t, err)

require.Equal(t, 2, len(keyHandler.privateKeys))
require.Equal(t, 2, len(keyHandler.publicKeys))
}

func TestRemoveOldestKey(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

err = keyHandler.RemoveOldestKey()
require.Error(t, err)

err = keyHandler.AddNewKey()
require.NoError(t, err)

require.Equal(t, 2, len(keyHandler.privateKeys))
require.Equal(t, 2, len(keyHandler.publicKeys))

secondPrivKey := keyHandler.privateKeys[1]
secondPubKey := keyHandler.publicKeys[1]

err = keyHandler.RemoveOldestKey()
require.NoError(t, err)

require.Equal(t, secondPrivKey, keyHandler.privateKeys[0])
require.Equal(t, secondPubKey, keyHandler.publicKeys[0])
}

func TestGetPrivateKey(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

require.Equal(t, keyHandler.privateKeys[0], keyHandler.GetPrivateKey())

err = keyHandler.AddNewKey()
require.NoError(t, err)

require.Equal(t, keyHandler.privateKeys[1], keyHandler.GetPrivateKey())

err = keyHandler.RemoveOldestKey()
require.NoError(t, err)

require.Equal(t, keyHandler.privateKeys[0], keyHandler.GetPrivateKey())
}

func TestGetPublicKey(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

require.Equal(t, keyHandler.publicKeys[0], keyHandler.GetPublicKey())

err = keyHandler.AddNewKey()
require.NoError(t, err)

require.Equal(t, keyHandler.publicKeys[1], keyHandler.GetPublicKey())

err = keyHandler.RemoveOldestKey()
require.NoError(t, err)

require.Equal(t, keyHandler.publicKeys[0], keyHandler.GetPublicKey())
}

func TestGetPublicKeySet(t *testing.T) {
keyHandler, err := NewHandler()
require.NoError(t, err)

keySet := keyHandler.GetPublicKeySet()
key, ok := keySet.Get(0)
require.True(t, ok)

require.Equal(t, keyHandler.publicKeys[0], key)

err = keyHandler.AddNewKey()
require.NoError(t, err)

keySet = keyHandler.GetPublicKeySet()
firstKey, ok := keySet.Get(0)
require.True(t, ok)

secondKey, ok := keySet.Get(1)
require.True(t, ok)

require.Equal(t, keyHandler.publicKeys[0], firstKey)
require.Equal(t, keyHandler.publicKeys[1], secondKey)

err = keyHandler.RemoveOldestKey()
require.NoError(t, err)

keySet = keyHandler.GetPublicKeySet()
key, ok = keySet.Get(0)
require.True(t, ok)

require.Equal(t, keyHandler.publicKeys[0], key)

_, ok = keySet.Get(1)
require.False(t, ok)
}
16 changes: 16 additions & 0 deletions models/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@ import "github.com/lestrrat-go/jwx/jwk"
type PrivateKey jwk.Key
type PublicKey jwk.Key

type PublicKeySet jwk.Set

type PrivateKeyGetter interface {
GetPrivateKey() PrivateKey
}

type PublicKeyGetter interface {
GetPublicKey() PublicKey
GetPublicKeySet() PublicKeySet
}

type KeysGetter interface {
PrivateKeyGetter
PublicKeyGetter
}

type KeysAdder interface {
AddNewKey() error
}

type KeysRemover interface {
RemoveOldestKey() error
}

type KeysUpdater interface {
KeysAdder
KeysRemover
}
4 changes: 2 additions & 2 deletions route/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ func (h *handler) Test(w http.ResponseWriter, r *http.Request) {
e.Encode(data)
}

func (h *handler) Jwk(w http.ResponseWriter, r *http.Request) {
func (h *handler) Jwks(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")

pubKey := h.publicKeyHandler.GetPublicKey()
pubKey := h.publicKeyHandler.GetPublicKeySet()

e := json.NewEncoder(w)
e.SetIndent("", " ")
Expand Down
4 changes: 2 additions & 2 deletions route/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func TestTestEndpoint(t *testing.T) {
require.Equal(t, testClientID, data.ClientID)
}

func TestJwk(t *testing.T) {
func TestJwks(t *testing.T) {
routeHandler := testNewRouteHandler(t)

oauthInfo := testGetOAuthInformation(t, "openid profile email")
Expand All @@ -179,7 +179,7 @@ func TestJwk(t *testing.T) {
req := httptest.NewRequest("GET", "/jwk", nil)

w := httptest.NewRecorder()
routeHandler.Jwk(w, req)
routeHandler.Jwks(w, req)
res := w.Result()

require.Equal(t, http.StatusOK, res.StatusCode)
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func newRouter(as models.AuthorizationServer, issuerHandler models.IssuerGetter,
router.HandleFunc("/oauth/authorize", routeHandler.Authorize)
router.HandleFunc("/oauth/token", routeHandler.Token)
router.HandleFunc("/test", routeHandler.Test)
router.HandleFunc("/jwk", routeHandler.Jwk)
router.HandleFunc("/jwks", routeHandler.Jwks)
router.HandleFunc("/.well-known/openid-configuration", routeHandler.Discovery)

return router, nil
Expand Down
Loading

0 comments on commit 6ccfc7b

Please sign in to comment.