From d936ed896be4951bfd8766906b214af98a000f34 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 1 Nov 2016 15:38:01 -0700 Subject: [PATCH 1/4] Add TypeDescriptor#getTypes --- .../org/apache/beam/sdk/values/TypeDescriptor.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java index 6eabf42be962c..14f2cb8eeb5b4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java @@ -288,6 +288,19 @@ public TypeDescriptor resolveType(Type type) { return new SimpleTypeDescriptor<>(token.resolveType(type)); } + /** + * Returns a set of {@link TypeDescriptor TypeDescriptor}, one for each + * superclass as well as each interface implemented by this class. + */ + @SuppressWarnings("rawtypes") + public Iterable getTypes() { + List interfaces = Lists.newArrayList(); + for (TypeToken interfaceToken : token.getTypes()) { + interfaces.add(new SimpleTypeDescriptor<>(interfaceToken)); + } + return interfaces; + } + /** * Returns a set of {@link TypeDescriptor}s, one for each * interface implemented by this class. From 8336b24c97c620fa3edb02301299080bda96379a Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 1 Nov 2016 14:48:54 -0700 Subject: [PATCH 2/4] Switch DoFnSignature, etc, from TypeToken to TypeDescriptor --- .../sdk/transforms/reflect/DoFnInvokers.java | 7 +- .../sdk/transforms/reflect/DoFnSignature.java | 23 ++- .../transforms/reflect/DoFnSignatures.java | 177 ++++++++++-------- .../DoFnSignaturesSplittableDoFnTest.java | 18 +- .../reflect/DoFnSignaturesTest.java | 7 +- .../reflect/DoFnSignaturesTestUtils.java | 8 +- 6 files changed, 124 insertions(+), 116 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index dd134b76ed454..c5a23dc8862a8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.google.common.reflect.TypeToken; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -263,9 +262,9 @@ public static void invokeSplitRestriction( /** Default implementation of {@link DoFn.GetRestrictionCoder}, for delegation by bytebuddy. */ public static class DefaultRestrictionCoder { - private final TypeToken restrictionType; + private final TypeDescriptor restrictionType; - DefaultRestrictionCoder(TypeToken restrictionType) { + DefaultRestrictionCoder(TypeDescriptor restrictionType) { this.restrictionType = restrictionType; } @@ -273,7 +272,7 @@ public static class DefaultRestrictionCoder { @SuppressWarnings({"unused", "unchecked"}) public Coder invokeGetRestrictionCoder(CoderRegistry registry) throws CannotProvideCoderException { - return (Coder) registry.getCoder(TypeDescriptor.of(restrictionType.getType())); + return (Coder) registry.getCoder(restrictionType); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 71f7e530f8b52..6b98805c4f336 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import com.google.common.base.Predicates; import com.google.common.collect.Iterables; -import com.google.common.reflect.TypeToken; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Collections; @@ -342,7 +341,7 @@ public abstract static class ProcessElementMethod implements DoFnMethod { /** Concrete type of the {@link RestrictionTracker} parameter, if present. */ @Nullable - abstract TypeToken trackerT(); + abstract TypeDescriptor trackerT(); /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */ public abstract boolean hasReturnValue(); @@ -350,7 +349,7 @@ public abstract static class ProcessElementMethod implements DoFnMethod { static ProcessElementMethod create( Method targetMethod, List extraParameters, - TypeToken trackerT, + TypeDescriptor trackerT, boolean hasReturnValue) { return new AutoValue_DoFnSignature_ProcessElementMethod( targetMethod, Collections.unmodifiableList(extraParameters), trackerT, hasReturnValue); @@ -462,9 +461,9 @@ public abstract static class GetInitialRestrictionMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the returned restriction. */ - abstract TypeToken restrictionT(); + abstract TypeDescriptor restrictionT(); - static GetInitialRestrictionMethod create(Method targetMethod, TypeToken restrictionT) { + static GetInitialRestrictionMethod create(Method targetMethod, TypeDescriptor restrictionT) { return new AutoValue_DoFnSignature_GetInitialRestrictionMethod(targetMethod, restrictionT); } } @@ -477,9 +476,9 @@ public abstract static class SplitRestrictionMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the restriction taken and returned. */ - abstract TypeToken restrictionT(); + abstract TypeDescriptor restrictionT(); - static SplitRestrictionMethod create(Method targetMethod, TypeToken restrictionT) { + static SplitRestrictionMethod create(Method targetMethod, TypeDescriptor restrictionT) { return new AutoValue_DoFnSignature_SplitRestrictionMethod(targetMethod, restrictionT); } } @@ -492,13 +491,13 @@ public abstract static class NewTrackerMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the input restriction. */ - abstract TypeToken restrictionT(); + abstract TypeDescriptor restrictionT(); /** Type of the returned {@link RestrictionTracker}. */ - abstract TypeToken trackerT(); + abstract TypeDescriptor trackerT(); static NewTrackerMethod create( - Method targetMethod, TypeToken restrictionT, TypeToken trackerT) { + Method targetMethod, TypeDescriptor restrictionT, TypeDescriptor trackerT) { return new AutoValue_DoFnSignature_NewTrackerMethod(targetMethod, restrictionT, trackerT); } } @@ -511,9 +510,9 @@ public abstract static class GetRestrictionCoderMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the returned {@link Coder}. */ - abstract TypeToken coderT(); + abstract TypeDescriptor coderT(); - static GetRestrictionCoderMethod create(Method targetMethod, TypeToken coderT) { + static GetRestrictionCoderMethod create(Method targetMethod, TypeDescriptor coderT) { return new AutoValue_DoFnSignature_GetRestrictionCoderMethod(targetMethod, coderT); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 5814c0ee3c49a..c690ace53f3e0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -22,8 +22,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; -import com.google.common.reflect.TypeParameter; -import com.google.common.reflect.TypeToken; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Field; @@ -57,6 +55,7 @@ import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeParameter; /** * Parses a {@link DoFn} and computes its {@link DoFnSignature}. See {@link #getSignature}. @@ -90,18 +89,18 @@ private static DoFnSignature parseSignature(Class> fnClass) errors.checkArgument(DoFn.class.isAssignableFrom(fnClass), "Must be subtype of DoFn"); builder.setFnClass(fnClass); - TypeToken> fnToken = TypeToken.of(fnClass); + TypeDescriptor> fnT = TypeDescriptor.of(fnClass); // Extract the input and output type, and whether the fn is bounded. - TypeToken inputT = null; - TypeToken outputT = null; - for (TypeToken supertype : fnToken.getTypes()) { + TypeDescriptor inputT = null; + TypeDescriptor outputT = null; + for (TypeDescriptor supertype : fnT.getTypes()) { if (!supertype.getRawType().equals(DoFn.class)) { continue; } Type[] args = ((ParameterizedType) supertype.getType()).getActualTypeArguments(); - inputT = TypeToken.of(args[0]); - outputT = TypeToken.of(args[1]); + inputT = TypeDescriptor.of(args[0]); + outputT = TypeDescriptor.of(args[1]); } errors.checkNotNull(inputT, "Unable to determine input type"); @@ -169,7 +168,7 @@ private static DoFnSignature parseSignature(Class> fnClass) DoFnSignature.ProcessElementMethod processElement = analyzeProcessElementMethod( processElementErrors, - fnToken, + fnT, processElementMethod, inputT, outputT, @@ -180,14 +179,14 @@ private static DoFnSignature parseSignature(Class> fnClass) if (startBundleMethod != null) { ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod); builder.setStartBundle( - analyzeBundleMethod(startBundleErrors, fnToken, startBundleMethod, inputT, outputT)); + analyzeBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT)); } if (finishBundleMethod != null) { ErrorReporter finishBundleErrors = errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod); builder.setFinishBundle( - analyzeBundleMethod(finishBundleErrors, fnToken, finishBundleMethod, inputT, outputT)); + analyzeBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT)); } if (setupMethod != null) { @@ -209,7 +208,7 @@ private static DoFnSignature parseSignature(Class> fnClass) builder.setGetInitialRestriction( getInitialRestriction = analyzeGetInitialRestrictionMethod( - getInitialRestrictionErrors, fnToken, getInitialRestrictionMethod, inputT)); + getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT)); } DoFnSignature.SplitRestrictionMethod splitRestriction = null; @@ -219,7 +218,7 @@ private static DoFnSignature parseSignature(Class> fnClass) builder.setSplitRestriction( splitRestriction = analyzeSplitRestrictionMethod( - splitRestrictionErrors, fnToken, splitRestrictionMethod, inputT)); + splitRestrictionErrors, fnT, splitRestrictionMethod, inputT)); } DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = null; @@ -229,17 +228,17 @@ private static DoFnSignature parseSignature(Class> fnClass) builder.setGetRestrictionCoder( getRestrictionCoder = analyzeGetRestrictionCoderMethod( - getRestrictionCoderErrors, fnToken, getRestrictionCoderMethod)); + getRestrictionCoderErrors, fnT, getRestrictionCoderMethod)); } DoFnSignature.NewTrackerMethod newTracker = null; if (newTrackerMethod != null) { ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod); builder.setNewTracker( - newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnToken, newTrackerMethod)); + newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod)); } - builder.setIsBoundedPerElement(inferBoundedness(fnToken, processElement, errors)); + builder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors)); DoFnSignature signature = builder.build(); @@ -271,11 +270,11 @@ private static DoFnSignature parseSignature(Class> fnClass) * */ private static PCollection.IsBounded inferBoundedness( - TypeToken fnToken, + TypeDescriptor fnT, DoFnSignature.ProcessElementMethod processElement, ErrorReporter errors) { PCollection.IsBounded isBounded = null; - for (TypeToken supertype : fnToken.getTypes()) { + for (TypeDescriptor supertype : fnT.getTypes()) { if (supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class) || supertype.getRawType().isAnnotationPresent(DoFn.UnboundedPerElement.class)) { errors.checkArgument( @@ -354,7 +353,7 @@ private static void verifySplittableMethods(DoFnSignature signature, ErrorReport ErrorReporter getInitialRestrictionErrors = errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod()); - TypeToken restrictionT = getInitialRestriction.restrictionT(); + TypeDescriptor restrictionT = getInitialRestriction.restrictionT(); getInitialRestrictionErrors.checkArgument( restrictionT.equals(newTracker.restrictionT()), @@ -411,49 +410,54 @@ private static void verifyUnsplittableMethods(ErrorReporter errors, DoFnSignatur } /** - * Generates a type token for {@code DoFn.ProcessContext} given {@code InputT} - * and {@code OutputT}. + * Generates a {@link TypeDescriptor} for {@code DoFn.ProcessContext} given + * {@code InputT} and {@code OutputT}. */ private static - TypeToken.ProcessContext> doFnProcessContextTypeOf( - TypeToken inputT, TypeToken outputT) { - return new TypeToken.ProcessContext>() {}.where( + TypeDescriptor.ProcessContext> doFnProcessContextTypeOf( + TypeDescriptor inputT, TypeDescriptor outputT) { + return new TypeDescriptor.ProcessContext>() {}.where( new TypeParameter() {}, inputT) .where(new TypeParameter() {}, outputT); } /** - * Generates a type token for {@code DoFn.Context} given {@code InputT} and - * {@code OutputT}. + * Generates a {@link TypeDescriptor} for {@code DoFn.Context} given {@code + * InputT} and {@code OutputT}. */ - private static TypeToken.Context> doFnContextTypeOf( - TypeToken inputT, TypeToken outputT) { - return new TypeToken.Context>() {}.where( + private static TypeDescriptor.Context> doFnContextTypeOf( + TypeDescriptor inputT, TypeDescriptor outputT) { + return new TypeDescriptor.Context>() {}.where( new TypeParameter() {}, inputT) .where(new TypeParameter() {}, outputT); } - /** Generates a type token for {@code DoFn.InputProvider} given {@code InputT}. */ - private static TypeToken> inputProviderTypeOf( - TypeToken inputT) { - return new TypeToken>() {}.where( + /** + * Generates a {@link TypeDescriptor} for {@code DoFn.InputProvider} given {@code InputT}. + */ + private static TypeDescriptor> inputProviderTypeOf( + TypeDescriptor inputT) { + return new TypeDescriptor>() {}.where( new TypeParameter() {}, inputT); } - /** Generates a type token for {@code DoFn.OutputReceiver} given {@code OutputT}. */ - private static TypeToken> outputReceiverTypeOf( - TypeToken inputT) { - return new TypeToken>() {}.where( + /** + * Generates a {@link TypeDescriptor} for {@code DoFn.OutputReceiver} given {@code + * OutputT}. + */ + private static TypeDescriptor> outputReceiverTypeOf( + TypeDescriptor inputT) { + return new TypeDescriptor>() {}.where( new TypeParameter() {}, inputT); } @VisibleForTesting static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( ErrorReporter errors, - TypeToken> fnClass, + TypeDescriptor> fnClass, Method m, - TypeToken inputT, - TypeToken outputT, + TypeDescriptor inputT, + TypeDescriptor outputT, Map stateDeclarations, Map timerDeclarations) { errors.checkArgument( @@ -462,27 +466,27 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( "Must return void or %s", DoFn.ProcessContinuation.class.getSimpleName()); - TypeToken processContextToken = doFnProcessContextTypeOf(inputT, outputT); + TypeDescriptor processContextT = doFnProcessContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); - TypeToken contextToken = null; + TypeDescriptor contextT = null; if (params.length > 0) { - contextToken = fnClass.resolveType(params[0]); + contextT = fnClass.resolveType(params[0]); } errors.checkArgument( - contextToken != null && contextToken.equals(processContextToken), + contextT != null && contextT.equals(processContextT), "Must take %s as the first argument", - formatType(processContextToken)); + formatType(processContextT)); List extraParameters = new ArrayList<>(); Map stateParameters = new HashMap<>(); Map timerParameters = new HashMap<>(); - TypeToken trackerT = null; + TypeDescriptor trackerT = null; - TypeToken expectedInputProviderT = inputProviderTypeOf(inputT); - TypeToken expectedOutputReceiverT = outputReceiverTypeOf(outputT); + TypeDescriptor expectedInputProviderT = inputProviderTypeOf(inputT); + TypeDescriptor expectedOutputReceiverT = outputReceiverTypeOf(outputT); for (int i = 1; i < params.length; ++i) { - TypeToken paramT = fnClass.resolveType(params[i]); + TypeDescriptor paramT = fnClass.resolveType(params[i]); Class rawType = paramT.getRawType(); if (rawType.equals(BoundedWindow.class)) { errors.checkArgument( @@ -641,8 +645,8 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( } else { List allowedParamTypes = Arrays.asList( - formatType(new TypeToken() {}), - formatType(new TypeToken>() {})); + formatType(new TypeDescriptor() {}), + formatType(new TypeDescriptor>() {})); errors.throwIllegalArgument( "%s is not a valid context parameter. Should be one of %s", formatType(paramT), allowedParamTypes); @@ -665,17 +669,17 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( @VisibleForTesting static DoFnSignature.BundleMethod analyzeBundleMethod( ErrorReporter errors, - TypeToken> fnToken, + TypeDescriptor> fnT, Method m, - TypeToken inputT, - TypeToken outputT) { + TypeDescriptor inputT, + TypeDescriptor outputT) { errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); - TypeToken expectedContextToken = doFnContextTypeOf(inputT, outputT); + TypeDescriptor expectedContextT = doFnContextTypeOf(inputT, outputT); Type[] params = m.getGenericParameterTypes(); errors.checkArgument( - params.length == 1 && fnToken.resolveType(params[0]).equals(expectedContextToken), + params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT), "Must take a single argument of type %s", - formatType(expectedContextToken)); + formatType(expectedContextT)); return DoFnSignature.BundleMethod.create(m); } @@ -688,27 +692,33 @@ private static DoFnSignature.LifecycleMethod analyzeLifecycleMethod( @VisibleForTesting static DoFnSignature.GetInitialRestrictionMethod analyzeGetInitialRestrictionMethod( - ErrorReporter errors, TypeToken fnToken, Method m, TypeToken inputT) { + ErrorReporter errors, + TypeDescriptor fnT, + Method m, + TypeDescriptor inputT) { // Method is of the form: // @GetInitialRestriction // RestrictionT getInitialRestriction(InputT element); Type[] params = m.getGenericParameterTypes(); errors.checkArgument( - params.length == 1 && fnToken.resolveType(params[0]).equals(inputT), + params.length == 1 && fnT.resolveType(params[0]).equals(inputT), "Must take a single argument of type %s", formatType(inputT)); return DoFnSignature.GetInitialRestrictionMethod.create( - m, fnToken.resolveType(m.getGenericReturnType())); + m, fnT.resolveType(m.getGenericReturnType())); } - /** Generates a type token for {@code List} given {@code T}. */ - private static TypeToken> listTypeOf(TypeToken elementT) { - return new TypeToken>() {}.where(new TypeParameter() {}, elementT); + /** Generates a {@link TypeDescriptor} for {@code List} given {@code T}. */ + private static TypeDescriptor> listTypeOf(TypeDescriptor elementT) { + return new TypeDescriptor>() {}.where(new TypeParameter() {}, elementT); } @VisibleForTesting static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod( - ErrorReporter errors, TypeToken fnToken, Method m, TypeToken inputT) { + ErrorReporter errors, + TypeDescriptor fnT, + Method m, + TypeDescriptor inputT) { // Method is of the form: // @SplitRestriction // void splitRestriction(InputT element, RestrictionT restriction); @@ -717,13 +727,13 @@ static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod( Type[] params = m.getGenericParameterTypes(); errors.checkArgument(params.length == 3, "Must have exactly 3 arguments"); errors.checkArgument( - fnToken.resolveType(params[0]).equals(inputT), + fnT.resolveType(params[0]).equals(inputT), "First argument must be the element type %s", formatType(inputT)); - TypeToken restrictionT = fnToken.resolveType(params[1]); - TypeToken receiverT = fnToken.resolveType(params[2]); - TypeToken expectedReceiverT = outputReceiverTypeOf(restrictionT); + TypeDescriptor restrictionT = fnT.resolveType(params[1]); + TypeDescriptor receiverT = fnT.resolveType(params[2]); + TypeDescriptor expectedReceiverT = outputReceiverTypeOf(restrictionT); errors.checkArgument( receiverT.equals(expectedReceiverT), "Third argument must be %s, but is %s", @@ -777,45 +787,46 @@ private static void validateTimerField( } } - /** Generates a type token for {@code Coder} given {@code T}. */ - private static TypeToken> coderTypeOf(TypeToken elementT) { - return new TypeToken>() {}.where(new TypeParameter() {}, elementT); + /** Generates a {@link TypeDescriptor} for {@code Coder} given {@code T}. */ + private static TypeDescriptor> coderTypeOf(TypeDescriptor elementT) { + return new TypeDescriptor>() {}.where(new TypeParameter() {}, elementT); } @VisibleForTesting static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod( - ErrorReporter errors, TypeToken fnToken, Method m) { + ErrorReporter errors, TypeDescriptor fnT, Method m) { errors.checkArgument(m.getParameterTypes().length == 0, "Must have zero arguments"); - TypeToken resT = fnToken.resolveType(m.getGenericReturnType()); + TypeDescriptor resT = fnT.resolveType(m.getGenericReturnType()); errors.checkArgument( - resT.isSubtypeOf(TypeToken.of(Coder.class)), + resT.isSubtypeOf(TypeDescriptor.of(Coder.class)), "Must return a Coder, but returns %s", formatType(resT)); return DoFnSignature.GetRestrictionCoderMethod.create(m, resT); } /** - * Generates a type token for {@code RestrictionTracker} given {@code RestrictionT}. + * Generates a {@link TypeDescriptor} for {@code RestrictionTracker} given {@code + * RestrictionT}. */ private static - TypeToken> restrictionTrackerTypeOf( - TypeToken restrictionT) { - return new TypeToken>() {}.where( + TypeDescriptor> restrictionTrackerTypeOf( + TypeDescriptor restrictionT) { + return new TypeDescriptor>() {}.where( new TypeParameter() {}, restrictionT); } @VisibleForTesting static DoFnSignature.NewTrackerMethod analyzeNewTrackerMethod( - ErrorReporter errors, TypeToken fnToken, Method m) { + ErrorReporter errors, TypeDescriptor fnT, Method m) { // Method is of the form: // @NewTracker // TrackerT newTracker(RestrictionT restriction); Type[] params = m.getGenericParameterTypes(); errors.checkArgument(params.length == 1, "Must have a single argument"); - TypeToken restrictionT = fnToken.resolveType(params[0]); - TypeToken trackerT = fnToken.resolveType(m.getGenericReturnType()); - TypeToken expectedTrackerT = restrictionTrackerTypeOf(restrictionT); + TypeDescriptor restrictionT = fnT.resolveType(params[0]); + TypeDescriptor trackerT = fnT.resolveType(m.getGenericReturnType()); + TypeDescriptor expectedTrackerT = restrictionTrackerTypeOf(restrictionT); errors.checkArgument( trackerT.isSubtypeOf(expectedTrackerT), "Returns %s, but must return a subtype of %s", @@ -985,7 +996,7 @@ private static String format(Method method) { return ReflectHelpers.METHOD_FORMATTER.apply(method); } - private static String formatType(TypeToken t) { + private static String formatType(TypeDescriptor t) { return ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(t.getType()); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 68278c55bb53e..573701b8d333d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import com.google.common.reflect.TypeToken; import java.util.List; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; @@ -34,6 +33,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -403,12 +403,12 @@ public void testSplitRestrictionReturnsWrongType() throws Exception { "Third argument must be OutputReceiver, but is OutputReceiver"); DoFnSignatures.analyzeSplitRestrictionMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new AnonymousMethod() { void method( Integer element, SomeRestriction restriction, DoFn.OutputReceiver receiver) {} }.getMethod(), - TypeToken.of(Integer.class)); + TypeDescriptor.of(Integer.class)); } @Test @@ -422,14 +422,14 @@ private List splitRestriction(String element, SomeRestriction r thrown.expectMessage("First argument must be the element type Integer"); DoFnSignatures.analyzeSplitRestrictionMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new AnonymousMethod() { void method( String element, SomeRestriction restriction, DoFn.OutputReceiver receiver) {} }.getMethod(), - TypeToken.of(Integer.class)); + TypeDescriptor.of(Integer.class)); } @Test @@ -437,7 +437,7 @@ public void testSplitRestrictionWrongNumArguments() throws Exception { thrown.expectMessage("Must have exactly 3 arguments"); DoFnSignatures.analyzeSplitRestrictionMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new AnonymousMethod() { private void method( Integer element, @@ -445,7 +445,7 @@ private void method( DoFn.OutputReceiver receiver, Object extra) {} }.getMethod(), - TypeToken.of(Integer.class)); + TypeDescriptor.of(Integer.class)); } @Test @@ -519,7 +519,7 @@ public void testNewTrackerWrongNumArguments() throws Exception { thrown.expectMessage("Must have a single argument"); DoFnSignatures.analyzeNewTrackerMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new AnonymousMethod() { private SomeRestrictionTracker method(SomeRestriction restriction, Object extra) { return null; @@ -533,7 +533,7 @@ public void testNewTrackerInconsistent() throws Exception { "Returns SomeRestrictionTracker, but must return a subtype of RestrictionTracker"); DoFnSignatures.analyzeNewTrackerMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new AnonymousMethod() { private SomeRestrictionTracker method(String restriction) { return null; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index fe88c3b6b3b49..52ecb2a2f6cf3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -25,7 +25,6 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; -import com.google.common.reflect.TypeToken; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; @@ -66,12 +65,12 @@ public void testBadExtraContext() throws Exception { DoFnSignatures.analyzeBundleMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), new DoFnSignaturesTestUtils.AnonymousMethod() { void method(DoFn.Context c, int n) {} }.getMethod(), - TypeToken.of(Integer.class), - TypeToken.of(String.class)); + TypeDescriptor.of(Integer.class), + TypeDescriptor.of(String.class)); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java index ce00f2d140ee0..49e2ba7691f38 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java @@ -17,11 +17,11 @@ */ package org.apache.beam.sdk.transforms.reflect; -import com.google.common.reflect.TypeToken; import java.lang.reflect.Method; import java.util.Collections; import java.util.NoSuchElementException; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.TypeDescriptor; /** Utilities for use in {@link DoFnSignatures} tests. */ class DoFnSignaturesTestUtils { @@ -57,10 +57,10 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod(AnonymousM throws Exception { return DoFnSignatures.analyzeProcessElementMethod( errors(), - TypeToken.of(FakeDoFn.class), + TypeDescriptor.of(FakeDoFn.class), method.getMethod(), - TypeToken.of(Integer.class), - TypeToken.of(String.class), + TypeDescriptor.of(Integer.class), + TypeDescriptor.of(String.class), Collections.EMPTY_MAP, Collections.EMPTY_MAP); } From 71fa7cdf17e9719ae7fd606a5406ebe5821eb41e Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Tue, 1 Nov 2016 14:50:24 -0700 Subject: [PATCH 3/4] DoFnSignature: Make TypeDescriptor-returning methods public --- .../beam/sdk/transforms/reflect/DoFnSignature.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 6b98805c4f336..431de02daec14 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -341,7 +341,7 @@ public abstract static class ProcessElementMethod implements DoFnMethod { /** Concrete type of the {@link RestrictionTracker} parameter, if present. */ @Nullable - abstract TypeDescriptor trackerT(); + public abstract TypeDescriptor trackerT(); /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */ public abstract boolean hasReturnValue(); @@ -461,7 +461,7 @@ public abstract static class GetInitialRestrictionMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the returned restriction. */ - abstract TypeDescriptor restrictionT(); + public abstract TypeDescriptor restrictionT(); static GetInitialRestrictionMethod create(Method targetMethod, TypeDescriptor restrictionT) { return new AutoValue_DoFnSignature_GetInitialRestrictionMethod(targetMethod, restrictionT); @@ -476,7 +476,7 @@ public abstract static class SplitRestrictionMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the restriction taken and returned. */ - abstract TypeDescriptor restrictionT(); + public abstract TypeDescriptor restrictionT(); static SplitRestrictionMethod create(Method targetMethod, TypeDescriptor restrictionT) { return new AutoValue_DoFnSignature_SplitRestrictionMethod(targetMethod, restrictionT); @@ -491,10 +491,10 @@ public abstract static class NewTrackerMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the input restriction. */ - abstract TypeDescriptor restrictionT(); + public abstract TypeDescriptor restrictionT(); /** Type of the returned {@link RestrictionTracker}. */ - abstract TypeDescriptor trackerT(); + public abstract TypeDescriptor trackerT(); static NewTrackerMethod create( Method targetMethod, TypeDescriptor restrictionT, TypeDescriptor trackerT) { @@ -510,7 +510,7 @@ public abstract static class GetRestrictionCoderMethod implements DoFnMethod { public abstract Method targetMethod(); /** Type of the returned {@link Coder}. */ - abstract TypeDescriptor coderT(); + public abstract TypeDescriptor coderT(); static GetRestrictionCoderMethod create(Method targetMethod, TypeDescriptor coderT) { return new AutoValue_DoFnSignature_GetRestrictionCoderMethod(targetMethod, coderT); From 8bf6d92cf35d11f4f3b02dae677a4fe778d34a61 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Mon, 31 Oct 2016 21:30:40 -0700 Subject: [PATCH 4/4] Refactor and reuse parameter analysis in DoFnSignatures --- .../sdk/transforms/reflect/DoFnSignature.java | 21 +- .../transforms/reflect/DoFnSignatures.java | 585 ++++++++++++------ .../DoFnSignaturesProcessElementTest.java | 18 +- .../DoFnSignaturesSplittableDoFnTest.java | 1 - .../reflect/DoFnSignaturesTest.java | 35 +- .../reflect/DoFnSignaturesTestUtils.java | 5 +- 6 files changed, 446 insertions(+), 219 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 431de02daec14..7087efa30bb11 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -126,12 +126,25 @@ abstract static class Builder { abstract DoFnSignature build(); } - /** A method delegated to a annotated method of an underlying {@link DoFn}. */ + /** A method delegated to an annotated method of an underlying {@link DoFn}. */ public interface DoFnMethod { /** The annotated method itself. */ Method targetMethod(); } + /** + * A method delegated to an annotated method of an underlying {@link DoFn} that accepts a dynamic + * list of parameters. + */ + public interface MethodWithExtraParameters extends DoFnMethod { + /** + * Types of optional parameters of the annotated method, in the order they appear. + * + *

Validation that these are allowed is external to this class. + */ + List extraParameters(); + } + /** A descriptor for an optional parameter of the {@link DoFn.ProcessElement} method. */ public abstract static class Parameter { @@ -331,12 +344,13 @@ public abstract static class TimerParameter extends Parameter { /** Describes a {@link DoFn.ProcessElement} method. */ @AutoValue - public abstract static class ProcessElementMethod implements DoFnMethod { + public abstract static class ProcessElementMethod implements MethodWithExtraParameters { /** The annotated method itself. */ @Override public abstract Method targetMethod(); /** Types of optional parameters of the annotated method, in the order they appear. */ + @Override public abstract List extraParameters(); /** Concrete type of the {@link RestrictionTracker} parameter, if present. */ @@ -380,7 +394,7 @@ public boolean isSplittable() { /** Describes a {@link DoFn.OnTimer} method. */ @AutoValue - public abstract static class OnTimerMethod implements DoFnMethod { + public abstract static class OnTimerMethod implements MethodWithExtraParameters { /** The id on the method's {@link DoFn.TimerId} annotation. */ public abstract String id(); @@ -390,6 +404,7 @@ public abstract static class OnTimerMethod implements DoFnMethod { public abstract Method targetMethod(); /** Types of optional parameters of the annotated method, in the order they appear. */ + @Override public abstract List extraParameters(); static OnTimerMethod create(Method targetMethod, String id, List extraParameters) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index c690ace53f3e0..0475404397068 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; @@ -41,9 +42,11 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; -import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; @@ -81,13 +84,134 @@ private DoFnSignatures() {} return signature; } + /** + * The context for a {@link DoFn} class, for use in analysis. + * + *

It contains much of the information that eventually becomes part of the {@link + * DoFnSignature}, but in an intermediate state. + */ + @VisibleForTesting + static class FnAnalysisContext { + + private final Map stateDeclarations = new HashMap<>(); + private final Map timerDeclarations = new HashMap<>(); + + private FnAnalysisContext() {} + + /** Create an empty context, with no declarations. */ + public static FnAnalysisContext create() { + return new FnAnalysisContext(); + } + + /** State parameters declared in this context, keyed by {@link StateId}. Unmodifiable. */ + public Map getStateDeclarations() { + return Collections.unmodifiableMap(stateDeclarations); + } + + /** Timer parameters declared in this context, keyed by {@link TimerId}. Unmodifiable. */ + public Map getTimerDeclarations() { + return Collections.unmodifiableMap(timerDeclarations); + } + + public void addStateDeclaration(StateDeclaration decl) { + stateDeclarations.put(decl.id(), decl); + } + + public void addStateDeclarations(Iterable decls) { + for (StateDeclaration decl : decls) { + addStateDeclaration(decl); + } + } + + public void addTimerDeclaration(TimerDeclaration decl) { + timerDeclarations.put(decl.id(), decl); + } + + public void addTimerDeclarations(Iterable decls) { + for (TimerDeclaration decl : decls) { + addTimerDeclaration(decl); + } + } + } + + /** + * The context of analysis within a particular method. + * + *

It contains much of the information that eventually becomes part of the {@link + * DoFnSignature.MethodWithExtraParameters}, but in an intermediate state. + */ + private static class MethodAnalysisContext { + + private final Map stateParameters = new HashMap<>(); + private final Map timerParameters = new HashMap<>(); + private final List extraParameters = new ArrayList<>(); + + private MethodAnalysisContext() {} + + /** State parameters declared in this context, keyed by {@link StateId}. */ + public Map getStateParameters() { + return Collections.unmodifiableMap(stateParameters); + } + + /** Timer parameters declared in this context, keyed by {@link TimerId}. */ + public Map getTimerParameters() { + return Collections.unmodifiableMap(timerParameters); + } + + /** Extra parameters in their entirety. Unmodifiable. */ + public List getExtraParameters() { + return Collections.unmodifiableList(extraParameters); + } + + /** + * Returns an {@link MethodAnalysisContext} like this one but including the provided {@link + * StateParameter}. + */ + public void addParameter(Parameter param) { + extraParameters.add(param); + + if (param instanceof StateParameter) { + StateParameter stateParameter = (StateParameter) param; + stateParameters.put(stateParameter.referent().id(), stateParameter); + } + if (param instanceof TimerParameter) { + TimerParameter timerParameter = (TimerParameter) param; + timerParameters.put(timerParameter.referent().id(), timerParameter); + } + } + + /** Create an empty context, with no declarations. */ + public static MethodAnalysisContext create() { + return new MethodAnalysisContext(); + } + } + + @AutoValue + abstract static class ParameterDescription { + public abstract Method getMethod(); + public abstract int getIndex(); + public abstract TypeDescriptor getType(); + public abstract List getAnnotations(); + + public static ParameterDescription of( + Method method, int index, TypeDescriptor type, List annotations) { + return new AutoValue_DoFnSignatures_ParameterDescription(method, index, type, annotations); + } + + public static ParameterDescription of( + Method method, int index, TypeDescriptor type, Annotation[] annotations) { + return new AutoValue_DoFnSignatures_ParameterDescription( + method, index, type, Arrays.asList(annotations)); + } + } + /** Analyzes a given {@link DoFn} class and extracts its {@link DoFnSignature}. */ private static DoFnSignature parseSignature(Class> fnClass) { - DoFnSignature.Builder builder = DoFnSignature.builder(); + DoFnSignature.Builder signatureBuilder = DoFnSignature.builder(); ErrorReporter errors = new ErrorReporter(null, fnClass.getName()); errors.checkArgument(DoFn.class.isAssignableFrom(fnClass), "Must be subtype of DoFn"); - builder.setFnClass(fnClass); + signatureBuilder.setFnClass(fnClass); TypeDescriptor> fnT = TypeDescriptor.of(fnClass); @@ -106,11 +230,9 @@ private static DoFnSignature parseSignature(Class> fnClass) // Find the state and timer declarations in advance of validating // method parameter lists - Map stateDeclarations = analyzeStateDeclarations(errors, fnClass); - builder.setStateDeclarations(stateDeclarations); - - Map timerDeclarations = analyzeTimerDeclarations(errors, fnClass); - builder.setTimerDeclarations(timerDeclarations); + FnAnalysisContext fnContext = FnAnalysisContext.create(); + fnContext.addStateDeclarations(analyzeStateDeclarations(errors, fnClass).values()); + fnContext.addTimerDeclarations(analyzeTimerDeclarations(errors, fnClass).values()); Method processElementMethod = findAnnotatedMethod(errors, DoFn.ProcessElement.class, fnClass, true); @@ -135,12 +257,12 @@ private static DoFnSignature parseSignature(Class> fnClass) for (Method onTimerMethod : onTimerMethods) { String id = onTimerMethod.getAnnotation(DoFn.OnTimer.class).value(); errors.checkArgument( - timerDeclarations.containsKey(id), + fnContext.getTimerDeclarations().containsKey(id), "Callback %s is for for undeclared timer %s", onTimerMethod, id); - TimerDeclaration timerDecl = timerDeclarations.get(id); + TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id); errors.checkArgument( timerDecl.field().getDeclaringClass().equals(onTimerMethod.getDeclaringClass()), "Callback %s is for timer %s declared in a different class %s." @@ -149,13 +271,14 @@ private static DoFnSignature parseSignature(Class> fnClass) id, timerDecl.field().getDeclaringClass().getCanonicalName()); - onTimerMethodMap.put(id, OnTimerMethod.create(onTimerMethod, id, Collections.EMPTY_LIST)); + onTimerMethodMap.put( + id, analyzeOnTimerMethod(errors, fnT, onTimerMethod, id, outputT, fnContext)); } - builder.setOnTimerMethods(onTimerMethodMap); + signatureBuilder.setOnTimerMethods(onTimerMethodMap); // Check the converse - that all timers have a callback. This could be relaxed to only // those timers used in methods, once method parameter lists support timers. - for (TimerDeclaration decl : timerDeclarations.values()) { + for (TimerDeclaration decl : fnContext.getTimerDeclarations().values()) { errors.checkArgument( onTimerMethodMap.containsKey(decl.id()), "No callback registered via %s for timer %s", @@ -172,30 +295,29 @@ private static DoFnSignature parseSignature(Class> fnClass) processElementMethod, inputT, outputT, - stateDeclarations, - timerDeclarations); - builder.setProcessElement(processElement); + fnContext); + signatureBuilder.setProcessElement(processElement); if (startBundleMethod != null) { ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod); - builder.setStartBundle( + signatureBuilder.setStartBundle( analyzeBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT)); } if (finishBundleMethod != null) { ErrorReporter finishBundleErrors = errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod); - builder.setFinishBundle( + signatureBuilder.setFinishBundle( analyzeBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT)); } if (setupMethod != null) { - builder.setSetup( + signatureBuilder.setSetup( analyzeLifecycleMethod(errors.forMethod(DoFn.Setup.class, setupMethod), setupMethod)); } if (teardownMethod != null) { - builder.setTeardown( + signatureBuilder.setTeardown( analyzeLifecycleMethod( errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod)); } @@ -205,7 +327,7 @@ private static DoFnSignature parseSignature(Class> fnClass) if (getInitialRestrictionMethod != null) { getInitialRestrictionErrors = errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod); - builder.setGetInitialRestriction( + signatureBuilder.setGetInitialRestriction( getInitialRestriction = analyzeGetInitialRestrictionMethod( getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT)); @@ -215,7 +337,7 @@ private static DoFnSignature parseSignature(Class> fnClass) if (splitRestrictionMethod != null) { ErrorReporter splitRestrictionErrors = errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod); - builder.setSplitRestriction( + signatureBuilder.setSplitRestriction( splitRestriction = analyzeSplitRestrictionMethod( splitRestrictionErrors, fnT, splitRestrictionMethod, inputT)); @@ -225,7 +347,7 @@ private static DoFnSignature parseSignature(Class> fnClass) if (getRestrictionCoderMethod != null) { ErrorReporter getRestrictionCoderErrors = errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod); - builder.setGetRestrictionCoder( + signatureBuilder.setGetRestrictionCoder( getRestrictionCoder = analyzeGetRestrictionCoderMethod( getRestrictionCoderErrors, fnT, getRestrictionCoderMethod)); @@ -234,13 +356,16 @@ private static DoFnSignature parseSignature(Class> fnClass) DoFnSignature.NewTrackerMethod newTracker = null; if (newTrackerMethod != null) { ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod); - builder.setNewTracker( + signatureBuilder.setNewTracker( newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod)); } - builder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors)); + signatureBuilder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors)); - DoFnSignature signature = builder.build(); + signatureBuilder.setStateDeclarations(fnContext.getStateDeclarations()); + signatureBuilder.setTimerDeclarations(fnContext.getTimerDeclarations()); + + DoFnSignature signature = signatureBuilder.build(); // Additional validation for splittable DoFn's. if (processElement.isSplittable()) { @@ -451,6 +576,42 @@ private static TypeDescriptor> outputRece new TypeParameter() {}, inputT); } + @VisibleForTesting + static DoFnSignature.OnTimerMethod analyzeOnTimerMethod( + ErrorReporter errors, + TypeDescriptor> fnClass, + Method m, + String timerId, + TypeDescriptor outputT, + FnAnalysisContext fnContext) { + errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); + + Type[] params = m.getGenericParameterTypes(); + + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + + List extraParameters = new ArrayList<>(); + TypeDescriptor expectedOutputReceiverT = outputReceiverTypeOf(outputT); + ErrorReporter onTimerErrors = errors.forMethod(DoFn.OnTimer.class, m); + for (int i = 0; i < params.length; ++i) { + extraParameters.add( + analyzeExtraParameter( + onTimerErrors, + fnContext, + methodContext, + fnClass, + ParameterDescription.of( + m, + i, + fnClass.resolveType(params[i]), + Arrays.asList(m.getParameterAnnotations()[i])), + null /* restriction type not applicable */, + expectedOutputReceiverT)); + } + + return DoFnSignature.OnTimerMethod.create(m, timerId, extraParameters); + } + @VisibleForTesting static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( ErrorReporter errors, @@ -458,8 +619,7 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( Method m, TypeDescriptor inputT, TypeDescriptor outputT, - Map stateDeclarations, - Map timerDeclarations) { + FnAnalysisContext fnContext) { errors.checkArgument( void.class.equals(m.getReturnType()) || DoFn.ProcessContinuation.class.equals(m.getReturnType()), @@ -468,6 +628,8 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( TypeDescriptor processContextT = doFnProcessContextTypeOf(inputT, outputT); + MethodAnalysisContext methodContext = MethodAnalysisContext.create(); + Type[] params = m.getGenericParameterTypes(); TypeDescriptor contextT = null; if (params.length > 0) { @@ -478,192 +640,211 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( "Must take %s as the first argument", formatType(processContextT)); - List extraParameters = new ArrayList<>(); - Map stateParameters = new HashMap<>(); - Map timerParameters = new HashMap<>(); - TypeDescriptor trackerT = null; - + TypeDescriptor trackerT = getTrackerType(fnClass, m); TypeDescriptor expectedInputProviderT = inputProviderTypeOf(inputT); TypeDescriptor expectedOutputReceiverT = outputReceiverTypeOf(outputT); for (int i = 1; i < params.length; ++i) { - TypeDescriptor paramT = fnClass.resolveType(params[i]); - Class rawType = paramT.getRawType(); - if (rawType.equals(BoundedWindow.class)) { - errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.boundedWindow()), - "Multiple %s parameters", - BoundedWindow.class.getSimpleName()); - extraParameters.add(DoFnSignature.Parameter.boundedWindow()); - } else if (rawType.equals(DoFn.InputProvider.class)) { - errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.inputProvider()), - "Multiple %s parameters", - DoFn.InputProvider.class.getSimpleName()); - errors.checkArgument( - paramT.equals(expectedInputProviderT), - "Wrong type of %s parameter: %s, should be %s", - DoFn.InputProvider.class.getSimpleName(), - formatType(paramT), - formatType(expectedInputProviderT)); - extraParameters.add(DoFnSignature.Parameter.inputProvider()); - } else if (rawType.equals(DoFn.OutputReceiver.class)) { - errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.outputReceiver()), - "Multiple %s parameters", - DoFn.OutputReceiver.class.getSimpleName()); - errors.checkArgument( - paramT.equals(expectedOutputReceiverT), - "Wrong type of %s parameter: %s, should be %s", - DoFn.OutputReceiver.class.getSimpleName(), - formatType(paramT), - formatType(expectedOutputReceiverT)); - extraParameters.add(DoFnSignature.Parameter.outputReceiver()); - } else if (Timer.class.equals(rawType)) { - // m.getParameters() is not available until Java 8 - Annotation[] annotations = m.getParameterAnnotations()[i]; - String id = null; - for (Annotation anno : annotations) { - if (anno.annotationType().equals(DoFn.TimerId.class)) { - id = ((DoFn.TimerId) anno).value(); - break; - } - } - errors.checkArgument( - id != null, - "%s parameter of type %s at index %s missing %s annotation", - fnClass.getRawType().getName(), - params[i], - i, - DoFn.TimerId.class.getSimpleName()); - errors.checkArgument( - !timerParameters.containsKey(id), - "%s parameter of type %s at index %s duplicates %s(\"%s\") on other parameter", - fnClass.getRawType().getName(), - params[i], - i, - DoFn.TimerId.class.getSimpleName(), - id); - - TimerDeclaration timerDecl = timerDeclarations.get(id); - errors.checkArgument( - timerDecl != null, - "%s parameter of type %s at index %s references undeclared %s \"%s\"", - fnClass.getRawType().getName(), - params[i], - i, - TimerId.class.getSimpleName(), - id); + Parameter extraParam = + analyzeExtraParameter( + errors.forMethod(DoFn.ProcessElement.class, m), + fnContext, + methodContext, + fnClass, + ParameterDescription.of( + m, + i, + fnClass.resolveType(params[i]), + Arrays.asList(m.getParameterAnnotations()[i])), + expectedInputProviderT, + expectedOutputReceiverT); + + methodContext.addParameter(extraParam); + } - errors.checkArgument( - timerDecl.field().getDeclaringClass().equals(m.getDeclaringClass()), - "Method %s has %s parameter at index %s for timer %s" - + " declared in a different class %s." - + " Timers may be referenced only in the lexical scope where they are declared.", - m, - Timer.class.getSimpleName(), - i, - id, - timerDecl.field().getDeclaringClass().getName()); + // A splittable DoFn can not have any other extra context parameters. + if (methodContext.getExtraParameters().contains(DoFnSignature.Parameter.restrictionTracker())) { + errors.checkArgument( + methodContext.getExtraParameters().size() == 1, + "Splittable DoFn must not have any extra arguments, but has: %s", + trackerT, + methodContext.getExtraParameters()); + } - DoFnSignature.Parameter.TimerParameter timerParameter = Parameter.timerParameter(timerDecl); - timerParameters.put(id, timerParameter); - extraParameters.add(timerParameter); + return DoFnSignature.ProcessElementMethod.create( + m, + methodContext.getExtraParameters(), + trackerT, + DoFn.ProcessContinuation.class.equals(m.getReturnType())); + } - } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { - errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.restrictionTracker()), - "Multiple %s parameters", - RestrictionTracker.class.getSimpleName()); - extraParameters.add(DoFnSignature.Parameter.restrictionTracker()); - trackerT = paramT; - } else if (State.class.isAssignableFrom(rawType)) { - // m.getParameters() is not available until Java 8 - Annotation[] annotations = m.getParameterAnnotations()[i]; - String id = null; - for (Annotation anno : annotations) { - if (anno.annotationType().equals(DoFn.StateId.class)) { - id = ((DoFn.StateId) anno).value(); - break; - } - } - errors.checkArgument( - id != null, - "%s parameter of type %s at index %s missing %s annotation", - fnClass.getRawType().getName(), - params[i], - i, - DoFn.StateId.class.getSimpleName()); + private static Parameter analyzeExtraParameter( + ErrorReporter methodErrors, + FnAnalysisContext fnContext, + MethodAnalysisContext methodContext, + TypeDescriptor> fnClass, + ParameterDescription param, + TypeDescriptor expectedInputProviderT, + TypeDescriptor expectedOutputReceiverT) { + TypeDescriptor paramT = param.getType(); + Class rawType = paramT.getRawType(); + + ErrorReporter paramErrors = methodErrors.forParameter(param); + + if (rawType.equals(BoundedWindow.class)) { + methodErrors.checkArgument( + !methodContext.getExtraParameters().contains(Parameter.boundedWindow()), + "Multiple %s parameters", + BoundedWindow.class.getSimpleName()); + return Parameter.boundedWindow(); + } else if (rawType.equals(DoFn.InputProvider.class)) { + methodErrors.checkArgument( + !methodContext.getExtraParameters().contains(Parameter.inputProvider()), + "Multiple %s parameters", + DoFn.InputProvider.class.getSimpleName()); + paramErrors.checkArgument( + paramT.equals(expectedInputProviderT), + "%s is for %s when it should be %s", + DoFn.InputProvider.class.getSimpleName(), + formatType(paramT), + formatType(expectedInputProviderT)); + return Parameter.inputProvider(); + } else if (rawType.equals(DoFn.OutputReceiver.class)) { + methodErrors.checkArgument( + !methodContext.getExtraParameters().contains(Parameter.outputReceiver()), + "Multiple %s parameters", + DoFn.OutputReceiver.class.getSimpleName()); + paramErrors.checkArgument( + paramT.equals(expectedOutputReceiverT), + "%s is for %s when it should be %s", + DoFn.OutputReceiver.class.getSimpleName(), + formatType(paramT), + formatType(expectedOutputReceiverT)); + return Parameter.outputReceiver(); + + } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { + methodErrors.checkArgument( + !methodContext.getExtraParameters().contains(Parameter.restrictionTracker()), + "Multiple %s parameters", + RestrictionTracker.class.getSimpleName()); + return Parameter.restrictionTracker(); + + } else if (rawType.equals(Timer.class)) { + // m.getParameters() is not available until Java 8 + String id = getTimerId(param.getAnnotations()); + + paramErrors.checkArgument( + id != null, + "%s missing %s annotation", + Timer.class.getSimpleName(), + TimerId.class.getSimpleName()); + + paramErrors.checkArgument( + !methodContext.getTimerParameters().containsKey(id), + "duplicate %s: \"%s\"", + TimerId.class.getSimpleName(), + id); - errors.checkArgument( - !stateParameters.containsKey(id), - "%s parameter of type %s at index %s duplicates %s(\"%s\") on other parameter", - fnClass.getRawType().getName(), - params[i], - i, - DoFn.StateId.class.getSimpleName(), - id); + TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id); + paramErrors.checkArgument( + timerDecl != null, + "reference to undeclared %s: \"%s\"", + TimerId.class.getSimpleName(), + id); - // By static typing this is already a well-formed State subclass - TypeDescriptor stateType = - (TypeDescriptor) - TypeDescriptor.of(fnClass.getType()) - .resolveType(params[i]); + paramErrors.checkArgument( + timerDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()), + "%s %s declared in a different class %s." + + " Timers may be referenced only in the lexical scope where they are declared.", + TimerId.class.getSimpleName(), + id, + timerDecl.field().getDeclaringClass().getName()); + + return Parameter.timerParameter(timerDecl); + + } else if (State.class.isAssignableFrom(rawType)) { + // m.getParameters() is not available until Java 8 + String id = getStateId(param.getAnnotations()); + paramErrors.checkArgument( + id != null, + "missing %s annotation", + DoFn.StateId.class.getSimpleName()); + + paramErrors.checkArgument( + !methodContext.getStateParameters().containsKey(id), + "duplicate %s: \"%s\"", + DoFn.StateId.class.getSimpleName(), + id); - StateDeclaration stateDecl = stateDeclarations.get(id); - errors.checkArgument( - stateDecl != null, - "%s parameter of type %s at index %s references undeclared %s \"%s\"", - fnClass.getRawType().getName(), - params[i], - i, - DoFn.StateId.class.getSimpleName(), - id); + // By static typing this is already a well-formed State subclass + TypeDescriptor stateType = (TypeDescriptor) param.getType(); - errors.checkArgument( - stateDecl.stateType().equals(stateType), - "%s parameter at index %s has type %s but is a reference to StateId %s of type %s", - fnClass.getRawType().getName(), - i, - params[i], - id, - stateDecl.stateType()); + StateDeclaration stateDecl = fnContext.getStateDeclarations().get(id); + paramErrors.checkArgument( + stateDecl != null, + "reference to undeclared %s: \"%s\"", + DoFn.StateId.class.getSimpleName(), + id); - errors.checkArgument( - stateDecl.field().getDeclaringClass().equals(m.getDeclaringClass()), - "Method %s has State parameter at index %s for state %s" - + " declared in a different class %s." - + " State may be referenced only in the class where it is declared.", - m, - i, - id, - stateDecl.field().getDeclaringClass().getName()); - - DoFnSignature.Parameter.StateParameter stateParameter = Parameter.stateParameter(stateDecl); - stateParameters.put(id, stateParameter); - extraParameters.add(stateParameter); - } else { - List allowedParamTypes = - Arrays.asList( - formatType(new TypeDescriptor() {}), - formatType(new TypeDescriptor>() {})); - errors.throwIllegalArgument( - "%s is not a valid context parameter. Should be one of %s", - formatType(paramT), allowedParamTypes); + paramErrors.checkArgument( + stateDecl.stateType().equals(stateType), + "reference to %s %s with different type %s", + StateId.class.getSimpleName(), + id, + stateDecl.stateType()); + + paramErrors.checkArgument( + stateDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()), + "%s %s declared in a different class %s." + + " State may be referenced only in the class where it is declared.", + StateId.class.getSimpleName(), + id, + stateDecl.field().getDeclaringClass().getName()); + + return Parameter.stateParameter(stateDecl); + } else { + List allowedParamTypes = + Arrays.asList( + formatType(new TypeDescriptor() {}), + formatType(new TypeDescriptor>() {})); + paramErrors.throwIllegalArgument( + "%s is not a valid context parameter. Should be one of %s", + formatType(paramT), allowedParamTypes); + // Unreachable + return null; + } + } + + @Nullable + private static String getTimerId(List annotations) { + for (Annotation anno : annotations) { + if (anno.annotationType().equals(DoFn.TimerId.class)) { + return ((DoFn.TimerId) anno).value(); } } + return null; + } - // A splittable DoFn can not have any other extra context parameters. - if (extraParameters.contains(DoFnSignature.Parameter.restrictionTracker())) { - errors.checkArgument( - extraParameters.size() == 1, - "Splittable DoFn must not have any extra arguments apart from BoundedWindow, but has: %s", - trackerT, - extraParameters); + @Nullable + private static String getStateId(List annotations) { + for (Annotation anno : annotations) { + if (anno.annotationType().equals(DoFn.StateId.class)) { + return ((DoFn.StateId) anno).value(); + } } + return null; + } - return DoFnSignature.ProcessElementMethod.create( - m, extraParameters, trackerT, DoFn.ProcessContinuation.class.equals(m.getReturnType())); + @Nullable + private static TypeDescriptor getTrackerType(TypeDescriptor fnClass, Method method) { + Type[] params = method.getGenericParameterTypes(); + for (int i = 0; i < params.length; i++) { + TypeDescriptor paramT = fnClass.resolveType(params[i]); + if (RestrictionTracker.class.isAssignableFrom(paramT.getRawType())) { + return paramT; + } + } + return null; } @VisibleForTesting @@ -905,7 +1086,7 @@ Collection declaredMembersWithAnnotation( return matches; } - private static ImmutableMap analyzeStateDeclarations( + private static Map analyzeStateDeclarations( ErrorReporter errors, Class fnClazz) { @@ -1015,6 +1196,14 @@ ErrorReporter forMethod(Class annotation, Method method) { annotation.getSimpleName(), (method == null) ? "(absent)" : format(method))); } + ErrorReporter forParameter(ParameterDescription param) { + return new ErrorReporter( + this, + String.format( + "parameter of type %s at index %s", + param.getType(), param.getIndex())); + } + void throwIllegalArgument(String message, Object... args) { throw new IllegalArgumentException(label + ": " + String.format(message, args)); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java index 329a09904ed73..6cbc95e32260a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java @@ -96,9 +96,9 @@ private void method( @Test public void testBadGenericsTwoArgs() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Wrong type of OutputReceiver parameter: " - + "OutputReceiver, should be OutputReceiver"); + thrown.expectMessage("OutputReceiver"); + thrown.expectMessage("should be"); + thrown.expectMessage("OutputReceiver"); analyzeProcessElementMethod( new AnonymousMethod() { @@ -112,9 +112,9 @@ private void method( @Test public void testBadGenericWildCards() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Wrong type of OutputReceiver parameter: " - + "OutputReceiver, should be OutputReceiver"); + thrown.expectMessage("OutputReceiver"); + thrown.expectMessage("should be"); + thrown.expectMessage("OutputReceiver"); analyzeProcessElementMethod( new AnonymousMethod() { @@ -137,9 +137,9 @@ public void badTypeVariables( @Test public void testBadTypeVariables() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage( - "Wrong type of OutputReceiver parameter: " - + "OutputReceiver, should be OutputReceiver"); + thrown.expectMessage("OutputReceiver"); + thrown.expectMessage("should be"); + thrown.expectMessage("OutputReceiver"); DoFnSignatures.INSTANCE.getSignature(BadTypeVariables.class); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 573701b8d333d..0751b594ac20a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -88,7 +88,6 @@ private void method( public void testSplittableProcessElementMustNotHaveOtherParams() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("must not have any extra arguments"); - thrown.expectMessage("BoundedWindow"); DoFnSignature.ProcessElementMethod signature = analyzeProcessElementMethod( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index 52ecb2a2f6cf3..4187e0a4ccf1a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -29,11 +30,11 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.OnTimer; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.Timer; import org.apache.beam.sdk.util.TimerSpec; @@ -249,7 +250,7 @@ public void process( @Test public void testTimerParameterDuplicate() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("duplicates"); + thrown.expectMessage("duplicate"); thrown.expectMessage("my-id"); thrown.expectMessage("myProcessElement"); thrown.expectMessage("index 2"); @@ -290,6 +291,28 @@ public void foo(ProcessContext context) {} }.getClass()); } + @Test + public void testWindowParamOnTimer() throws Exception { + final String timerId = "some-timer-id"; + + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature(new DoFn() { + @TimerId(timerId) + private final TimerSpec myfield1 = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ProcessElement + public void process(ProcessContext c) {} + + @OnTimer(timerId) + public void onTimer(BoundedWindow w) {} + }.getClass()); + + assertThat(sig.onTimerMethods().get(timerId).extraParameters().size(), equalTo(1)); + assertThat( + sig.onTimerMethods().get(timerId).extraParameters().get(0), + instanceOf(DoFnSignature.Parameter.BoundedWindowParameter.class)); + } + @Test public void testDeclAndUsageOfTimerInSuperclass() throws Exception { DoFnSignature sig = @@ -525,7 +548,7 @@ public void myProcessElement( @Test public void testStateParameterDuplicate() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("duplicates"); + thrown.expectMessage("duplicate"); thrown.expectMessage("my-id"); thrown.expectMessage("myProcessElement"); thrown.expectMessage("index 2"); @@ -549,7 +572,8 @@ public void myProcessElement( public void testStateParameterWrongStateType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("WatermarkHoldState"); - thrown.expectMessage("but is a reference to"); + thrown.expectMessage("reference to"); + thrown.expectMessage("different type"); thrown.expectMessage("ValueState"); thrown.expectMessage("my-id"); thrown.expectMessage("myProcessElement"); @@ -572,7 +596,8 @@ public void myProcessElement( public void testStateParameterWrongGenericType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("ValueState"); - thrown.expectMessage("but is a reference to"); + thrown.expectMessage("reference to"); + thrown.expectMessage("different type"); thrown.expectMessage("ValueState"); thrown.expectMessage("my-id"); thrown.expectMessage("myProcessElement"); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java index 49e2ba7691f38..b7d137a50cb80 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.transforms.reflect; import java.lang.reflect.Method; -import java.util.Collections; import java.util.NoSuchElementException; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures.FnAnalysisContext; import org.apache.beam.sdk.values.TypeDescriptor; /** Utilities for use in {@link DoFnSignatures} tests. */ @@ -61,7 +61,6 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod(AnonymousM method.getMethod(), TypeDescriptor.of(Integer.class), TypeDescriptor.of(String.class), - Collections.EMPTY_MAP, - Collections.EMPTY_MAP); + FnAnalysisContext.create()); } }