Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
yosida95 committed Feb 1, 2017
1 parent c465a91 commit 8f29155
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 24 deletions.
31 changes: 24 additions & 7 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,29 @@ class JWTException(Exception):
"""


MalformedJWKError = type('MalformedJWKError', (JWTException, ), {})
UnsupportedKeyTypeError = type('UnsupportedKeyTypeError', (JWTException, ), {})
class MalformedJWKError(JWTException):
pass

InvalidKeyTypeError = type('InvalidKeyTypeError', (JWTException, ), {})

JWSEncodeError = type('JWSEncodeError', (JWTException, ), {})
JWSDecodeError = type('JWSDecodeError', (JWTException, ), {})
JWTEncodeError = type('JWTEncodeError', (JWTException, ), {})
JWTDecodeError = type('JWTDecodeError', (JWTException, ), {})
class UnsupportedKeyTypeError(JWTException):
pass


class InvalidKeyTypeError(JWTException):
pass


class JWSEncodeError(JWTException):
pass


class JWSDecodeError(JWTException):
pass


class JWTEncodeError(JWTException):
pass


class JWTDecodeError(JWTException):
pass
5 changes: 3 additions & 2 deletions jwt/jwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import hashlib
import hmac
from typing import Callable

from cryptography.hazmat.primitives.hashes import (
SHA256,
Expand Down Expand Up @@ -49,7 +50,7 @@ def verify(self, message: bytes, key: AbstractJWKBase,

class HMACAlgorithm(AbstractSigningAlgorithm):

def __init__(self, hash_fun: object) -> None:
def __init__(self, hash_fun: Callable) -> None:
self.hash_fun = hash_fun

def _check_key(self, key: AbstractJWKBase) -> None:
Expand All @@ -68,7 +69,7 @@ def sign(self, message: bytes, key: AbstractJWKBase) -> bytes:
def verify(self, message: bytes, key: AbstractJWKBase,
signature: bytes) -> bool:
self._check_key(key)
return key.verify(message, signature, self._sign)
return key.verify(message, signature, signer=self._sign)


HS256 = HMACAlgorithm(hashlib.sha256)
Expand Down
26 changes: 13 additions & 13 deletions jwt/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,33 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',
header = {}
header_b64 = b64encode(json.dumps(header).encode('ascii'))
message_b64 = b64encode(message)
signing_message = header_b64 + b'.' + message_b64
signing_message = header_b64 + '.' + message_b64

signature = alg_impl.sign(signing_message, key)
signature = alg_impl.sign(signing_message.encode('ascii'), key)
signature_b64 = b64encode(signature)

return signing_message + b'.' + signature_b64
return signing_message + '.' + signature_b64

def _decode_segments(self, message: bytes, key: AbstractJWKBase = None
) -> Tuple[dict, bytes, bytes, bytes]:
def _decode_segments(self, message: str) -> Tuple[dict, bytes, bytes, str]:
try:
signing_message, signature_b64 = message.rsplit('.', 1)
header_b64, message_b64 = signing_message.split('.')
except ValueError:
raise JWSDecodeError('malformed JWS payload')

header = json.loads(b64decode(header_b64))
message = b64decode(message_b64)
signature = b64decode(message_b64)
return header, message, signature, signing_message
header = json.loads(b64decode(header_b64).decode('ascii'))
message_bin = b64decode(message_b64)
signature = b64decode(signature_b64)
return header, message_bin, signature, signing_message

def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True) -> bytes:
header, message, signature, signing_message =\
self._decode_segments(message, key)
header, message_bin, signature, signing_message =\
self._decode_segments(message)

alg_impl = self._retrieve_alg(header['alg'])
if do_verify and not alg_impl.verify(signing_message, key, signature):
if do_verify and not alg_impl.verify(
signing_message.encode('ascii'), key, signature):
raise JWSDecodeError('JWS passed could not be validated')

return message
return message_bin
4 changes: 2 additions & 2 deletions jwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def encode(self, payload: dict, key: AbstractJWKBase = None, alg='HS256',
def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True) -> dict:
try:
message = self._jws.decode(message, key, do_verify)
message_bin = self._jws.decode(message, key, do_verify)
except JWSDecodeError as why:
raise JWTDecodeError('failed to decode JWT') from why
try:
payload = json.loads(message.decode('utf-8'))
payload = json.loads(message_bin.decode('utf-8'))
return payload
except ValueError as why:
raise JWTDecodeError(
Expand Down

0 comments on commit 8f29155

Please sign in to comment.