diff --git a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java index c7931c7362c..40525aca2ac 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericData.java @@ -105,6 +105,8 @@ public GenericData(ClassLoader classLoader) { private Map, Map>> conversionsByClass = new IdentityHashMap, Map>>(); + private Map, Class[]> classUnionTypes = new HashMap, Class[]>(); + /** * Registers the given conversion to be used when reading and writing with * this data model. @@ -124,6 +126,20 @@ public void addLogicalTypeConversion(Conversion conversion) { } } + /** + * add Union types for a given class. This is equivalent to adding an @Union annotation + * with values. But this can help in cases where you cannot mutate the class to annotate it + * @param clazz to add Union types for, typically the base classes + * @param unionTypes the class types to add to the union, typically the implementation classes + */ + public void addUnionTypes(Class clazz, Class[] unionTypes) { + classUnionTypes.put(clazz, unionTypes); + } + + protected Class[] getUnionTypes(Class clazz) { + return classUnionTypes.get(clazz); + } + /** * Returns the first conversion found for the given class. * diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java index 60095ad7851..bcbb463c0ee 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java @@ -577,8 +577,11 @@ protected Schema createSchema(Type type, Map names) { if (c.getEnclosingClass() != null) // nested class space = c.getEnclosingClass().getName() + "$"; Union union = c.getAnnotation(Union.class); + Schema unionSchema = getUnionSchema(c, names); if (union != null) { // union annotated return getAnnotatedUnion(union, names); + } else if(unionSchema != null) { + return unionSchema; } else if (isStringable(c)) { // Stringable Schema result = Schema.create(Schema.Type.STRING); result.addProp(CLASS_PROP, c.getName()); @@ -669,8 +672,21 @@ private void setElement(Schema schema, Type element) { // construct a schema from a union annotation private Schema getAnnotatedUnion(Union union, Map names) { + Class[] unionTypes = union.value(); + return getUnionSchema(unionTypes, names); + } + + private Schema getUnionSchema(Class clazz, Map names) { + Class[] unionTypes = getUnionTypes(clazz); + if(unionTypes != null) { + return getUnionSchema(unionTypes, names); + } + return null; + } + + private Schema getUnionSchema(Class[] unionTypes, Map names) { List branches = new ArrayList(); - for (Class branch : union.value()) + for (Class branch : unionTypes) branches.add(createSchema(branch, names)); return Schema.createUnion(branches); } @@ -788,7 +804,7 @@ private Message getMessage(Method method, Protocol protocol, Type[] paramTypes = method.getGenericParameterTypes(); Annotation[][] annotations = method.getParameterAnnotations(); for (int i = 0; i < paramTypes.length; i++) { - Schema paramSchema = getSchema(paramTypes[i], names); + Schema paramSchema = getUnionSchema(paramTypes[i], names); for (int j = 0; j < annotations[i].length; j++) { Annotation annotation = annotations[i][j]; if (annotation instanceof AvroSchema) // explicit schema @@ -808,7 +824,7 @@ else if (annotation instanceof Nullable) // nullable Union union = method.getAnnotation(Union.class); Schema response = union == null - ? getSchema(method.getGenericReturnType(), names) + ? getUnionSchema(method.getGenericReturnType(), names) : getAnnotatedUnion(union, names); if (method.isAnnotationPresent(Nullable.class)) // nullable response = makeNullable(response); @@ -821,12 +837,12 @@ else if (annotation instanceof Nullable) // nullable errs.add(Protocol.SYSTEM_ERROR); // every method can throw for (Type err : method.getGenericExceptionTypes()) if (err != AvroRemoteException.class) - errs.add(getSchema(err, names)); + errs.add(getUnionSchema(err, names)); Schema errors = Schema.createUnion(errs); return protocol.createMessage(method.getName(), null /* doc */, request, response, errors); } - private Schema getSchema(Type type, Map names) { + private Schema getUnionSchema(Type type, Map names) { try { return createSchema(type, names); } catch (AvroTypeException e) { // friendly exception diff --git a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflect.java b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflect.java index 8b2373083df..58acfbb4b74 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflect.java +++ b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflect.java @@ -284,6 +284,34 @@ public boolean equals(Object o) { } } + public static class R6_1 {} + + public static class R7_1 extends R6_1 { + public int value; + @Override + public boolean equals(Object o) { + if (!(o instanceof R7_1)) return false; + return this.value == ((R7_1)o).value; + } + } + public static class R8_1 extends R6_1 { + public float value; + @Override + public boolean equals(Object o) { + if (!(o instanceof R8_1)) return false; + return this.value == ((R8_1)o).value; + } + } + + public static class R9_1 { + public R6_1[] r6_1s; + @Override + public boolean equals(Object o) { + if (!(o instanceof R9_1)) return false; + return Arrays.equals(this.r6_1s, ((R9_1)o).r6_1s); + } + } + @Test public void testR6() throws Exception { R7 r7 = new R7(); r7.value = 1; @@ -296,6 +324,21 @@ public boolean equals(Object o) { checkReadWrite(r9, ReflectData.get().getSchema(R9.class)); } + @Test public void testUnionSubTypesWithoutAnnotations() throws Exception { + Class[] unionTypes = new Class[]{R7_1.class, R8_1.class}; + ReflectData.get().addUnionTypes(R6_1.class, unionTypes); + + R7_1 r7_1 = new R7_1(); + r7_1.value = 101; + checkReadWrite(r7_1, ReflectData.get().getSchema(R6_1.class)); + R8_1 r8_1 = new R8_1(); + r8_1.value = 102; + checkReadWrite(r8_1, ReflectData.get().getSchema(R6_1.class)); + R9_1 r9_1 = new R9_1(); + r9_1.r6_1s = new R6_1[] {r7_1, r8_1}; + checkReadWrite(r9_1, ReflectData.get().getSchema(R9_1.class)); + } + // test union annotation on methods and parameters public static interface P0 { @Union({Void.class,String.class})