Skip to content

Commit

Permalink
feat(jws): add RFC7797 implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 9, 2023
1 parent 11efa0c commit ddf490f
Show file tree
Hide file tree
Showing 9 changed files with 491 additions and 35 deletions.
35 changes: 8 additions & 27 deletions src/joserfc/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
detach_compact_content,
)
from .rfc7515.json import (
construct_json_signature,
sign_json,
verify_json,
extract_json,
Expand All @@ -38,6 +39,7 @@
"types",
"JWSAlgModel",
"JWSRegistry",
"HeaderMember",
"CompactSignature",
"JSONSignature",
"serialize_compact",
Expand Down Expand Up @@ -119,9 +121,10 @@ def validate_compact(
registry = construct_registry(algorithms)

headers = obj.headers()
alg: JWSAlgModel = registry.get_alg(headers["alg"])
registry.check_header(headers)
key: Key = guess_key(public_key, obj)
key.check_use("sig")
alg: JWSAlgModel = registry.get_alg(headers["alg"])
if not verify_compact(obj, alg, key):
raise BadSignatureError()

Expand Down Expand Up @@ -190,23 +193,10 @@ def serialize_json(
if registry is None:
registry = construct_registry(algorithms)

if isinstance(members, dict):
flatten = True
__check_member(registry, members)
members = [members]
else:
flatten = False
for member in members:
__check_member(registry, member)

members = [HeaderMember(**member) for member in members]
payload = to_bytes(payload)
obj = JSONSignature(members, payload)
obj.segments["payload"] = urlsafe_b64encode(payload)
obj.flattened = flatten

obj = construct_json_signature(members, payload, registry)
obj.segments["payload"] = urlsafe_b64encode(obj.payload)
find_key = lambda d: guess_key(private_key, d)
return sign_json(obj, registry.get_alg, find_key)
return sign_json(obj, registry, find_key)


def validate_json(
Expand All @@ -226,7 +216,7 @@ def validate_json(
if registry is None:
registry = construct_registry(algorithms)
find_key = lambda d: guess_key(public_key, d)
if not verify_json(obj, registry.get_alg, find_key):
if not verify_json(obj, registry, find_key):
raise BadSignatureError()


Expand Down Expand Up @@ -273,12 +263,3 @@ def detach_content(value: t.Union[str, JSONSerialization]):
if isinstance(value, str):
return detach_compact_content(value)
return detach_json_content(value)


def __check_member(registry: JWSRegistry, member: HeaderDict):
header = {}
if "protected" in member:
header.update(member["protected"])
if "header" in member:
header.update(member["header"])
registry.check_header(header)
46 changes: 39 additions & 7 deletions src/joserfc/rfc7515/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,63 @@
import binascii
from .model import JWSAlgModel, HeaderMember, JSONSignature
from .types import (
HeaderDict,
JSONSignatureDict,
JSONSerialization,
GeneralJSONSerialization,
FlattenedJSONSerialization,
)
from .registry import JWSRegistry
from ..util import (
to_bytes,
json_b64encode,
json_b64decode,
urlsafe_b64encode,
urlsafe_b64decode,
)
from ..errors import DecodeError

FindAlgorithm = t.Callable[[str], JWSAlgModel]

def construct_json_signature(
members: t.Union[HeaderDict, t.List[HeaderDict]],
payload: t.AnyStr,
registry: JWSRegistry) -> JSONSignature:
if isinstance(members, dict):
flattened = True
__check_member(registry, members)
members = [members]
else:
flattened = False
for member in members:
__check_member(registry, member)

members = [HeaderMember(**member) for member in members]
payload = to_bytes(payload)
obj = JSONSignature(members, payload)
obj.flattened = flattened
return obj


def __check_member(registry: JWSRegistry, member: HeaderDict):
header = {}
if "protected" in member:
header.update(member["protected"])
if "header" in member:
header.update(member["header"])
registry.check_header(header)


def sign_json(obj: JSONSignature, find_alg: FindAlgorithm, find_key) -> JSONSerialization:
def sign_json(obj: JSONSignature, registry: JWSRegistry, find_key) -> JSONSerialization:
signatures: t.List[JSONSignatureDict] = []

payload_segment = obj.segments["payload"]
for member in obj.members:
headers = member.headers()
alg = find_alg(headers["alg"])
registry.check_header(headers)
alg = registry.get_alg(headers["alg"])
key = find_key(member)
key.check_use("sig")
signature = _sign_member(payload_segment, member, alg, key)
signature = __sign_member(payload_segment, member, alg, key)
signatures.append(signature)

rv = {"payload": payload_segment.decode("utf-8")}
Expand All @@ -40,7 +71,7 @@ def sign_json(obj: JSONSignature, find_alg: FindAlgorithm, find_key) -> JSONSeri
return rv


def _sign_member(payload_segment, member: HeaderMember, alg: JWSAlgModel, key) -> JSONSignatureDict:
def __sign_member(payload_segment, member: HeaderMember, alg: JWSAlgModel, key) -> JSONSignatureDict:
if member.protected:
protected_segment = json_b64encode(member.protected)
else:
Expand Down Expand Up @@ -99,7 +130,7 @@ def extract_json(value: JSONSerialization) -> JSONSignature:
return obj


def verify_json(obj: JSONSignature, find_alg: FindAlgorithm, find_key) -> bool:
def verify_json(obj: JSONSignature, registry: JWSRegistry, find_key) -> bool:
"""Verify the signature of this JSON serialization with the given
algorithm and key.
Expand All @@ -111,7 +142,8 @@ def verify_json(obj: JSONSignature, find_alg: FindAlgorithm, find_key) -> bool:
for index, signature in enumerate(obj.signatures):
member = obj.members[index]
headers = member.headers()
alg = find_alg(headers["alg"])
registry.check_header(headers)
alg = registry.get_alg(headers["alg"])
key = find_key(member)
key.check_use("sig")
if not _verify_signature(signature, payload_segment, alg, key):
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7515/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class JWSRegistry(object):
:param algorithms: allowed algorithms to be used
:param strict_check_header: only allow header key in the registry to be used
"""
default_header_registry: HeaderRegistryDict = JWS_HEADER_REGISTRY
algorithms: Dict[str, JWSAlgModel] = {}
recommended: List[str] = []

Expand All @@ -28,7 +29,7 @@ def __init__(
algorithms: Optional[List[str]] = None,
strict_check_header: bool = True):
self.header_registry: HeaderRegistryDict = {}
self.header_registry.update(JWS_HEADER_REGISTRY)
self.header_registry.update(self.default_header_registry)
if header_registry is not None:
self.header_registry.update(header_registry)
self.allowed = algorithms
Expand Down
11 changes: 11 additions & 0 deletions src/joserfc/rfc7797/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .registry import JWSRegistry
from .compact import serialize_compact, deserialize_compact
from .json import serialize_json, deserialize_json

__all__ = [
"JWSRegistry",
"serialize_compact",
"deserialize_compact",
"serialize_json",
"deserialize_json",
]
126 changes: 126 additions & 0 deletions src/joserfc/rfc7797/compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import typing as t
import re
import binascii
from ..registry import Header
from ..jwk import KeyFlexible, guess_key
from ..jws import (
CompactSignature,
JWSRegistry as _JWSRegistry,
serialize_compact as _serialize_compact,
deserialize_compact as _deserialize_compact,
)
from ..util import (
to_bytes,
to_unicode,
json_b64encode,
json_b64decode,
urlsafe_b64encode,
urlsafe_b64decode,
)
from ..errors import BadSignatureError, MissingAlgorithmError, DecodeError
from .registry import JWSRegistry


def serialize_compact(
protected: Header,
payload: t.AnyStr,
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[_JWSRegistry] = None) -> str:

if "b64" not in protected:
return _serialize_compact(protected, payload, private_key, algorithms, registry)

if registry is None:
registry = JWSRegistry(algorithms=algorithms)

if protected["b64"] is True:
return _serialize_compact(protected, payload, private_key, registry=registry)

registry.check_header(protected)
obj = CompactSignature(protected, to_bytes(payload))
alg = registry.get_alg(protected["alg"])
key = guess_key(private_key, obj)
key.check_use("sig")

header_segment = json_b64encode(protected)
signing_input = header_segment + b"." + obj.payload
signature = urlsafe_b64encode(alg.sign(signing_input, key))

# if need to detach payload
if __is_urlsafe_characters(payload):
out = signing_input + b"." + signature
else:
out = header_segment + b".." + signature
return out.decode("utf-8")


def deserialize_compact(
value: t.AnyStr,
public_key: KeyFlexible,
payload: t.Optional[t.AnyStr] = None,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> CompactSignature:
obj = _extract_compact(to_bytes(value), payload)
if obj is None:
return _deserialize_compact(value, public_key, algorithms, registry)

if registry is None:
registry = JWSRegistry(algorithms=algorithms)

if obj is True:
return _deserialize_compact(value, public_key, registry=registry)

headers = obj.headers()
registry.check_header(headers)
key = guess_key(public_key, obj)
key.check_use("sig")
alg = registry.get_alg(headers["alg"])

signing_input = obj.segments["header"] + b"." + obj.payload
sig = urlsafe_b64decode(obj.segments["signature"])
if not alg.verify(signing_input, sig, key):
raise BadSignatureError()
return obj


# https://datatracker.ietf.org/doc/html/rfc7797#section-5.2
# the application MUST ensure that the payload contains only the URL-safe
# characters 'a'-'z', 'A'-'Z', '0'-'9', dash ('-'), underscore ('_'),
# and tilde ('~')
_re_urlsafe = re.compile("^[a-zA-Z0-9-_~]+$")


def __is_urlsafe_characters(s: t.AnyStr) -> bool:
return bool(_re_urlsafe.match(to_unicode(s)))


def _extract_compact(value: bytes, payload: t.Optional[t.AnyStr] = None):
parts = value.split(b".")
if len(parts) != 3:
raise ValueError("Invalid JSON Web Signature")

header_segment, payload_segment, signature_segment = parts
try:
protected = json_b64decode(header_segment)
if "alg" not in protected:
raise MissingAlgorithmError()
except (TypeError, ValueError, binascii.Error):
raise DecodeError("Invalid header")

if "b64" not in protected:
return None

if protected["b64"] is True:
return True

if payload:
obj = CompactSignature(protected, to_bytes(payload))
else:
obj = CompactSignature(protected, payload_segment)
obj.segments.update({
"header": header_segment,
"payload": payload_segment,
"signature": signature_segment,
})
return obj
Loading

0 comments on commit ddf490f

Please sign in to comment.