Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public GenericData(ClassLoader classLoader) {
private Map<Class<?>, Map<String, Conversion<?>>> conversionsByClass =
new IdentityHashMap<Class<?>, Map<String, Conversion<?>>>();

private Map<Class<?>, Class[]> classUnionTypes = new HashMap<Class<?>, Class[]>();

/**
* Registers the given conversion to be used when reading and writing with
* this data model.
Expand All @@ -124,6 +126,20 @@ public void addLogicalTypeConversion(Conversion<?> conversion) {
}
}

/**
* add Union types for a given class. This is equivalent to adding an @Union annotation
* with values. But this can help in cases where you cannot mutate the class to annotate it
* @param clazz to add Union types for, typically the base classes
* @param unionTypes the class types to add to the union, typically the implementation classes
*/
public void addUnionTypes(Class<?> clazz, Class[] unionTypes) {
classUnionTypes.put(clazz, unionTypes);
}

protected Class[] getUnionTypes(Class<?> clazz) {
return classUnionTypes.get(clazz);
}

/**
* Returns the first conversion found for the given class.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,11 @@ protected Schema createSchema(Type type, Map<String,Schema> names) {
if (c.getEnclosingClass() != null) // nested class
space = c.getEnclosingClass().getName() + "$";
Union union = c.getAnnotation(Union.class);
Schema unionSchema = getUnionSchema(c, names);
if (union != null) { // union annotated
return getAnnotatedUnion(union, names);
} else if(unionSchema != null) {
return unionSchema;
} else if (isStringable(c)) { // Stringable
Schema result = Schema.create(Schema.Type.STRING);
result.addProp(CLASS_PROP, c.getName());
Expand Down Expand Up @@ -669,8 +672,21 @@ private void setElement(Schema schema, Type element) {

// construct a schema from a union annotation
private Schema getAnnotatedUnion(Union union, Map<String,Schema> names) {
Class[] unionTypes = union.value();
return getUnionSchema(unionTypes, names);
}

private Schema getUnionSchema(Class<?> clazz, Map<String, Schema> names) {
Class[] unionTypes = getUnionTypes(clazz);
if(unionTypes != null) {
return getUnionSchema(unionTypes, names);
}
return null;
}

private Schema getUnionSchema(Class[] unionTypes, Map<String, Schema> names) {
List<Schema> branches = new ArrayList<Schema>();
for (Class branch : union.value())
for (Class branch : unionTypes)
branches.add(createSchema(branch, names));
return Schema.createUnion(branches);
}
Expand Down Expand Up @@ -788,7 +804,7 @@ private Message getMessage(Method method, Protocol protocol,
Type[] paramTypes = method.getGenericParameterTypes();
Annotation[][] annotations = method.getParameterAnnotations();
for (int i = 0; i < paramTypes.length; i++) {
Schema paramSchema = getSchema(paramTypes[i], names);
Schema paramSchema = getUnionSchema(paramTypes[i], names);
for (int j = 0; j < annotations[i].length; j++) {
Annotation annotation = annotations[i][j];
if (annotation instanceof AvroSchema) // explicit schema
Expand All @@ -808,7 +824,7 @@ else if (annotation instanceof Nullable) // nullable

Union union = method.getAnnotation(Union.class);
Schema response = union == null
? getSchema(method.getGenericReturnType(), names)
? getUnionSchema(method.getGenericReturnType(), names)
: getAnnotatedUnion(union, names);
if (method.isAnnotationPresent(Nullable.class)) // nullable
response = makeNullable(response);
Expand All @@ -821,12 +837,12 @@ else if (annotation instanceof Nullable) // nullable
errs.add(Protocol.SYSTEM_ERROR); // every method can throw
for (Type err : method.getGenericExceptionTypes())
if (err != AvroRemoteException.class)
errs.add(getSchema(err, names));
errs.add(getUnionSchema(err, names));
Schema errors = Schema.createUnion(errs);
return protocol.createMessage(method.getName(), null /* doc */, request, response, errors);
}

private Schema getSchema(Type type, Map<String,Schema> names) {
private Schema getUnionSchema(Type type, Map<String,Schema> names) {
try {
return createSchema(type, names);
} catch (AvroTypeException e) { // friendly exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,34 @@ public boolean equals(Object o) {
}
}

public static class R6_1 {}

public static class R7_1 extends R6_1 {
public int value;
@Override
public boolean equals(Object o) {
if (!(o instanceof R7_1)) return false;
return this.value == ((R7_1)o).value;
}
}
public static class R8_1 extends R6_1 {
public float value;
@Override
public boolean equals(Object o) {
if (!(o instanceof R8_1)) return false;
return this.value == ((R8_1)o).value;
}
}

public static class R9_1 {
public R6_1[] r6_1s;
@Override
public boolean equals(Object o) {
if (!(o instanceof R9_1)) return false;
return Arrays.equals(this.r6_1s, ((R9_1)o).r6_1s);
}
}

@Test public void testR6() throws Exception {
R7 r7 = new R7();
r7.value = 1;
Expand All @@ -296,6 +324,21 @@ public boolean equals(Object o) {
checkReadWrite(r9, ReflectData.get().getSchema(R9.class));
}

@Test public void testUnionSubTypesWithoutAnnotations() throws Exception {
Class[] unionTypes = new Class[]{R7_1.class, R8_1.class};
ReflectData.get().addUnionTypes(R6_1.class, unionTypes);

R7_1 r7_1 = new R7_1();
r7_1.value = 101;
checkReadWrite(r7_1, ReflectData.get().getSchema(R6_1.class));
R8_1 r8_1 = new R8_1();
r8_1.value = 102;
checkReadWrite(r8_1, ReflectData.get().getSchema(R6_1.class));
R9_1 r9_1 = new R9_1();
r9_1.r6_1s = new R6_1[] {r7_1, r8_1};
checkReadWrite(r9_1, ReflectData.get().getSchema(R9_1.class));
}

// test union annotation on methods and parameters
public static interface P0 {
@Union({Void.class,String.class})
Expand Down