Skip to content

Commit

Permalink
The InvalidRequestFilter is more flexible
Browse files Browse the repository at this point in the history
Allowing encoded periods and forward slashes can now be independently enabled

----
Add tests for SavedRequest redirects
  • Loading branch information
lprimak committed Oct 29, 2023
1 parent 44a2d87 commit 8400d08
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

@SuppressWarnings("checkstyle:LineLength")
/**
Expand Down Expand Up @@ -65,6 +66,12 @@ public class InvalidRequestFilter extends AccessControlFilter {

private boolean blockTraversal = true;

private boolean blockEncodedPeriod = true;

private boolean blockEncodedForwardSlash = true;

private boolean blockRewriteTraversal = true;

@Override
protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception {
HttpServletRequest request = WebUtils.toHttp(req);
Expand All @@ -77,12 +84,15 @@ && isValid(request.getServletPath())
&& isValid(request.getPathInfo());
}

@SuppressWarnings("checkstyle:BooleanExpressionComplexity")
private boolean isValid(String uri) {
return !StringUtils.hasText(uri)
|| (!containsSemicolon(uri)
&& !containsBackslash(uri)
&& !containsNonAsciiCharacters(uri))
&& !containsTraversal(uri);
|| (!containsSemicolon(uri)
&& !containsBackslash(uri)
&& !containsNonAsciiCharacters(uri)
&& !containsTraversal(uri)
&& !containsEncodedPeriods(uri)
&& !containsEncodedForwardSlash(uri));
}

@Override
Expand Down Expand Up @@ -125,9 +135,22 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) {

private boolean containsTraversal(String uri) {
if (isBlockTraversal()) {
return !(isNormalized(uri)
&& PERIOD.stream().noneMatch(uri::contains)
&& FORWARDSLASH.stream().noneMatch(uri::contains));
return !isNormalized(uri)
|| (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains));
}
return false;
}

private boolean containsEncodedPeriods(String uri) {
if (isBlockEncodedPeriod()) {
return PERIOD.stream().anyMatch(uri::contains);
}
return false;
}

private boolean containsEncodedForwardSlash(String uri) {
if (isBlockEncodedForwardSlash()) {
return FORWARDSLASH.stream().anyMatch(uri::contains);
}
return false;
}
Expand Down Expand Up @@ -189,4 +212,28 @@ public boolean isBlockTraversal() {
public void setBlockTraversal(boolean blockTraversal) {
this.blockTraversal = blockTraversal;
}

public boolean isBlockEncodedPeriod() {
return blockEncodedPeriod;
}

public void setBlockEncodedPeriod(boolean blockEncodedPeriod) {
this.blockEncodedPeriod = blockEncodedPeriod;
}

public boolean isBlockEncodedForwardSlash() {
return blockEncodedForwardSlash;
}

public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) {
this.blockEncodedForwardSlash = blockEncodedForwardSlash;
}

public boolean isBlockRewriteTraversal() {
return blockRewriteTraversal;
}

public void setBlockRewriteTraversal(boolean blockRewriteTraversal) {
this.blockRewriteTraversal = blockRewriteTraversal;
}
}
6 changes: 6 additions & 0 deletions web/src/main/java/org/apache/shiro/web/util/SavedRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ public String getRequestURI() {

public String getRequestUrl() {
StringBuilder requestUrl = new StringBuilder(getRequestURI());

// remove duplicate leading slashes
while (requestUrl.length() > 1 && requestUrl.charAt(1) == '/') {
requestUrl.deleteCharAt(0);
}

if (getQueryString() != null) {
requestUrl.append("?").append(getQueryString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class InvalidRequestFilterTest {
assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii()
assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon()
assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal()
assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal()
assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod()
assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash()
}

@Test
Expand All @@ -58,7 +61,6 @@ class InvalidRequestFilterTest {
}
}


@Test
void testFilterBlocks() {
InvalidRequestFilter filter = new InvalidRequestFilter()
Expand All @@ -72,6 +74,7 @@ class InvalidRequestFilterTest {

assertPathBlocked(filter, "/something", "/;something")
assertPathBlocked(filter, "/something", "/something", "/;")
assertPathBlocked(filter, "/something", "/something", "/.;")
}

@Test
Expand All @@ -80,23 +83,81 @@ class InvalidRequestFilterTest {
assertPathBlocked(filter, "/something/../")
assertPathBlocked(filter, "/something/../bar")
assertPathBlocked(filter, "/something/../bar/")
assertPathBlocked(filter, "/something/%2e%2E/bar/")
assertPathBlocked(filter, "/something/..")
assertPathBlocked(filter, "/..")
assertPathBlocked(filter, "..")
assertPathBlocked(filter, "../")
assertPathBlocked(filter, "%2E./")
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/./")
assertPathBlocked(filter, "/something/./bar")
assertPathBlocked(filter, "/something/\u002e/bar")
assertPathBlocked(filter, "/something/./bar/")
assertPathBlocked(filter, "/something/%2e/bar/")
assertPathBlocked(filter, "/something/%2f/bar/")
assertPathBlocked(filter, "/something/.")
assertPathBlocked(filter, "/.")
assertPathBlocked(filter, "/something/../something/.")
assertPathBlocked(filter, "/something/../something/.")
assertPathBlocked(filter, "/something/.;")
assertPathBlocked(filter, "/something/%2e%3b")

assertPathAllowed(filter, "/something/.bar")
assertPathAllowed(filter, "/.something")
assertPathAllowed(filter, ".something")
}

@Test
void testBlocksEncodedPeriod() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "/%2esomething")
assertPathBlocked(filter, "%2esomething")
assertPathBlocked(filter, "%2E./")
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/%2e;")
assertPathBlocked(filter, "/something/%2e%3b")
assertPathBlocked(filter, "/something/%2e%2E/bar/")
assertPathBlocked(filter, "/something/%2e/bar/")
}

@Test
void testAllowsEncodedPeriod() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockEncodedPeriod(false)
assertPathAllowed(filter, "/%2esomething")
assertPathAllowed(filter, "%2esomething")
assertPathAllowed(filter, "%2E./")
assertPathAllowed(filter, "/something/%2e%2E/bar/")
assertPathAllowed(filter, "/something/%2e/bar/")
}

@Test
void testBlocksEncodedForwardSlash() {
InvalidRequestFilter filter = new InvalidRequestFilter()
assertPathBlocked(filter, "%2F./")
assertPathBlocked(filter, "/something/%2f/bar/")
}

@Test
void testAllowsEncodedForwardSlash() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockEncodedForwardSlash(false)
assertPathAllowed(filter, "%2F./")
assertPathAllowed(filter, "/something/%2f/bar/")
}

@Test
void testBlocksRewriteTraversal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockSemicolon(false)
assertPathBlocked(filter, "/something/..;jsessionid=foobar")
assertPathBlocked(filter, "/something/.;jsessionid=foobar")
}

@Test
void testAllowRewriteTraversal() {
InvalidRequestFilter filter = new InvalidRequestFilter()
filter.setBlockSemicolon(false)
filter.setBlockRewriteTraversal(false)
assertPathAllowed(filter, "/something/..;jsessionid=foobar")
assertPathAllowed(filter, "/something/.;jsessionid=foobar")
}

@Test
Expand Down Expand Up @@ -159,15 +220,11 @@ class InvalidRequestFilterTest {
assertPathAllowed(filter, "/..")
assertPathAllowed(filter, "..")
assertPathAllowed(filter, "../")
assertPathAllowed(filter, "%2E./")
assertPathAllowed(filter, "%2F./")
assertPathAllowed(filter, "/something/./")
assertPathAllowed(filter, "/something/./bar")
assertPathAllowed(filter, "/something/\u002e/bar")
assertPathAllowed(filter, "/something\u002fbar")
assertPathAllowed(filter, "/something/./bar/")
assertPathAllowed(filter, "/something/%2e/bar/")
assertPathAllowed(filter, "/something/%2f/bar/")
assertPathAllowed(filter, "/something/.")
assertPathAllowed(filter, "/.")
assertPathAllowed(filter, "/something/../something/.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.shiro.web.util

import org.junit.jupiter.api.Test

import javax.servlet.http.HttpServletRequest
import static org.hamcrest.MatcherAssert.assertThat
import static org.hamcrest.Matchers.equalTo
import static org.easymock.EasyMock.niceMock
import static org.easymock.EasyMock.expect
import static org.easymock.EasyMock.replay
import static org.easymock.EasyMock.verify

class SavedRequestTest {

@Test
void testGetRequestUrl() {
doTestGetRequestUrl("/foo//bar", "one=two&three=four", "/foo//bar?one=two&three=four")
doTestGetRequestUrl("///foo//bar", "one=two&three=four", "/foo//bar?one=two&three=four")
doTestGetRequestUrl("///foo//bar", "/foo//bar")
doTestGetRequestUrl("/foo", "/foo")
doTestGetRequestUrl("/", "one=two&three=four", "/?one=two&three=four")
doTestGetRequestUrl("/", "/")
doTestGetRequestUrl("//////", "/")
doTestGetRequestUrl("", "")
}

private static void doTestGetRequestUrl(String requestURI, String expected) {
doTestGetRequestUrl(requestURI, null, expected)
}

private static void doTestGetRequestUrl(String requestURI, String query, String expected) {
HttpServletRequest request = niceMock(HttpServletRequest)
expect(request.getRequestURI()).andReturn(requestURI)
expect(request.getQueryString()).andReturn(query)
replay request
assertThat new SavedRequest(request).getRequestUrl(), equalTo(expected)
verify request
}
}

0 comments on commit 8400d08

Please sign in to comment.