Skip to content

Commit

Permalink
Improve signature checks
Browse files Browse the repository at this point in the history
- Enforce allowed canonicalization methods
- Enforce allowed transform aglorithms
- Ensure the Object element is absent

Signed-off-by: Ivan Kanakarakis <ivan.kanak@gmail.com>
  • Loading branch information
c00kiemon5ter committed Jun 19, 2021
1 parent 4d2dcce commit 1e59eaa
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 34 deletions.
55 changes: 50 additions & 5 deletions src/saml2/sigver.py
Expand Up @@ -42,6 +42,10 @@
from saml2.s_utils import Unsupported
from saml2.time_util import instant
from saml2.time_util import str_to_time
from saml2.xmldsig import ALLOWED_CANONICALIZATIONS
from saml2.xmldsig import ALLOWED_TRANSFORMS
from saml2.xmldsig import TRANSFORM_C14N
from saml2.xmldsig import TRANSFORM_ENVELOPED
from saml2.xmldsig import SIG_RSA_SHA1
from saml2.xmldsig import SIG_RSA_SHA224
from saml2.xmldsig import SIG_RSA_SHA256
Expand Down Expand Up @@ -1503,7 +1507,8 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
# * the Reference element must have a URI attribute
# * the URI attribute contains an anchor
# * the anchor points to the enclosing element's ID attribute
references = item.signature.signed_info.reference
signed_info = item.signature.signed_info
references = signed_info.reference
signatures_must_have_a_single_reference_element = len(references) == 1
the_Reference_element_must_have_a_URI_attribute = (
signatures_must_have_a_single_reference_element
Expand All @@ -1518,6 +1523,41 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
the_URI_attribute_contains_an_anchor
and references[0].uri == "#{id}".format(id=item.id)
)

# SAML implementations SHOULD use Exclusive Canonicalization,
# with or without comments
canonicalization_method_is_c14n = (
signed_info.canonicalization_method.algorithm in ALLOWED_CANONICALIZATIONS
)

# Signatures in SAML messages SHOULD NOT contain transforms other than the
# - enveloped signature transform
# (with the identifier http://www.w3.org/2000/09/xmldsig#enveloped-signature)
# - or the exclusive canonicalization transforms
# (with the identifier http://www.w3.org/2001/10/xml-exc-c14n#
# or http://www.w3.org/2001/10/xml-exc-c14n#WithComments).
transform_alogs = [
transform.algorithm
for transform in references[0].transforms.transform
]
transform_alogs_n = len(transform_alogs)
only_up_to_two_transforms_are_defined = (
signatures_must_have_a_single_reference_element
and 1 <= transform_alogs_n <= 2
)
all_transform_algs_are_allowed = (
only_up_to_two_transforms_are_defined
and transform_alogs_n == len(
ALLOWED_TRANSFORMS.intersection(transform_alogs)
)
)

# The <ds:Object> element is not defined for use with SAML signatures,
# and SHOULD NOT be present.
# Since it can be used in service of an attacker by carrying unsigned data,
# verifiers SHOULD reject signatures that contain a <ds:Object> element.
object_element_is_not_present = not item.signature.object

validators = {
"signatures must have a single reference element": (
signatures_must_have_a_single_reference_element
Expand All @@ -1531,6 +1571,12 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None,
"the anchor points to the enclosing element ID attribute": (
the_anchor_points_to_the_enclosing_element_ID_attribute
),
"canonicalization method is c14n": canonicalization_method_is_c14n,
"only up to two transforms are defined": (
only_up_to_two_transforms_are_defined
),
"all transform algs are allowed": all_transform_algs_are_allowed,
"object element is not present": object_element_is_not_present,
}
if not all(validators.values()):
error_context = {
Expand Down Expand Up @@ -1818,10 +1864,9 @@ def pre_signature_part(
sign_alg = ds.DefaultSignature().get_sign_alg()

signature_method = ds.SignatureMethod(algorithm=sign_alg)
canonicalization_method = ds.CanonicalizationMethod(
algorithm=ds.ALG_EXC_C14N)
trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED)
trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N)
canonicalization_method = ds.CanonicalizationMethod(algorithm=TRANSFORM_C14N)
trans0 = ds.Transform(algorithm=TRANSFORM_ENVELOPED)
trans1 = ds.Transform(algorithm=TRANSFORM_C14N)
transforms = ds.Transforms(transform=[trans0, trans1])
digest_method = ds.DigestMethod(algorithm=digest_alg)

Expand Down
15 changes: 11 additions & 4 deletions src/saml2/xmldsig/__init__.py
Expand Up @@ -53,14 +53,21 @@

MAC_SHA1 = 'http://www.w3.org/2000/09/xmldsig#hmac-sha1'

C14N = 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315'
C14N_WITH_C = 'http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments'
ALG_EXC_C14N = 'http://www.w3.org/2001/10/xml-exc-c14n#'

TRANSFORM_XSLT = 'http://www.w3.org/TR/1999/REC-xslt-19991116'
TRANSFORM_XPATH = 'http://www.w3.org/TR/1999/REC-xpath-19991116'
TRANSFORM_ENVELOPED = 'http://www.w3.org/2000/09/xmldsig#enveloped-signature'
TRANSFORM_C14N = 'http://www.w3.org/2001/10/xml-exc-c14n#'
TRANSFORM_C14N_WITH_COMMENTS = 'http://www.w3.org/2001/10/xml-exc-c14n#WithComments'

ALLOWED_CANONICALIZATIONS = {
TRANSFORM_C14N,
TRANSFORM_C14N_WITH_COMMENTS,
}
ALLOWED_TRANSFORMS = {
TRANSFORM_ENVELOPED,
TRANSFORM_C14N,
TRANSFORM_C14N_WITH_COMMENTS,
}

class DefaultSignature(object):
class _DefaultSignature(object):
Expand Down
49 changes: 24 additions & 25 deletions tests/test_00_xmldsig.py
Expand Up @@ -33,12 +33,10 @@ def testUsingTestData(self):
new_object = ds.object_from_string(ds_data.TEST_OBJECT)
assert new_object.id == "object_id"
assert new_object.encoding == ds.ENCODING_BASE64
assert new_object.text.strip() == \
"V2VkIEp1biAgNCAxMjoxMTowMyBFRFQgMjAwMwo"

assert new_object.text.strip() == "V2VkIEp1biAgNCAxMjoxMTowMyBFRFQgMjAwMwo"

class TestMgmtData:

class TestMgmtData:
def setup_class(self):
self.mgmt_data = ds.MgmtData()

Expand Down Expand Up @@ -156,7 +154,7 @@ def testAccessors(self):
self.x509_data.x509_certificate = ds.X509Certificate(
text="x509 certificate")
self.x509_data.x509_crl = ds.X509CRL(text="x509 crl")

new_x509_data = ds.x509_data_from_string(self.x509_data.to_string())
print(new_x509_data.keyswv())
print(new_x509_data.__dict__.keys())
Expand Down Expand Up @@ -231,7 +229,7 @@ def testAccessors(self):
ds.TRANSFORM_ENVELOPED
assert new_transforms.transform[0].x_path[0].text.strip() == "xpath"
assert new_transforms.transform[1].x_path[0].text.strip() == "xpath"

def testUsingTestData(self):
"""Test for transform_from_string() using test data"""
new_transforms = ds.transforms_from_string(ds_data.TEST_TRANSFORMS)
Expand Down Expand Up @@ -261,7 +259,7 @@ def testAccessors(self):
assert new_retrieval_method.uri == "http://www.example.com/URI"
assert new_retrieval_method.type == "http://www.example.com/Type"
assert isinstance(new_retrieval_method.transforms, ds.Transforms)

def testUsingTestData(self):
"""Test for retrieval_method_from_string() using test data"""
new_retrieval_method = ds.retrieval_method_from_string(
Expand All @@ -285,7 +283,7 @@ def testAccessors(self):
assert isinstance(new_rsa_key_value.exponent, ds.Exponent)
assert new_rsa_key_value.modulus.text.strip() == "modulus"
assert new_rsa_key_value.exponent.text.strip() == "exponent"

def testUsingTestData(self):
"""Test for rsa_key_value_from_string() using test data"""
new_rsa_key_value = ds.rsa_key_value_from_string(
Expand Down Expand Up @@ -325,7 +323,7 @@ def testAccessors(self):
assert new_dsa_key_value.j.text.strip() == "j"
assert new_dsa_key_value.seed.text.strip() == "seed"
assert new_dsa_key_value.pgen_counter.text.strip() == "pgen counter"

def testUsingTestData(self):
"""Test for dsa_key_value_from_string() using test data"""
new_dsa_key_value = ds.dsa_key_value_from_string(
Expand Down Expand Up @@ -362,7 +360,7 @@ def testAccessors(self):
ds_data.TEST_RSA_KEY_VALUE)
new_key_value = ds.key_value_from_string(self.key_value.to_string())
assert isinstance(new_key_value.rsa_key_value, ds.RSAKeyValue)

def testUsingTestData(self):
"""Test for key_value_from_string() using test data"""
new_key_value = ds.key_value_from_string(ds_data.TEST_KEY_VALUE1)
Expand All @@ -384,7 +382,7 @@ def testAccessors(self):
self.key_name.text = "key name"
new_key_name = ds.key_name_from_string(self.key_name.to_string())
assert new_key_name.text.strip() == "key name"

def testUsingTestData(self):
"""Test for key_name_from_string() using test data"""
new_key_name = ds.key_name_from_string(ds_data.TEST_KEY_NAME)
Expand Down Expand Up @@ -423,7 +421,7 @@ def testAccessors(self):
assert isinstance(new_key_info.spki_data[0], ds.SPKIData)
assert isinstance(new_key_info.mgmt_data[0], ds.MgmtData)
assert new_key_info.id == "id"

def testUsingTestData(self):
"""Test for key_info_from_string() using test data"""
new_key_info = ds.key_info_from_string(ds_data.TEST_KEY_INFO)
Expand All @@ -436,7 +434,7 @@ def testUsingTestData(self):
assert isinstance(new_key_info.spki_data[0], ds.SPKIData)
assert isinstance(new_key_info.mgmt_data[0], ds.MgmtData)
assert new_key_info.id == "id"


class TestDigestValue:

Expand All @@ -448,7 +446,7 @@ def testAccessors(self):
self.digest_value.text = "digest value"
new_digest_value = ds.digest_value_from_string(self.digest_value.to_string())
assert new_digest_value.text.strip() == "digest value"

def testUsingTestData(self):
"""Test for digest_value_from_string() using test data"""
new_digest_value = ds.digest_value_from_string(ds_data.TEST_DIGEST_VALUE)
Expand All @@ -466,7 +464,7 @@ def testAccessors(self):
new_digest_method = ds.digest_method_from_string(
self.digest_method.to_string())
assert new_digest_method.algorithm == ds.DIGEST_SHA1

def testUsingTestData(self):
"""Test for digest_method_from_string() using test data"""
new_digest_method = ds.digest_method_from_string(
Expand Down Expand Up @@ -497,7 +495,7 @@ def testAccessors(self):
assert new_reference.id == "id"
assert new_reference.uri == "http://www.example.com/URI"
assert new_reference.type == "http://www.example.com/Type"

def testUsingTestData(self):
"""Test for reference_from_string() using test data"""
new_reference = ds.reference_from_string(ds_data.TEST_REFERENCE)
Expand All @@ -524,7 +522,7 @@ def testAccessors(self):
ds.HMACOutputLength)
assert new_signature_method.hmac_output_length.text.strip() == "8"
assert new_signature_method.algorithm == ds.SIG_RSA_SHA1

def testUsingTestData(self):
"""Test for signature_method_from_string() using test data"""
new_signature_method = ds.signature_method_from_string(
Expand All @@ -542,16 +540,17 @@ def setup_class(self):

def testAccessors(self):
"""Test for CanonicalizationMethod accessors"""
self.canonicalization_method.algorithm = ds.C14N_WITH_C
self.canonicalization_method.algorithm = ds.TRANSFORM_C14N_WITH_COMMENTS
new_canonicalization_method = ds.canonicalization_method_from_string(
self.canonicalization_method.to_string())
assert new_canonicalization_method.algorithm == ds.C14N_WITH_C
assert new_canonicalization_method.algorithm == ds.TRANSFORM_C14N_WITH_COMMENTS

def testUsingTestData(self):
"""Test for canonicalization_method_from_string() using test data"""
new_canonicalization_method = ds.canonicalization_method_from_string(
ds_data.TEST_CANONICALIZATION_METHOD)
assert new_canonicalization_method.algorithm == ds.C14N_WITH_C
ds_data.TEST_CANONICALIZATION_METHOD
)
assert new_canonicalization_method.algorithm == "http://www.w3.org/TR/2001/REC-xml-c14n-20010315#WithComments"


class TestSignedInfo:
Expand All @@ -574,7 +573,7 @@ def testAccessors(self):
ds.CanonicalizationMethod)
assert isinstance(new_si.signature_method, ds.SignatureMethod)
assert isinstance(new_si.reference[0], ds.Reference)

def testUsingTestData(self):
"""Test for signed_info_from_string() using test data"""
new_si = ds.signed_info_from_string(ds_data.TEST_SIGNED_INFO)
Expand All @@ -597,7 +596,7 @@ def testAccessors(self):
self.signature_value.to_string())
assert new_signature_value.id == "id"
assert new_signature_value.text.strip() == "signature value"

def testUsingTestData(self):
"""Test for signature_value_from_string() using test data"""
new_signature_value = ds.signature_value_from_string(
Expand Down Expand Up @@ -627,7 +626,7 @@ def testAccessors(self):
assert isinstance(new_signature.signature_value, ds.SignatureValue)
assert isinstance(new_signature.key_info, ds.KeyInfo)
assert isinstance(new_signature.object[0], ds.Object)

def testUsingTestData(self):
"""Test for signature_value_from_string() using test data"""
new_signature = ds.signature_from_string(ds_data.TEST_SIGNATURE)
Expand Down

0 comments on commit 1e59eaa

Please sign in to comment.