Skip to content

Commit

Permalink
Merge pull request #33 from axant/PSS_support
Browse files Browse the repository at this point in the history
added support for PS256 PS384 PS512
  • Loading branch information
yosida95 committed Jan 7, 2021
2 parents 282af57 + 1597ec7 commit db36813
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ Supported Algorithms

- Asymmetric

- PS256
- PS384
- PS512
- RS256
- RS384
- RS512
Expand Down
77 changes: 70 additions & 7 deletions jwt/jwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Optional,
)

from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import (
SHA256,
SHA384,
Expand Down Expand Up @@ -97,8 +98,11 @@ class RSAAlgorithm(AbstractSigningAlgorithm):
def __init__(self, hash_fun: object) -> None:
self.hash_fun = hash_fun

def _check_key(self, key: Optional[AbstractJWKBase], must_sign_key=False) \
-> AbstractJWKBase:
def _check_key(
self,
key: Optional[AbstractJWKBase],
must_sign_key: bool = False,
) -> AbstractJWKBase:
if not key or key.get_kty() != 'RSA':
raise InvalidKeyTypeError('RSA key is required')
if must_sign_key and not key.is_sign_key():
Expand All @@ -108,19 +112,75 @@ def _check_key(self, key: Optional[AbstractJWKBase], must_sign_key=False) \

def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
key = self._check_key(key, must_sign_key=True)
return key.sign(message, hash_fun=self.hash_fun)

def verify(self, message: bytes, key: Optional[AbstractJWKBase],
signature: bytes) -> bool:
return key.sign(message, hash_fun=self.hash_fun,
padding=padding.PKCS1v15())

def verify(
self,
message: bytes,
key: Optional[AbstractJWKBase],
signature: bytes,
) -> bool:
key = self._check_key(key)
return key.verify(message, signature, hash_fun=self.hash_fun)
return key.verify(message, signature, hash_fun=self.hash_fun,
padding=padding.PKCS1v15())


RS256 = RSAAlgorithm(SHA256)
RS384 = RSAAlgorithm(SHA384)
RS512 = RSAAlgorithm(SHA512)


class PSSRSAAlgorithm(AbstractSigningAlgorithm):
def __init__(self, hash_fun: object) -> None:
self.hash_fun = hash_fun

def _check_key(
self,
key: Optional[AbstractJWKBase],
must_sign_key: bool = False,
) -> AbstractJWKBase:
if not key or key.get_kty() != 'RSA':
raise InvalidKeyTypeError('RSA key is required')
if must_sign_key and not key.is_sign_key():
raise InvalidKeyTypeError(
'a RSA private key is required, but passed is RSA public key')
return key

def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
key = self._check_key(key, must_sign_key=True)
return key.sign(
message,
hash_fun=self.hash_fun,
padding=padding.PSS(
mgf=padding.MGF1(self.hash_fun()),
salt_length=self.hash_fun().digest_size,
),
)

def verify(
self,
message: bytes,
key: Optional[AbstractJWKBase],
signature: bytes
) -> bool:
key = self._check_key(key)
return key.verify(
message,
signature,
hash_fun=self.hash_fun,
padding=padding.PSS(
mgf=padding.MGF1(self.hash_fun()),
salt_length=self.hash_fun().digest_size,
),
)


PS256 = PSSRSAAlgorithm(SHA256)
PS384 = PSSRSAAlgorithm(SHA384)
PS512 = PSSRSAAlgorithm(SHA512)


def supported_signing_algorithms():
# NOTE(yosida95): exclude vulnerable 'none' algorithm by default.
return {
Expand All @@ -130,4 +190,7 @@ def supported_signing_algorithms():
'RS256': RS256,
'RS384': RS384,
'RS512': RS512,
'PS256': PS256,
'PS384': PS384,
'PS512': PS512,
}
16 changes: 12 additions & 4 deletions jwt/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@
RSAPublicKey,
RSAPublicNumbers,
)

from cryptography.hazmat.primitives.serialization import (
load_pem_private_key,
load_pem_public_key,
)

from cryptography.hazmat.primitives.hashes import HashAlgorithm

from .exceptions import (
MalformedJWKError,
UnsupportedKeyTypeError,
Expand Down Expand Up @@ -147,27 +150,32 @@ def __init__(self, keyobj: Union[RSAPrivateKey, RSAPublicKey],
self.keyobj = keyobj

optnames = {'use', 'key_ops', 'alg', 'kid',
'x5u', 'x5c', 'x5t', 'x5t#s256'}
'x5u', 'x5c', 'x5t', 'x5t#s256', }
self.options = {k: v for k, v in options.items() if k in optnames}

def is_sign_key(self) -> bool:
return isinstance(self.keyobj, RSAPrivateKey)

def _get_hash_fun(self, options) -> Callable:
def _get_hash_fun(self, options) -> Callable[[], HashAlgorithm]:
return options['hash_fun']

def _get_padding(self, options) -> padding.AsymmetricPadding:
return options['padding']

def sign(self, message: bytes, **options) -> bytes:
hash_fun = self._get_hash_fun(options)
return self.keyobj.sign(message, padding.PKCS1v15(), hash_fun())
_padding = self._get_padding(options)
return self.keyobj.sign(message, _padding, hash_fun())

def verify(self, message: bytes, signature: bytes, **options) -> bool:
hash_fun = self._get_hash_fun(options)
_padding = self._get_padding(options)
if self.is_sign_key():
pubkey = self.keyobj.public_key()
else:
pubkey = self.keyobj
try:
pubkey.verify(signature, message, padding.PKCS1v15(), hash_fun())
pubkey.verify(signature, message, _padding, hash_fun())
return True
except InvalidSignature:
return False
Expand Down
14 changes: 14 additions & 0 deletions jwt/tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,17 @@ def test_no_before_used_after(self):
}
compact_jws = self.inst.encode(message, self.key)
self.assertEqual(self.inst.decode(compact_jws, self.key), message)

def test_encoded_with_rs(self):
message = {'hello': 'there'}
key = jwk_from_dict(
json.loads(load_testdata('rsa_privkey.json', 'r')))
comp = self.inst.encode(message, key, alg='RS256')
assert self.inst.decode(comp, key) == message

def test_encoded_with_pss(self):
message = {'hello': 'there'}
key = jwk_from_dict(
json.loads(load_testdata('rsa_privkey.json', 'r')))
comp = self.inst.encode(message, key, alg='PS256')
assert self.inst.decode(comp, key) == message

0 comments on commit db36813

Please sign in to comment.