From 269a1027495c2b425e87946cf768e7f8a5784d9a Mon Sep 17 00:00:00 2001 From: JCgH4164838Gh792C124B5 <43964333+JCgH4164838Gh792C124B5@users.noreply.github.com> Date: Sun, 28 May 2023 19:43:47 -0400 Subject: [PATCH] Update: - Add a few additional tests to SecurityMemberAccessTest. - Rename some existing tests involving non-static methods to more accurately reflect that. - Add one minor optimization to SecurityMemberAccess. --- .../xwork2/ognl/SecurityMemberAccess.java | 3 +- .../xwork2/ognl/SecurityMemberAccessTest.java | 167 +++++++++++++++++- 2 files changed, 164 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java index 384d6cf244..c21b5b089e 100644 --- a/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java +++ b/core/src/main/java/com/opensymphony/xwork2/ognl/SecurityMemberAccess.java @@ -129,7 +129,8 @@ public boolean isAccessible(Map context, Object target, Member member, String pr return false; } - if (isClassExcluded(targetClass)) { + if (targetClass != memberClass && isClassExcluded(targetClass)) { + // Optimization: Already checked memberClass exclusion, so if-and-only-if targetClass == memberClass, this check is redundant. LOG.warn("Target class [{}] of target [{}] is excluded!", targetClass, target); return false; } diff --git a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java index e0f4ed1839..acf4bbc805 100644 --- a/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java +++ b/core/src/test/java/com/opensymphony/xwork2/ognl/SecurityMemberAccessTest.java @@ -29,11 +29,13 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.regex.Pattern; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class SecurityMemberAccessTest { @@ -383,8 +385,9 @@ public void testAccessStaticField() throws Exception { } @Test - public void testBlockedStaticFieldWhenFlagIsFalse() throws Exception { + public void testBlockedStaticFieldWhenFlagIsTrue() throws Exception { // given + assignNewSma(true); sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); // when @@ -479,6 +482,104 @@ public void testBlockedStaticFieldWhenFlagIsFalse() throws Exception { assertFalse("Access to private final static field is allowed?", actual); } + @Test + public void testBlockedStaticFieldWhenFlagIsFalse() throws Exception { + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + Member method = StaticTester.class.getField("MAX_VALUE"); + boolean actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to public static field is allowed when flag false?", actual); + + // public static final test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.class.getField("MIN_VALUE"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to public final static field is allowed when flag is false?", actual); + + // package static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("PACKAGE_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to package static field is allowed?", actual); + + // package final static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("FINAL_PACKAGE_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to package final static field is allowed?", actual); + + // protected static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("PROTECTED_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to protected static field is allowed?", actual); + + // protected final static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("FINAL_PROTECTED_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to protected final static field is allowed?", actual); + + // private static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("PRIVATE_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to private static field is allowed?", actual); + + // private final static test + // given + assignNewSma(false); + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + method = StaticTester.getFieldByName("FINAL_PRIVATE_STRING"); + actual = sma.isAccessible(context, null, method, null); + + // then + assertFalse("Access to private final static field is allowed?", actual); + } + @Test public void testBlockedStaticFieldWhenClassIsExcluded() throws Exception { // given @@ -506,7 +607,7 @@ public void testBlockStaticMethodAccess() throws Exception { } @Test - public void testBlockStaticAccessIfClassIsExcluded() throws Exception { + public void testBlockAccessIfClassIsExcluded() throws Exception { // given sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); @@ -515,11 +616,25 @@ public void testBlockStaticAccessIfClassIsExcluded() throws Exception { boolean actual = sma.isAccessible(context, Class.class, method, null); // then - assertFalse("Access to static method of excluded class isn't blocked!", actual); + assertFalse("Access to method of excluded class isn't blocked!", actual); + } + + @Test + public void testBlockAccessIfClassIsExcluded_2() throws Exception { + // given + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(ClassLoader.class))); + + // when + Member method = ClassLoader.class.getMethod("loadClass", String.class); + ClassLoader classLoaderTarget = this.getClass().getClassLoader(); + boolean actual = sma.isAccessible(context, classLoaderTarget, method, null); + + // then + assertFalse("Invalid test! Access to method of excluded class isn't blocked!", actual); } @Test - public void testAllowStaticAccessIfClassIsNotExcluded() throws Exception { + public void testAllowAccessIfClassIsNotExcluded() throws Exception { // given sma.setExcludedClasses(new HashSet<>(Collections.singletonList(ClassLoader.class))); @@ -528,7 +643,26 @@ public void testAllowStaticAccessIfClassIsNotExcluded() throws Exception { boolean actual = sma.isAccessible(context, Class.class, method, null); // then - assertTrue("Invalid test! Access to static method of excluded class is blocked!", actual); + assertTrue("Invalid test! Access to method of non-excluded class is blocked!", actual); + } + + @Test + public void testIllegalArgumentExceptionExpectedForTargetMemberMismatch() throws Exception { + // given + sma.setExcludedClasses(new HashSet<>(Collections.singletonList(Class.class))); + + // when + Member method = ClassLoader.class.getMethod("loadClass", String.class); + String mismatchTarget = "misMatchTargetObject"; + try { + boolean actual = sma.isAccessible(context, mismatchTarget, method, null); + + // then + assertFalse("Invalid test! Access to method of excluded class isn't blocked!", actual); + fail("Mismatch between target and member did not cause IllegalArgumentException?"); + } catch (IllegalArgumentException iex) { + // Expected result is this exception + } } @Test @@ -686,10 +820,12 @@ public void setStringField(String stringField) { this.stringField = stringField; } + @Override public String fooLogic() { return "fooLogic"; } + @Override public String barLogic() { return "barLogic"; } @@ -699,6 +835,27 @@ public int hashCode() { return 1; } + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final FooBar other = (FooBar) obj; + if (this.intField != other.intField) { + return false; + } + if (!Objects.equals(this.stringField, other.stringField)) { + return false; + } + return Objects.equals(this.doubleField, other.doubleField); + } + public int getIntField() { return intField; }