From 39391e07f075d5e1bd2f791da4c610d766968b94 Mon Sep 17 00:00:00 2001 From: Greg Hogan Date: Thu, 26 May 2016 14:45:00 -0400 Subject: [PATCH] [FLINK-3978] [core] Add hasBroadcastVariable method to RuntimeContext New method RuntimeContext.hasBroadcastVariable(String). --- .../api/common/functions/RuntimeContext.java | 9 +++++++++ .../functions/util/RuntimeUDFContext.java | 5 +++++ .../functions/util/RuntimeUDFContextTest.java | 9 +++++++-- .../util/DistributedRuntimeUDFContext.java | 18 +++++++++--------- .../api/operators/StreamingRuntimeContext.java | 5 +++++ 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java index ed2f613987fd4..da813dcaf4e56 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java @@ -162,6 +162,15 @@ public interface RuntimeContext { // -------------------------------------------------------------------------------------------- + /** + * Tests for the existence of the broadcast variable identified by the + * given {@code name}. + * + * @param name The name under which the broadcast variable is registered; + * @return Whether a broadcast variable exists for the given name. + */ + boolean hasBroadcastVariable(String name); + /** * Returns the result bound to the broadcast variable identified by the * given {@code name}. diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/RuntimeUDFContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/RuntimeUDFContext.java index 6571d0d44686e..ba3f85e068f50 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/RuntimeUDFContext.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/RuntimeUDFContext.java @@ -48,6 +48,11 @@ public RuntimeUDFContext(TaskInfo taskInfo, ClassLoader userCodeClassLoader, Exe super(taskInfo, userCodeClassLoader, executionConfig, accumulators, cpTasks, metrics); } + @Override + public boolean hasBroadcastVariable(String name) { + return this.initializedBroadcastVars.containsKey(name) || this.uninitializedBroadcastVars.containsKey(name); + } + @Override @SuppressWarnings("unchecked") public List getBroadcastVariable(String name) { diff --git a/flink-core/src/test/java/org/apache/flink/api/common/functions/util/RuntimeUDFContextTest.java b/flink-core/src/test/java/org/apache/flink/api/common/functions/util/RuntimeUDFContextTest.java index 83c88cc3608ae..b3c7f917a8021 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/functions/util/RuntimeUDFContextTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/functions/util/RuntimeUDFContextTest.java @@ -43,7 +43,9 @@ public class RuntimeUDFContextTest { public void testBroadcastVariableNotFound() { try { RuntimeUDFContext ctx = new RuntimeUDFContext(taskInfo, getClass().getClassLoader(), new ExecutionConfig(), new HashMap>(),new HashMap>(), new DummyMetricGroup()); - + + assertFalse(ctx.hasBroadcastVariable("some name")); + try { ctx.getBroadcastVariable("some name"); fail("should throw an exception"); @@ -76,7 +78,10 @@ public void testBroadcastVariableSimple() { ctx.setBroadcastVariable("name1", Arrays.asList(1, 2, 3, 4)); ctx.setBroadcastVariable("name2", Arrays.asList(1.0, 2.0, 3.0, 4.0)); - + + assertTrue(ctx.hasBroadcastVariable("name1")); + assertTrue(ctx.hasBroadcastVariable("name2")); + List list1 = ctx.getBroadcastVariable("name1"); List list2 = ctx.getBroadcastVariable("name2"); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/DistributedRuntimeUDFContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/DistributedRuntimeUDFContext.java index 293d34f44a283..6c7f5f33021f5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/DistributedRuntimeUDFContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/DistributedRuntimeUDFContext.java @@ -47,11 +47,15 @@ public DistributedRuntimeUDFContext(TaskInfo taskInfo, ClassLoader userCodeClass Map> cpTasks, Map> accumulators, MetricGroup metrics) { super(taskInfo, userCodeClassLoader, executionConfig, accumulators, cpTasks, metrics); } - + + @Override + public boolean hasBroadcastVariable(String name) { + return this.broadcastVars.containsKey(name); + } @Override public List getBroadcastVariable(String name) { - Preconditions.checkNotNull(name); + Preconditions.checkNotNull(name, "The broadcast variable name must not be null."); // check if we have an initialized version @SuppressWarnings("unchecked") @@ -71,13 +75,9 @@ public List getBroadcastVariable(String name) { @Override public C getBroadcastVariableWithInitializer(String name, BroadcastVariableInitializer initializer) { - if (name == null) { - throw new NullPointerException("Thw broadcast variable name must not be null."); - } - if (initializer == null) { - throw new NullPointerException("Thw broadcast variable initializer must not be null."); - } - + Preconditions.checkNotNull(name, "The broadcast variable name must not be null."); + Preconditions.checkNotNull(initializer, "The broadcast variable initializer must not be null."); + // check if we have an initialized version @SuppressWarnings("unchecked") BroadcastVariableMaterialization variable = (BroadcastVariableMaterialization) this.broadcastVars.get(name); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java index 5fef3c797dbd5..a858b4c4d97a4 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java @@ -98,6 +98,11 @@ public ScheduledFuture registerTimer(long time, Triggerable target) { // broadcast variables // ------------------------------------------------------------------------ + @Override + public boolean hasBroadcastVariable(String name) { + throw new UnsupportedOperationException("Broadcast variables can only be used in DataSet programs"); + } + @Override public List getBroadcastVariable(String name) { throw new UnsupportedOperationException("Broadcast variables can only be used in DataSet programs");