Skip to content

Commit

Permalink
AVRO-2357: Allow generic types in reflect protos (#490)
Browse files Browse the repository at this point in the history
Adds support for generic types in ReflectData for
Protocols.
  • Loading branch information
ivangreene authored and Fokko committed Mar 31, 2019
1 parent 392c761 commit de48a0a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand Down Expand Up @@ -769,13 +770,15 @@ public Protocol getProtocol(Class iface) {
iface.getPackage() == null ? "" : iface.getPackage().getName());
Map<String, Schema> names = new LinkedHashMap<>();
Map<String, Message> messages = protocol.getMessages();
for (Method method : iface.getMethods())
Map<TypeVariable<?>, Type> genericTypeVariableMap = ReflectionUtil.resolveTypeVariables(iface);
for (Method method : iface.getMethods()) {
if ((method.getModifiers() & Modifier.STATIC) == 0) {
String name = method.getName();
if (messages.containsKey(name))
throw new AvroTypeException("Two methods with same name: " + name);
messages.put(name, getMessage(method, protocol, names));
messages.put(name, getMessage(method, protocol, names, genericTypeVariableMap));
}
}

// reverse types, since they were defined in reference order
List<Schema> types = new ArrayList<>(names.values());
Expand All @@ -785,38 +788,29 @@ public Protocol getProtocol(Class iface) {
return protocol;
}

private String[] getParameterNames(Method m) {
Parameter[] parameters = m.getParameters();
String[] paramNames = new String[parameters.length];
for (int i = 0; i < parameters.length; i++) {
paramNames[i] = parameters[i].getName();
}
return paramNames;
}

private Message getMessage(Method method, Protocol protocol, Map<String, Schema> names) {
private Message getMessage(Method method, Protocol protocol, Map<String, Schema> names,
Map<? extends Type, Type> genericTypeMap) {
List<Schema.Field> fields = new ArrayList<>();
String[] paramNames = getParameterNames(method);
Type[] paramTypes = method.getGenericParameterTypes();
Annotation[][] annotations = method.getParameterAnnotations();
for (int i = 0; i < paramTypes.length; i++) {
Schema paramSchema = getSchema(paramTypes[i], names);
for (int j = 0; j < annotations[i].length; j++) {
Annotation annotation = annotations[i][j];
for (Parameter parameter : method.getParameters()) {
Schema paramSchema = getSchema(genericTypeMap.getOrDefault(parameter.getParameterizedType(), parameter.getType()),
names);
for (Annotation annotation : parameter.getAnnotations()) {
if (annotation instanceof AvroSchema) // explicit schema
paramSchema = new Schema.Parser().parse(((AvroSchema) annotation).value());
else if (annotation instanceof Union) // union
paramSchema = getAnnotatedUnion(((Union) annotation), names);
else if (annotation instanceof Nullable) // nullable
paramSchema = makeNullable(paramSchema);
}
String paramName = paramNames.length == paramTypes.length ? paramNames[i] : paramSchema.getName() + i;
fields.add(new Schema.Field(paramName, paramSchema, null /* doc */, null));
fields.add(new Schema.Field(parameter.getName(), paramSchema, null /* doc */, null));
}

Schema request = Schema.createRecord(fields);

Type genericReturnType = method.getGenericReturnType();
Type returnType = genericTypeMap.getOrDefault(genericReturnType, genericReturnType);
Union union = method.getAnnotation(Union.class);
Schema response = union == null ? getSchema(method.getGenericReturnType(), names) : getAnnotatedUnion(union, names);
Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union, names);
if (method.isAnnotationPresent(Nullable.class)) // nullable
response = makeNullable(response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import org.apache.avro.AvroRuntimeException;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.IdentityHashMap;
import java.util.Map;

/**
* A few utility methods for using @link{java.misc.Unsafe}, mostly for private
* use.
Expand Down Expand Up @@ -119,4 +125,35 @@ private FieldAccessor accessor(FieldAccess access, String name) throws Exception
}
}

/**
* For an interface, get a map of any {@link TypeVariable}s to their actual
* types.
*
* @param iface interface to resolve types for.
* @return a map of {@link TypeVariable}s to actual types.
*/
protected static Map<TypeVariable<?>, Type> resolveTypeVariables(Class<?> iface) {
return resolveTypeVariables(iface, new IdentityHashMap<>());
}

private static Map<TypeVariable<?>, Type> resolveTypeVariables(Class<?> iface, Map<TypeVariable<?>, Type> reuse) {

for (Type type : iface.getGenericInterfaces()) {
if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type rawType = parameterizedType.getRawType();
if (rawType instanceof Class<?>) {
Class<?> classType = (Class<?>) rawType;
TypeVariable<? extends Class<?>>[] typeParameters = classType.getTypeParameters();
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
for (int i = 0; i < typeParameters.length; i++) {
reuse.putIfAbsent(typeParameters[i], reuse.getOrDefault(actualTypeArguments[i], actualTypeArguments[i]));
}
resolveTypeVariables(classType, reuse);
}
}
}
return reuse;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

package org.apache.avro.reflect;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertThat;

import java.util.Collections;

import org.apache.avro.Protocol;
import org.apache.avro.Schema;
import org.junit.Test;

Expand All @@ -46,4 +50,44 @@ public void testWeakSchemaCaching() throws Exception {

assertThat("ReflectData cache should release references", classData.bySchema.size(), lessThan(numSchemas));
}

@Test
public void testGenericProtocol() {
Protocol protocol = ReflectData.get().getProtocol(FooBarProtocol.class);
Schema recordSchema = ReflectData.get().getSchema(FooBarReflectiveRecord.class);

assertThat(protocol.getTypes(), contains(recordSchema));

assertThat(protocol.getMessages().keySet(), containsInAnyOrder("store", "findById", "exists"));

Schema.Field storeArgument = protocol.getMessages().get("store").getRequest().getFields().get(0);
assertThat(storeArgument.schema(), equalTo(recordSchema));

Schema.Field findByIdArgument = protocol.getMessages().get("findById").getRequest().getFields().get(0);
assertThat(findByIdArgument.schema(), equalTo(Schema.create(Schema.Type.STRING)));

Schema findByIdResponse = protocol.getMessages().get("findById").getResponse();
assertThat(findByIdResponse, equalTo(recordSchema));

Schema.Field existsArgument = protocol.getMessages().get("exists").getRequest().getFields().get(0);
assertThat(existsArgument.schema(), equalTo(Schema.create(Schema.Type.STRING)));
}

private interface CrudProtocol<R, I> extends OtherProtocol<I> {
void store(R record);

R findById(I id);
}

private interface OtherProtocol<G> {
boolean exists(G id);
}

private interface FooBarProtocol extends OtherProtocol<String>, CrudProtocol<FooBarReflectiveRecord, String> {
}

private static class FooBarReflectiveRecord {
private String bar;
private int baz;
}
}

0 comments on commit de48a0a

Please sign in to comment.