Skip to content
Permalink
Browse files Browse the repository at this point in the history
Refactor utility methods into SAMLTools, add tests.
  • Loading branch information
robotdan committed Feb 7, 2021
1 parent 15fa48d commit c66fb68
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 210 deletions.
5 changes: 2 additions & 3 deletions fusionauth-samlv2.iml
Expand Up @@ -6,7 +6,7 @@
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/main/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/src/test/java" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/src/test/resources" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
Expand Down Expand Up @@ -136,5 +136,4 @@
</library>
</orderEntry>
</component>
</module>

</module>
209 changes: 17 additions & 192 deletions src/main/java/io/fusionauth/samlv2/service/DefaultSAMLv2Service.java
Expand Up @@ -15,11 +15,7 @@
*/
package io.fusionauth.samlv2.service;

import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.crypto.KeySelector;
import javax.xml.crypto.MarshalException;
import javax.xml.crypto.dsig.CanonicalizationMethod;
Expand All @@ -37,19 +33,7 @@
import javax.xml.crypto.dsig.keyinfo.X509Data;
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
import javax.xml.crypto.dsig.spec.TransformParameterSpec;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
Expand All @@ -59,8 +43,6 @@
import java.security.PrivateKey;
import java.security.Signature;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
Expand All @@ -75,9 +57,6 @@
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import com.sun.org.apache.xerces.internal.jaxp.datatype.XMLGregorianCalendarImpl;
import io.fusionauth.samlv2.domain.Algorithm;
Expand All @@ -89,7 +68,6 @@
import io.fusionauth.samlv2.domain.MetaData;
import io.fusionauth.samlv2.domain.MetaData.IDPMetaData;
import io.fusionauth.samlv2.domain.MetaData.SPMetaData;
import io.fusionauth.samlv2.domain.NameID;
import io.fusionauth.samlv2.domain.NameIDFormat;
import io.fusionauth.samlv2.domain.ResponseStatus;
import io.fusionauth.samlv2.domain.SAMLException;
Expand Down Expand Up @@ -128,6 +106,7 @@
import io.fusionauth.samlv2.domain.jaxb.oasis.protocol.StatusType;
import io.fusionauth.samlv2.domain.jaxb.w3c.xmldsig.KeyInfoType;
import io.fusionauth.samlv2.domain.jaxb.w3c.xmldsig.X509DataType;
import io.fusionauth.samlv2.util.SAMLTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Attr;
Expand All @@ -136,7 +115,16 @@
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import static io.fusionauth.samlv2.util.SAMLTools.convertToZonedDateTime;
import static io.fusionauth.samlv2.util.SAMLTools.decodeAndInflate;
import static io.fusionauth.samlv2.util.SAMLTools.marshallToBytes;
import static io.fusionauth.samlv2.util.SAMLTools.marshallToDocument;
import static io.fusionauth.samlv2.util.SAMLTools.marshallToString;
import static io.fusionauth.samlv2.util.SAMLTools.newDocumentFromBytes;
import static io.fusionauth.samlv2.util.SAMLTools.parseNameId;
import static io.fusionauth.samlv2.util.SAMLTools.toXMLGregorianCalendar;
import static io.fusionauth.samlv2.util.SAMLTools.toZonedDateTime;
import static io.fusionauth.samlv2.util.SAMLTools.unmarshallFromDocument;

/**
* Default implementation of the SAML service.
Expand Down Expand Up @@ -411,7 +399,7 @@ public String buildRedirectAuthnRequest(AuthenticationRequest request, String re

@Override
public MetaData parseMetaData(String metaDataXML) throws SAMLException {
Document document = parseFromBytes(metaDataXML.getBytes(StandardCharsets.UTF_8));
Document document = newDocumentFromBytes(metaDataXML.getBytes(StandardCharsets.UTF_8));
EntityDescriptorType root = unmarshallFromDocument(document, EntityDescriptorType.class);
MetaData metaData = new MetaData();
metaData.id = root.getID();
Expand Down Expand Up @@ -450,7 +438,7 @@ public MetaData parseMetaData(String metaDataXML) throws SAMLException {
metaData.idp.certificates = idp.getKeyDescriptor()
.stream()
.filter(kd -> kd.getUse() == KeyTypes.SIGNING)
.map(this::toCertificate)
.map(SAMLTools::toCertificate)
.filter(Objects::nonNull)
.collect(Collectors.toList());
} catch (IllegalArgumentException e) {
Expand Down Expand Up @@ -532,7 +520,7 @@ public AuthenticationResponse parseResponse(String encodedResponse, boolean veri
byte[] decodedResponse = Base64.getMimeDecoder().decode(encodedResponse);
response.rawResponse = new String(decodedResponse, StandardCharsets.UTF_8);

Document document = parseFromBytes(decodedResponse);
Document document = newDocumentFromBytes(decodedResponse);
if (verifySignature) {
verifySignature(document, keySelector);
}
Expand Down Expand Up @@ -613,7 +601,7 @@ public AuthenticationResponse parseResponse(String encodedResponse, boolean veri
AttributeType attributeType = (AttributeType) attributeObject;
String name = attributeType.getName();
List<Object> attributeValues = attributeType.getAttributeValue();
List<String> values = attributeValues.stream().map(this::attributeToString).collect(Collectors.toList());
List<String> values = attributeValues.stream().map(SAMLTools::attributeToString).collect(Collectors.toList());
response.assertion.attributes.computeIfAbsent(name, k -> new ArrayList<>()).addAll(values);
} else {
throw new SAMLException("This library currently doesn't support encrypted attributes");
Expand Down Expand Up @@ -646,24 +634,6 @@ private void addKeyDescriptors(SSODescriptorType descriptor, List<Certificate> c
});
}

private String attributeToString(Object attribute) {
if (attribute == null) {
return null;
}

if (attribute instanceof Number) {
return attribute.toString();
} else if (attribute instanceof String) {
return (String) attribute;
} else if (attribute instanceof Element) {
return ((Element) attribute).getTextContent();
} else {
logger.warn("This library currently doesn't handle attributes of type [" + attribute.getClass() + "]");
}

return null;
}

private String buildPostAuthnRequest(AuthnRequestType authnRequest, boolean sign, PrivateKey privateKey,
X509Certificate certificate,
Algorithm algorithm, String xmlSignatureC14nMethod) throws SAMLException {
Expand All @@ -687,7 +657,7 @@ private String buildRedirectAuthnRequest(AuthnRequestType authnRequest, String r
PrivateKey key, Algorithm algorithm) throws SAMLException {
try {
byte[] xml = marshallToBytes(PROTOCOL_OBJECT_FACTORY.createAuthnRequest(authnRequest), AuthnRequestType.class);
String encodedResult = deflateAndEncode(xml);
String encodedResult = SAMLTools.deflateAndEncode(xml);
String parameters = "SAMLRequest=" + URLEncoder.encode(encodedResult, "UTF-8");
if (relayState != null) {
parameters += "&RelayState=" + URLEncoder.encode(relayState, "UTF-8");
Expand All @@ -711,43 +681,6 @@ private String buildRedirectAuthnRequest(AuthnRequestType authnRequest, String r
}
}

private ZonedDateTime convertToZonedDateTime(XMLGregorianCalendar cal) {
return cal != null ? cal.toGregorianCalendar().toZonedDateTime() : null;
}

private byte[] decodeAndInflate(String encodedRequest) throws SAMLException {
byte[] bytes = Base64.getMimeDecoder().decode(encodedRequest);
Inflater inflater = new Inflater(true);
inflater.setInput(bytes);
inflater.finished();

try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] result = new byte[bytes.length];
while (!inflater.finished()) {
int length = inflater.inflate(result);
if (length > 0) {
baos.write(result, 0, length);
}
}

return baos.toByteArray();
} catch (DataFormatException e) {
throw new SAMLException("Invalid AuthnRequest. Inflating the bytes failed.", e);
}
}

private String deflateAndEncode(byte[] result) {
Deflater deflater = new Deflater(Deflater.DEFLATED, true);
deflater.setInput(result);
deflater.finish();
byte[] deflatedResult = new byte[result.length];
int length = deflater.deflate(deflatedResult);
deflater.end();
byte[] src = Arrays.copyOf(deflatedResult, length);
return Base64.getEncoder().encodeToString(src);
}

private void fixIDs(Element element) {
NamedNodeMap attributes = element.getAttributes();
for (int i = 0; i < attributes.getLength(); i++) {
Expand All @@ -766,42 +699,6 @@ private void fixIDs(Element element) {
}
}

private <T> byte[] marshallToBytes(JAXBElement<T> object, Class<T> type) throws SAMLException {
try {
JAXBContext context = JAXBContext.newInstance(type);
Marshaller marshaller = context.createMarshaller();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
marshaller.marshal(object, baos);
return baos.toByteArray();
} catch (JAXBException e) {
throw new SAMLException("Unable to marshallRequest JAXB SAML object to bytes.", e);
}
}

@SuppressWarnings("SameParameterValue")
private <T> Document marshallToDocument(JAXBElement<T> object, Class<T> type) throws SAMLException {
try {
JAXBContext context = JAXBContext.newInstance(type);
Marshaller marshaller = context.createMarshaller();
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
dbf.setNamespaceAware(true);
DocumentBuilder db = dbf.newDocumentBuilder();
Document document = db.newDocument();
marshaller.marshal(object, document);
return document;
} catch (JAXBException | ParserConfigurationException e) {
throw new SAMLException("Unable to marshallRequest JAXB SAML object to DOM.", e);
}
}

private String marshallToString(Document document) throws TransformerException {
StringWriter sw = new StringWriter();
TransformerFactory tf = TransformerFactory.newInstance();
Transformer transformer = tf.newTransformer();
transformer.transform(new DOMSource(document), new StreamResult(sw));
return sw.toString();
}

private SubjectConfirmation parseConfirmation(SubjectConfirmationType subjectConfirmationType) {
SubjectConfirmation subjectConfirmation = new SubjectConfirmation();
SubjectConfirmationDataType data = subjectConfirmationType.getSubjectConfirmationData();
Expand All @@ -818,32 +715,14 @@ private SubjectConfirmation parseConfirmation(SubjectConfirmationType subjectCon
return subjectConfirmation;
}

private Document parseFromBytes(byte[] bytes) throws SAMLException {
DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
documentBuilderFactory.setNamespaceAware(true);
try {
DocumentBuilder builder = documentBuilderFactory.newDocumentBuilder();
return builder.parse(new ByteArrayInputStream(bytes));
} catch (ParserConfigurationException | SAXException | IOException e) {
throw new SAMLException("Unable to parse SAML v2.0 authentication response", e);
}
}

private NameID parseNameId(NameIDType element) {
NameID nameId = new NameID();
nameId.format = NameIDFormat.fromSAMLFormat(element.getFormat());
nameId.id = element.getValue();
return nameId;
}

private AuthnRequestParseResult parseRequest(byte[] xmlBytes) throws SAMLException {
String xml = new String(xmlBytes, StandardCharsets.UTF_8);
if (logger.isDebugEnabled()) {
logger.debug("SAMLRequest XML is\n{}", xml);
}

AuthnRequestParseResult result = new AuthnRequestParseResult();
result.document = parseFromBytes(xmlBytes);
result.document = newDocumentFromBytes(xmlBytes);
result.authnRequest = unmarshallFromDocument(result.document, AuthnRequestType.class);
result.request = new AuthenticationRequest();
result.request.xml = xml;
Expand Down Expand Up @@ -907,60 +786,6 @@ private AuthnRequestType toAuthnRequest(AuthenticationRequest request, String ve
return authnRequest;
}

private Certificate toCertificate(KeyDescriptorType keyDescriptorType) {
try {
List<Object> keyData = keyDescriptorType.getKeyInfo().getContent();
for (Object keyDatum : keyData) {
if (keyDatum instanceof JAXBElement<?>) {
JAXBElement<?> element = (JAXBElement<?>) keyDatum;
if (element.getDeclaredType() == X509DataType.class) {
X509DataType cert = (X509DataType) element.getValue();
List<Object> certData = cert.getX509IssuerSerialOrX509SKIOrX509SubjectName();
for (Object certDatum : certData) {
element = (JAXBElement<?>) certDatum;
if (element.getName().getLocalPart().equals("X509Certificate")) {
byte[] certBytes = (byte[]) element.getValue();
CertificateFactory cf = CertificateFactory.getInstance("X.509");
return cf.generateCertificate(new ByteArrayInputStream(certBytes));
}
}
}
}
}

return null;
} catch (CertificateException e) {
throw new IllegalArgumentException(e);
}
}

private XMLGregorianCalendar toXMLGregorianCalendar(ZonedDateTime instant) {
if (instant == null) {
return null;
}

return new XMLGregorianCalendarImpl(GregorianCalendar.from(instant));
}

private ZonedDateTime toZonedDateTime(XMLGregorianCalendar instant) {
if (instant == null) {
return null;
}

return instant.toGregorianCalendar().toZonedDateTime();
}

private <T> T unmarshallFromDocument(Document document, Class<T> type) throws SAMLException {
try {
JAXBContext context = JAXBContext.newInstance(type);
Unmarshaller unmarshaller = context.createUnmarshaller();
JAXBElement<T> element = unmarshaller.unmarshal(document, type);
return element.getValue();
} catch (JAXBException e) {
throw new SAMLException("Unable to unmarshall SAML response", e);
}
}

private void verifySignature(Document document, KeySelector keySelector) throws SAMLException {
// Fix the IDs in the entire document per the suggestions at http://stackoverflow.com/questions/17331187/xml-dig-sig-error-after-upgrade-to-java7u25
fixIDs(document.getDocumentElement());
Expand Down

0 comments on commit c66fb68

Please sign in to comment.