Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ jobs:
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-maven-
- uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18
- uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24
- name: Install fory java
run: cd java && mvn -T10 --no-transfer-progress clean install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true && cd -
- name: Test
Expand Down Expand Up @@ -622,7 +622,7 @@ jobs:
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-maven-
- uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18
- uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24
- name: Run Scala Xlang Test
env:
FORY_SCALA_JAVA_CI: "1"
Expand Down Expand Up @@ -655,7 +655,7 @@ jobs:
key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ runner.os }}-maven-
- uses: sbt/setup-sbt@1cad58d595b729a71ca2254cdf5b43dd6f42d4bb # v1.1.18
- uses: sbt/setup-sbt@2e222825582620cc38d2a54e674f3c01b7c14f5d # v1.1.24
- name: Install Fory Java
run: |
cd java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ public MemoryBuffer readBufferObject() {
if (size < 0) {
throw new IllegalArgumentException("Buffer object size must be non-negative: " + size);
}
// This returns a zero-copy slice. Allocation limits belong to serializers which allocate
// objects from the slice, not to the buffer-object transport itself.
buffer.checkReadableBytes(size);
int readerIndex = buffer.readerIndex();
MemoryBuffer slice = buffer.slice(readerIndex, size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.io.StreamCorruptedException;
import java.lang.reflect.Proxy;
import org.apache.fory.resolver.TypeResolver;

// Derived from
// https://github.com/apache/commons-io/blob/5168fa5e9de9dd2ff6ace3f34226397a4faebc14/src/main/java/org/apache/commons/io/input/ClassLoaderObjectInputStream.java.
Expand All @@ -38,6 +40,8 @@ public class ClassLoaderObjectInputStream extends ObjectInputStream {
/** The class loader to use. */
private final ClassLoader classLoader;

private final TypeResolver typeResolver;

/**
* Constructs a new ClassLoaderObjectInputStream.
*
Expand All @@ -50,6 +54,14 @@ public ClassLoaderObjectInputStream(ClassLoader classLoader, InputStream inputSt
throws IOException, StreamCorruptedException {
super(inputStream);
this.classLoader = classLoader;
typeResolver = null;
}

public ClassLoaderObjectInputStream(TypeResolver typeResolver, InputStream inputStream)
throws IOException, StreamCorruptedException {
super(inputStream);
this.typeResolver = typeResolver;
classLoader = typeResolver.getClassLoader();
}

/**
Expand All @@ -67,10 +79,13 @@ protected Class<?> resolveClass(ObjectStreamClass objectStreamClass)
Class<?> clazz = Class.forName(objectStreamClass.getName(), false, classLoader);
if (clazz != null) {
// the classloader knows of the class
checkClass(clazz);
return clazz;
} else {
// classloader knows not of class, let the super classloader do it
return super.resolveClass(objectStreamClass);
Class<?> superClass = super.resolveClass(objectStreamClass);
checkClass(superClass);
return superClass;
}
}

Expand All @@ -91,11 +106,27 @@ protected Class<?> resolveProxyClass(String[] interfaces)
Class<?>[] interfaceClasses = new Class[interfaces.length];
for (int i = 0; i < interfaces.length; i++) {
interfaceClasses[i] = Class.forName(interfaces[i], false, classLoader);
checkClass(interfaceClasses[i]);
}
try {
return Proxy.getProxyClass(classLoader, interfaceClasses);
} catch (IllegalArgumentException e) {
return super.resolveProxyClass(interfaces);
Class<?> proxyClass = super.resolveProxyClass(interfaces);
checkClass(proxyClass);
return proxyClass;
}
}

private void checkClass(Class<?> cls) throws InvalidClassException {
if (typeResolver == null) {
return;
}
try {
typeResolver.checkClassForDeserialization(cls);
} catch (RuntimeException e) {
InvalidClassException exception = new InvalidClassException(cls.getName(), e.getMessage());
exception.initCause(e);
throw exception;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.fory.collection.UInt32List;
import org.apache.fory.collection.UInt64List;
import org.apache.fory.collection.UInt8List;
import org.apache.fory.exception.DeserializationException;
import org.apache.fory.logging.Logger;
import org.apache.fory.logging.LoggerFactory;
import org.apache.fory.memory.MemoryBuffer;
Expand All @@ -67,6 +68,7 @@

public class FieldTypes {
private static final Logger LOG = LoggerFactory.getLogger(FieldTypes.class);
private static final int MAX_ARRAY_DIMS = 255;

/** Returns true if can use current field type. */
static boolean useFieldType(Class<?> parsedType, Descriptor descriptor) {
Expand Down Expand Up @@ -525,6 +527,9 @@ public static FieldType read(
return new CollectionFieldType(-1, nullable, trackingRef, read(buffer, resolver));
} else if (kind == 3) {
int dims = buffer.readVarUInt32Small7();
if (dims <= 0 || dims > MAX_ARRAY_DIMS) {
throw new DeserializationException("Invalid array dimensions in TypeDef: " + dims);
}
return new ArrayFieldType(-1, nullable, trackingRef, read(buffer, resolver), dims);
} else if (kind == 4) {
return new EnumFieldType(nullable, -1, -1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
int rootTypeId = nativeTypeId(bodyHeader >>> 4);
int numClasses = bodyHeader & NUM_CLASS_THRESHOLD;
if (numClasses == NUM_CLASS_THRESHOLD) {
numClasses += typeDefBuf.readVarUInt32Small7();
int extraClasses = typeDefBuf.readVarUInt32Small7();
if (extraClasses < 0 || extraClasses > Integer.MAX_VALUE - NUM_CLASS_THRESHOLD - 1) {
throw new DeserializationException("Invalid TypeDef class count");
}
numClasses += extraClasses;
}
numClasses += 1;
String className;
Expand All @@ -95,6 +99,9 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
// | num fields + register flag | header + package name | header + class name
// | header + type id + field name | next field info | ... |
int currentClassHeader = typeDefBuf.readVarUInt32Small7();
if (currentClassHeader < 0) {
throw new DeserializationException("Invalid TypeDef field count");
}
boolean isRegistered = (currentClassHeader & 0b1) != 0;
int numFields = currentClassHeader >>> 1;
Class<?> currentClass = null;
Expand Down Expand Up @@ -269,7 +276,7 @@ static void validateParsedTypeDefHash(long id, byte[] encoded) {

private static List<FieldInfo> readFieldsInfo(
MemoryBuffer buffer, ClassResolver resolver, String className, int numFields) {
List<FieldInfo> fieldInfos = new ArrayList<>(numFields);
List<FieldInfo> fieldInfos = new ArrayList<>();
for (int i = 0; i < numFields; i++) {
int header = buffer.readByte() & 0xff;
// `3 bits size + 2 bits field name encoding + nullability flag + ref tracking flag`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu
}
numFields = header & SMALL_NUM_FIELDS_THRESHOLD;
if (numFields == SMALL_NUM_FIELDS_THRESHOLD) {
numFields += buffer.readVarUInt32Small7();
int extraFields = buffer.readVarUInt32Small7();
if (extraFields < 0 || extraFields > Integer.MAX_VALUE - SMALL_NUM_FIELDS_THRESHOLD) {
throw new DeserializationException("Invalid TypeDef field count");
}
numFields += extraFields;
}
if (named) {
String namespace = readPkgName(buffer);
Expand Down Expand Up @@ -201,7 +205,7 @@ static int nonStructTypeId(int kindCode) {
// | header + type info + field name | ... | header + type info + field name |
private static List<FieldInfo> readFieldsInfo(
MemoryBuffer buffer, XtypeResolver resolver, String className, int numFields) {
List<FieldInfo> fieldInfos = new ArrayList<>(numFields);
List<FieldInfo> fieldInfos = new ArrayList<>();
for (int i = 0; i < numFields; i++) {
// header: 2 bits field name encoding + 4 bits size + nullability flag + ref tracking flag
byte header = buffer.readByte();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.IOException;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
Expand All @@ -57,6 +58,9 @@
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TimeZone;
import java.util.TreeMap;
import java.util.TreeSet;
Expand Down Expand Up @@ -1908,13 +1912,39 @@ private boolean isSecure(Class<?> cls) {
if (config.requireClassRegistration()) {
return Functions.isLambda(cls)
|| ReflectionUtils.isJdkProxy(cls)
|| isDefaultSafeClassToken(cls)
|| extRegistry.registeredClassIdMap.get(cls) != null
|| shimDispatcher.contains(cls);
} else {
return extRegistry.typeChecker.checkType(this, cls.getName());
}
}

private static boolean isDefaultSafeClassToken(Class<?> cls) {
return cls == Serializable.class || cls == Externalizable.class || isDefaultSafeInterface(cls);
}

private static boolean isDefaultSafeInterface(Class<?> cls) {
return cls.isInterface()
&& (cls == Collection.class
|| cls == List.class
|| cls == Set.class
|| cls == Map.class
|| cls == SortedMap.class
|| cls == SortedSet.class
|| cls.getName().startsWith("java.util.function.")
|| !hasDefaultMethods(cls));
}

private static boolean hasDefaultMethods(Class<?> cls) {
for (Method method : cls.getMethods()) {
if (method.isDefault()) {
return true;
}
}
return false;
}

/**
* Write class info to <code>buffer</code>. TODO(chaokunyang): The method should try to write
* aligned data to reduce cpu instruction overhead. `writeTypeInfo` is the last step before
Expand Down Expand Up @@ -2025,11 +2055,22 @@ private TypeInfo buildClassInfo(Class<?> cls) {
}

/**
* Read serialized java classname. Note that the object of the class can be non-serializable. For
* serializable object, {@link #readTypeInfo(ReadContext)} or {@link #readTypeInfo(ReadContext,
* TypeInfoHolder)} should be invoked.
* Read a serialized Java class token.
*
* <p>For named-class tokens, this method enforces deserialization class policy before returning
* the class, including the disallowed list and registration or TypeChecker checks. For registered
* type-id tokens, it returns the registered class whose admission was already checked during
* registration.
*
* <p>Note that the object of the class can be non-serializable. For serializable object, {@link
* #readTypeInfo(ReadContext)} or {@link #readTypeInfo(ReadContext, TypeInfoHolder)} should be
* invoked.
*/
public Class<?> readClassInternal(ReadContext readContext) {
return readClassInternal(readContext, true);
}

private Class<?> readClassInternal(ReadContext readContext, boolean checkNamedClass) {
MemoryBuffer buffer = readContext.getBuffer();
int header = buffer.readVarUInt32Small14();
if ((header & 0b1) != 0) {
Expand All @@ -2038,7 +2079,7 @@ public Class<?> readClassInternal(ReadContext readContext) {
MetaStringReader metaStringReader = readContext.getMetaStringReader();
EncodedMetaString packageBytes = metaStringReader.readMetaStringWithFlag(buffer, header);
EncodedMetaString simpleClassNameBytes = metaStringReader.readMetaString(buffer);
return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes).type;
return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes, checkNamedClass).type;
}
int typeId = header >>> 1;
switch (typeId) {
Expand All @@ -2055,6 +2096,11 @@ public Class<?> readClassInternal(ReadContext readContext) {
}
}

@Internal
public Class<?> readClassInternalUnchecked(ReadContext readContext) {
return readClassInternal(readContext, false);
}

private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTypeId) {
TypeInfo typeInfo;
if (userTypeId != INVALID_USER_TYPE_ID) {
Expand All @@ -2069,10 +2115,16 @@ private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTyp
@Override
protected TypeInfo loadBytesToTypeInfo(
EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes) {
return loadBytesToTypeInfo(packageBytes, simpleClassNameBytes, true);
}

private TypeInfo loadBytesToTypeInfo(
EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes, boolean checkClass) {
TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes, simpleClassNameBytes);
TypeInfo typeInfo = compositeNameBytes2TypeInfo.get(typeNameBytes);
if (typeInfo == null) {
typeInfo = populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes);
typeInfo =
populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes, checkClass);
}
// Note: Don't create serializer here - this method is used by both readTypeInfo
// (which needs serializer) and readClassInternal (which doesn't need serializer).
Expand Down Expand Up @@ -2102,28 +2154,47 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) {
private TypeInfo populateBytesToTypeInfo(
TypeNameBytes typeNameBytes,
EncodedMetaString packageBytes,
EncodedMetaString simpleClassNameBytes) {
EncodedMetaString simpleClassNameBytes,
boolean checkClass) {
String packageName = packageBytes.decode(PACKAGE_DECODER);
String className = simpleClassNameBytes.decode(TYPE_NAME_DECODER);
ClassSpec classSpec = Encoders.decodePkgAndClass(packageName, className);
Class<?> cls = loadClass(classSpec.entireClassName, classSpec.isEnum, classSpec.dimension);
boolean unknownClass = UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls));
if (checkClass && !unknownClass) {
checkClassForDeserialization(cls);
}
int typeId = buildUnregisteredTypeId(cls, null);
TypeInfo typeInfo =
new TypeInfo(cls, packageBytes, simpleClassNameBytes, null, typeId, INVALID_USER_TYPE_ID);
if (UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) {
if (unknownClass) {
typeInfo.serializer =
UnknownClassSerializers.getSerializer(this, classSpec.entireClassName, cls);
} else {
} else if (checkClass) {
// don't create serializer here, if the class is an interface,
// there won't be serializer since interface has no instance.
if (!classInfoMap.containsKey(cls)) {
classInfoMap.put(cls, typeInfo);
}
}
compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo);
if (checkClass) {
compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo);
}
return typeInfo;
}

@Internal
@Override
public void checkClassForDeserialization(Class<?> cls) {
if (UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) {
return;
}
DisallowedList.checkNotInDisallowedList(cls.getName());
if (!isSecure(cls)) {
throw new InsecureException(generateSecurityMsg(cls));
}
}

public Class<?> loadClassForMeta(String className, boolean isEnum, int arrayDims) {
String pkg = ReflectionUtils.getPackage(className);
String typeName = ReflectionUtils.getClassNameWithoutPackage(className);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ public final ClassLoader getClassLoader() {
return extRegistry.classLoader;
}

@Internal
public void checkClassForDeserialization(Class<?> cls) {}

public final SharedRegistry getSharedRegistry() {
return sharedRegistry;
}
Expand Down Expand Up @@ -1011,6 +1014,8 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) {
return typeInfo;
}
Class<?> cls = loadClass(typeDef.getClassSpec());
// A wire TypeDef may create a compatible serializer; admit the class before caching it by id.
checkClassForDeserialization(cls);
if (!typeDef.isStructSchemaKind()
&& !UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) {
typeInfo = getTypeInfo(cls);
Expand Down
Loading