Skip to content

Commit

Permalink
More idiomatic RSA testing
Browse files Browse the repository at this point in the history
Skip most of the OAEP tests on SoftHSM, as support is only partial.
  • Loading branch information
Richard Kettlewell authored and Richard Kettlewell committed Aug 3, 2018
1 parent 1b06889 commit d3fcc57
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 68 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ The configuration looks like this:
"Pin" : "password"
}

(At time of writing) PSS and OAEP aren't supported so expect test failures.
(At time of writing) OAEP is only partial, so expect test skips.

Limitations
===========
Expand Down
179 changes: 112 additions & 67 deletions rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"fmt"
"github.com/miekg/pkcs11"
"testing"
)

Expand All @@ -37,18 +39,24 @@ var rsaSizes = []int{1024, 2048}
func TestNativeRSA(t *testing.T) {
var err error
var key *rsa.PrivateKey
ConfigureFromFile("config")
for _, nbits := range rsaSizes {
if key, err = rsa.GenerateKey(rand.Reader, nbits); err != nil {
t.Errorf("crypto.rsa.GenerateKey: %v", err)
return
}
if err = key.Validate(); err != nil {
t.Errorf("crypto.rsa.PrivateKey.Validate: %v", err)
return
}
testRsaSigning(t, key, nbits)
testRsaEncryption(t, key, nbits)
t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) {
t.Run("Generate", func(t *testing.T) {
if key, err = rsa.GenerateKey(rand.Reader, nbits); err != nil {
t.Errorf("crypto.rsa.GenerateKey: %v", err)
return
}
if err = key.Validate(); err != nil {
t.Errorf("crypto.rsa.PrivateKey.Validate: %v", err)
return
}
})
t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits) })
t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits) })
})
}
Close()
}

func TestHardRSA(t *testing.T) {
Expand All @@ -58,52 +66,74 @@ func TestHardRSA(t *testing.T) {
var id, label []byte
ConfigureFromFile("config")
for _, nbits := range rsaSizes {
if key, err = GenerateRSAKeyPair(nbits); err != nil {
t.Errorf("crypto11.GenerateRSAKeyPair: %v", err)
return
}
if key == nil {
t.Errorf("crypto11.dsa.GenerateRSAKeyPair: returned nil but no error")
return
}
if err = key.Validate(); err != nil {
t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Validate: %v", err)
return
}
testRsaSigning(t, key, nbits)
testRsaEncryption(t, key, nbits)
// Get a fresh handle to the key
if id, label, err = key.Identify(); err != nil {
t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Identify: %v", err)
return
}
if key2, err = FindKeyPair(id, nil); err != nil {
t.Errorf("crypto11.rsa.FindRSAKeyPair by id: %v", err)
return
}
testRsaSigning(t, key2.(*PKCS11PrivateKeyRSA), nbits)
if key3, err = FindKeyPair(nil, label); err != nil {
t.Errorf("crypto11.rsa.FindKeyPair by label: %v", err)
return
}
testRsaSigning(t, key3.(crypto.Signer), nbits)
t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) {
t.Run("Generate", func(t *testing.T) {
if key, err = GenerateRSAKeyPair(nbits); err != nil {
t.Errorf("crypto11.GenerateRSAKeyPair: %v", err)
return
}
if key == nil {
t.Errorf("crypto11.dsa.GenerateRSAKeyPair: returned nil but no error")
return
}
if err = key.Validate(); err != nil {
t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Validate: %v", err)
return
}
})
t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits) })
t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits) })
t.Run("FindId", func(t *testing.T) {
// Get a fresh handle to the key
if id, label, err = key.Identify(); err != nil {
t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Identify: %v", err)
return
}
if key2, err = FindKeyPair(id, nil); err != nil {
t.Errorf("crypto11.rsa.FindRSAKeyPair by id: %v", err)
return
}
})
t.Run("SignId", func(t *testing.T) {
if key2 == nil {
t.SkipNow()
}
testRsaSigning(t, key2.(*PKCS11PrivateKeyRSA), nbits)
})
t.Run("FindLabel", func(t *testing.T) {
if key3, err = FindKeyPair(nil, label); err != nil {
t.Errorf("crypto11.rsa.FindKeyPair by label: %v", err)
return
}
})
t.Run("SignLabel", func(t *testing.T) {
if key3 == nil {
t.SkipNow()
}
testRsaSigning(t, key3.(crypto.Signer), nbits)
})
})
}
Close()
}

func testRsaSigning(t *testing.T, key crypto.Signer, nbits int) {
testRsaSigningPKCS1v15(t, key, crypto.SHA1)
testRsaSigningPKCS1v15(t, key, crypto.SHA224)
testRsaSigningPKCS1v15(t, key, crypto.SHA256)
testRsaSigningPKCS1v15(t, key, crypto.SHA384)
testRsaSigningPKCS1v15(t, key, crypto.SHA512)
testRsaSigningPSS(t, key, crypto.SHA1)
testRsaSigningPSS(t, key, crypto.SHA224)
testRsaSigningPSS(t, key, crypto.SHA256)
testRsaSigningPSS(t, key, crypto.SHA384)
if nbits > 1024 { // key too smol for SHA512 with sLen=hLen
testRsaSigningPSS(t, key, crypto.SHA512)
}
t.Run("SHA1", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA1) })
t.Run("SHA224", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA224) })
t.Run("SHA256", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA256) })
t.Run("SHA384", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA384) })
t.Run("SHA512", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA512) })
t.Run("PSSSHA1", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA1) })
t.Run("PSSSHA224", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA224) })
t.Run("PSSSHA256", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA256) })
t.Run("PSSSHA384", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA384) })
t.Run("PSSSHA512", func(t *testing.T) {
if nbits > 1024 {
testRsaSigningPSS(t, key, crypto.SHA512)
} else {
t.Skipf("key too smol for SHA512 with sLen=hLen")
}
})
}

func testRsaSigningPKCS1v15(t *testing.T, key crypto.Signer, hashFunction crypto.Hash) {
Expand Down Expand Up @@ -147,21 +177,29 @@ func testRsaSigningPSS(t *testing.T, key crypto.Signer, hashFunction crypto.Hash
}

func testRsaEncryption(t *testing.T, key crypto.Decrypter, nbits int) {
testRsaEncryptionPKCS1v15(t, key)
testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{})
testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{})
testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{})
testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{})
if nbits > 1024 { // key too smol for SHA512
testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{})
}
testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4})
testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8})
testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9})
testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15})
if nbits > 1024 {
testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18})
}
t.Run("PKCS1v15", func(t *testing.T) { testRsaEncryptionPKCS1v15(t, key) })
t.Run("OAEPSHA1", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{}) })
t.Run("OAEPSHA224", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{}) })
t.Run("OAEPSHA256", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{}) })
t.Run("OAEPSHA384", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{}) })
t.Run("OAEPSHA512", func(t *testing.T) {
if nbits > 1024 {
testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{})
} else {
t.Skipf("key too smol for SHA512")
}
})
t.Run("OAEPSHA1Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4}) })
t.Run("OAEPSHA224Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8}) })
t.Run("OAEPSHA256Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9}) })
t.Run("OAEPSHA384Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15}) })
t.Run("OAEPSHA512Label", func(t *testing.T) {
if nbits > 1024 {
testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18})
} else {
t.Skipf("key too smol for SHA512")
}
})
}

func testRsaEncryptionPKCS1v15(t *testing.T, key crypto.Decrypter) {
Expand Down Expand Up @@ -198,7 +236,14 @@ func testRsaEncryptionPKCS1v15(t *testing.T, key crypto.Decrypter) {
func testRsaEncryptionOAEP(t *testing.T, key crypto.Decrypter, hashFunction crypto.Hash, label []byte) {
var err error
var ciphertext, decrypted []byte

var info pkcs11.Info
if info, err = libHandle.GetInfo(); err != nil {
t.Errorf("GetInfo: %v", err)
return
}
if info.ManufacturerID == "SoftHSM" && (hashFunction != crypto.SHA1 || len(label) > 0) {
t.Skipf("SoftHSM OAEP only supports SHA-1 with no label")
}
plaintext := []byte("encrypt me with new hotness")
h := hashFunction.New()
rsaPubkey := key.Public().(crypto.PublicKey).(*rsa.PublicKey)
Expand Down

0 comments on commit d3fcc57

Please sign in to comment.