From 13eac8e823da2429b5dad304e8a442e79a53989a Mon Sep 17 00:00:00 2001 From: twalthr Date: Wed, 12 Oct 2016 10:33:47 +0200 Subject: [PATCH] [FLINK-4801] [types] Input type inference is faulty with custom Tuples and RichFunctions --- .../java/typeutils/TypeExtractionUtils.java | 38 ++++++ .../api/java/typeutils/TypeExtractor.java | 128 +++++------------- .../typeutils/runtime/kryo/Serializers.java | 6 +- .../api/java/typeutils/TypeExtractorTest.java | 37 ++++- 4 files changed, 114 insertions(+), 95 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java index 44396123bf07c..0aac257aa84c6 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java @@ -20,7 +20,9 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -81,6 +83,12 @@ public boolean executablesEquals(Constructor c) { } } + /** + * Checks if the given function has been implemented using a Java 8 lambda. If yes, a LambdaExecutable + * is returned describing the method/constructor. Otherwise null. + * + * @throws TypeExtractionException lambda extraction is pretty hacky, it might fail for unknown JVM issues. + */ public static LambdaExecutable checkAndExtractLambda(Function function) throws TypeExtractionException { try { // get serialized lambda @@ -164,4 +172,34 @@ public static List getAllDeclaredMethods(Class clazz) { } return result; } + + /** + * Convert ParameterizedType or Class to a Class. + */ + public static Class typeToClass(Type t) { + if (t instanceof Class) { + return (Class)t; + } + else if (t instanceof ParameterizedType) { + return ((Class)((ParameterizedType) t).getRawType()); + } + throw new IllegalArgumentException("Cannot convert type to class"); + } + + /** + * Checks if a type can be converted to a Class. This is true for ParameterizedType and Class. + */ + public static boolean isClassType(Type t) { + return t instanceof Class || t instanceof ParameterizedType; + } + + /** + * Checks whether two types are type variables describing the same. + */ + public static boolean sameTypeVars(Type t1, Type t2) { + return t1 instanceof TypeVariable && + t2 instanceof TypeVariable && + ((TypeVariable) t1).getName().equals(((TypeVariable) t2).getName()) && + ((TypeVariable) t1).getGenericDeclaration().equals(((TypeVariable) t2).getGenericDeclaration()); + } } diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index c1febeaa9ae22..b41bbc1562b41 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -68,6 +68,9 @@ import org.apache.flink.api.java.typeutils.TypeExtractionUtils.LambdaExecutable; import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.checkAndExtractLambda; import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.getAllDeclaredMethods; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.isClassType; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.sameTypeVars; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.typeToClass; import org.apache.flink.types.Either; import org.apache.flink.types.Value; import org.apache.flink.util.InstantiationUtil; @@ -859,6 +862,14 @@ private TypeInformation createTypeInfoFromInputs(TypeVariable r return null; } + /** + * Finds the type information to a type variable. + * + * It solve the following: + * + * Return the type information for "returnTypeVar" given that "inType" has type information "inTypeInfo". + * Thus "inType" must contain "returnTypeVar" in a "inputTypeHierarchy", otherwise null is returned. + */ @SuppressWarnings({"unchecked", "rawtypes"}) private TypeInformation createTypeInfoFromInput(TypeVariable returnTypeVar, ArrayList inputTypeHierarchy, Type inType, TypeInformation inTypeInfo) { TypeInformation info = null; @@ -891,9 +902,14 @@ private TypeInformation createTypeInfoFromInput(TypeVariable returnT } } // the input is a type variable + else if (sameTypeVars(inType, returnTypeVar)) { + return inTypeInfo; + } else if (inType instanceof TypeVariable) { - inType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable) inType); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + Type resolvedInType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable) inType); + if (resolvedInType != inType) { + info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, resolvedInType, inTypeInfo); + } } // input is an array else if (inType instanceof GenericArrayType) { @@ -910,8 +926,7 @@ else if (inTypeInfo instanceof ObjectArrayTypeInfo) { info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, ((GenericArrayType) inType).getGenericComponentType(), componentInfo); } // the input is a tuple - else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) - && Tuple.class.isAssignableFrom(typeToClass(inType))) { + else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType))) { ParameterizedType tupleBaseClass; // get tuple from possible tuple subclass @@ -935,10 +950,25 @@ else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) } } // the input is a pojo - else if (inTypeInfo instanceof PojoTypeInfo) { + else if (inTypeInfo instanceof PojoTypeInfo && isClassType(inType)) { // build the entire type hierarchy for the pojo getTypeHierarchy(inputTypeHierarchy, inType, Object.class); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + // determine a field containing the type variable + List fields = getAllDeclaredFields(typeToClass(inType)); + for (Field field : fields) { + Type fieldType = field.getGenericType(); + if (fieldType instanceof TypeVariable && sameTypeVars(returnTypeVar, materializeTypeVariable(inputTypeHierarchy, (TypeVariable) fieldType))) { + return getTypeOfPojoField(inTypeInfo, field); + } + else if (fieldType instanceof ParameterizedType || fieldType instanceof GenericArrayType) { + ArrayList typeHierarchyWithFieldType = new ArrayList<>(inputTypeHierarchy); + typeHierarchyWithFieldType.add(fieldType); + TypeInformation foundInfo = createTypeInfoFromInput(returnTypeVar, typeHierarchyWithFieldType, fieldType, getTypeOfPojoField(inTypeInfo, field)); + if (foundInfo != null) { + return foundInfo; + } + } + } } return info; } @@ -1557,66 +1587,7 @@ else if (primitiveClass == short.class) { } throw new InvalidTypesException(); } - - private static TypeInformation findCorrespondingInfo(TypeVariable typeVar, Type type, TypeInformation corrInfo, ArrayList typeHierarchy) { - if (sameTypeVars(type, typeVar)) { - return corrInfo; - } - else if (type instanceof TypeVariable && sameTypeVars(materializeTypeVariable(typeHierarchy, (TypeVariable) type), typeVar)) { - return corrInfo; - } - else if (type instanceof GenericArrayType) { - TypeInformation componentInfo = null; - if (corrInfo instanceof BasicArrayTypeInfo) { - componentInfo = ((BasicArrayTypeInfo) corrInfo).getComponentInfo(); - } - else if (corrInfo instanceof PrimitiveArrayTypeInfo) { - componentInfo = BasicTypeInfo.getInfoFor(corrInfo.getTypeClass().getComponentType()); - } - else if (corrInfo instanceof ObjectArrayTypeInfo) { - componentInfo = ((ObjectArrayTypeInfo) corrInfo).getComponentInfo(); - } - TypeInformation info = findCorrespondingInfo(typeVar, ((GenericArrayType) type).getGenericComponentType(), componentInfo, typeHierarchy); - if (info != null) { - return info; - } - } - else if (corrInfo instanceof TupleTypeInfo - && type instanceof ParameterizedType - && Tuple.class.isAssignableFrom((Class) ((ParameterizedType) type).getRawType())) { - ParameterizedType tuple = (ParameterizedType) type; - Type[] args = tuple.getActualTypeArguments(); - - for (int i = 0; i < args.length; i++) { - TypeInformation info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo) corrInfo).getTypeAt(i), typeHierarchy); - if (info != null) { - return info; - } - } - } - else if (corrInfo instanceof PojoTypeInfo && isClassType(type)) { - // determine a field containing the type variable - List fields = getAllDeclaredFields(typeToClass(type)); - for (Field field : fields) { - Type fieldType = field.getGenericType(); - if (fieldType instanceof TypeVariable - && sameTypeVars(typeVar, materializeTypeVariable(typeHierarchy, (TypeVariable) fieldType))) { - return getTypeOfPojoField(corrInfo, field); - } - else if (fieldType instanceof ParameterizedType - || fieldType instanceof GenericArrayType) { - ArrayList typeHierarchyWithFieldType = new ArrayList(typeHierarchy); - typeHierarchyWithFieldType.add(fieldType); - TypeInformation info = findCorrespondingInfo(typeVar, fieldType, getTypeOfPojoField(corrInfo, field), typeHierarchyWithFieldType); - if (info != null) { - return info; - } - } - } - } - return null; - } - + /** * Tries to find a concrete value (Class, ParameterizedType etc. ) for a TypeVariable by traversing the type hierarchy downwards. * If a value could not be found it will return the most bottom type variable in the hierarchy. @@ -1991,30 +1962,6 @@ private static boolean hasFieldWithSameName(String name, List fields) { } return false; } - - @Internal - public static Class typeToClass(Type t) { - if (t instanceof Class) { - return (Class)t; - } - else if (t instanceof ParameterizedType) { - return ((Class)((ParameterizedType) t).getRawType()); - } - throw new IllegalArgumentException("Cannot convert type to class"); - } - - @Internal - public static boolean isClassType(Type t) { - return t instanceof Class || t instanceof ParameterizedType; - } - - private static boolean sameTypeVars(Type t1, Type t2) { - if (!(t1 instanceof TypeVariable) || !(t2 instanceof TypeVariable)) { - return false; - } - return ((TypeVariable) t1).getName().equals(((TypeVariable) t2).getName()) - && ((TypeVariable) t1).getGenericDeclaration().equals(((TypeVariable) t2).getGenericDeclaration()); - } private static TypeInformation getTypeOfPojoField(TypeInformation pojoInfo, Field field) { for (int j = 0; j < pojoInfo.getArity(); j++) { @@ -2026,7 +1973,6 @@ private static TypeInformation getTypeOfPojoField(TypeInformation pojoInfo return null; } - public static TypeInformation getForObject(X value) { return new TypeExtractor().privateGetForObject(value); } diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java index b6e978fdb0e2f..4976d6a78e531 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/Serializers.java @@ -34,7 +34,7 @@ import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; -import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.TypeExtractionUtils; import java.io.Serializable; import java.lang.reflect.Field; @@ -113,8 +113,8 @@ private static void recursivelyRegisterGenericType(Type fieldType, ExecutionConf ParameterizedType parameterizedFieldType = (ParameterizedType) fieldType; for (Type t: parameterizedFieldType.getActualTypeArguments()) { - if (TypeExtractor.isClassType(t) ) { - recursivelyRegisterType(TypeExtractor.typeToClass(t), config, alreadySeen); + if (TypeExtractionUtils.isClassType(t) ) { + recursivelyRegisterType(TypeExtractionUtils.typeToClass(t), config, alreadySeen); } } diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java index 443cbc3fc9599..55cd42debfb66 100644 --- a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/TypeExtractorTest.java @@ -30,12 +30,14 @@ import java.util.Map; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.InvalidTypesException; +import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.functions.RichCrossFunction; import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichJoinFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; @@ -1665,7 +1667,40 @@ public void testInputInference4() { Assert.assertTrue(ti.isBasicType()); Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); } - + + public static class CustomTuple2WithArray extends Tuple2 { + + public CustomTuple2WithArray() { + // default constructor + } + } + + public class JoinWithCustomTuple2WithArray extends RichJoinFunction, CustomTuple2WithArray, CustomTuple2WithArray> { + + @Override + public CustomTuple2WithArray join(CustomTuple2WithArray first, CustomTuple2WithArray second) throws Exception { + return null; + } + } + + @Test + public void testInputInferenceWithCustomTupleAndRichFunction() { + JoinFunction, CustomTuple2WithArray, CustomTuple2WithArray> function = new JoinWithCustomTuple2WithArray<>(); + + TypeInformation ti = TypeExtractor.getJoinReturnTypes( + function, + new TypeHint>(){}.getTypeInfo(), + new TypeHint>(){}.getTypeInfo()); + + Assert.assertTrue(ti.isTupleType()); + TupleTypeInfo tti = (TupleTypeInfo) ti; + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, tti.getTypeAt(1)); + + Assert.assertTrue(tti.getTypeAt(0) instanceof ObjectArrayTypeInfo); + ObjectArrayTypeInfo oati = (ObjectArrayTypeInfo) tti.getTypeAt(0); + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, oati.getComponentInfo()); + } + public static enum MyEnum { ONE, TWO, THREE }