diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8658e0b0a2..1f7ce2daa6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -593,7 +593,7 @@ jobs: key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | ${{ runner.os }}-maven- - - uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18 + - uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24 - name: Install fory java run: cd java && mvn -T10 --no-transfer-progress clean install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true && cd - - name: Test @@ -622,7 +622,7 @@ jobs: key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | ${{ runner.os }}-maven- - - uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18 + - uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24 - name: Run Scala Xlang Test env: FORY_SCALA_JAVA_CI: "1" @@ -655,7 +655,7 @@ jobs: key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | ${{ runner.os }}-maven- - - uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18 + - uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24 - name: Install Fory Java run: | cd java diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 0b03800dd6..6dca2e503a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -470,6 +470,8 @@ public MemoryBuffer readBufferObject() { if (size < 0) { throw new IllegalArgumentException("Buffer object size must be non-negative: " + size); } + // This returns a zero-copy slice. Allocation limits belong to serializers which allocate + // objects from the slice, not to the buffer-object transport itself. buffer.checkReadableBytes(size); int readerIndex = buffer.readerIndex(); MemoryBuffer slice = buffer.slice(readerIndex, size); diff --git a/java/fory-core/src/main/java/org/apache/fory/io/ClassLoaderObjectInputStream.java b/java/fory-core/src/main/java/org/apache/fory/io/ClassLoaderObjectInputStream.java index 38c70634fd..9e9f7fab5d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/io/ClassLoaderObjectInputStream.java +++ b/java/fory-core/src/main/java/org/apache/fory/io/ClassLoaderObjectInputStream.java @@ -19,10 +19,12 @@ import java.io.IOException; import java.io.InputStream; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; import java.io.StreamCorruptedException; import java.lang.reflect.Proxy; +import org.apache.fory.resolver.TypeResolver; // Derived from // https://github.com/apache/commons-io/blob/5168fa5e9de9dd2ff6ace3f34226397a4faebc14/src/main/java/org/apache/commons/io/input/ClassLoaderObjectInputStream.java. @@ -38,6 +40,8 @@ public class ClassLoaderObjectInputStream extends ObjectInputStream { /** The class loader to use. */ private final ClassLoader classLoader; + private final TypeResolver typeResolver; + /** * Constructs a new ClassLoaderObjectInputStream. * @@ -50,6 +54,14 @@ public ClassLoaderObjectInputStream(ClassLoader classLoader, InputStream inputSt throws IOException, StreamCorruptedException { super(inputStream); this.classLoader = classLoader; + typeResolver = null; + } + + public ClassLoaderObjectInputStream(TypeResolver typeResolver, InputStream inputStream) + throws IOException, StreamCorruptedException { + super(inputStream); + this.typeResolver = typeResolver; + classLoader = typeResolver.getClassLoader(); } /** @@ -67,10 +79,13 @@ protected Class resolveClass(ObjectStreamClass objectStreamClass) Class clazz = Class.forName(objectStreamClass.getName(), false, classLoader); if (clazz != null) { // the classloader knows of the class + checkClass(clazz); return clazz; } else { // classloader knows not of class, let the super classloader do it - return super.resolveClass(objectStreamClass); + Class superClass = super.resolveClass(objectStreamClass); + checkClass(superClass); + return superClass; } } @@ -91,11 +106,27 @@ protected Class resolveProxyClass(String[] interfaces) Class[] interfaceClasses = new Class[interfaces.length]; for (int i = 0; i < interfaces.length; i++) { interfaceClasses[i] = Class.forName(interfaces[i], false, classLoader); + checkClass(interfaceClasses[i]); } try { return Proxy.getProxyClass(classLoader, interfaceClasses); } catch (IllegalArgumentException e) { - return super.resolveProxyClass(interfaces); + Class proxyClass = super.resolveProxyClass(interfaces); + checkClass(proxyClass); + return proxyClass; + } + } + + private void checkClass(Class cls) throws InvalidClassException { + if (typeResolver == null) { + return; + } + try { + typeResolver.checkClassForDeserialization(cls); + } catch (RuntimeException e) { + InvalidClassException exception = new InvalidClassException(cls.getName(), e.getMessage()); + exception.initCause(e); + throw exception; } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index c4d1779911..074fcc6e36 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -46,6 +46,7 @@ import org.apache.fory.collection.UInt32List; import org.apache.fory.collection.UInt64List; import org.apache.fory.collection.UInt8List; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; import org.apache.fory.memory.MemoryBuffer; @@ -67,6 +68,7 @@ public class FieldTypes { private static final Logger LOG = LoggerFactory.getLogger(FieldTypes.class); + private static final int MAX_ARRAY_DIMS = 255; /** Returns true if can use current field type. */ static boolean useFieldType(Class parsedType, Descriptor descriptor) { @@ -525,6 +527,9 @@ public static FieldType read( return new CollectionFieldType(-1, nullable, trackingRef, read(buffer, resolver)); } else if (kind == 3) { int dims = buffer.readVarUInt32Small7(); + if (dims <= 0 || dims > MAX_ARRAY_DIMS) { + throw new DeserializationException("Invalid array dimensions in TypeDef: " + dims); + } return new ArrayFieldType(-1, nullable, trackingRef, read(buffer, resolver), dims); } else if (kind == 4) { return new EnumFieldType(nullable, -1, -1); diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index a4d1256017..7115414abe 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -83,7 +83,11 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, int rootTypeId = nativeTypeId(bodyHeader >>> 4); int numClasses = bodyHeader & NUM_CLASS_THRESHOLD; if (numClasses == NUM_CLASS_THRESHOLD) { - numClasses += typeDefBuf.readVarUInt32Small7(); + int extraClasses = typeDefBuf.readVarUInt32Small7(); + if (extraClasses < 0 || extraClasses > Integer.MAX_VALUE - NUM_CLASS_THRESHOLD - 1) { + throw new DeserializationException("Invalid TypeDef class count"); + } + numClasses += extraClasses; } numClasses += 1; String className; @@ -95,6 +99,9 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, // | num fields + register flag | header + package name | header + class name // | header + type id + field name | next field info | ... | int currentClassHeader = typeDefBuf.readVarUInt32Small7(); + if (currentClassHeader < 0) { + throw new DeserializationException("Invalid TypeDef field count"); + } boolean isRegistered = (currentClassHeader & 0b1) != 0; int numFields = currentClassHeader >>> 1; Class currentClass = null; @@ -269,7 +276,7 @@ static void validateParsedTypeDefHash(long id, byte[] encoded) { private static List readFieldsInfo( MemoryBuffer buffer, ClassResolver resolver, String className, int numFields) { - List fieldInfos = new ArrayList<>(numFields); + List fieldInfos = new ArrayList<>(); for (int i = 0; i < numFields; i++) { int header = buffer.readByte() & 0xff; // `3 bits size + 2 bits field name encoding + nullability flag + ref tracking flag` diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java index b184656c47..83ca2287bc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java @@ -76,7 +76,11 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu } numFields = header & SMALL_NUM_FIELDS_THRESHOLD; if (numFields == SMALL_NUM_FIELDS_THRESHOLD) { - numFields += buffer.readVarUInt32Small7(); + int extraFields = buffer.readVarUInt32Small7(); + if (extraFields < 0 || extraFields > Integer.MAX_VALUE - SMALL_NUM_FIELDS_THRESHOLD) { + throw new DeserializationException("Invalid TypeDef field count"); + } + numFields += extraFields; } if (named) { String namespace = readPkgName(buffer); @@ -201,7 +205,7 @@ static int nonStructTypeId(int kindCode) { // | header + type info + field name | ... | header + type info + field name | private static List readFieldsInfo( MemoryBuffer buffer, XtypeResolver resolver, String className, int numFields) { - List fieldInfos = new ArrayList<>(numFields); + List fieldInfos = new ArrayList<>(); for (int i = 0; i < numFields; i++) { // header: 2 bits field name encoding + 4 bits size + nullability flag + ref tracking flag byte header = buffer.readByte(); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 49c4a6143f..3a94bc805b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.io.Serializable; import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -57,6 +58,9 @@ import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; import java.util.TimeZone; import java.util.TreeMap; import java.util.TreeSet; @@ -1908,6 +1912,7 @@ private boolean isSecure(Class cls) { if (config.requireClassRegistration()) { return Functions.isLambda(cls) || ReflectionUtils.isJdkProxy(cls) + || isDefaultSafeClassToken(cls) || extRegistry.registeredClassIdMap.get(cls) != null || shimDispatcher.contains(cls); } else { @@ -1915,6 +1920,31 @@ private boolean isSecure(Class cls) { } } + private static boolean isDefaultSafeClassToken(Class cls) { + return cls == Serializable.class || cls == Externalizable.class || isDefaultSafeInterface(cls); + } + + private static boolean isDefaultSafeInterface(Class cls) { + return cls.isInterface() + && (cls == Collection.class + || cls == List.class + || cls == Set.class + || cls == Map.class + || cls == SortedMap.class + || cls == SortedSet.class + || cls.getName().startsWith("java.util.function.") + || !hasDefaultMethods(cls)); + } + + private static boolean hasDefaultMethods(Class cls) { + for (Method method : cls.getMethods()) { + if (method.isDefault()) { + return true; + } + } + return false; + } + /** * Write class info to buffer. TODO(chaokunyang): The method should try to write * aligned data to reduce cpu instruction overhead. `writeTypeInfo` is the last step before @@ -2025,11 +2055,22 @@ private TypeInfo buildClassInfo(Class cls) { } /** - * Read serialized java classname. Note that the object of the class can be non-serializable. For - * serializable object, {@link #readTypeInfo(ReadContext)} or {@link #readTypeInfo(ReadContext, - * TypeInfoHolder)} should be invoked. + * Read a serialized Java class token. + * + *

For named-class tokens, this method enforces deserialization class policy before returning + * the class, including the disallowed list and registration or TypeChecker checks. For registered + * type-id tokens, it returns the registered class whose admission was already checked during + * registration. + * + *

Note that the object of the class can be non-serializable. For serializable object, {@link + * #readTypeInfo(ReadContext)} or {@link #readTypeInfo(ReadContext, TypeInfoHolder)} should be + * invoked. */ public Class readClassInternal(ReadContext readContext) { + return readClassInternal(readContext, true); + } + + private Class readClassInternal(ReadContext readContext, boolean checkNamedClass) { MemoryBuffer buffer = readContext.getBuffer(); int header = buffer.readVarUInt32Small14(); if ((header & 0b1) != 0) { @@ -2038,7 +2079,7 @@ public Class readClassInternal(ReadContext readContext) { MetaStringReader metaStringReader = readContext.getMetaStringReader(); EncodedMetaString packageBytes = metaStringReader.readMetaStringWithFlag(buffer, header); EncodedMetaString simpleClassNameBytes = metaStringReader.readMetaString(buffer); - return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes).type; + return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes, checkNamedClass).type; } int typeId = header >>> 1; switch (typeId) { @@ -2055,6 +2096,11 @@ public Class readClassInternal(ReadContext readContext) { } } + @Internal + public Class readClassInternalUnchecked(ReadContext readContext) { + return readClassInternal(readContext, false); + } + private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTypeId) { TypeInfo typeInfo; if (userTypeId != INVALID_USER_TYPE_ID) { @@ -2069,10 +2115,16 @@ private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTyp @Override protected TypeInfo loadBytesToTypeInfo( EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes) { + return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes, true); + } + + private TypeInfo loadBytesToTypeInfo( + EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes, boolean checkClass) { TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes, simpleClassNameBytes); TypeInfo typeInfo = compositeNameBytes2TypeInfo.get(typeNameBytes); if (typeInfo == null) { - typeInfo = populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes); + typeInfo = + populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes, checkClass); } // Note: Don't create serializer here - this method is used by both readTypeInfo // (which needs serializer) and readClassInternal (which doesn't need serializer). @@ -2102,28 +2154,47 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) { private TypeInfo populateBytesToTypeInfo( TypeNameBytes typeNameBytes, EncodedMetaString packageBytes, - EncodedMetaString simpleClassNameBytes) { + EncodedMetaString simpleClassNameBytes, + boolean checkClass) { String packageName = packageBytes.decode(PACKAGE_DECODER); String className = simpleClassNameBytes.decode(TYPE_NAME_DECODER); ClassSpec classSpec = Encoders.decodePkgAndClass(packageName, className); Class cls = loadClass(classSpec.entireClassName, classSpec.isEnum, classSpec.dimension); + boolean unknownClass = UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls)); + if (checkClass && !unknownClass) { + checkClassForDeserialization(cls); + } int typeId = buildUnregisteredTypeId(cls, null); TypeInfo typeInfo = new TypeInfo(cls, packageBytes, simpleClassNameBytes, null, typeId, INVALID_USER_TYPE_ID); - if (UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) { + if (unknownClass) { typeInfo.serializer = UnknownClassSerializers.getSerializer(this, classSpec.entireClassName, cls); - } else { + } else if (checkClass) { // don't create serializer here, if the class is an interface, // there won't be serializer since interface has no instance. if (!classInfoMap.containsKey(cls)) { classInfoMap.put(cls, typeInfo); } } - compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo); + if (checkClass) { + compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo); + } return typeInfo; } + @Internal + @Override + public void checkClassForDeserialization(Class cls) { + if (UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) { + return; + } + DisallowedList.checkNotInDisallowedList(cls.getName()); + if (!isSecure(cls)) { + throw new InsecureException(generateSecurityMsg(cls)); + } + } + public Class loadClassForMeta(String className, boolean isEnum, int arrayDims) { String pkg = ReflectionUtils.getPackage(className); String typeName = ReflectionUtils.getClassNameWithoutPackage(className); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index f249884d07..d60a72ce3c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -176,6 +176,9 @@ public final ClassLoader getClassLoader() { return extRegistry.classLoader; } + @Internal + public void checkClassForDeserialization(Class cls) {} + public final SharedRegistry getSharedRegistry() { return sharedRegistry; } @@ -1011,6 +1014,8 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) { return typeInfo; } Class cls = loadClass(typeDef.getClassSpec()); + // A wire TypeDef may create a compatible serializer; admit the class before caching it by id. + checkClassForDeserialization(cls); if (!typeDef.isStructSchemaKind() && !UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) { typeInfo = getTypeInfo(cls); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index dc22b47edb..b0f5f99e04 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -485,6 +485,14 @@ private static void readAndSkipLayerClassMeta(ReadContext readContext) { private static List readSuppressedExceptions(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int numSuppressedExceptions = buffer.readVarUInt32(); + int maxCollectionSize = readContext.getConfig().maxCollectionSize(); + if (numSuppressedExceptions < 0 || numSuppressedExceptions > maxCollectionSize) { + throw new ForyException( + "Throwable suppressed exception count " + + numSuppressedExceptions + + " exceeds max collection size " + + maxCollectionSize); + } List suppressedExceptions = new ArrayList<>(numSuppressedExceptions); for (int i = 0; i < numSuppressedExceptions; i++) { suppressedExceptions.add((Throwable) readContext.readRef()); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/JavaSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/JavaSerializer.java index eb0497dc28..397c3ac151 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/JavaSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/JavaSerializer.java @@ -101,8 +101,7 @@ public Object read(ReadContext readContext) { ObjectInputStream objectInputStream = (ObjectInputStream) readContext.getContextObject(objectInput); if (objectInputStream == null) { - objectInputStream = - new ClassLoaderObjectInputStream(typeResolver.getClassLoader(), objectInput); + objectInputStream = new ClassLoaderObjectInputStream(typeResolver, objectInput); readContext.putContextObject(objectInput, objectInputStream); } return objectInputStream.readObject(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index fabaa64afd..2e7a8df801 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -291,7 +291,7 @@ public Object read(ReadContext readContext) { ClassResolver classResolver = (ClassResolver) typeResolver; TreeMap callbacks = new TreeMap<>(Collections.reverseOrder()); for (int i = 0; i < numClasses; i++) { - Class currentClass = classResolver.readClassInternal(readContext); + Class currentClass = classResolver.readClassInternalUnchecked(readContext); // Find the matching local slot for sender's class SlotInfo matchedSlot = null; @@ -324,6 +324,7 @@ public Object read(ReadContext readContext) { if (matchedSlot == null) { // Sender has a layer that receiver doesn't have - read TypeDef and skip the data + classResolver.checkClassForDeserialization(currentClass); skipUnknownLayerData(readContext, currentClass); continue; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java index e359979a08..0962e0710b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java @@ -25,6 +25,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.platform.AndroidSupport; @@ -42,6 +43,7 @@ public class SerializedLambdaSerializer extends Serializer { static final Class SERIALIZED_LAMBDA = SerializedLambda.class; private static final MethodHandle READ_RESOLVE_HANDLE; private final TypeResolver typeResolver; + private final int maxCollectionSize; static { if (AndroidSupport.IS_ANDROID) { @@ -62,6 +64,7 @@ public class SerializedLambdaSerializer extends Serializer { public SerializedLambdaSerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), cls); this.typeResolver = typeResolver; + maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); Preconditions.checkArgument(cls == SERIALIZED_LAMBDA); } @@ -105,7 +108,8 @@ public Object copy(CopyContext copyContext, Object value) { serializedLambda.getImplMethodName(), serializedLambda.getImplMethodSignature(), serializedLambda.getInstantiatedMethodType(), - capturedArgs); + capturedArgs, + false); } @Override @@ -127,6 +131,9 @@ Object readUnresolved(ReadContext readContext) { int implMethodKind = buffer.readVarInt32(); String instantiatedMethodType = readContext.readStringRef(); int capturedArgCount = buffer.readVarUInt32Small7(); + if (capturedArgCount < 0 || capturedArgCount > maxCollectionSize) { + throwInvalidCapturedArgCount(capturedArgCount); + } Object[] capturedArgs = new Object[capturedArgCount]; for (int i = 0; i < capturedArgCount; i++) { capturedArgs[i] = readContext.readRef(); @@ -141,7 +148,20 @@ Object readUnresolved(ReadContext readContext) { implMethodName, implMethodSignature, instantiatedMethodType, - capturedArgs); + capturedArgs, + true); + } + + private void throwInvalidCapturedArgCount(int capturedArgCount) { + if (capturedArgCount < 0) { + throw new DeserializationException( + "SerializedLambda captured arg count must be non-negative: " + capturedArgCount); + } + throw new DeserializationException( + "SerializedLambda captured arg count " + + capturedArgCount + + " exceeds max collection size " + + maxCollectionSize); } static Object readResolve(Object replacement) { @@ -170,9 +190,10 @@ private SerializedLambda newSerializedLambda( String implMethodName, String implMethodSignature, String instantiatedMethodType, - Object[] capturedArgs) { + Object[] capturedArgs, + boolean checkCapturingClass) { return new SerializedLambda( - loadCapturingClass(capturingClass), + loadCapturingClass(capturingClass, checkCapturingClass), functionalInterfaceClass, functionalInterfaceMethodName, functionalInterfaceMethodSignature, @@ -184,17 +205,27 @@ private SerializedLambda newSerializedLambda( capturedArgs); } - private Class loadCapturingClass(String className) { + private Class loadCapturingClass(String className, boolean checkClass) { String binaryClassName = className.replace('/', '.'); try { - return Class.forName(binaryClassName, false, typeResolver.getClassLoader()); + return loadCapturingClass(binaryClassName, typeResolver.getClassLoader(), checkClass); } catch (ClassNotFoundException e) { try { - return Class.forName( - binaryClassName, false, Thread.currentThread().getContextClassLoader()); + return loadCapturingClass( + binaryClassName, Thread.currentThread().getContextClassLoader(), checkClass); } catch (ClassNotFoundException ex) { throw new RuntimeException("Can't load capturing class " + binaryClassName, ex); } } } + + private Class loadCapturingClass(String className, ClassLoader classLoader, boolean checkClass) + throws ClassNotFoundException { + Class cls = Class.forName(className, false, classLoader); + if (checkClass) { + // JDK SerializedLambda readResolve invokes restoration code on the capturing class. + typeResolver.checkClassForDeserialization(cls); + } + return cls; + } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java index 8da69a667b..4bff2392f7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java @@ -699,6 +699,8 @@ public void write(WriteContext writeContext, Class value) { @Override public Class read(ReadContext readContext) { + // A wire-provided Class value can later drive reflection, proxy creation, or serializer + // selection, so class literals must stay under the same registration/TypeChecker boundary. return ((ClassResolver) readContext.getTypeResolver()).readClassInternal(readContext); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java index bacdc6b29b..939ee9aad6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java @@ -40,6 +40,7 @@ import org.apache.fory.config.Config; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.LittleEndian; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.NativeByteOrder; @@ -139,6 +140,7 @@ private static class Offset { private final boolean compressString; private final boolean writeNumUtf16BytesForUtf8Encoding; private final boolean xlang; + private final long maxBinarySize; // set default length to 0, since char array and bytes array won't be used at the same time. private static final byte[] EMPTY_BYTES_STUB = new byte[0]; @@ -157,6 +159,7 @@ public StringSerializer(Config config) { Preconditions.checkArgument(compressString, "compress string muse be enabled for xlang mode"); } writeNumUtf16BytesForUtf8Encoding = config.writeNumUtf16BytesForUtf8Encoding(); + maxBinarySize = config.maxBinarySize(); } @Override @@ -221,7 +224,7 @@ public static Expression readStringExpr( public String readBytesString(MemoryBuffer buffer) { long header = buffer.readVarUint36Small(); byte coder = (byte) (header & 0b11); - int numBytes = (int) (header >>> 2); + int numBytes = readStringSize(header); byte[] bytes; if (!NativeByteOrder.IS_LITTLE_ENDIAN && coder == UTF16) { bytes = readBytesUTF16BE(buffer, numBytes); @@ -239,7 +242,7 @@ public String readBytesString(MemoryBuffer buffer) { public String readCharsString(MemoryBuffer buffer) { long header = buffer.readVarUint36Small(); byte coder = (byte) (header & 0b11); - int numBytes = (int) (header >>> 2); + int numBytes = readStringSize(header); char[] chars; if (coder == LATIN1) { chars = readCharsLatin1(buffer, numBytes); @@ -255,7 +258,7 @@ public String readCharsString(MemoryBuffer buffer) { public String readCompressedBytesString(MemoryBuffer buffer) { long header = buffer.readVarUint36Small(); byte coder = (byte) (header & 0b11); - int numBytes = (int) (header >>> 2); + int numBytes = readStringSize(header); if (coder == UTF8) { byte[] data; if (writeNumUtf16BytesForUtf8Encoding) { @@ -345,7 +348,7 @@ String readBytesUTF8ForXlang(MemoryBuffer buffer, int numBytes) { public String readCompressedCharsString(MemoryBuffer buffer) { long header = buffer.readVarUint36Small(); byte coder = (byte) (header & 0b11); - int numBytes = (int) (header >>> 2); + int numBytes = readStringSize(header); char[] chars; if (coder == LATIN1) { chars = readCharsLatin1(buffer, numBytes); @@ -436,13 +439,14 @@ private void writeStringSlow(MemoryBuffer buffer, String value) { private String readStringSlow(MemoryBuffer buffer) { long header = buffer.readVarUint36Small(); byte coder = (byte) (header & 0b11); - int numBytes = (int) (header >>> 2); + int numBytes = readStringSize(header); if (coder == LATIN1) { return new String(readBytesUnCompressedUTF16(buffer, numBytes), StandardCharsets.ISO_8859_1); } else if (coder == UTF16) { return new String(readCharsUTF16(buffer, numBytes)); } else if (coder == UTF8) { int utf8Bytes = writeNumUtf16BytesForUtf8Encoding ? buffer.readInt32() : numBytes; + checkStringSize(utf8Bytes); return new String(buffer.readBytes(utf8Bytes), StandardCharsets.UTF_8); } else { throw new RuntimeException("Unknown coder type " + coder); @@ -628,6 +632,7 @@ public byte[] readBytesUTF8(MemoryBuffer buffer, int numBytes) { private byte[] readBytesUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { int udf8Bytes = buffer.readInt32(); + checkStringSize(udf8Bytes); byte[] bytes = new byte[numBytes]; // noinspection Duplicates buffer.checkReadableBytes(udf8Bytes); @@ -690,8 +695,10 @@ public String readCharsUTF8(MemoryBuffer buffer, int numBytes) { } public String readCharsUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { + checkStringSize(numBytes); int udf16Chars = numBytes >> 1; int udf8Bytes = buffer.readInt32(); + checkStringSize(udf8Bytes); char[] chars = new char[udf16Chars]; // noinspection Duplicates buffer.checkReadableBytes(udf8Bytes); @@ -710,6 +717,25 @@ public String readCharsUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { return newCharsStringZeroCopy(chars); } + private int readStringSize(long header) { + long size = header >>> 2; + if (size > maxBinarySize) { + throwStringSizeOutOfBounds(size); + } + return (int) size; + } + + private void checkStringSize(int size) { + if (size < 0 || size > maxBinarySize) { + throwStringSizeOutOfBounds(size); + } + } + + private void throwStringSizeOutOfBounds(long size) { + throw new DeserializationException( + "String payload size " + size + " is outside allowed range [0, " + maxBinarySize + "]"); + } + public void writeCharsLatin1(MemoryBuffer buffer, char[] chars, int numBytes) { int writerIndex = buffer.writerIndex(); long header = ((long) numBytes << 2) | LATIN1; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 953314a03d..43a464b12f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -19,7 +19,15 @@ package org.apache.fory.serializer.collection; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Externalizable; +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidObjectException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.lang.invoke.MethodHandle; import java.util.Collection; import java.util.Collections; @@ -29,6 +37,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; @@ -39,17 +48,15 @@ import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; import org.apache.fory.memory.MemoryBuffer; -import org.apache.fory.platform.AndroidSupport; -import org.apache.fory.platform.UnsafeOps; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.resolver.ClassResolver; import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.ExternalizableSerializer; -import org.apache.fory.serializer.JavaSerializer; import org.apache.fory.serializer.ReplaceResolveSerializer; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.Serializers; +import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.Preconditions; /** @@ -333,24 +340,35 @@ public Map newMap(CopyContext copyContext, Map originMap) { } public static class EnumMapSerializer extends MapSerializer { - private static final byte NORMAL_ENUM_MAP = 0; - private static final byte JAVA_SERIALIZED_EMPTY_ENUM_MAP = 1; - - private static final class KeyTypeFieldOffset { - // Make offset compatible with graalvm native image. - private static final long VALUE; + private static final class CapturingObjectInputStream extends ObjectInputStream { + private final ClassLoader fallbackLoader; + private Class enumClass; + + private CapturingObjectInputStream(InputStream in, ClassLoader fallbackLoader) + throws IOException { + super(in); + this.fallbackLoader = fallbackLoader; + } - static { + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws IOException, ClassNotFoundException { + Class cls; try { - VALUE = UnsafeOps.objectFieldOffset(EnumMap.class.getDeclaredField("keyType")); - } catch (final Exception e) { - throw new RuntimeException(e); + cls = super.resolveClass(desc); + } catch (ClassNotFoundException e) { + if (fallbackLoader == null) { + throw e; + } + cls = Class.forName(desc.getName(), false, fallbackLoader); } + if (enumClass == null && cls != Enum.class && Enum.class.isAssignableFrom(cls)) { + enumClass = cls; + } + return cls; } } - private JavaSerializer javaSerializer; - public EnumMapSerializer(TypeResolver typeResolver) { // getMapKeyValueType(EnumMap.class) will be `K, V` without Enum as key bound. // so no need to infer key generics in init. @@ -360,12 +378,6 @@ public EnumMapSerializer(TypeResolver typeResolver) { @Override public Map onMapWrite(WriteContext writeContext, EnumMap value) { MemoryBuffer buffer = writeContext.getBuffer(); - if (AndroidSupport.IS_ANDROID && value.isEmpty()) { - buffer.writeByte(JAVA_SERIALIZED_EMPTY_ENUM_MAP); - getJavaSerializer().write(writeContext, value); - return value; - } - buffer.writeByte(NORMAL_ENUM_MAP); buffer.writeVarUInt32Small7(value.size()); Class keyType = getKeyType(value); ((ClassResolver) typeResolver).writeClassAndUpdateCache(writeContext, keyType); @@ -375,16 +387,6 @@ public Map onMapWrite(WriteContext writeContext, EnumMap value) { @Override public EnumMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - byte payloadMode = buffer.readByte(); - if (payloadMode == JAVA_SERIALIZED_EMPTY_ENUM_MAP) { - EnumMap map = (EnumMap) getJavaSerializer().read(readContext); - setNumElements(0); - readContext.reference(map); - return map; - } - if (payloadMode != NORMAL_ENUM_MAP) { - throw new IllegalArgumentException("Unknown EnumMap payload mode: " + payloadMode); - } setNumElements(readMapSize(buffer)); Class keyType = typeResolver.readTypeInfo(readContext).getType(); EnumMap map = new EnumMap(keyType); @@ -397,20 +399,38 @@ public EnumMap copy(CopyContext copyContext, EnumMap originMap) { return new EnumMap(originMap); } - private static Class getKeyType(EnumMap value) { + private Class getKeyType(EnumMap value) { + Objects.requireNonNull(value, "value"); if (!value.isEmpty()) { Enum key = (Enum) value.keySet().iterator().next(); return key.getDeclaringClass(); } - return (Class) UnsafeOps.getObject(value, KeyTypeFieldOffset.VALUE); + try { + return keyTypeBySerialization(value, typeResolver.getClassLoader()); + } catch (IOException | ClassNotFoundException e) { + throw ExceptionUtils.throwException(e); + } } - private JavaSerializer getJavaSerializer() { - JavaSerializer javaSerializer = this.javaSerializer; - if (javaSerializer == null) { - javaSerializer = this.javaSerializer = new JavaSerializer(typeResolver, EnumMap.class); + private static Class keyTypeBySerialization(EnumMap value, ClassLoader fallbackLoader) + throws IOException, ClassNotFoundException { + // This JDK stream is local-only key-type discovery for an already-owned EnumMap; remote Fory + // payloads must keep using the normal class metadata path in newMap. + EnumMap copy = value.clone(); + copy.clear(); + ByteArrayOutputStream bytes = new ByteArrayOutputStream(128); + try (ObjectOutputStream out = new ObjectOutputStream(bytes)) { + out.writeObject(copy); + } + try (CapturingObjectInputStream in = + new CapturingObjectInputStream( + new ByteArrayInputStream(bytes.toByteArray()), fallbackLoader)) { + in.readObject(); + if (in.enumClass == null) { + throw new InvalidObjectException("Cannot determine EnumMap key type"); + } + return in.enumClass; } - return javaSerializer; } } diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java index 036a9c6dad..4a0e39c372 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java @@ -72,6 +72,32 @@ public void testBigMetaEncoding() { } } + @Test + public void testTypeDefCountIgnoresLimit() { + Fory writer = Fory.builder().withXlang(false).withMetaShare(true).build(); + Fory reader = + Fory.builder().withXlang(false).withMetaShare(true).withMaxCollectionSize(1).build(); + TypeDef typeDef = + TypeDef.buildTypeDef(writer.getTypeResolver(), TypeDefTest.TestFieldsOrderClass1.class); + + TypeDef decoded = + TypeDef.readTypeDef( + reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded())); + Assert.assertEquals(decoded, typeDef); + } + + @Test + public void testTypeDefArrayDimensionLimit() { + Fory fory = Fory.builder().withXlang(false).build(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(16); + buffer.writeByte(3 << 2); + buffer.writeVarUInt32Small7(256); + + Assert.assertThrows( + DeserializationException.class, + () -> FieldTypes.FieldType.read(buffer, fory.getTypeResolver())); + } + @Data public static class Foo1 { private int f1; diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index 2f723a8b78..7c8a4b6e89 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -53,6 +53,7 @@ import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.InsecureException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; import org.apache.fory.memory.MemoryBuffer; @@ -449,6 +450,19 @@ public void testSerializeClasses(boolean referenceTracking) { Arrays.asList(Interface1.class, Interface1.class, Interface2.class, Interface2.class)); } + @Test + public void testClassLiteralRegistration() { + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + byte[] serialized = writer.serialize(Foo.class); + + Fory reader = Fory.builder().withXlang(false).build(); + Assert.assertThrows(InsecureException.class, () -> reader.deserialize(serialized)); + + Fory registeredReader = Fory.builder().withXlang(false).build(); + registeredReader.register(Foo.class); + Assert.assertSame(registeredReader.deserialize(serialized), Foo.class); + } + @Test public void testWriteClassName() { { diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaShareContextTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaShareContextTest.java index 7eec8145cd..0280473b1b 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaShareContextTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaShareContextTest.java @@ -28,6 +28,7 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; +import org.apache.fory.exception.InsecureException; import org.apache.fory.meta.TypeDef; import org.apache.fory.test.bean.BeanA; import org.apache.fory.test.bean.BeanB; @@ -40,6 +41,12 @@ public interface InterfacePrice { int cents(); } + public static class MetaSharedPayload { + public int value; + + public MetaSharedPayload() {} + } + @Test public void testShareClassName() { Fory fory = @@ -89,6 +96,21 @@ public void testMetaSharedInterfaceDoesNotBuildInstantiatingSerializer(boolean e Assert.assertNull(typeInfo.getSerializer()); } + @Test + public void testMetaTypeDefAdmission() { + Fory writer = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(true) + .requireClassRegistration(false) + .build(); + Fory reader = Fory.builder().withXlang(false).withMetaShare(true).withCompatible(true).build(); + TypeDef typeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), MetaSharedPayload.class); + Assert.assertThrows( + InsecureException.class, () -> reader.getTypeResolver().buildMetaSharedTypeInfo(typeDef)); + } + private void checkMetaShare(Fory fory, Object o) { MetaWriteContext metaWriteContext = new MetaWriteContext(); MetaReadContext metaReadContext = new MetaReadContext(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java index 2b0505e077..97899c4e4f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java @@ -60,6 +60,23 @@ public void testBuiltInThrowableRoundTrip(Fory fory) { Assert.assertEquals(copy.getSuppressed()[1].getMessage(), "suppressed-2"); } + @Test + public void testSuppressedCountLimit() { + Fory writer = Fory.builder().withXlang(false).build(); + Fory reader = Fory.builder().withXlang(false).withMaxCollectionSize(1).build(); + RuntimeException value = new RuntimeException("outer"); + RuntimeException suppressed1 = new RuntimeException("suppressed-1"); + RuntimeException suppressed2 = new RuntimeException("suppressed-2"); + value.setStackTrace(new StackTraceElement[0]); + suppressed1.setStackTrace(new StackTraceElement[0]); + suppressed2.setStackTrace(new StackTraceElement[0]); + value.addSuppressed(suppressed1); + value.addSuppressed(suppressed2); + byte[] bytes = writer.serialize(value); + + Assert.assertThrows(ForyException.class, () -> reader.deserialize(bytes)); + } + @Test(dataProvider = "javaFory") public void testStackTraceElementRoundTrip(Fory fory) { StackTraceElement value = new Exception().getStackTrace()[0]; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/JavaSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/JavaSerializerTest.java index f69260eeb1..9afcd8b1c0 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/JavaSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/JavaSerializerTest.java @@ -21,6 +21,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InvalidClassException; import java.io.ObjectOutputStream; import java.io.ObjectStreamConstants; import java.io.Serializable; @@ -31,6 +32,7 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.memory.BigEndian; +import org.apache.fory.memory.MemoryBuffer; import org.testng.Assert; import org.testng.annotations.Test; @@ -52,6 +54,16 @@ private void readObject(java.io.ObjectInputStream s) throws Exception { } } + public static class JavaBox implements Serializable { + Object value; + + JavaBox(Object value) { + this.value = value; + } + } + + public static class NestedValue implements Serializable {} + @Test public void testWriteObject() { Fory fory = @@ -85,4 +97,16 @@ public void testJdkSerializationCopy(Fory fory) throws MalformedURLException { fory.registerSerializer(URL.class, JavaSerializer.class); copyCheck(fory, url); } + + @Test + public void testJdkStreamChecksNestedClass() { + Fory fory = Fory.builder().withXlang(false).build(); + Serializer serializer = new JavaSerializer(fory.getTypeResolver(), JavaBox.class); + fory.registerSerializer(JavaBox.class, serializer); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(128); + writeSerializer(fory, serializer, buffer, new JavaBox(new NestedValue())); + + Assert.assertThrows( + InvalidClassException.class, () -> readSerializer(fory, serializer, buffer)); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java index 01f664a732..b3ebe2724e 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java @@ -88,6 +88,54 @@ public void testJdkProxyInterfaceClassHonorsTypeCheckerFalse() { assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); } + @Test + public void testJdkProxyStrictInterfaces() { + Fory fory = Fory.builder().withXlang(false).requireClassRegistration(true).build(); + fory.register(TestInvocationHandler.class); + Function function = + (Function) + Proxy.newProxyInstance( + fory.getClassLoader(), + new Class[] {Function.class, Serializable.class}, + new TestInvocationHandler()); + + Function deserializedFunction = (Function) fory.deserialize(fory.serialize(function)); + assertEquals(deserializedFunction.apply(null), 1); + } + + @Test + public void testJdkProxyStrictNoDefaultInterface() { + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + TestInterface function = + (TestInterface) + Proxy.newProxyInstance( + writer.getClassLoader(), + new Class[] {TestInterface.class}, + new TestInvocationHandler()); + byte[] bytes = writer.serialize(function); + + Fory reader = Fory.builder().withXlang(false).requireClassRegistration(true).build(); + reader.register(TestInvocationHandler.class); + TestInterface deserializedFunction = (TestInterface) reader.deserialize(bytes); + assertEquals(deserializedFunction.test(), 1); + } + + @Test + public void testJdkProxyStrictRejectsDefaultInterface() { + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + TestDefaultInterface function = + (TestDefaultInterface) + Proxy.newProxyInstance( + writer.getClassLoader(), + new Class[] {TestDefaultInterface.class}, + new TestInvocationHandler()); + byte[] bytes = writer.serialize(function); + + Fory reader = Fory.builder().withXlang(false).requireClassRegistration(true).build(); + reader.register(TestInvocationHandler.class); + assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); + } + @Test(dataProvider = "foryCopyConfig") public void testJdkProxy(Fory fory) { Function function = @@ -206,7 +254,13 @@ public void testSerializeProxyWriteReplace() { } interface TestInterface { - void test(); + int test(); + } + + interface TestDefaultInterface { + default int test() { + return 1; + } } static class ProxyFactory { diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java index 099540b6fc..a508d0843c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java @@ -33,6 +33,9 @@ import java.util.function.Function; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.InsecureException; +import org.testng.Assert; import org.testng.annotations.Test; @SuppressWarnings("unchecked") @@ -105,6 +108,33 @@ public void testSerializedLambda(Fory fory) throws Exception { assertEquals(newFunc.apply(10), Integer.valueOf(17)); } + @Test + public void testSerializedLambdaAdmission() throws Exception { + int delta = 7; + Function function = + (Serializable & Function) (x) -> x + delta; + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + Fory reader = Fory.builder().withXlang(false).build(); + byte[] bytes = writer.serialize(extractSerializedLambda(function)); + Assert.assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); + } + + @Test + public void testSerializedLambdaArgLimit() throws Exception { + int delta = 7; + Function function = + (Serializable & Function) (x) -> x + delta; + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + Fory reader = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withMaxCollectionSize(0) + .build(); + byte[] bytes = writer.serialize(extractSerializedLambda(function)); + Assert.assertThrows(DeserializationException.class, () -> reader.deserialize(bytes)); + } + @Test(dataProvider = "foryCopyConfig") public void testSerializedLambdaCopy(Fory fory) throws Exception { int delta = 7; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java index d623d609c7..66a93e71ca 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java @@ -48,6 +48,7 @@ import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; +import org.apache.fory.exception.InsecureException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.SharedRegistry; import org.apache.fory.serializer.collection.CollectionSerializers; @@ -1029,6 +1030,54 @@ public void testHierarchyMixedSerialization(boolean compatible) { assertEquals(result.childValue, 42); } + @Test + public void testObjectStreamExpectedParentLayer() { + Fory writerFory = + Fory.builder().withXlang(false).withRefTracking(true).withCodegen(false).build(); + writerFory.register(HierarchyChildDefault.class); + writerFory.registerSerializer( + HierarchyChildDefault.class, + new ObjectStreamSerializer(writerFory.getTypeResolver(), HierarchyChildDefault.class)); + + Fory readerFory = + Fory.builder().withXlang(false).withRefTracking(true).withCodegen(false).build(); + readerFory.register(HierarchyChildDefault.class); + readerFory.registerSerializer( + HierarchyChildDefault.class, + new ObjectStreamSerializer(readerFory.getTypeResolver(), HierarchyChildDefault.class)); + + HierarchyChildDefault obj = new HierarchyChildDefault("parent", "child", 42); + HierarchyChildDefault result = + (HierarchyChildDefault) readerFory.deserialize(writerFory.serialize(obj)); + assertEquals(result.parentData, "parent"); + assertEquals(result.childData, "child"); + assertEquals(result.childValue, 42); + } + + @Test + public void testObjectStreamRejectsParentRoot() { + Fory writerFory = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withRefTracking(true) + .withCodegen(false) + .build(); + writerFory.registerSerializer( + HierarchyParentPutFields.class, + new ObjectStreamSerializer(writerFory.getTypeResolver(), HierarchyParentPutFields.class)); + + Fory readerFory = + Fory.builder().withXlang(false).withRefTracking(true).withCodegen(false).build(); + readerFory.register(HierarchyChildDefault.class); + readerFory.registerSerializer( + HierarchyChildDefault.class, + new ObjectStreamSerializer(readerFory.getTypeResolver(), HierarchyChildDefault.class)); + + byte[] bytes = writerFory.serialize(new HierarchyParentPutFields("parent")); + Assert.assertThrows(InsecureException.class, () -> readerFory.deserialize(bytes)); + } + // ==================== Cross-Fory Instance Schema Tests ==================== @Test(dataProvider = "compatibleModeProvider") diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ReplaceResolveSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ReplaceResolveSerializerTest.java index 027f006174..143a3e983e 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ReplaceResolveSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ReplaceResolveSerializerTest.java @@ -20,6 +20,7 @@ package org.apache.fory.serializer; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotSame; import static org.testng.Assert.assertSame; import static org.testng.Assert.assertThrows; @@ -40,6 +41,7 @@ import lombok.Data; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.exception.InsecureException; import org.apache.fory.util.Preconditions; import org.testng.annotations.Test; @@ -617,6 +619,56 @@ public void testWriteReplaceExternalizable() { assertEquals(o.f1, 10); } + static class ReplaceProtectedExternalizable implements Externalizable { + static boolean readExternalCalled; + private int f1; + + public ReplaceProtectedExternalizable() {} + + ReplaceProtectedExternalizable(int f1) { + this.f1 = f1; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(f1); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + readExternalCalled = true; + f1 = in.readInt(); + } + + private Object writeReplace() { + return this; + } + } + + @Test + public void testRejectExternalizableReplace() { + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + byte[] bytes = writer.serialize(new ReplaceProtectedExternalizable(10)); + ReplaceProtectedExternalizable.readExternalCalled = false; + + Fory reader = Fory.builder().withXlang(false).build(); + assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); + assertFalse(ReplaceProtectedExternalizable.readExternalCalled); + } + + @Test + public void testRegisteredExternalizableReplace() { + Fory writer = Fory.builder().withXlang(false).requireClassRegistration(false).build(); + byte[] bytes = writer.serialize(new ReplaceProtectedExternalizable(10)); + ReplaceProtectedExternalizable.readExternalCalled = false; + + Fory reader = Fory.builder().withXlang(false).build(); + reader.register(ReplaceProtectedExternalizable.class); + ReplaceProtectedExternalizable o = (ReplaceProtectedExternalizable) reader.deserialize(bytes); + assertTrue(ReplaceProtectedExternalizable.readExternalCalled); + assertEquals(o.f1, 10); + } + static class ReplaceSelfExternalizable implements Externalizable { private transient int f1; private transient boolean newInstance; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java index f25ed88c6c..89bb3e3909 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java @@ -24,6 +24,8 @@ import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; +import java.io.Externalizable; +import java.io.Serializable; import java.math.BigDecimal; import java.math.BigInteger; import java.math.MathContext; @@ -31,18 +33,25 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Collection; import java.util.Currency; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.regex.Pattern; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.config.ForyBuilder; import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.InsecureException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.testng.Assert; @@ -277,6 +286,14 @@ public void testUUID() { private static class TestClassSerialization {} + private interface TestClassTokenInterface { + void test(); + } + + private interface TestDefaultClassTokenInterface { + default void test() {} + } + private static class TestReplaceClassSerialization { private Object writeReplace() { return 1; @@ -300,6 +317,22 @@ public void testSerializeClass() { serDe(fory, new TestReplaceClassSerialization()); } + @Test + public void testDefaultSafeClassTokens() { + Fory fory = Fory.builder().withXlang(false).requireClassRegistration(true).build(); + assertSame(serDe(fory, Serializable.class), Serializable.class); + assertSame(serDe(fory, Externalizable.class), Externalizable.class); + assertSame(serDe(fory, Function.class), Function.class); + assertSame(serDe(fory, Collection.class), Collection.class); + assertSame(serDe(fory, List.class), List.class); + assertSame(serDe(fory, Set.class), Set.class); + assertSame(serDe(fory, Map.class), Map.class); + assertSame(serDe(fory, SortedMap.class), SortedMap.class); + assertSame(serDe(fory, SortedSet.class), SortedSet.class); + assertSame(serDe(fory, TestClassTokenInterface.class), TestClassTokenInterface.class); + assertThrows(InsecureException.class, () -> serDe(fory, TestDefaultClassTokenInterface.class)); + } + @Test public void testEmptyObject() { Fory fory = Fory.builder().withXlang(false).requireClassRegistration(true).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java index 95c41e6714..d45f3271a4 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java @@ -33,6 +33,7 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.collection.Tuple2; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.platform.JdkVersion; @@ -168,6 +169,18 @@ public void testJavaStringSimple() { } } + @Test + public void testStringSizeLimit() { + Fory writer = Fory.builder().withXlang(false).build(); + Fory reader = Fory.builder().withXlang(false).withMaxBinarySize(2).build(); + MemoryBuffer buffer = MemoryUtils.buffer(32); + new StringSerializer(writer.getConfig()).writeString(buffer, "abcd"); + + Assert.assertThrows( + DeserializationException.class, + () -> new StringSerializer(reader.getConfig()).readString(buffer)); + } + @Data public static class Simple { private String str; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java index 2db9d5c842..25c3b61866 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java @@ -716,6 +716,33 @@ public void testEnumMap() { MapSerializers.EnumMapSerializer.class); } + @Test + public void testEmptyEnumMap() { + Fory fory = getJavaFory(); + EnumMap enumMap = new EnumMap<>(TestEnum.class); + Serializer serializer = fory.getSerializer(EnumMap.class); + MemoryBuffer buffer = MemoryUtils.buffer(64); + writeSerializer(fory, serializer, buffer, enumMap); + Assert.assertEquals(buffer.getByte(0), (byte) 0); + Assert.assertEquals(buffer.readerIndex(), 0); + EnumMap restored = readSerializer(fory, serializer, buffer); + Assert.assertEquals(restored, enumMap); + restored.put(TestEnum.A, "value"); + Assert.assertEquals(restored.get(TestEnum.A), "value"); + } + + @Test + public void testEnumMapHasNoPayloadMode() { + Fory fory = getJavaFory(); + Serializer serializer = fory.getSerializer(EnumMap.class); + EnumMap enumMap = new EnumMap<>(TestEnum.class); + enumMap.put(TestEnum.A, "value"); + MemoryBuffer buffer = MemoryUtils.buffer(64); + writeSerializer(fory, serializer, buffer, enumMap); + Assert.assertEquals(buffer.getByte(0), (byte) 1); + Assert.assertEquals(readSerializer(fory, serializer, buffer), enumMap); + } + @Test(dataProvider = "foryCopyConfig") public void testEnumMap(Fory fory) { EnumMap enumMap = new EnumMap<>(TestEnum.class);