Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand Down Expand Up @@ -769,13 +770,15 @@ public Protocol getProtocol(Class iface) {
iface.getPackage() == null ? "" : iface.getPackage().getName());
Map<String, Schema> names = new LinkedHashMap<>();
Map<String, Message> messages = protocol.getMessages();
for (Method method : iface.getMethods())
Map<TypeVariable<?>, Type> genericTypeVariableMap = ReflectionUtil.resolveTypeVariables(iface);
for (Method method : iface.getMethods()) {
if ((method.getModifiers() & Modifier.STATIC) == 0) {
String name = method.getName();
if (messages.containsKey(name))
throw new AvroTypeException("Two methods with same name: " + name);
messages.put(name, getMessage(method, protocol, names));
messages.put(name, getMessage(method, protocol, names, genericTypeVariableMap));
}
}

// reverse types, since they were defined in reference order
List<Schema> types = new ArrayList<>(names.values());
Expand All @@ -785,38 +788,29 @@ public Protocol getProtocol(Class iface) {
return protocol;
}

private String[] getParameterNames(Method m) {
Parameter[] parameters = m.getParameters();
String[] paramNames = new String[parameters.length];
for (int i = 0; i < parameters.length; i++) {
paramNames[i] = parameters[i].getName();
}
return paramNames;
}

private Message getMessage(Method method, Protocol protocol, Map<String, Schema> names) {
private Message getMessage(Method method, Protocol protocol, Map<String, Schema> names,
Map<? extends Type, Type> genericTypeMap) {
List<Schema.Field> fields = new ArrayList<>();
String[] paramNames = getParameterNames(method);
Type[] paramTypes = method.getGenericParameterTypes();
Annotation[][] annotations = method.getParameterAnnotations();
for (int i = 0; i < paramTypes.length; i++) {
Schema paramSchema = getSchema(paramTypes[i], names);
for (int j = 0; j < annotations[i].length; j++) {
Annotation annotation = annotations[i][j];
for (Parameter parameter : method.getParameters()) {
Schema paramSchema = getSchema(genericTypeMap.getOrDefault(parameter.getParameterizedType(), parameter.getType()),
names);
for (Annotation annotation : parameter.getAnnotations()) {
if (annotation instanceof AvroSchema) // explicit schema
paramSchema = new Schema.Parser().parse(((AvroSchema) annotation).value());
else if (annotation instanceof Union) // union
paramSchema = getAnnotatedUnion(((Union) annotation), names);
else if (annotation instanceof Nullable) // nullable
paramSchema = makeNullable(paramSchema);
}
String paramName = paramNames.length == paramTypes.length ? paramNames[i] : paramSchema.getName() + i;
fields.add(new Schema.Field(paramName, paramSchema, null /* doc */, null));
fields.add(new Schema.Field(parameter.getName(), paramSchema, null /* doc */, null));
}

Schema request = Schema.createRecord(fields);

Type genericReturnType = method.getGenericReturnType();
Type returnType = genericTypeMap.getOrDefault(genericReturnType, genericReturnType);
Union union = method.getAnnotation(Union.class);
Schema response = union == null ? getSchema(method.getGenericReturnType(), names) : getAnnotatedUnion(union, names);
Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union, names);
if (method.isAnnotationPresent(Nullable.class)) // nullable
response = makeNullable(response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import org.apache.avro.AvroRuntimeException;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.IdentityHashMap;
import java.util.Map;

/**
* A few utility methods for using @link{java.misc.Unsafe}, mostly for private
* use.
Expand Down Expand Up @@ -119,4 +125,35 @@ private FieldAccessor accessor(FieldAccess access, String name) throws Exception
}
}

/**
* For an interface, get a map of any {@link TypeVariable}s to their actual
* types.
*
* @param iface interface to resolve types for.
* @return a map of {@link TypeVariable}s to actual types.
*/
protected static Map<TypeVariable<?>, Type> resolveTypeVariables(Class<?> iface) {
return resolveTypeVariables(iface, new IdentityHashMap<>());
}

private static Map<TypeVariable<?>, Type> resolveTypeVariables(Class<?> iface, Map<TypeVariable<?>, Type> reuse) {

for (Type type : iface.getGenericInterfaces()) {
if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type rawType = parameterizedType.getRawType();
if (rawType instanceof Class<?>) {
Class<?> classType = (Class<?>) rawType;
TypeVariable<? extends Class<?>>[] typeParameters = classType.getTypeParameters();
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
for (int i = 0; i < typeParameters.length; i++) {
reuse.putIfAbsent(typeParameters[i], reuse.getOrDefault(actualTypeArguments[i], actualTypeArguments[i]));
}
resolveTypeVariables(classType, reuse);
}
}
}
return reuse;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

package org.apache.avro.reflect;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertThat;

import java.util.Collections;

import org.apache.avro.Protocol;
import org.apache.avro.Schema;
import org.junit.Test;

Expand All @@ -46,4 +50,44 @@ public void testWeakSchemaCaching() throws Exception {

assertThat("ReflectData cache should release references", classData.bySchema.size(), lessThan(numSchemas));
}

@Test
public void testGenericProtocol() {
Protocol protocol = ReflectData.get().getProtocol(FooBarProtocol.class);
Schema recordSchema = ReflectData.get().getSchema(FooBarReflectiveRecord.class);

assertThat(protocol.getTypes(), contains(recordSchema));

assertThat(protocol.getMessages().keySet(), containsInAnyOrder("store", "findById", "exists"));

Schema.Field storeArgument = protocol.getMessages().get("store").getRequest().getFields().get(0);
assertThat(storeArgument.schema(), equalTo(recordSchema));

Schema.Field findByIdArgument = protocol.getMessages().get("findById").getRequest().getFields().get(0);
assertThat(findByIdArgument.schema(), equalTo(Schema.create(Schema.Type.STRING)));

Schema findByIdResponse = protocol.getMessages().get("findById").getResponse();
assertThat(findByIdResponse, equalTo(recordSchema));

Schema.Field existsArgument = protocol.getMessages().get("exists").getRequest().getFields().get(0);
assertThat(existsArgument.schema(), equalTo(Schema.create(Schema.Type.STRING)));
}

private interface CrudProtocol<R, I> extends OtherProtocol<I> {
void store(R record);

R findById(I id);
}

private interface OtherProtocol<G> {
boolean exists(G id);
}

private interface FooBarProtocol extends OtherProtocol<String>, CrudProtocol<FooBarReflectiveRecord, String> {
}

private static class FooBarReflectiveRecord {
private String bar;
private int baz;
}
}