Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Scala] support scala collection jit serialization #1077

Merged
merged 2 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@
import io.fury.serializer.Serializers;
import io.fury.serializer.StringSerializer;
import io.fury.serializer.collection.AbstractCollectionSerializer;
import io.fury.serializer.collection.CollectionSerializer;
import io.fury.serializer.collection.MapSerializer;
import io.fury.serializer.collection.AbstractMapSerializer;
import io.fury.type.ScalaTypes;
import io.fury.type.TypeUtils;
import io.fury.util.ReflectionUtils;
import io.fury.util.StringUtils;
Expand Down Expand Up @@ -114,8 +114,8 @@ public abstract class BaseObjectCodecBuilder extends CodecBuilder {
TypeToken.of(StringSerializer.class);
private static final TypeToken<?> SERIALIZER_TYPE = TypeToken.of(Serializer.class);
private static final TypeToken<?> COLLECTION_SERIALIZER_TYPE =
TypeToken.of(CollectionSerializer.class);
private static final TypeToken<?> MAP_SERIALIZER_TYPE = TypeToken.of(MapSerializer.class);
TypeToken.of(AbstractCollectionSerializer.class);
private static final TypeToken<?> MAP_SERIALIZER_TYPE = TypeToken.of(AbstractMapSerializer.class);

protected final Reference refResolverRef;
protected final Reference classResolverRef =
Expand Down Expand Up @@ -281,7 +281,7 @@ protected void addCommonImports() {
ctx.addImports(LazyInitBeanSerializer.class, Serializers.EnumSerializer.class);
ctx.addImports(Serializer.class, StringSerializer.class);
ctx.addImports(ObjectSerializer.class, CompatibleSerializer.class);
ctx.addImports(CollectionSerializer.class, MapSerializer.class, ObjectSerializer.class);
ctx.addImports(AbstractCollectionSerializer.class, AbstractMapSerializer.class);
}

protected Expression serializeFor(
Expand Down Expand Up @@ -402,11 +402,16 @@ private Expression serializeForNotNull(
}

protected boolean useCollectionSerialization(TypeToken<?> typeToken) {
return COLLECTION_TYPE.isSupertypeOf(typeToken);
return COLLECTION_TYPE.isSupertypeOf(typeToken)
|| (fury.getConfig().isScalaOptimizationEnabled()
&& (!ScalaTypes.getScalaMapType().isAssignableFrom(typeToken.getRawType())
&& ScalaTypes.getScalaIterableType().isAssignableFrom(typeToken.getRawType())));
}

protected boolean useMapSerialization(TypeToken<?> typeToken) {
return MAP_TYPE.isSupertypeOf(typeToken);
return MAP_TYPE.isSupertypeOf(typeToken)
|| (fury.getConfig().isScalaOptimizationEnabled()
&& ScalaTypes.getScalaMapType().isAssignableFrom(typeToken.getRawType()));
}

/**
Expand Down Expand Up @@ -639,7 +644,7 @@ protected Expression castSerializer(Expression serializer, TypeToken<?> objType)
serializer = new Cast(serializer, COLLECTION_SERIALIZER_TYPE, "colSerializer");
} else if (MAP_TYPE.isSupertypeOf(objType)
&& !MAP_SERIALIZER_TYPE.isSupertypeOf(serializer.type())) {
serializer = new Cast(serializer, TypeToken.of(MapSerializer.class), "mapSerializer");
serializer = new Cast(serializer, TypeToken.of(AbstractMapSerializer.class), "mapSerializer");
}
return serializer;
}
Expand Down Expand Up @@ -677,7 +682,7 @@ protected Expression serializeForCollection(
writeClassAction.add(
fury.getClassResolver().writeClassExpr(classResolverRef, buffer, classInfo));
serializer = new Invoke(classInfo, "getSerializer", "serializer", SERIALIZER_TYPE, false);
serializer = new Cast(serializer, TypeToken.of(CollectionSerializer.class));
serializer = new Cast(serializer, TypeToken.of(AbstractCollectionSerializer.class));
writeClassAction.add(serializer, new Return(serializer));
// Spit this into a separate method to avoid method too big to inline.
serializer =
Expand All @@ -688,8 +693,9 @@ protected Expression serializeForCollection(
"writeCollectionClassInfo",
false);
}
} else if (!TypeToken.of(CollectionSerializer.class).isSupertypeOf(serializer.type())) {
serializer = new Cast(serializer, TypeToken.of(CollectionSerializer.class), "colSerializer");
} else if (!TypeToken.of(AbstractCollectionSerializer.class).isSupertypeOf(serializer.type())) {
serializer =
new Cast(serializer, TypeToken.of(AbstractCollectionSerializer.class), "colSerializer");
}
// write collection data.
ListExpression actions = new ListExpression();
Expand Down Expand Up @@ -802,7 +808,7 @@ protected Expression writeCollectionData(

/**
* Write collection elements header: flags and maybe elements classinfo. Keep this consistent with
* `CollectionSerializer#writeElementsHeader`.
* `AbstractCollectionSerializer#writeElementsHeader`.
*
* @return Tuple(flags, Nullable ( element serializer))
*/
Expand Down Expand Up @@ -984,15 +990,15 @@ protected Expression serializeForMap(
writeClassAction.add(
fury.getClassResolver().writeClassExpr(classResolverRef, buffer, classInfo));
serializer = new Invoke(classInfo, "getSerializer", "serializer", SERIALIZER_TYPE, false);
serializer = new Cast(serializer, TypeToken.of(MapSerializer.class));
serializer = new Cast(serializer, TypeToken.of(AbstractMapSerializer.class));
writeClassAction.add(serializer, new Return(serializer));
// Spit this into a separate method to avoid method too big to inline.
serializer =
invokeGenerated(
ctx, ImmutableSet.of(buffer, map), writeClassAction, "writeMapClassInfo", false);
}
} else if (!MapSerializer.class.isAssignableFrom(serializer.type().getRawType())) {
serializer = new Cast(serializer, TypeToken.of(MapSerializer.class), "mapSerializer");
} else if (!AbstractMapSerializer.class.isAssignableFrom(serializer.type().getRawType())) {
serializer = new Cast(serializer, TypeToken.of(AbstractMapSerializer.class), "mapSerializer");
}
Expression write =
new If(
Expand Down Expand Up @@ -1200,12 +1206,15 @@ protected Expression deserializeForCollection(
Expression classInfo = readClassInfo(cls, buffer);
serializer = new Invoke(classInfo, "getSerializer", "serializer", SERIALIZER_TYPE, false);
serializer =
new Cast(serializer, TypeToken.of(CollectionSerializer.class), "collectionSerializer");
new Cast(
serializer,
TypeToken.of(AbstractCollectionSerializer.class),
"collectionSerializer");
}
} else {
checkArgument(
CollectionSerializer.class.isAssignableFrom(serializer.type().getRawType()),
"Expected CollectionSerializer but got %s",
AbstractCollectionSerializer.class.isAssignableFrom(serializer.type().getRawType()),
"Expected AbstractCollectionSerializer but got %s",
serializer.type());
}
Invoke supportHook = inlineInvoke(serializer, "supportCodegenHook", PRIMITIVE_BOOLEAN_TYPE);
Expand All @@ -1214,12 +1223,12 @@ protected Expression deserializeForCollection(
// if add branch by `ArrayList`, generated code will be > 325 bytes.
// and List#add is more likely be inlined if there is only one subclass.
Expression hookRead = readCollectionCodegen(buffer, collection, size, elementType);
hookRead = new Invoke(serializer, "onCollectionRead", COLLECTION_TYPE, hookRead);
hookRead = new Invoke(serializer, "onCollectionRead", OBJECT_TYPE, hookRead);
Expression action =
new If(
supportHook,
new ListExpression(collection, hookRead),
new Invoke(serializer, "read", COLLECTION_TYPE, buffer),
new Invoke(serializer, "read", OBJECT_TYPE, buffer),
false);
if (cutPoint != null && cutPoint.genNewMethod) {
cutPoint.add(buffer);
Expand Down Expand Up @@ -1428,12 +1437,13 @@ protected Expression deserializeForMap(
} else {
Expression classInfo = readClassInfo(cls, buffer);
serializer = new Invoke(classInfo, "getSerializer", SERIALIZER_TYPE);
serializer = new Cast(serializer, TypeToken.of(MapSerializer.class), "mapSerializer");
serializer =
new Cast(serializer, TypeToken.of(AbstractMapSerializer.class), "mapSerializer");
}
} else {
checkArgument(
MapSerializer.class.isAssignableFrom(serializer.type().getRawType()),
"Expected MapSerializer but got %s",
AbstractMapSerializer.class.isAssignableFrom(serializer.type().getRawType()),
"Expected AbstractMapSerializer but got %s",
serializer.type());
}
Invoke supportHook = inlineInvoke(serializer, "supportCodegenHook", PRIMITIVE_BOOLEAN_TYPE);
Expand Down Expand Up @@ -1466,9 +1476,9 @@ protected Expression deserializeForMap(
});
// first newMap to create map, last newMap as expr value
Expression hookRead = new ListExpression(newMap, size, readKeyValues, newMap);
hookRead = new Invoke(serializer, "onMapRead", MAP_TYPE, hookRead);
hookRead = new Invoke(serializer, "onMapRead", OBJECT_TYPE, hookRead);
Expression action =
new If(supportHook, hookRead, new Invoke(serializer, "read", MAP_TYPE, buffer), false);
new If(supportHook, hookRead, new Invoke(serializer, "read", OBJECT_TYPE, buffer), false);
if (cutPoint != null && cutPoint.genNewMethod) {
cutPoint.add(buffer);
return invokeGenerated(
Expand Down
76 changes: 76 additions & 0 deletions java/fury-core/src/main/java/io/fury/type/ScalaTypes.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 The Fury Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.fury.type;

import com.google.common.reflect.TypeToken;
import io.fury.collection.Tuple2;
import io.fury.util.ReflectionUtils;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;

/**
* Scala types utils using reflection without dependency on scala library.
*
* @author chaokunyang
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public class ScalaTypes {
private static final Class<?> SCALA_MAP_TYPE;
private static final Class<?> SCALA_SEQ_TYPE;
private static final Class<?> SCALA_ITERABLE_TYPE;
private static final Class<?> SCALA_ITERATOR_TYPE;
private static final java.lang.reflect.Type SCALA_ITERATOR_RETURN_TYPE;
private static final java.lang.reflect.Type SCALA_NEXT_RETURN_TYPE;

static {
try {
SCALA_ITERABLE_TYPE = ReflectionUtils.loadClass("scala.collection.Iterable");
SCALA_ITERATOR_TYPE = ReflectionUtils.loadClass("scala.collection.Iterator");
SCALA_MAP_TYPE = ReflectionUtils.loadClass("scala.collection.Map");
SCALA_SEQ_TYPE = ReflectionUtils.loadClass("scala.collection.Seq");
SCALA_ITERATOR_RETURN_TYPE = SCALA_ITERABLE_TYPE.getMethod("iterator").getGenericReturnType();
SCALA_NEXT_RETURN_TYPE = SCALA_ITERATOR_TYPE.getMethod("next").getGenericReturnType();
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}

public static Class<?> getScalaMapType() {
return SCALA_MAP_TYPE;
}

public static Class<?> getScalaSeqType() {
return SCALA_SEQ_TYPE;
}

public static Class<?> getScalaIterableType() {
return SCALA_ITERABLE_TYPE;
}

public static TypeToken<?> getElementType(TypeToken typeToken) {
TypeToken<?> supertype = typeToken.getSupertype(getScalaIterableType());
return supertype.resolveType(SCALA_ITERATOR_RETURN_TYPE).resolveType(SCALA_NEXT_RETURN_TYPE);
}

/** Returns key/value type of scala map. */
public static Tuple2<TypeToken<?>, TypeToken<?>> getMapKeyValueType(TypeToken typeToken) {
TypeToken<?> kvTupleType = getElementType(typeToken);
ParameterizedType type = (ParameterizedType) kvTupleType.getType();
Type[] types = type.getActualTypeArguments();
return Tuple2.of(TypeToken.of(types[0]), TypeToken.of(types[1]));
}
}
34 changes: 7 additions & 27 deletions java/fury-core/src/main/java/io/fury/type/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
*
* @author chaokunyang
*/
@SuppressWarnings("UnstableApiUsage")
@SuppressWarnings({"UnstableApiUsage", "unchecked"})
public class TypeUtils {
public static final String JAVA_BOOLEAN = "boolean";
public static final String JAVA_BYTE = "byte";
Expand Down Expand Up @@ -419,7 +419,9 @@ public static TypeToken<?> getElementType(TypeToken<?> typeToken) {
}
}
}
@SuppressWarnings("unchecked")
if (typeToken.getType().getTypeName().startsWith("scala.collection")) {
return ScalaTypes.getElementType(typeToken);
}
TypeToken<?> supertype =
((TypeToken<? extends Iterable<?>>) typeToken).getSupertype(Iterable.class);
return supertype.resolveType(ITERATOR_RETURN_TYPE).resolveType(NEXT_RETURN_TYPE);
Expand Down Expand Up @@ -448,6 +450,9 @@ public static Tuple2<TypeToken<?>, TypeToken<?>> getMapKeyValueType(TypeToken<?>
}
}
}
if (typeToken.getType().getTypeName().startsWith("scala.collection")) {
return ScalaTypes.getMapKeyValueType(typeToken);
}
@SuppressWarnings("unchecked")
TypeToken<?> supertype = ((TypeToken<? extends Map<?, ?>>) typeToken).getSupertype(Map.class);
TypeToken<?> keyType = getElementType(supertype.resolveType(KEY_SET_RETURN_TYPE));
Expand Down Expand Up @@ -691,29 +696,4 @@ public static List<TypeToken<?>> getAllTypeArguments(TypeToken typeToken) {

return new ArrayList<>(allTypeArguments);
}

private static volatile TypeToken<?> scalaMapType;
private static volatile TypeToken<?> scalaSeqType;
private static volatile TypeToken<?> scalaIterableType;

public static TypeToken<?> getScalaMapType() {
if (scalaMapType == null) {
scalaMapType = TypeToken.of(ReflectionUtils.loadClass("scala.collection.Map"));
}
return scalaMapType;
}

public static TypeToken<?> getScalaSeqType() {
if (scalaSeqType == null) {
scalaSeqType = TypeToken.of(ReflectionUtils.loadClass("scala.collection.Seq"));
}
return scalaSeqType;
}

public static TypeToken<?> getScalaIterableType() {
if (scalaIterableType == null) {
scalaIterableType = TypeToken.of(ReflectionUtils.loadClass("scala.collection.Iterable"));
}
return scalaIterableType;
}
}
22 changes: 11 additions & 11 deletions scala/src/main/java/io/fury/serializer/scala/ScalaDispatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ public class ScalaDispatcher implements SerializerFactory {
*/
@Override
public Serializer createSerializer(Fury fury, Class<?> clz) {
// Many map/seq/set types doesn't extends DefaultSerializable.
if (scala.collection.SortedMap.class.isAssignableFrom(clz)) {
return new ScalaSortedMapSerializer(fury, clz);
} else if (scala.collection.Map.class.isAssignableFrom(clz)) {
return new ScalaMapSerializer(fury, clz);
} else if (scala.collection.SortedSet.class.isAssignableFrom(clz)) {
return new ScalaSortedSetSerializer(fury, clz);
} else if (scala.collection.Seq.class.isAssignableFrom(clz)) {
return new ScalaSeqSerializer(fury, clz);
}
if (DefaultSerializable.class.isAssignableFrom(clz)) {
Method replaceMethod = JavaSerializer.getWriteReplaceMethod(clz);
Preconditions.checkNotNull(replaceMethod);
if (scala.collection.SortedMap.class.isAssignableFrom(clz)) {
return new ScalaSortedMapSerializer(fury, clz);
} else if (scala.collection.Map.class.isAssignableFrom(clz)) {
return new ScalaMapSerializer(fury, clz);
} else if (scala.collection.SortedSet.class.isAssignableFrom(clz)) {
return new ScalaSortedSetSerializer(fury, clz);
} else if (scala.collection.Seq.class.isAssignableFrom(clz)) {
return new ScalaSeqSerializer(fury, clz);
} else {
return new ScalaCollectionSerializer(fury, clz);
}
return new ScalaCollectionSerializer(fury, clz);
}
return null;
}
Expand Down
Loading