From 391cc8fde82a38797c0355031f156f8914845f1b Mon Sep 17 00:00:00 2001 From: twalthr Date: Tue, 13 Jan 2015 23:59:35 +0100 Subject: [PATCH] [FLINK-1147][Java API] TypeInference on POJOs --- .../api/java/typeutils/TypeExtractor.java | 189 +++++++++++++---- .../api/java/typeutils/TypeInfoParser.java | 7 +- .../extractor/PojoTypeExtractionTest.java | 196 +++++++++++++++++- .../type/extractor/TypeExtractorTest.java | 20 +- 4 files changed, 366 insertions(+), 46 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index edff09c3332b5..a1f5dd6fbcd93 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -422,7 +422,12 @@ private TypeInformation createTypeInfoWithTypeHierarchy(Arr int fieldCount = countFieldsInClass(tAsClass); if(fieldCount != tupleSubTypes.length) { // the class is not a real tuple because it contains additional fields. treat as a pojo - return (TypeInformation) analyzePojo(tAsClass, new ArrayList(typeHierarchy), null); // the typeHierarchy here should be sufficient, even though it stops at the Tuple.class. + if (t instanceof ParameterizedType) { + return (TypeInformation) analyzePojo(tAsClass, new ArrayList(typeHierarchy), (ParameterizedType) t, in1Type, in2Type); + } + else { + return (TypeInformation) analyzePojo(tAsClass, new ArrayList(typeHierarchy), null, in1Type, in2Type); + } } return new TupleTypeInfo(tAsClass, tupleSubTypes); @@ -482,9 +487,9 @@ else if (t instanceof GenericArrayType) { in1Type, in2Type); return ObjectArrayTypeInfo.getInfoFor(t, componentInfo); } - // objects with generics are treated as raw type - else if (t instanceof ParameterizedType) { //TODO - return privateGetForClass((Class) ((ParameterizedType) t).getRawType(), typeHierarchy, (ParameterizedType) t); + // objects with generics are treated as Class first + else if (t instanceof ParameterizedType) { + return (TypeInformation) privateGetForClass(typeToClass(t), typeHierarchy, (ParameterizedType) t, in1Type, in2Type); } // no tuple, no TypeVariable, no generic type else if (t instanceof Class) { @@ -553,10 +558,25 @@ private TypeInformation createTypeInfoFromInput(TypeVariable returnT // the input is a type variable if (inType instanceof TypeVariable) { inType = materializeTypeVariable(inputTypeHierarchy, (TypeVariable) inType); - info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo); + info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + } + // input is an array + else if (inType instanceof GenericArrayType) { + TypeInformation componentInfo = null; + if (inTypeInfo instanceof BasicArrayTypeInfo) { + componentInfo = ((BasicArrayTypeInfo) inTypeInfo).getComponentInfo(); + } + else if (inTypeInfo instanceof PrimitiveArrayTypeInfo) { + componentInfo = BasicTypeInfo.getInfoFor(inTypeInfo.getTypeClass().getComponentType()); + } + else if (inTypeInfo instanceof ObjectArrayTypeInfo) { + componentInfo = ((ObjectArrayTypeInfo) inTypeInfo).getComponentInfo(); + } + info = createTypeInfoFromInput(returnTypeVar, inputTypeHierarchy, ((GenericArrayType) inType).getGenericComponentType(), componentInfo); } - // the input is a tuple that may contains type variables - else if (isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType))) { + // the input is a tuple + else if (inTypeInfo instanceof TupleTypeInfo && isClassType(inType) + && Tuple.class.isAssignableFrom(typeToClass(inType))) { ParameterizedType tupleBaseClass = null; // get tuple from possible tuple subclass @@ -579,6 +599,12 @@ else if (isClassType(inType) && Tuple.class.isAssignableFrom(typeToClass(inType) } } } + // the input is a pojo + else if (inTypeInfo instanceof PojoTypeInfo) { + // build the entire type hierarchy for the pojo + getTypeHierarchy(inputTypeHierarchy, inType, Object.class); + info = findCorrespondingInfo(returnTypeVar, inType, inTypeInfo, inputTypeHierarchy); + } return info; } @@ -841,7 +867,7 @@ else if (typeInfo instanceof GenericTypeInfo) { * @param curT : start type * @return Type The immediate child of the top class */ - private Type getTypeHierarchy(ArrayList typeHierarchy, Type curT, Class stopAtClass) { + private static Type getTypeHierarchy(ArrayList typeHierarchy, Type curT, Class stopAtClass) { // skip first one if (typeHierarchy.size() > 0 && typeHierarchy.get(0) == curT && isClassType(curT)) { curT = typeToClass(curT).getGenericSuperclass(); @@ -926,26 +952,69 @@ else if (primitiveClass == short.class) { throw new InvalidTypesException(); } - private static TypeInformation findCorrespondingInfo(TypeVariable typeVar, Type type, TypeInformation corrInfo) { - if (type instanceof TypeVariable) { - TypeVariable variable = (TypeVariable) type; - if (variable.getName().equals(typeVar.getName()) && variable.getGenericDeclaration().equals(typeVar.getGenericDeclaration())) { - return corrInfo; + private static TypeInformation findCorrespondingInfo(TypeVariable typeVar, Type type, TypeInformation corrInfo, ArrayList typeHierarchy) { + if (sameTypeVars(type, typeVar)) { + return corrInfo; + } + else if (type instanceof TypeVariable && sameTypeVars(materializeTypeVariable(typeHierarchy, (TypeVariable) type), typeVar)) { + return corrInfo; + } + else if (type instanceof GenericArrayType) { + TypeInformation componentInfo = null; + if (corrInfo instanceof BasicArrayTypeInfo) { + componentInfo = ((BasicArrayTypeInfo) corrInfo).getComponentInfo(); + } + else if (corrInfo instanceof PrimitiveArrayTypeInfo) { + componentInfo = BasicTypeInfo.getInfoFor(corrInfo.getTypeClass().getComponentType()); + } + else if (corrInfo instanceof ObjectArrayTypeInfo) { + componentInfo = ((ObjectArrayTypeInfo) corrInfo).getComponentInfo(); } - } else if (type instanceof ParameterizedType && Tuple.class.isAssignableFrom((Class) ((ParameterizedType) type).getRawType())) { + TypeInformation info = findCorrespondingInfo(typeVar, ((GenericArrayType) type).getGenericComponentType(), componentInfo, typeHierarchy); + if (info != null) { + return info; + } + } + else if (corrInfo instanceof TupleTypeInfo + && type instanceof ParameterizedType + && Tuple.class.isAssignableFrom((Class) ((ParameterizedType) type).getRawType())) { ParameterizedType tuple = (ParameterizedType) type; Type[] args = tuple.getActualTypeArguments(); for (int i = 0; i < args.length; i++) { - TypeInformation info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo) corrInfo).getTypeAt(i)); + TypeInformation info = findCorrespondingInfo(typeVar, args[i], ((TupleTypeInfo) corrInfo).getTypeAt(i), typeHierarchy); if (info != null) { return info; } } } + else if (corrInfo instanceof PojoTypeInfo && isClassType(type)) { + // determine a field containing the type variable + List fields = getAllDeclaredFields(typeToClass(type)); + for (Field field : fields) { + Type fieldType = field.getGenericType(); + if (fieldType instanceof TypeVariable + && sameTypeVars(typeVar, materializeTypeVariable(typeHierarchy, (TypeVariable) fieldType))) { + return getTypeOfPojoField(corrInfo, field); + } + else if (fieldType instanceof ParameterizedType + || fieldType instanceof GenericArrayType) { + ArrayList typeHierarchyWithFieldType = new ArrayList(typeHierarchy); + typeHierarchyWithFieldType.add(fieldType); + TypeInformation info = findCorrespondingInfo(typeVar, fieldType, getTypeOfPojoField(corrInfo, field), typeHierarchyWithFieldType); + if (info != null) { + return info; + } + } + } + } return null; } + /** + * Tries to find a concrete value (Class, ParameterizedType etc. ) for a TypeVariable by traversing the type hierarchy downwards. + * If a value could not be found it will return the most bottom type variable in the hierarchy. + */ private static Type materializeTypeVariable(ArrayList typeHierarchy, TypeVariable typeVar) { TypeVariable inTypeTypeVar = typeVar; // iterate thru hierarchy from top to bottom until type variable gets a class assigned @@ -961,8 +1030,7 @@ private static Type materializeTypeVariable(ArrayList typeHierarchy, TypeV TypeVariable curVarOfCurT = rawType.getTypeParameters()[paramIndex]; // check if variable names match - if (curVarOfCurT.getName().equals(inTypeTypeVar.getName()) - && curVarOfCurT.getGenericDeclaration().equals(inTypeTypeVar.getGenericDeclaration())) { + if (sameTypeVars(curVarOfCurT, inTypeTypeVar)) { Type curVarType = ((ParameterizedType) curT).getActualTypeArguments()[paramIndex]; // another type variable level @@ -982,15 +1050,26 @@ private static Type materializeTypeVariable(ArrayList typeHierarchy, TypeV return inTypeTypeVar; } + /** + * Creates type information from a given Class such as Integer, String[] or POJOs. + * + * This method does not support ParameterizedTypes such as Tuples or complex type hierarchies. + * In most cases {@link TypeExtractor#createTypeInfo(Type)} is the recommended method for type extraction + * (a Class is a child of Type). + * + * @param clazz a Class to create TypeInformation for + * @return TypeInformation that describes the passed Class + */ public static TypeInformation getForClass(Class clazz) { return new TypeExtractor().privateGetForClass(clazz, new ArrayList()); } private TypeInformation privateGetForClass(Class clazz, ArrayList typeHierarchy) { - return privateGetForClass(clazz, typeHierarchy, null); + return privateGetForClass(clazz, typeHierarchy, null, null, null); } @SuppressWarnings({ "unchecked", "rawtypes" }) - private TypeInformation privateGetForClass(Class clazz, ArrayList typeHierarchy, ParameterizedType clazzTypeHint) { + private TypeInformation privateGetForClass(Class clazz, ArrayList typeHierarchy, + ParameterizedType parameterizedType, TypeInformation in1Type, TypeInformation in2Type) { Validate.notNull(clazz); // check for abstract classes or interfaces @@ -999,20 +1078,20 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList(clazz); + return new GenericTypeInfo(clazz); } // check for arrays if (clazz.isArray()) { // primitive arrays: int[], byte[], ... - PrimitiveArrayTypeInfo primitiveArrayInfo = PrimitiveArrayTypeInfo.getInfoFor(clazz); + PrimitiveArrayTypeInfo primitiveArrayInfo = PrimitiveArrayTypeInfo.getInfoFor(clazz); if (primitiveArrayInfo != null) { return primitiveArrayInfo; } // basic type arrays: String[], Integer[], Double[] - BasicArrayTypeInfo basicArrayInfo = BasicArrayTypeInfo.getInfoFor(clazz); + BasicArrayTypeInfo basicArrayInfo = BasicArrayTypeInfo.getInfoFor(clazz); if (basicArrayInfo != null) { return basicArrayInfo; } @@ -1025,11 +1104,11 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList) WritableTypeInfo.getWritableTypeInfo((Class) clazz); + return (TypeInformation) WritableTypeInfo.getWritableTypeInfo((Class) clazz); } // check for basic types - TypeInformation basicTypeInfo = BasicTypeInfo.getInfoFor(clazz); + TypeInformation basicTypeInfo = BasicTypeInfo.getInfoFor(clazz); if (basicTypeInfo != null) { return basicTypeInfo; } @@ -1037,7 +1116,7 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList valueClass = clazz.asSubclass(Value.class); - return (TypeInformation) ValueTypeInfo.getValueTypeInfo(valueClass); + return (TypeInformation) ValueTypeInfo.getValueTypeInfo(valueClass); } // check for subclasses of Tuple @@ -1047,22 +1126,22 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList) new EnumTypeInfo(clazz); + return (TypeInformation) new EnumTypeInfo(clazz); } if (alreadySeen.contains(clazz)) { - return new GenericTypeInfo(clazz); + return new GenericTypeInfo(clazz); } alreadySeen.add(clazz); if (clazz.equals(Class.class)) { // special case handling for Class, this should not be handled by the POJO logic - return new GenericTypeInfo(clazz); + return new GenericTypeInfo(clazz); } try { - TypeInformation pojoType = analyzePojo(clazz, new ArrayList(typeHierarchy), clazzTypeHint); + TypeInformation pojoType = analyzePojo(clazz, new ArrayList(typeHierarchy), parameterizedType, in1Type, in2Type); if (pojoType != null) { return pojoType; } @@ -1074,7 +1153,7 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList(clazz); + return new GenericTypeInfo(clazz); } /** @@ -1142,14 +1221,16 @@ private boolean isValidPojoField(Field f, Class clazz, ArrayList typeHi } @SuppressWarnings("unchecked") - private TypeInformation analyzePojo(Class clazz, ArrayList typeHierarchy, ParameterizedType clazzTypeHint) { - // try to create Type hierarchy, if the incoming only contains the most bottom one or none. - if(typeHierarchy.size() <= 1) { + private TypeInformation analyzePojo(Class clazz, ArrayList typeHierarchy, + ParameterizedType parameterizedType, TypeInformation in1Type, TypeInformation in2Type) { + // add the hierarchy of the POJO itself if it is generic + if (parameterizedType != null) { + getTypeHierarchy(typeHierarchy, parameterizedType, Object.class); + } + // create a type hierarchy, if the incoming only contains the most bottom one or none. + else if(typeHierarchy.size() <= 1) { getTypeHierarchy(typeHierarchy, clazz, Object.class); } - if(clazzTypeHint != null) { - getTypeHierarchy(typeHierarchy, clazzTypeHint, Object.class); - } List fields = getAllDeclaredFields(clazz); List pojoFields = new ArrayList(); @@ -1162,17 +1243,18 @@ private TypeInformation analyzePojo(Class clazz, ArrayList typeH try { ArrayList fieldTypeHierarchy = new ArrayList(typeHierarchy); fieldTypeHierarchy.add(fieldType); - pojoFields.add(new PojoField(field, createTypeInfoWithTypeHierarchy(fieldTypeHierarchy, fieldType, null, null) )); + TypeInformation ti = createTypeInfoWithTypeHierarchy(fieldTypeHierarchy, fieldType, in1Type, in2Type); + pojoFields.add(new PojoField(field, ti)); } catch (InvalidTypesException e) { Class genericClass = Object.class; if(isClassType(fieldType)) { genericClass = typeToClass(fieldType); } - pojoFields.add(new PojoField(field, new GenericTypeInfo( (Class) genericClass ))); + pojoFields.add(new PojoField(field, new GenericTypeInfo((Class) genericClass))); } } - CompositeType pojoType = new PojoTypeInfo(clazz, pojoFields); + CompositeType pojoType = new PojoTypeInfo(clazz, pojoFields); // // Validate the correctness of the pojo. @@ -1223,6 +1305,15 @@ public static List getAllDeclaredFields(Class clazz) { return result; } + public static Field getDeclaredField(Class clazz, String name) { + for (Field field : getAllDeclaredFields(clazz)) { + if (field.getName().equals(name)) { + return field; + } + } + return null; + } + private static boolean hasFieldWithSameName(String name, List fields) { for(Field field : fields) { if(name.equals(field.getName())) { @@ -1260,6 +1351,24 @@ private static boolean isClassType(Type t) { return t instanceof Class || t instanceof ParameterizedType; } + private static boolean sameTypeVars(Type t1, Type t2) { + if (!(t1 instanceof TypeVariable) || !(t2 instanceof TypeVariable)) { + return false; + } + return ((TypeVariable) t1).getName().equals(((TypeVariable)t2).getName()) + && ((TypeVariable) t1).getGenericDeclaration().equals(((TypeVariable)t2).getGenericDeclaration()); + } + + private static TypeInformation getTypeOfPojoField(TypeInformation pojoInfo, Field field) { + for (int j = 0; j < pojoInfo.getArity(); j++) { + PojoField pf = ((PojoTypeInfo) pojoInfo).getPojoFieldAt(j); + if (pf.field.getName().equals(field.getName())) { + return pf.type; + } + } + return null; + } + public static TypeInformation getForObject(X value) { return new TypeExtractor().privateGetForObject(value); @@ -1275,7 +1384,7 @@ private TypeInformation privateGetForObject(X value) { int numFields = t.getArity(); if(numFields != countFieldsInClass(value.getClass())) { // not a tuple since it has more fields. - return analyzePojo((Class) value.getClass(), new ArrayList(), null); // we immediately call analyze Pojo here, because + return analyzePojo((Class) value.getClass(), new ArrayList(), null, null, null); // we immediately call analyze Pojo here, because // there is currently no other type that can handle such a class. } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java index e9d5dac52a757..33f041d16baff 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java @@ -157,6 +157,7 @@ private static TypeInformation parse(StringBuilder sb) throws ClassNotFoundEx } else { arrayClazz = Class.forName("[L" + TUPLE_PACKAGE + "." + className + ";"); } + sb.delete(0, 2); returnType = ObjectArrayTypeInfo.getInfoFor(arrayClazz, new TupleTypeInfo(clazz, types)); } else if (sb.length() < 1 || sb.charAt(0) != '[') { returnType = new TupleTypeInfo(clazz, types); @@ -291,10 +292,8 @@ else if (pojoGenericMatcher.find()) { String fieldName = fieldMatcher.group(1); sb.delete(0, fieldName.length() + 1); - Field field = null; - try { - field = clazz.getDeclaredField(fieldName); - } catch (Exception e) { + Field field = TypeExtractor.getDeclaredField(clazz, fieldName); + if (field == null) { throw new IllegalArgumentException("Field '" + fieldName + "'could not be accessed."); } fields.add(new PojoField(field, parse(sb))); diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java index 39d6e10e76f2a..27db31da86714 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeExtractionTest.java @@ -23,18 +23,21 @@ import java.util.Date; import java.util.List; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.PojoField; import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.api.java.typeutils.TypeInfoParser; import org.apache.flink.api.java.typeutils.TypeInfoParserTest.MyWritable; import org.apache.flink.api.java.typeutils.WritableTypeInfo; import org.junit.Assert; @@ -208,6 +211,7 @@ public void testPojoWC() { checkWCPojoAsserts(typeForObject); } + @SuppressWarnings({ "unchecked", "rawtypes" }) private void checkWCPojoAsserts(TypeInformation typeInfo) { Assert.assertFalse(typeInfo.isBasicType()); Assert.assertFalse(typeInfo.isTupleType()); @@ -406,7 +410,6 @@ private void checkWCPojoAsserts(TypeInformation typeInfo) { Assert.assertEquals(typeInfo.getArity(), 2); } - // Kryo is required for this, so disable for now. @Test public void testPojoAllPublic() { TypeInformation typeForClass = TypeExtractor.createTypeInfo(AllPublic.class); @@ -616,4 +619,195 @@ public void testGetterSetterWithVertex() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.fromElements(new VertexTyped(0L, 3.0), new VertexTyped(1L, 1.0)); } + + public static class MyMapper implements MapFunction, PojoWithGenerics> { + private static final long serialVersionUID = 1L; + + @Override + public PojoWithGenerics map(PojoWithGenerics value) + throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference1() { + MapFunction function = new MyMapper(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithGenerics")); + Assert.assertTrue(ti instanceof PojoTypeInfo); + PojoTypeInfo pti = (PojoTypeInfo) ti; + for(int i = 0; i < pti.getArity(); i++) { + PojoField field = pti.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("field1")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if (name.equals("field2")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if (name.equals("key")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + public static class PojoTuple extends Tuple3 { + private static final long serialVersionUID = 1L; + + public A extraField; + } + + public static class MyMapper2 implements MapFunction, PojoTuple> { + private static final long serialVersionUID = 1L; + + @Override + public PojoTuple map(Tuple2 value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference2() { + MapFunction function = new MyMapper2(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("Tuple2")); + Assert.assertTrue(ti instanceof PojoTypeInfo); + PojoTypeInfo pti = (PojoTypeInfo) ti; + for(int i = 0; i < pti.getArity(); i++) { + PojoField field = pti.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("extraField")) { + Assert.assertEquals(BasicTypeInfo.CHAR_TYPE_INFO, field.type); + } else if (name.equals("f0")) { + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, field.type); + } else if (name.equals("f1")) { + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, field.type); + } else if (name.equals("f2")) { + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + public static class MyMapper3 implements MapFunction, Tuple2> { + private static final long serialVersionUID = 1L; + + @Override + public Tuple2 map(PojoTuple value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference3() { + MapFunction function = new MyMapper3(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoTuple")); + Assert.assertTrue(ti instanceof TupleTypeInfo); + TupleTypeInfo tti = (TupleTypeInfo) ti; + Assert.assertEquals(BasicTypeInfo.CHAR_TYPE_INFO, tti.getTypeAt(0)); + Assert.assertEquals(BasicTypeInfo.BOOLEAN_TYPE_INFO, tti.getTypeAt(1)); + } + + public static class PojoWithParameterizedFields1 { + public Tuple2 field; + } + + public static class MyMapper4 implements MapFunction, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields1 value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference4() { + MapFunction function = new MyMapper4(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields1>")); + Assert.assertEquals(BasicTypeInfo.BYTE_TYPE_INFO, ti); + } + + public static class PojoWithParameterizedFields2 { + public PojoWithGenerics field; + } + + public static class MyMapper5 implements MapFunction, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields2 value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference5() { + MapFunction function = new MyMapper5(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields2<" + + "field=org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithGenerics" + + ">")); + Assert.assertEquals(BasicTypeInfo.BYTE_TYPE_INFO, ti); + } + + public static class PojoWithParameterizedFields3 { + public Z[] field; + } + + public static class MyMapper6 implements MapFunction, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields3 value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference6() { + MapFunction function = new MyMapper6(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields3<" + + "field=int[]" + + ">")); + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, ti); + } + + public static class MyMapper7 implements MapFunction, A> { + private static final long serialVersionUID = 1L; + @Override + public A map(PojoWithParameterizedFields4 value) throws Exception { + return null; + } + } + + public static class PojoWithParameterizedFields4 { + public Tuple1[] field; + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testGenericPojoTypeInference7() { + MapFunction function = new MyMapper7(); + + TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) + TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.PojoTypeExtractionTest$PojoWithParameterizedFields4<" + + "field=Tuple1[]" + + ">")); + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, ti); + } } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java index 1364a2f7738bd..8a2d67592ca80 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java @@ -1260,7 +1260,7 @@ public static class MyObject { public static class InType extends MyObject {} @SuppressWarnings({ "rawtypes", "unchecked" }) @Test - public void testParamertizedCustomObject() { + public void testParameterizedPojo() { RichMapFunction function = new RichMapFunction>() { private static final long serialVersionUID = 1L; @@ -1622,6 +1622,24 @@ public void testInputInference3() { Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); } + public static class EdgeMapper4 implements MapFunction[], V> { + private static final long serialVersionUID = 1L; + + @Override + public V map(Edge[] value) throws Exception { + return null; + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + public void testInputInference4() { + EdgeMapper4 em = new EdgeMapper4(); + TypeInformation ti = TypeExtractor.getMapReturnTypes((MapFunction) em, TypeInfoParser.parse("Tuple3[]")); + Assert.assertTrue(ti.isBasicType()); + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, ti); + } + public static enum MyEnum { ONE, TWO, THREE }