diff --git a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/AbstractJwsJsonReaderProvider.java b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/AbstractJwsJsonReaderProvider.java index 1457867fd49..e55c163fc86 100644 --- a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/AbstractJwsJsonReaderProvider.java +++ b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/AbstractJwsJsonReaderProvider.java @@ -73,6 +73,21 @@ protected void validate(JwsJsonConsumer c, JwsSignatureVerifier theSigVerifier) JAXRSUtils.getCurrentMessage().put(JwsJsonConsumer.class, c); } + protected JwsJsonSignatureEntry getValidatedSignatureEntry(JwsJsonConsumer c) { + @SuppressWarnings("unchecked") + List remaining = + (List)JAXRSUtils.getCurrentMessage().get("jws.json.remaining.entries"); + if (remaining != null) { + for (JwsJsonSignatureEntry sigEntry : c.getSignatureEntries()) { + if (!remaining.contains(sigEntry)) { + return sigEntry; + } + } + } + // If there are no recorded remaining entries then the first (or only) entry is valid. + return c.getSignatureEntries().get(0); + } + public Map getEntryProps() { return entryProps; } diff --git a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonClientResponseFilter.java b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonClientResponseFilter.java index 2a0888bbe9a..794291aa05b 100644 --- a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonClientResponseFilter.java +++ b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonClientResponseFilter.java @@ -55,8 +55,7 @@ public void filter(ClientRequestContext req, ClientResponseContext res) throws I res.setEntityStream(new ByteArrayInputStream(bytes)); res.getHeaders().putSingle("Content-Length", Integer.toString(bytes.length)); - // the list is guaranteed to be non-empty - JwsJsonSignatureEntry sigEntry = c.getSignatureEntries().get(0); + JwsJsonSignatureEntry sigEntry = getValidatedSignatureEntry(c); String ct = JoseUtils.checkContentType(sigEntry.getUnionHeader().getContentType(), getDefaultMediaType()); if (ct != null) { res.getHeaders().putSingle("Content-Type", ct); diff --git a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilter.java b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilter.java index d5b7fe8bb47..2dde1824b77 100644 --- a/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilter.java +++ b/rt/rs/security/jose-parent/jose-jaxrs/src/main/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilter.java @@ -62,8 +62,7 @@ public void filter(ContainerRequestContext context) throws IOException { context.setEntityStream(new ByteArrayInputStream(bytes)); context.getHeaders().putSingle("Content-Length", Integer.toString(bytes.length)); - // the list is guaranteed to be non-empty - JwsJsonSignatureEntry sigEntry = c.getSignatureEntries().get(0); + JwsJsonSignatureEntry sigEntry = getValidatedSignatureEntry(c); String ct = JoseUtils.checkContentType(sigEntry.getUnionHeader().getContentType(), getDefaultMediaType()); if (ct != null) { context.getHeaders().putSingle("Content-Type", ct); diff --git a/rt/rs/security/jose-parent/jose-jaxrs/src/test/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilterTest.java b/rt/rs/security/jose-parent/jose-jaxrs/src/test/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilterTest.java new file mode 100644 index 00000000000..f51e0fc397d --- /dev/null +++ b/rt/rs/security/jose-parent/jose-jaxrs/src/test/java/org/apache/cxf/rs/security/jose/jaxrs/JwsJsonContainerRequestFilterTest.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.cxf.rs.security.jose.jaxrs; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.Response; +import org.apache.cxf.bus.managers.PhaseManagerImpl; +import org.apache.cxf.helpers.IOUtils; +import org.apache.cxf.interceptor.Fault; +import org.apache.cxf.jaxrs.impl.ContainerRequestContextImpl; +import org.apache.cxf.jaxrs.impl.MetadataMap; +import org.apache.cxf.message.ExchangeImpl; +import org.apache.cxf.message.Message; +import org.apache.cxf.message.MessageImpl; +import org.apache.cxf.phase.Phase; +import org.apache.cxf.phase.PhaseInterceptor; +import org.apache.cxf.phase.PhaseInterceptorChain; +import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm; +import org.apache.cxf.rs.security.jose.jws.HmacJwsSignatureProvider; +import org.apache.cxf.rs.security.jose.jws.HmacJwsSignatureVerifier; +import org.apache.cxf.rs.security.jose.jws.JwsHeaders; +import org.apache.cxf.rs.security.jose.jws.JwsJsonConsumer; +import org.apache.cxf.rs.security.jose.jws.JwsJsonProducer; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class JwsJsonContainerRequestFilterTest { + private static final String BAD_KEY = "AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75" + + "aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow"; + private static final String GOOD_KEY = "09Y_RK7l5rAY9QY7EblYQNuYbu9cy1j7ovCbkeIyAKN8LIeRL-3H8g" + + "c8kZSYzAQ1uTRC_egZ_8cgZSZa9T5nmQ"; + + @Test + public void testUsesValidatedSignatureMetadata() throws Exception { + String payload = "{\"role\":\"user\"}"; + String signedDocument = createSignedDocument(payload); + + JwsJsonConsumer consumer = new JwsJsonConsumer(signedDocument); + assertEquals("application/xml", + consumer.getSignatureEntries().get(0).getUnionHeader().getContentType()); + assertEquals("application/json", + consumer.getSignatureEntries().get(1).getUnionHeader().getContentType()); + + Message message = new MessageImpl(); + message.setExchange(new ExchangeImpl()); + message.put(Message.HTTP_REQUEST_METHOD, "POST"); + message.setContent(java.io.InputStream.class, + new ByteArrayInputStream(signedDocument.getBytes(StandardCharsets.UTF_8))); + + MetadataMap headers = new MetadataMap<>(); + headers.putSingle(HttpHeaders.CONTENT_TYPE, "application/xml"); + headers.putSingle(HttpHeaders.CONTENT_LENGTH, + Integer.toString(signedDocument.getBytes(StandardCharsets.UTF_8).length)); + message.put(Message.PROTOCOL_HEADERS, headers); + + JwsJsonContainerRequestFilter filter = new JwsJsonContainerRequestFilter(); + filter.setSignatureVerifier(new HmacJwsSignatureVerifier(GOOD_KEY, SignatureAlgorithm.HS256)); + filter.setValidateHttpHeaders(false); + + runInPhase(message, () -> { + try { + ContainerRequestContext context = new ContainerRequestContextImpl(message, true, false); + filter.filter(context); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + }); + + assertNull(message.getExchange().get(Response.class)); + assertEquals("application/json", + headers.getFirst(HttpHeaders.CONTENT_TYPE)); + assertEquals(payload, + IOUtils.readStringFromStream(message.getContent(java.io.InputStream.class))); + } + + private String createSignedDocument(String payload) { + JwsJsonProducer producer = new JwsJsonProducer(payload); + producer.signWith(new HmacJwsSignatureProvider(BAD_KEY, SignatureAlgorithm.HS256), + createHeaders("application/xml")); + producer.signWith(new HmacJwsSignatureProvider(GOOD_KEY, SignatureAlgorithm.HS256), + createHeaders("application/json")); + return producer.getJwsJsonSignedDocument(); + } + + private JwsHeaders createHeaders(String contentType) { + JwsHeaders headers = new JwsHeaders(); + headers.setSignatureAlgorithm(SignatureAlgorithm.HS256); + headers.setContentType(contentType); + headers.setHeader("http." + HttpHeaders.CONTENT_TYPE, contentType); + return headers; + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void runInPhase(Message message, Runnable action) { + PhaseInterceptorChain chain = new PhaseInterceptorChain(new PhaseManagerImpl().getInPhases()); + chain.add(new PhaseInterceptor() { + @Override + public void handleMessage(Message message) throws Fault { + action.run(); + } + + @Override + public void handleFault(Message message) { + } + + @Override + public java.util.Set getAfter() { + return Collections.emptySet(); + } + + @Override + public java.util.Set getBefore() { + return Collections.emptySet(); + } + + @Override + public String getId() { + return "test-jws-json-request-filter"; + } + + @Override + public String getPhase() { + return Phase.INVOKE; + } + + @Override + public java.util.Collection getAdditionalInterceptors() { + return Collections.emptyList(); + } + }); + message.setInterceptorChain(chain); + chain.doIntercept(message); + } +} \ No newline at end of file