From 15e85b01d51dbbd1981d1f311cd7eff4add1c67e Mon Sep 17 00:00:00 2001 From: Albumen Kevin Date: Mon, 2 Aug 2021 13:50:33 +0800 Subject: [PATCH] Add super class check --- .../com/caucho/hessian/io/ClassFactory.java | 61 +++++++++++++++---- .../com/caucho/hessian/io/DenyListTest.java | 11 ++++ .../com/caucho/hessian/io/TestClass.java | 4 ++ .../com/caucho/hessian/io/TestInterface.java | 24 ++++++++ 4 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java diff --git a/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java b/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java index d5976ba23..d64b6fe5e 100644 --- a/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java +++ b/src/main/java/com/alibaba/com/caucho/hessian/io/ClassFactory.java @@ -54,6 +54,7 @@ import java.io.InputStreamReader; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -71,6 +72,7 @@ protected static final Logger log = Logger.getLogger(ClassFactory.class.getName()); private static final ArrayList _staticAllowList; + private static final Map _allowSubClassSet = new ConcurrentHashMap<>(); private static final Map _allowClassSet = new ConcurrentHashMap<>(); private ClassLoader _loader; @@ -88,10 +90,43 @@ public Class load(String className) throws ClassNotFoundException { if (isAllow(className)) { - return Class.forName(className, false, _loader); + Class aClass = Class.forName(className, false, _loader); + + if (_allowClassSet.containsKey(className)) { + return aClass; + } + + if (aClass.getInterfaces().length > 0) { + for (Class anInterface : aClass.getInterfaces()) { + if(!isAllow(anInterface.getName())) { + log.log(Level.SEVERE, className + "'s interfaces: " + anInterface.getName() + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead."); + return HashMap.class; + } + } + } + + List> allSuperClasses = new LinkedList<>(); + + Class superClass = aClass.getSuperclass(); + while (superClass != null) { + // add current super class + allSuperClasses.add(superClass); + superClass = superClass.getSuperclass(); + } + + for (Class aSuperClass : allSuperClasses) { + if(!isAllow(aSuperClass.getName())) { + log.log(Level.SEVERE, className + "'s superClass: " + aSuperClass.getName() + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead."); + return HashMap.class; + } + + } + + _allowClassSet.put(className, className); + return aClass; } else { - log.log(Level.SEVERE, className + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead."); + log.log(Level.SEVERE, className + " in blacklist or not in whitelist, deserialization with type 'HashMap' instead."); return HashMap.class; } } @@ -104,19 +139,16 @@ private boolean isAllow(String className) return true; } - if (_allowClassSet.containsKey(className)) { + if (_allowSubClassSet.containsKey(className)) { return true; } - int size = allowList.size(); - for (int i = 0; i < size; i++) { - Allow allow = allowList.get(i); - + for (Allow allow : allowList) { Boolean isAllow = allow.allow(className); if (isAllow != null) { if (isAllow) { - _allowClassSet.put(className, className); + _allowSubClassSet.put(className, className); } return isAllow; } @@ -126,13 +158,14 @@ private boolean isAllow(String className) return false; } - _allowClassSet.put(className, className); + _allowSubClassSet.put(className, className); return true; } public void setWhitelist(boolean isWhitelist) { _allowClassSet.clear(); + _allowSubClassSet.clear(); _isWhitelist = isWhitelist; initAllow(); @@ -141,6 +174,7 @@ public void setWhitelist(boolean isWhitelist) public void allow(String pattern) { _allowClassSet.clear(); + _allowSubClassSet.clear(); initAllow(); synchronized (this) { @@ -151,6 +185,7 @@ public void allow(String pattern) public void deny(String pattern) { _allowClassSet.clear(); + _allowSubClassSet.clear(); initAllow(); synchronized (this) { @@ -158,7 +193,7 @@ public void deny(String pattern) } } - private String toPattern(String pattern) + private static String toPattern(String pattern) { pattern = pattern.replace(".", "\\."); pattern = pattern.replace("*", ".*"); @@ -233,7 +268,11 @@ Boolean allow(String className) if (denyClass.startsWith("#")) { continue; } - _staticAllowList.add(new AllowPrefix(denyClass, false)); + if (denyClass.endsWith(".")) { + _staticAllowList.add(new AllowPrefix(denyClass, false)); + } else { + _staticAllowList.add(new Allow(toPattern(denyClass), false)); + } } } catch (IOException ignore) { diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java b/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java index d05a0c755..3c8375dff 100644 --- a/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java +++ b/src/test/java/com/alibaba/com/caucho/hessian/io/DenyListTest.java @@ -18,6 +18,7 @@ import org.junit.Assert; import org.junit.Test; +import sun.rmi.transport.StreamRemoteCall; import java.lang.reflect.Array; import java.util.HashMap; @@ -35,6 +36,15 @@ public void testDeny() throws ClassNotFoundException { Assert.assertEquals(HashMap.class, classFactory.load("java.beans.C")); Assert.assertEquals(HashMap.class, classFactory.load("java.beans.D")); Assert.assertEquals(HashMap.class, classFactory.load("java.beans.E")); + Assert.assertEquals(HashMap.class, classFactory.load("sun.rmi.transport.StreamRemoteCall")); + + classFactory.deny(TestClass.class.getName()); + Assert.assertEquals(HashMap.class, classFactory.load(TestClass.class.getName())); + Assert.assertEquals(HashMap.class, classFactory.load(TestClass1.class.getName())); + + classFactory.deny(TestInterface.class.getName()); + Assert.assertEquals(HashMap.class, classFactory.load(TestInterface.class.getName())); + Assert.assertEquals(HashMap.class, classFactory.load(TestImpl.class.getName())); } @Test @@ -47,5 +57,6 @@ public void testAllow() throws ClassNotFoundException { Assert.assertEquals(List.class, classFactory.load(List.class.getName())); Assert.assertEquals(Array.class, classFactory.load(Array.class.getName())); Assert.assertEquals(LinkedList.class, classFactory.load(LinkedList.class.getName())); + Assert.assertEquals(RuntimeException.class, classFactory.load(RuntimeException.class.getName())); } } diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java b/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java index 43907b381..0420962bc 100644 --- a/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java +++ b/src/test/java/com/alibaba/com/caucho/hessian/io/TestClass.java @@ -20,3 +20,7 @@ public class TestClass implements Serializable { } + +class TestClass1 extends TestClass { + +} \ No newline at end of file diff --git a/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java b/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java new file mode 100644 index 000000000..72d6d3c1c --- /dev/null +++ b/src/test/java/com/alibaba/com/caucho/hessian/io/TestInterface.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.com.caucho.hessian.io; + +public interface TestInterface { +} + +class TestImpl implements TestInterface { + +} \ No newline at end of file