diff --git a/src/main/java/net/openhft/chronicle/wire/GenerateMethodReader.java b/src/main/java/net/openhft/chronicle/wire/GenerateMethodReader.java index 848163534..68722763b 100644 --- a/src/main/java/net/openhft/chronicle/wire/GenerateMethodReader.java +++ b/src/main/java/net/openhft/chronicle/wire/GenerateMethodReader.java @@ -68,6 +68,7 @@ public class GenerateMethodReader { MethodWriter.class, SourceContext.class ); + System.out.println("patched Wire"); } private final WireType wireType; @@ -160,6 +161,7 @@ private void generateSourceCode() { handledInterfaces.clear(); handledMethodNames.clear(); handledMethodSignatures.clear(); + final List>[] interfacesForInstance = new List[instances.length]; for (int i = 0; i < instances.length; i++) { final Class aClass = instances[i].getClass(); @@ -167,7 +169,8 @@ private void generateSourceCode() { boolean methodFilter = instances[i] instanceof MethodFilterOnFirstArg; methodFilterPresent |= methodFilter; - for (Class anInterface : ReflectionUtil.interfaces(aClass)) { + interfacesForInstance[i] = ReflectionUtil.interfaces(aClass); + for (Class anInterface : interfacesForInstance[i]) { if (IGNORED_INTERFACES.contains(anInterface)) continue; handleInterface(anInterface, "instance" + i, methodFilter, eventNameSwitchBlock, eventIdSwitchBlock); @@ -198,6 +201,9 @@ private void generateSourceCode() { } for (int i = 0; i < instances.length; i++) { sourceCode.append(format("private final Object instance%d;\n", i)); + for (Class iface : interfacesForInstance[i]) { + sourceCode.append(format("private final %s instance%d_%s;\n", iface.getCanonicalName(), i, sanitise(iface))); + } } sourceCode.append("private final WireParselet defaultParselet;\n"); sourceCode.append("\n"); @@ -242,10 +248,14 @@ private void generateSourceCode() { for (int i = 0; metaDataHandler != null && i < metaDataHandler.length; i++) sourceCode.append(format("metaInstance%d = metaInstances[%d];\n", i, i)); - for (int i = 0; i < instances.length - 1; i++) + for (int i = 0; i < instances.length; i++) { sourceCode.append(format("instance%d = instances[%d];\n", i, i)); + for (Class iface : interfacesForInstance[i]) { + sourceCode.append(format("instance%d_%s = (%s) instances[%d];\n", i, sanitise(iface), iface.getCanonicalName(), i)); + } + } - sourceCode.append(format("instance%d = instances[%d];\n}\n\n", instances.length - 1, instances.length - 1)); + sourceCode.append("}\n\n"); if (hasChainedCalls) { sourceCode.append("" + @@ -347,6 +357,10 @@ private void generateSourceCode() { System.out.println(sourceCode.toString()); } + private String sanitise(Class iface) { + return iface.getCanonicalName().replaceAll("\\.", "_"); + } + /** * Generates code for handling all method calls of passed interface. * Called recursively for chained methods. @@ -460,18 +474,18 @@ private void handleMethod(Method m, Class anInterface, String instanceFieldNa eventNameSwitchBlock.append(format("case \"%s\":\n", m.getName())); if (parameterTypes.length == 0) { eventNameSwitchBlock.append("valueIn.skipValue();\n"); - eventNameSwitchBlock.append(methodCall(m, instanceFieldName, chainedCallPrefix, chainReturnType)); + eventNameSwitchBlock.append(methodCall(anInterface, m, instanceFieldName, chainedCallPrefix, chainReturnType)); } else if (parameterTypes.length == 1) { eventNameSwitchBlock.append(argumentRead(m, 0, false, parameterTypes)); - eventNameSwitchBlock.append(methodCall(m, instanceFieldName, chainedCallPrefix, chainReturnType)); + eventNameSwitchBlock.append(methodCall(anInterface, m, instanceFieldName, chainedCallPrefix, chainReturnType)); } else { if (methodFilter) { eventNameSwitchBlock.append("ignored = false;\n"); eventNameSwitchBlock.append("valueIn.sequence(this, (f, v) -> {\n"); eventNameSwitchBlock.append(argumentRead(m, 0, true, parameterTypes)); - eventNameSwitchBlock.append(format("if (((MethodFilterOnFirstArg) f.%s)." + + eventNameSwitchBlock.append(format("if ((f.%s)." + "ignoreMethodBasedOnFirstArg(\"%s\", f.%sarg%d)) {\n", - instanceFieldName, m.getName(), m.getName(), 0)); + instanceFieldName + "_" + sanitise(MethodFilterOnFirstArg.class), m.getName(), m.getName(), 0)); eventNameSwitchBlock.append("f.ignored = true;\n"); for (int i = 1; i < parameterTypes.length; i++) @@ -496,7 +510,7 @@ private void handleMethod(Method m, Class anInterface, String instanceFieldNa eventNameSwitchBlock.append("});\n"); } - eventNameSwitchBlock.append(methodCall(m, instanceFieldName, chainedCallPrefix, chainReturnType)); + eventNameSwitchBlock.append(methodCall(anInterface, m, instanceFieldName, chainedCallPrefix, chainReturnType)); if (methodFilter) eventNameSwitchBlock.append("}\n"); @@ -528,12 +542,13 @@ private void addMethodIdSwitch(String methodName, int methodId, SourceCodeFormat * Generates code that invokes passed method, saves method return value (in case it's a chained call) * and handles {@link MethodReaderInterceptorReturns} if it's specified. * + * @param anInterface * @param m Method that is being processed. * @param instanceFieldName In generated code, method is executed on field with this name. * @param chainedCallPrefix Prefix for method call statement, passed in order to save method result for chaining. * @return Code that performs a method call. */ - private String methodCall(Method m, String instanceFieldName, String chainedCallPrefix, @Nullable Class returnType) { + private String methodCall(Class anInterface, Method m, String instanceFieldName, String chainedCallPrefix, @Nullable Class returnType) { StringBuilder res = new StringBuilder(); Class[] parameterTypes = m.getParameterTypes(); @@ -557,9 +572,15 @@ private String methodCall(Method m, String instanceFieldName, String chainedCall res.append(codeBefore).append("\n"); } - res.append(format("%s((%s) %s).%s(%s);\n", - chainedCallPrefix, m.getDeclaringClass().getCanonicalName(), instanceFieldName, m.getName(), - String.join(", ", args))); + if (instanceFieldName.startsWith("instance") && chainedCallPrefix.isEmpty() && anInterface == m.getDeclaringClass()) { + res.append(format("%s.%s(%s);\n", + instanceFieldName + "_" + sanitise(m.getDeclaringClass()), m.getName(), + String.join(", ", args))); + } else { + res.append(format("%s((%s) %s).%s(%s);\n", + chainedCallPrefix, m.getDeclaringClass().getCanonicalName(), instanceFieldName, m.getName(), + String.join(", ", args))); + } if (generatingInterceptor != null) { final String codeAfter = generatingInterceptor.codeAfterCall(m, instanceFieldName, args);