diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java index 4e5e49a527708..79961f7305db9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/RpcEndpoint.java @@ -85,9 +85,9 @@ protected RpcEndpoint(final RpcService rpcService) { // IMPORTANT: Don't change order of selfGatewayType and self because rpcService.startServer // requires that selfGatewayType has been initialized - this.selfGatewayType = ReflectionUtil.getTemplateType1(getClass()); + this.selfGatewayType = determineSelfGatewayType(); this.self = rpcService.startServer(this); - + this.mainThreadExecutor = new MainThreadExecutor((MainThreadExecutable) self); } @@ -255,4 +255,23 @@ public void execute(Runnable runnable) { gateway.runAsync(runnable); } } + + /** + * Determines the self gateway type specified in one of the subclasses which extend this class. + * May traverse multiple class hierarchies until a Gateway type is found as a first type argument. + * @return Class The determined self gateway type + */ + private Class determineSelfGatewayType() { + + // determine self gateway type + Class c = getClass(); + Class determinedSelfGatewayType; + do { + determinedSelfGatewayType = ReflectionUtil.getTemplateType1(c); + // check if super class contains self gateway type in next loop + c = c.getSuperclass(); + } while (!RpcGateway.class.isAssignableFrom(determinedSelfGatewayType)); + + return determinedSelfGatewayType; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/RpcCompletenessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/RpcCompletenessTest.java index 53355e805e0af..e7143aea3e708 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/RpcCompletenessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/RpcCompletenessTest.java @@ -26,9 +26,14 @@ import org.apache.flink.util.TestLogger; import org.junit.Test; import org.reflections.Reflections; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.lang.annotation.Annotation; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -41,8 +46,33 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +/** + * Test which ensures that all classes of subtype {@link RpcEndpoint} implement + * the methods specified in the generic gateway type argument. + * + * {@code + * RpcEndpoint + * } + * + * Note, that the class hierarchy can also be nested. In this case the type argument + * always has to be the first argument, e.g. {@code + * + * // RpcClass needs to implement RpcGatewayClass' methods + * RpcClass extends RpcEndpoint + * + * // RpcClass2 or its subclass needs to implement RpcGatewayClass' methods + * RpcClass extends RpcEndpoint + * RpcClass2 extends RpcClass + * + * // needless to say, this can even be nested further + * ... + * } + * + */ public class RpcCompletenessTest extends TestLogger { + private static Logger LOG = LoggerFactory.getLogger(RpcCompletenessTest.class); + private static final Class futureClass = Future.class; private static final Class timeoutClass = Time.class; @@ -55,16 +85,52 @@ public void testRpcCompleteness() { Class c; - for (Class rpcEndpoint :classes){ + mainloop: + for (Class rpcEndpoint : classes) { c = rpcEndpoint; - Class rpcGatewayType = ReflectionUtil.getTemplateType1(c); + LOG.debug("-------------"); + LOG.debug("c: {}", c); - if (rpcGatewayType != null) { - checkCompleteness(rpcEndpoint, (Class) rpcGatewayType); - } else { - fail("Could not retrieve the rpc gateway class for the given rpc endpoint class " + rpcEndpoint.getName()); + // skip abstract classes + if (Modifier.isAbstract(c.getModifiers())) { + LOG.debug("Skipping abstract class"); + continue; } + + // check for type parameter bound to RpcGateway + // skip if one is found because a subclass will provide the concrete argument + TypeVariable>[] typeParameters = c.getTypeParameters(); + LOG.debug("Checking {} parameters.", typeParameters.length); + for (int i = 0; i < typeParameters.length; i++) { + for (Type bound : typeParameters[i].getBounds()) { + LOG.debug("checking bound {} of type parameter {}", bound, typeParameters[i]); + if (bound.toString().equals("interface " + RpcGateway.class.getName())) { + if (i > 0) { + fail("Type parameter for RpcGateway should come first in " + c); + } + LOG.debug("Skipping class with type parameter bound to RpcGateway."); + // Type parameter is bound to RpcGateway which a subclass will provide + continue mainloop; + } + } + } + + // check if this class or any super class contains the RpcGateway argument + Class rpcGatewayType; + do { + LOG.debug("checking type argument of class: {}", c); + rpcGatewayType = ReflectionUtil.getTemplateType1(c); + LOG.debug("type argument is: {}", rpcGatewayType); + + c = (Class) c.getSuperclass(); + + } while (!RpcGateway.class.isAssignableFrom(rpcGatewayType)); + + LOG.debug("Checking RRC completeness of endpoint '{}' with gateway '{}'", + rpcEndpoint.getSimpleName(), rpcGatewayType.getSimpleName()); + + checkCompleteness(rpcEndpoint, (Class) rpcGatewayType); } } @@ -352,7 +418,7 @@ private static Class resolvePrimitiveType(Class primitveType) { */ private List getRpcMethodsFromGateway(Class interfaceClass) { if(!interfaceClass.isInterface()) { - fail(interfaceClass.getName() + "is not a interface"); + fail(interfaceClass.getName() + " is not a interface"); } ArrayList allMethods = new ArrayList<>();