From 30302ce9166fdd8e7ade29e45660c18c36f5c5f7 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 5 Oct 2016 16:11:21 -0700 Subject: [PATCH 1/9] Perform initial splitting in the DirectRunner This allows sources to be read from in parallel and generates initial splits. --- .../direct/BoundedReadEvaluatorFactory.java | 47 ++++++++++++---- .../beam/runners/direct/DirectOptions.java | 22 ++++++++ .../beam/runners/direct/DirectRunner.java | 11 +++- .../runners/direct/EmptyInputProvider.java | 3 +- .../ExecutorServiceParallelExecutor.java | 18 ++++--- .../runners/direct/RootInputProvider.java | 7 ++- .../runners/direct/RootProviderRegistry.java | 5 +- .../direct/TestStreamEvaluatorFactory.java | 4 +- .../direct/TransformEvaluatorRegistry.java | 10 ++-- .../direct/UnboundedReadEvaluatorFactory.java | 42 +++++++++++---- .../BoundedReadEvaluatorFactoryTest.java | 41 +++++++++++++- .../direct/FlattenEvaluatorFactoryTest.java | 2 +- .../TestStreamEvaluatorFactoryTest.java | 2 +- .../UnboundedReadEvaluatorFactoryTest.java | 53 ++++++++++++++++++- 14 files changed, 221 insertions(+), 46 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java index 326a535e51a5..f81cd5bc21fe 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java @@ -18,28 +18,32 @@ package org.apache.beam.runners.direct; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.Collection; -import java.util.Collections; +import java.util.List; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.runners.direct.StepTransformResult.Builder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.BoundedSource.BoundedReader; -import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.Read.Bounded; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} * for the {@link Bounded Read.Bounded} primitive {@link PTransform}. */ final class BoundedReadEvaluatorFactory implements TransformEvaluatorFactory { + private static final Logger LOG = LoggerFactory.getLogger(BoundedReadEvaluatorFactory.class); private final EvaluationContext evaluationContext; BoundedReadEvaluatorFactory(EvaluationContext evaluationContext) { @@ -126,18 +130,39 @@ static class InputProvider implements RootInputProvider { } @Override - public Collection> getInitialInputs(AppliedPTransform transform) { - return createInitialSplits((AppliedPTransform) transform); + public Collection> getInitialInputs( + AppliedPTransform transform, int targetParallelism) throws Exception { + return createInitialSplits((AppliedPTransform) transform, targetParallelism); } - private Collection> createInitialSplits( - AppliedPTransform> transform) { + private + Collection>> createInitialSplits( + AppliedPTransform> transform, int targetParallelism) + throws Exception { BoundedSource source = transform.getTransform().getSource(); - return Collections.>singleton( - evaluationContext - .>createRootBundle() - .add(WindowedValue.valueInGlobalWindow(BoundedSourceShard.of(source))) - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + PipelineOptions options = evaluationContext.getPipelineOptions(); + long estimatedBytes = source.getEstimatedSizeBytes(options); + long bytesPerBundle = estimatedBytes / targetParallelism; + List> bundles = + source.splitIntoBundles(bytesPerBundle, options); + ImmutableList.Builder>> shards = + ImmutableList.builder(); + if (bundles.isEmpty()) { + LOG.debug("Splits of source {} were empty, using empty split"); + shards.add( + evaluationContext + .>createRootBundle() + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + } + for (BoundedSource bundle : bundles) { + CommittedBundle> inputShard = + evaluationContext + .>createRootBundle() + .add(WindowedValue.valueInGlobalWindow(BoundedSourceShard.of(bundle))) + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + shards.add(inputShard); + } + return shards.build(); } } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java index 89e1bb805bcc..f31f84430f77 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java @@ -19,6 +19,7 @@ import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; @@ -62,4 +63,25 @@ public interface DirectOptions extends PipelineOptions, ApplicationNameOptions { + "PCollection are encodable. All elements in a PCollection must be encodable.") boolean isEnforceEncodability(); void setEnforceEncodability(boolean test); + + @Default.InstanceFactory(AvailableParallelismFactory.class) + @Description("Controls the amount of target parallelism the DirectRunner will use. Defaults to" + + " the greater of the number of available processors as returned by the Runtime and 3. Must" + + " be a value greater than zero.") + int getTargetParallelism(); + void setTargetParallelism(int target); + + /** + * A {@link DefaultValueFactory} that returns the result of {@link Runtime#availableProcessors()} + * from the {@link #create(PipelineOptions)} method. Uses {@link Runtime#getRuntime()} to obtain + * the {@link Runtime}. + */ + class AvailableParallelismFactory implements DefaultValueFactory { + private static final int MIN_PARALLELISM = 3; + + @Override + public Integer create(PipelineOptions options) { + return Math.max(Runtime.getRuntime().availableProcessors(), MIN_PARALLELISM); + } + } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 224101ae59a2..e8b0cd5aad07 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -179,7 +179,7 @@ public static interface PCollectionViewWriter { //////////////////////////////////////////////////////////////////////////////////////////////// private final DirectOptions options; - private Supplier executorServiceSupplier = new FixedThreadPoolSupplier(); + private Supplier executorServiceSupplier; private Supplier clockSupplier = new NanosOffsetClockSupplier(); public static DirectRunner fromOptions(PipelineOptions options) { @@ -188,6 +188,7 @@ public static DirectRunner fromOptions(PipelineOptions options) { private DirectRunner(DirectOptions options) { this.options = options; + this.executorServiceSupplier = new FixedThreadPoolSupplier(options); } /** @@ -428,9 +429,15 @@ public State waitUntilFinish(Duration duration) throws IOException { * {@link Executors#newFixedThreadPool(int)}. */ private static class FixedThreadPoolSupplier implements Supplier { + private final DirectOptions options; + + private FixedThreadPoolSupplier(DirectOptions options) { + this.options = options; + } + @Override public ExecutorService get() { - return Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + return Executors.newFixedThreadPool(options.getTargetParallelism()); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java index 10d63e95dd09..fda2b287bc96 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java @@ -42,7 +42,8 @@ class EmptyInputProvider implements RootInputProvider { * as appropriate. */ @Override - public Collection> getInitialInputs(AppliedPTransform transform) { + public Collection> getInitialInputs( + AppliedPTransform transform, int ignored) { return Collections.>singleton( evaluationContext.createRootBundle().commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 567def2c5626..576bd26a3a93 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -51,6 +51,7 @@ import org.apache.beam.sdk.util.KeyedWorkItems; import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.TimerInternals.TimerData; +import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -166,12 +167,16 @@ public TransformExecutorService load(StepAndKey stepAndKey) throws Exception { public void start(Collection> roots) { for (AppliedPTransform root : roots) { ConcurrentLinkedQueue> pending = new ConcurrentLinkedQueue<>(); - Collection> initialInputs = rootInputProvider.getInitialInputs(root); - checkState( - !initialInputs.isEmpty(), - "All root transforms must have initial inputs. Got 0 for %s", - root.getFullName()); - pending.addAll(initialInputs); + try { + Collection> initialInputs = rootInputProvider.getInitialInputs(root, 1); + checkState( + !initialInputs.isEmpty(), + "All root transforms must have initial inputs. Got 0 for %s", + root.getFullName()); + pending.addAll(initialInputs); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } pendingRootBundles.put(root, pending); } evaluationContext.initialize(pendingRootBundles); @@ -453,7 +458,6 @@ private void fireTimers() throws Exception { } KeyedWorkItem work = KeyedWorkItems.timersWorkItem(keyTimers.getKey().getKey(), delivery); - LOG.warn("Delivering {} timers for {}", delivery.size(), keyTimers.getKey().getKey()); @SuppressWarnings({"unchecked", "rawtypes"}) CommittedBundle bundle = evaluationContext diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java index 40c7301129fe..19d00406bdd7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java @@ -36,6 +36,11 @@ interface RootInputProvider { *

For source transforms, these should be sufficient that, when provided to the evaluators * produced by {@link TransformEvaluatorFactory#forApplication(AppliedPTransform, * CommittedBundle)}, all of the elements contained in the source are eventually produced. + * + * @param transform the {@link AppliedPTransform} to get initial inputs for. + * @param targetParallelism the target amount of parallelism to obtain from the source. Must be + * greater than or equal to 1. */ - Collection> getInitialInputs(AppliedPTransform transform); + Collection> getInitialInputs( + AppliedPTransform transform, int targetParallelism) throws Exception; } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java index f6335fd01a29..bb5fcd2715f8 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java @@ -52,7 +52,8 @@ private RootProviderRegistry(Map, RootInputProvider> } @Override - public Collection> getInitialInputs(AppliedPTransform transform) { + public Collection> getInitialInputs( + AppliedPTransform transform, int targetParallelism) throws Exception { Class transformClass = transform.getTransform().getClass(); RootInputProvider provider = checkNotNull( @@ -60,6 +61,6 @@ public Collection> getInitialInputs(AppliedPTransform> getInitialInputs(AppliedPTransform transform) { + public Collection> getInitialInputs( + AppliedPTransform transform, int targetParallelism) { return createInputBundle((AppliedPTransform) transform); } @@ -213,6 +214,7 @@ private Collection> createInputBundle( return Collections.>singleton(initialBundle); } } + @AutoValue abstract static class TestStreamIndex { static TestStreamIndex of(TestStream stream) { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java index 4b495e68f4d7..3dd44a71087d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.direct; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ImmutableMap; @@ -81,14 +82,13 @@ public TransformEvaluator forApplication( throws Exception { checkState( !finished.get(), "Tried to get an evaluator for a finished TransformEvaluatorRegistry"); - TransformEvaluatorFactory factory = getFactory(application); + Class transformClass = application.getTransform().getClass(); + TransformEvaluatorFactory factory = + checkNotNull( + factories.get(transformClass), "No evaluator for PTransform type %s", transformClass); return factory.forApplication(application, inputBundle); } - private TransformEvaluatorFactory getFactory(AppliedPTransform application) { - return factories.get(application.getTransform().getClass()); - } - @Override public void cleanup() throws Exception { Collection thrownInCleanup = new ArrayList<>(); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java index 08dc286d48bf..6324bf138b0c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java @@ -19,9 +19,11 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; @@ -33,20 +35,22 @@ import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} * for the {@link Unbounded Read.Unbounded} primitive {@link PTransform}. */ class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory { + private static final Logger LOG = LoggerFactory.getLogger(UnboundedReadEvaluatorFactory.class); // Occasionally close an existing reader and resume from checkpoint, to exercise close-and-resume - @VisibleForTesting static final double DEFAULT_READER_REUSE_CHANCE = 0.95; + private static final double DEFAULT_READER_REUSE_CHANCE = 0.95; private final EvaluationContext evaluationContext; private final double readerReuseChance; @@ -253,24 +257,40 @@ static class InputProvider implements RootInputProvider { } @Override - public Collection> getInitialInputs(AppliedPTransform transform) { - return createInitialSplits((AppliedPTransform) transform); + public Collection> getInitialInputs( + AppliedPTransform transform, int targetParallelism) throws Exception { + return createInitialSplits((AppliedPTransform) transform, targetParallelism); } private Collection> createInitialSplits( - AppliedPTransform> transform) { + AppliedPTransform> transform, int targetParallelism) + throws Exception { UnboundedSource source = transform.getTransform().getSource(); + List> splits = + source.generateInitialSplits(targetParallelism, evaluationContext.getPipelineOptions()); UnboundedReadDeduplicator deduplicator = source.requiresDeduping() ? UnboundedReadDeduplicator.CachedIdDeduplicator.create() : NeverDeduplicator.create(); - UnboundedSourceShard shard = UnboundedSourceShard.unstarted(source, deduplicator); - return Collections.>singleton( - evaluationContext - .>createRootBundle() - .add(WindowedValue.>valueInGlobalWindow(shard)) - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + ImmutableList.Builder> initialShards = ImmutableList.builder(); + if (splits.isEmpty()) { + LOG.debug("Splits of source {} were empty, using empty split"); + initialShards.add( + evaluationContext + .>createRootBundle() + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + } + for (UnboundedSource split : splits) { + UnboundedSourceShard shard = + UnboundedSourceShard.unstarted(split, deduplicator); + initialShards.add( + evaluationContext + .>createRootBundle() + .add(WindowedValue.>valueInGlobalWindow(shard)) + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + } + return initialShards.build(); } } } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java index ee17eaefd10c..7f2f93c6d570 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java @@ -17,10 +17,14 @@ */ package org.apache.beam.runners.direct; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.when; @@ -43,9 +47,11 @@ import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.hamcrest.Matchers; @@ -56,6 +62,8 @@ import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** * Tests for {@link BoundedReadEvaluatorFactory}. @@ -87,7 +95,7 @@ public void boundedSourceInMemoryTransformEvaluatorProducesElements() throws Exc Collection> initialInputs = new BoundedReadEvaluatorFactory.InputProvider(context) - .getInitialInputs(longs.getProducingTransformInternal()); + .getInitialInputs(longs.getProducingTransformInternal(), 1); List> outputs = new ArrayList<>(); for (CommittedBundle shardBundle : initialInputs) { TransformEvaluator evaluator = @@ -114,6 +122,37 @@ public void boundedSourceInMemoryTransformEvaluatorProducesElements() throws Exc gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); } + @Test + public void getInitialInputsSplitsIntoBundles() throws Exception { + when(context.createRootBundle()) + .thenAnswer( + new Answer>() { + @Override + public UncommittedBundle answer(InvocationOnMock invocation) throws Throwable { + return bundleFactory.createRootBundle(); + } + }); + Collection> initialInputs = + new BoundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(longs.getProducingTransformInternal(), 3); + + assertThat(initialInputs, hasSize(allOf(greaterThanOrEqualTo(2), lessThanOrEqualTo(4)))); + + Collection> sources = new ArrayList<>(); + for (CommittedBundle initialInput : initialInputs) { + Iterable>> shards = + (Iterable) initialInput.getElements(); + WindowedValue> shard = Iterables.getOnlyElement(shards); + assertThat(shard.getWindows(), Matchers.contains(GlobalWindow.INSTANCE)); + assertThat(shard.getTimestamp(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + sources.add(shard.getValue().getSource()); + } + + SourceTestUtils.assertSourcesEqualReferenceSource(source, + (List>) sources, + PipelineOptionsFactory.create()); + } + @Test public void boundedSourceInMemoryTransformEvaluatorShardsOfSource() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java index aa7b1783bfbe..03bea5e38f98 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java @@ -129,7 +129,7 @@ public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Except FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(evaluationContext); Collection> initialInputs = new EmptyInputProvider(evaluationContext) - .getInitialInputs(flattened.getProducingTransformInternal()); + .getInitialInputs(flattened.getProducingTransformInternal(), 1); TransformEvaluator emptyEvaluator = factory.forApplication( flattened.getProducingTransformInternal(), Iterables.getOnlyElement(initialInputs)); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java index 60b9c79daa13..94a0d41fd3c2 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java @@ -82,7 +82,7 @@ public void producesElementsInSequence() throws Exception { Collection> initialInputs = new TestStreamEvaluatorFactory.InputProvider(context) - .getInitialInputs(streamVals.getProducingTransformInternal()); + .getInitialInputs(streamVals.getProducingTransformInternal(), 1); @SuppressWarnings("unchecked") CommittedBundle> initialBundle = (CommittedBundle>) Iterables.getOnlyElement(initialInputs); diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java index b78fbeb00820..a7a1f5f50562 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; @@ -33,10 +34,12 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.NoSuchElementException; +import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; @@ -51,9 +54,12 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.WindowedValue; @@ -66,6 +72,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + /** * Tests for {@link UnboundedReadEvaluatorFactory}. */ @@ -92,13 +101,50 @@ public void setup() { when(context.createBundle(longs)).thenReturn(output); } + @Test + public void generatesInitialSplits() throws Exception { + when(context.createRootBundle()).thenAnswer(new Answer>() { + @Override + public UncommittedBundle answer(InvocationOnMock invocation) throws Throwable { + return bundleFactory.createRootBundle(); + } + }); + + int numSplits = 5; + Collection> initialInputs = + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(longs.getProducingTransformInternal(), numSplits); + // CountingSource.unbounded has very good splitting behavior + assertThat(initialInputs, hasSize(numSplits)); + + int readPerSplit = 100; + int totalSize = numSplits * readPerSplit; + Set expectedOutputs = + ContiguousSet.create(Range.closedOpen(0L, (long) totalSize), DiscreteDomain.longs()); + + Collection readItems = new ArrayList<>(totalSize); + for (CommittedBundle initialInput : initialInputs) { + CommittedBundle> shardBundle = + (CommittedBundle>) initialInput; + WindowedValue> shard = + Iterables.getOnlyElement(shardBundle.getElements()); + assertThat(shard.getTimestamp(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(shard.getWindows(), Matchers.contains(GlobalWindow.INSTANCE)); + UnboundedSource shardSource = shard.getValue().getSource(); + readItems.addAll(SourceTestUtils.readNItemsFromUnstartedReader(shardSource.createReader( + PipelineOptionsFactory.create(), + null), readPerSplit)); + } + assertThat(readItems, containsInAnyOrder(expectedOutputs.toArray(new Long[0]))); + } + @Test public void unboundedSourceInMemoryTransformEvaluatorProducesElements() throws Exception { when(context.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); Collection> initialInputs = new UnboundedReadEvaluatorFactory.InputProvider(context) - .getInitialInputs(longs.getProducingTransformInternal()); + .getInitialInputs(longs.getProducingTransformInternal(), 1); CommittedBundle inputShards = Iterables.getOnlyElement(initialInputs); UnboundedSourceShard inputShard = @@ -143,7 +189,8 @@ public void unboundedSourceWithDuplicatesMultipleCalls() throws Exception { when(context.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); Collection> initialInputs = - new UnboundedReadEvaluatorFactory.InputProvider(context).getInitialInputs(sourceTransform); + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(sourceTransform, 1); UncommittedBundle output = bundleFactory.createBundle(pcollection); when(context.createBundle(pcollection)).thenReturn(output); @@ -198,6 +245,8 @@ public void evaluatorReusesReader() throws Exception { .commit(Instant.now()); UnboundedReadEvaluatorFactory factory = new UnboundedReadEvaluatorFactory(context, 1.0 /* Always reuse */); + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(pcollection.getProducingTransformInternal(), 1); TransformEvaluator> evaluator = factory.forApplication(sourceTransform, inputBundle); evaluator.processElement(shard); From d6596284bc959864205614817b9b601ed09577ca Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 10 Oct 2016 13:13:24 -0700 Subject: [PATCH 2/9] fixup! Perform initial splitting in the DirectRunner --- .../runners/direct/BoundedReadEvaluatorFactory.java | 7 ------- .../org/apache/beam/runners/direct/DirectOptions.java | 7 ++++--- .../beam/runners/direct/EmptyInputProvider.java | 11 +++-------- .../direct/ExecutorServiceParallelExecutor.java | 6 ------ .../runners/direct/UnboundedReadEvaluatorFactory.java | 7 ------- .../apache/beam/runners/direct/WatermarkManager.java | 1 + .../direct/BoundedReadEvaluatorFactoryTest.java | 2 +- .../runners/direct/FlattenEvaluatorFactoryTest.java | 9 +++------ 8 files changed, 12 insertions(+), 38 deletions(-) diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java index f81cd5bc21fe..843dcd66d5d9 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java @@ -147,13 +147,6 @@ Collection>> createInitialSplits( source.splitIntoBundles(bytesPerBundle, options); ImmutableList.Builder>> shards = ImmutableList.builder(); - if (bundles.isEmpty()) { - LOG.debug("Splits of source {} were empty, using empty split"); - shards.add( - evaluationContext - .>createRootBundle() - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); - } for (BoundedSource bundle : bundles) { CommittedBundle> inputShard = evaluationContext diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java index f31f84430f77..b2c4f4762194 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java @@ -65,9 +65,10 @@ public interface DirectOptions extends PipelineOptions, ApplicationNameOptions { void setEnforceEncodability(boolean test); @Default.InstanceFactory(AvailableParallelismFactory.class) - @Description("Controls the amount of target parallelism the DirectRunner will use. Defaults to" - + " the greater of the number of available processors as returned by the Runtime and 3. Must" - + " be a value greater than zero.") + @Description( + "Controls the amount of target parallelism the DirectRunner will use. Defaults to" + + " the greater of the number of available processors and 3. Must be a value greater" + + " than zero.") int getTargetParallelism(); void setTargetParallelism(int target); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java index fda2b287bc96..10589435203f 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java @@ -21,8 +21,6 @@ import java.util.Collections; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; /** * A {@link RootInputProvider} that provides a singleton empty bundle. @@ -37,14 +35,11 @@ class EmptyInputProvider implements RootInputProvider { /** * {@inheritDoc}. * - *

Returns a single empty bundle. This bundle ensures that any {@link PTransform PTransforms} - * that consume from the output of the provided {@link AppliedPTransform} have watermarks updated - * as appropriate. + *

Returns an empty collection. */ @Override public Collection> getInitialInputs( - AppliedPTransform transform, int ignored) { - return Collections.>singleton( - evaluationContext.createRootBundle().commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + AppliedPTransform transform, int targetParallelism) { + return Collections.emptyList(); } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 576bd26a3a93..c4b3dfa207cf 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.direct; -import static com.google.common.base.Preconditions.checkState; - import com.google.auto.value.AutoValue; import com.google.common.base.MoreObjects; import com.google.common.base.Optional; @@ -169,10 +167,6 @@ public void start(Collection> roots) { ConcurrentLinkedQueue> pending = new ConcurrentLinkedQueue<>(); try { Collection> initialInputs = rootInputProvider.getInitialInputs(root, 1); - checkState( - !initialInputs.isEmpty(), - "All root transforms must have initial inputs. Got 0 for %s", - root.getFullName()); pending.addAll(initialInputs); } catch (Exception e) { throw UserCodeException.wrap(e); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java index 6324bf138b0c..18d3d0aa7ecc 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java @@ -274,13 +274,6 @@ private Collection> createInitialSplits( : NeverDeduplicator.create(); ImmutableList.Builder> initialShards = ImmutableList.builder(); - if (splits.isEmpty()) { - LOG.debug("Splits of source {} were empty, using empty split"); - initialShards.add( - evaluationContext - .>createRootBundle() - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); - } for (UnboundedSource split : splits) { UnboundedSourceShard shard = UnboundedSourceShard.unstarted(split, deduplicator); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 82a6e4f3fbe3..c55a036a676e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -799,6 +799,7 @@ public void initialize( for (CommittedBundle initialBundle : rootEntry.getValue()) { rootWms.addPending(initialBundle); } + pendingRefreshes.offer(rootEntry.getKey()); } } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java index 7f2f93c6d570..8a76a5362243 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java @@ -136,7 +136,7 @@ public UncommittedBundle answer(InvocationOnMock invocation) throws Throwable new BoundedReadEvaluatorFactory.InputProvider(context) .getInitialInputs(longs.getProducingTransformInternal(), 3); - assertThat(initialInputs, hasSize(allOf(greaterThanOrEqualTo(2), lessThanOrEqualTo(4)))); + assertThat(initialInputs, hasSize(allOf(greaterThanOrEqualTo(3), lessThanOrEqualTo(4)))); Collection> sources = new ArrayList<>(); for (CommittedBundle initialInput : initialInputs) { diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java index 03bea5e38f98..417aa6406ca6 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java @@ -24,13 +24,13 @@ import static org.mockito.Mockito.when; import com.google.common.collect.Iterables; -import java.util.Collection; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; @@ -122,17 +122,14 @@ public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Except PCollection flattened = list.apply(Flatten.pCollections()); EvaluationContext evaluationContext = mock(EvaluationContext.class); - when(evaluationContext.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); when(evaluationContext.createBundle(flattened)) .thenReturn(bundleFactory.createBundle(flattened)); FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(evaluationContext); - Collection> initialInputs = - new EmptyInputProvider(evaluationContext) - .getInitialInputs(flattened.getProducingTransformInternal(), 1); TransformEvaluator emptyEvaluator = factory.forApplication( - flattened.getProducingTransformInternal(), Iterables.getOnlyElement(initialInputs)); + flattened.getProducingTransformInternal(), + bundleFactory.createRootBundle().commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); TransformResult leftSideResult = emptyEvaluator.finishBundle(); From 9e885f56b40504015b53092197c9a03cc10b9525 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 10 Oct 2016 13:57:23 -0700 Subject: [PATCH 3/9] fixup! Perform initial splitting in the DirectRunner --- .../apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java index f21e6c00fc6d..3ca2b6458657 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigtable; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Verify.verifyNotNull; import static org.apache.beam.sdk.testing.SourceTestUtils.assertSourcesEqualReferenceSource; import static org.apache.beam.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; @@ -222,6 +223,7 @@ public void testReadingFailsTableDoesNotExist() throws Exception { public void testReadingEmptyTable() throws Exception { final String table = "TEST-EMPTY-TABLE"; service.createTable(table); + service.setupSampleRowKeys(table, 1, 1L); runReadTest(defaultRead.withTableId(table), new ArrayList()); logged.verifyInfo("Closing reader after reading 0 records."); @@ -234,8 +236,9 @@ public void testReading() throws Exception { final int numRows = 1001; List testRows = makeTableData(table, numRows); + service.setupSampleRowKeys(table, 3, 1000L); runReadTest(defaultRead.withTableId(table), testRows); - logged.verifyInfo(String.format("Closing reader after reading %d records.", numRows)); + logged.verifyInfo(String.format("Closing reader after reading %d records.", numRows / 3)); } /** A {@link Predicate} that a {@link Row Row's} key matches the given regex. */ @@ -284,6 +287,7 @@ public void testReadingWithKeyRange() throws Exception { ByteKey startKey = ByteKey.copyFrom("key000000100".getBytes()); ByteKey endKey = ByteKey.copyFrom("key000000300".getBytes()); + service.setupSampleRowKeys(table, numRows / 10, "key000000100".length()); // Test prefix: [beginning, startKey). final ByteKeyRange prefixRange = ByteKeyRange.ALL_KEYS.withEndKey(startKey); List prefixRows = filterToRange(testRows, prefixRange); @@ -336,6 +340,7 @@ public boolean apply(@Nullable Row input) { RowFilter filter = RowFilter.newBuilder().setRowKeyRegexFilter(ByteString.copyFromUtf8(regex)).build(); + service.setupSampleRowKeys(table, 5, 10L); runReadTest( defaultRead.withTableId(table).withRowFilter(filter), Lists.newArrayList(filteredRows)); @@ -743,7 +748,7 @@ public FakeBigtableWriter openForWriting(String tableId) { @Override public List getSampleRowKeys(BigtableSource source) { List samples = sampleRowKeys.get(source.getTableId()); - checkArgument(samples != null, "No samples found for table %s", source.getTableId()); + checkNotNull(samples, "No samples found for table %s", source.getTableId()); return samples; } From f7416ad235ef450f9155d93b2a2fa18cbef6cd61 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 10 Oct 2016 16:15:28 -0700 Subject: [PATCH 4/9] fixup! Perform initial splitting in the DirectRunner Write an "Export Job" file. TODO: clean up --- .../sdk/io/gcp/bigquery/BigQueryIOTest.java | 235 +++++++++++++++++- 1 file changed, 227 insertions(+), 8 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index 05a7c5c0509e..96bc45707bca 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigquery; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.fromJsonString; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.toJsonString; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; @@ -32,7 +33,9 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; +import com.google.api.client.json.GenericJson; import com.google.api.client.util.Data; +import com.google.api.services.bigquery.model.Dataset; import com.google.api.services.bigquery.model.ErrorProto; import com.google.api.services.bigquery.model.Job; import com.google.api.services.bigquery.model.JobConfigurationExtract; @@ -50,21 +53,35 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import com.google.common.base.Strings; +import com.google.common.collect.HashBasedTable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import javax.annotation.Nullable; +import org.apache.avro.Schema; +import org.apache.avro.Schema.Field; +import org.apache.avro.Schema.Type; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; @@ -110,7 +127,9 @@ import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.IOChannelFactory; import org.apache.beam.sdk.util.IOChannelUtils; +import org.apache.beam.sdk.util.MimeTypes; import org.apache.beam.sdk.util.PCollectionViews; +import org.apache.beam.sdk.util.Transport; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -144,6 +163,7 @@ public class BigQueryIOTest implements Serializable { Status.SUCCEEDED, new Job().setStatus(new JobStatus()), Status.FAILED, new Job().setStatus(new JobStatus().setErrorResult(new ErrorProto()))); + private static class FakeBigQueryServices implements BigQueryServices { private String[] jsonTableRowReturns = new String[0]; @@ -276,6 +296,18 @@ public FakeJobService getJobReturns(Object... getJobReturns) { */ public FakeJobService pollJobReturns(Object... pollJobReturns) { this.pollJobReturns = pollJobReturns; + for (int i = 0; i < pollJobReturns.length; i++) { + if (pollJobReturns[i] instanceof Job) { + try { + // Job is not serializable, so encode the job as a byte array. + pollJobReturns[i] = Transport.getJsonFactory().toByteArray(pollJobReturns[i]); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format( + "Could not encode Job %s via available JSON factory", pollJobReturns[i])); + } + } + } return this; } @@ -290,25 +322,25 @@ public FakeJobService verifyExecutingProject(String executingProject) { @Override public void startLoadJob(JobReference jobRef, JobConfigurationLoad loadConfig) throws InterruptedException, IOException { - startJob(jobRef); + startJob(jobRef, loadConfig); } @Override public void startExtractJob(JobReference jobRef, JobConfigurationExtract extractConfig) throws InterruptedException, IOException { - startJob(jobRef); + startJob(jobRef, extractConfig); } @Override public void startQueryJob(JobReference jobRef, JobConfigurationQuery query) throws IOException, InterruptedException { - startJob(jobRef); + startJob(jobRef, query); } @Override public void startCopyJob(JobReference jobRef, JobConfigurationTableCopy copyConfig) throws IOException, InterruptedException { - startJob(jobRef); + startJob(jobRef, copyConfig); } @Override @@ -323,8 +355,14 @@ public Job pollJob(JobReference jobRef, int maxAttempts) if (pollJobStatusCallsCount < pollJobReturns.length) { Object ret = pollJobReturns[pollJobStatusCallsCount++]; - if (ret instanceof Job) { - return (Job) ret; + if (ret instanceof byte[]) { + try { + return Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream((byte[]) ret)) + .parse(Job.class); + } catch (IOException e) { + throw new RuntimeException("Couldn't parse encoded Job", e); + } } else if (ret instanceof Status) { return JOB_STATUS_MAP.get(ret); } else if (ret instanceof InterruptedException) { @@ -338,7 +376,8 @@ public Job pollJob(JobReference jobRef, int maxAttempts) } } - private void startJob(JobReference jobRef) throws IOException, InterruptedException { + private void startJob(JobReference jobRef, GenericJson config) + throws IOException, InterruptedException { if (!Strings.isNullOrEmpty(executingProject)) { checkArgument( jobRef.getProjectId().equals(executingProject), @@ -352,6 +391,11 @@ private void startJob(JobReference jobRef) throws IOException, InterruptedExcept throw (IOException) ret; } else if (ret instanceof InterruptedException) { throw (InterruptedException) ret; + } else if (ret instanceof SerializableFunction) { + SerializableFunction fn = + (SerializableFunction) ret; + fn.apply(config); + return; } else { return; } @@ -394,6 +438,90 @@ public Job getJob(JobReference jobRef) throws InterruptedException { } } + /** A fake dataset service that can be serialized, for use in testReadFromTable. */ + private static class FakeDatasetService implements DatasetService, Serializable { + private final com.google.common.collect.Table> tables = + HashBasedTable.create(); + + public FakeDatasetService withTable( + String projectId, String datasetId, String tableId, Table table) throws IOException { + Map dataset = tables.get(projectId, datasetId); + if (dataset == null) { + dataset = new HashMap<>(); + tables.put(projectId, datasetId, dataset); + } + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + dataset.put(tableId, Transport.getJsonFactory().toByteArray(table)); + return this; + } + + @Override + public Table getTable(String projectId, String datasetId, String tableId) + throws InterruptedException, IOException { + Table table = deserTable(projectId, datasetId, tableId); + return table; + } + + private Table deserTable(String projectId, String datasetId, String tableId) + throws IOException { + Map dataset = + checkNotNull( + tables.get(projectId, datasetId), + "Tried to get a table %s:%s.%s from %s, but no such table was set", + projectId, + datasetId, + tableId, + FakeDatasetService.class.getSimpleName()); + byte[] tableBytes = checkNotNull(dataset.get(tableId), + "Tried to get a table %s:%s.%s from %s, but no such table was set", + projectId, + datasetId, + tableId, + FakeDatasetService.class.getSimpleName()); + return Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream(tableBytes)) + .parse(Table.class); + } + + @Override + public void deleteTable(String projectId, String datasetId, String tableId) + throws IOException, InterruptedException { + } + + @Override + public boolean isTableEmpty(String projectId, String datasetId, String tableId) + throws IOException, InterruptedException { + Long numBytes = deserTable(projectId, datasetId, tableId).getNumBytes(); + return numBytes == null || numBytes == 0L; + } + + @Override + public Dataset getDataset( + String projectId, String datasetId) throws IOException, InterruptedException { + return null; + } + + @Override + public void createDataset( + String projectId, String datasetId, String location, String description) + throws IOException, InterruptedException { + + } + + @Override + public void deleteDataset(String projectId, String datasetId) + throws IOException, InterruptedException { + + } + + @Override + public long insertAll( + TableReference ref, List rowList, @Nullable List insertIdList) + throws IOException, InterruptedException { + return 0; + } + } + @Rule public transient ExpectedException thrown = ExpectedException.none(); @Rule public transient ExpectedLogs logged = ExpectedLogs.none(BigQueryIO.class); @Rule public transient TemporaryFolder testFolder = new TemporaryFolder(); @@ -627,11 +755,65 @@ public void testReadFromTable() throws IOException { bqOptions.setProject("defaultProject"); bqOptions.setTempLocation(testFolder.newFolder("BigQueryIOTest").getAbsolutePath()); + Job job = new Job(); + JobStatus status = new JobStatus(); + job.setStatus(status); + JobStatistics jobStats = new JobStatistics(); + job.setStatistics(jobStats); + JobStatistics4 extract = new JobStatistics4(); + jobStats.setExtract(extract); + extract.setDestinationUriFileCounts(ImmutableList.of(1L)); + + Table sometable = new Table(); + sometable.setSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))); + sometable.setNumBytes(1024L * 1024L); + FakeDatasetService fakeDatasetService = + new FakeDatasetService() + .withTable("non-executing-project", "somedataset", "sometable", sometable); + SerializableFunction schemaGenerator = + new SerializableFunction() { + @Override + public Schema apply(Void input) { + return Schema.createRecord( + "TestTableRow", + "Table Rows in BigQueryIOTest", + "org.apache.beam.sdk.io.gcp.bigquery", + false, + ImmutableList.of( + new Field( + "name", + Schema.createUnion(Schema.create(Type.STRING), Schema.create(Type.NULL)), + "The name field", + ""), + new Field( + "number", + Schema.createUnion(Schema.create(Type.LONG), Schema.create(Type.NULL)), + "The number field", + 0))); + } + }; + Collection> records = + ImmutableList.>builder() + .add(ImmutableMap.builder().put("name", "a").put("number", 1L).build()) + .add(ImmutableMap.builder().put("name", "b").put("number", 2L).build()) + .add(ImmutableMap.builder().put("name", "c").put("number", 3L).build()) + .build(); + + SerializableFunction onStartJob = + new WriteExtractFiles(schemaGenerator, records); + FakeBigQueryServices fakeBqServices = new FakeBigQueryServices() .withJobService(new FakeJobService() - .startJobReturns("done", "done") + .startJobReturns(onStartJob, "done") + .pollJobReturns(job) .getJobReturns((Job) null) .verifyExecutingProject(bqOptions.getProject())) + .withDatasetService(fakeDatasetService) .readerReturns( toJsonString(new TableRow().set("name", "a").set("number", 1)), toJsonString(new TableRow().set("name", "b").set("number", 2)), @@ -1701,4 +1883,41 @@ public boolean accept(File pathname) { return pathname.isFile(); }}).length); } + + private class WriteExtractFiles implements SerializableFunction { + private final SerializableFunction schemaGenerator; + private final Collection> records; + + private WriteExtractFiles( + SerializableFunction schemaGenerator, + Collection> records) { + this.schemaGenerator = schemaGenerator; + this.records = records; + } + + @Override + public Void apply(GenericJson input) { + List destinations = (List) input.get("destinationUris"); + for (String destination : destinations) { + String newDest = destination.replace("*", "000000000000"); + Schema schema = schemaGenerator.apply(null); + try (WritableByteChannel channel = IOChannelUtils.create(newDest, MimeTypes.BINARY); + DataFileWriter tableRowWriter = + new DataFileWriter<>(new GenericDatumWriter(schema)) + .create(schema, Channels.newOutputStream(channel))) { + for (Map record : records) { + GenericRecordBuilder genericRecordBuilder = new GenericRecordBuilder(schema); + for (Map.Entry field : record.entrySet()) { + genericRecordBuilder.set(field.getKey(), field.getValue()); + } + tableRowWriter.append(genericRecordBuilder.build()); + } + } catch (IOException e) { + throw new IllegalStateException( + String.format("Could not create destination for extract job %s", destination), e); + } + } + return null; + } + } } From 20d49a8a09615aec447a041eb29bc900453a7fcc Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Tue, 11 Oct 2016 12:54:51 -0700 Subject: [PATCH 5/9] fixup! fixup! Perform initial splitting in the DirectRunner --- .../io/gcp/bigquery/BigQueryAvroUtils.java | 69 +++++++-- .../gcp/bigquery/BigQueryAvroUtilsTest.java | 132 ++++++++++++++---- .../sdk/io/gcp/bigquery/BigQueryIOTest.java | 21 +-- 3 files changed, 162 insertions(+), 60 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java index 6a9ea6bc1c39..20dd2d0568d1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java @@ -28,8 +28,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.BaseEncoding; - import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; import org.apache.avro.Schema; @@ -47,6 +47,19 @@ */ class BigQueryAvroUtils { + public static final ImmutableMap BIG_QUERY_TO_AVRO_TYPES = + ImmutableMap.builder() + .put("STRING", Type.STRING) + .put("BYTES", Type.BYTES) + .put("INTEGER", Type.LONG) + .put("FLOAT", Type.DOUBLE) + .put("BOOLEAN", Type.BOOLEAN) + .put("TIMESTAMP", Type.LONG) + .put("RECORD", Type.RECORD) + .put("DATE", Type.STRING) + .put("DATETIME", Type.STRING) + .put("TIME", Type.STRING) + .build(); /** * Formats BigQuery seconds-since-epoch into String matching JSON export. Thread-safe and * immutable. @@ -154,23 +167,10 @@ private static Object convertRequiredField( // REQUIRED fields are represented as the corresponding Avro types. For example, a BigQuery // INTEGER type maps to an Avro LONG type. checkNotNull(v, "REQUIRED field %s should not be null", fieldSchema.getName()); - ImmutableMap fieldMap = - ImmutableMap.builder() - .put("STRING", Type.STRING) - .put("BYTES", Type.BYTES) - .put("INTEGER", Type.LONG) - .put("FLOAT", Type.DOUBLE) - .put("BOOLEAN", Type.BOOLEAN) - .put("TIMESTAMP", Type.LONG) - .put("RECORD", Type.RECORD) - .put("DATE", Type.STRING) - .put("DATETIME", Type.STRING) - .put("TIME", Type.STRING) - .build(); // Per https://cloud.google.com/bigquery/docs/reference/v2/tables#schema, the type field // is required, so it may not be null. String bqType = fieldSchema.getType(); - Type expectedAvroType = fieldMap.get(bqType); + Type expectedAvroType = BIG_QUERY_TO_AVRO_TYPES.get(bqType); verifyNotNull(expectedAvroType, "Unsupported BigQuery type: %s", bqType); verify( avroType == expectedAvroType, @@ -248,4 +248,43 @@ private static Object convertNullableField( } return convertRequiredField(unionTypes.get(1).getType(), fieldSchema, v); } + + static Schema toGenericAvroSchema(String schemaName, List fieldSchemas) { + List avroFields = new ArrayList<>(); + for (TableFieldSchema bigQueryField : fieldSchemas) { + avroFields.add(convertField(bigQueryField)); + } + return Schema.createRecord( + schemaName, + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for " + schemaName, + false, + avroFields); + } + + private static Field convertField(TableFieldSchema bigQueryField) { + Type avroType = BIG_QUERY_TO_AVRO_TYPES.get(bigQueryField.getType()); + Schema elementSchema; + if (avroType == Type.RECORD) { + elementSchema = toGenericAvroSchema(bigQueryField.getName(), bigQueryField.getFields()); + } else { + elementSchema = Schema.create(avroType); + } + Schema fieldSchema; + if (bigQueryField.getMode() == null || bigQueryField.getMode().equals("NULLABLE")) { + fieldSchema = Schema.createUnion(Schema.create(Type.NULL), elementSchema); + } else if (bigQueryField.getMode().equals("REQUIRED")) { + fieldSchema = elementSchema; + } else if (bigQueryField.getMode().equals("REPEATED")) { + fieldSchema = Schema.createArray(elementSchema); + } else { + throw new IllegalArgumentException( + String.format("Unknown BigQuery Field Mode: %s", bigQueryField.getMode())); + } + return new Field( + bigQueryField.getName(), + fieldSchema, + bigQueryField.getDescription(), + (Object) null /* Cast to avoid deprecated JsonNode constructor. */); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java index 1d3ea812e8bb..644c545f1dc7 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java @@ -17,18 +17,22 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.io.BaseEncoding; - import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import org.apache.avro.Schema; +import org.apache.avro.Schema.Field; +import org.apache.avro.Schema.Type; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.avro.reflect.Nullable; @@ -44,36 +48,37 @@ */ @RunWith(JUnit4.class) public class BigQueryAvroUtilsTest { + private List subFields = Lists.newArrayList( + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); + /* + * Note that the quality and quantity fields do not have their mode set, so they should default + * to NULLABLE. This is an important test of BigQuery semantics. + * + * All the other fields we set in this function are required on the Schema response. + * + * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema + */ + private List fields = + Lists.newArrayList( + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), + new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, + new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, + new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), + new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), + new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryDatetime") + .setType("DATETIME").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), + new TableFieldSchema().setName("scion").setType("RECORD").setMode("NULLABLE") + .setFields(subFields), + new TableFieldSchema().setName("associates").setType("RECORD").setMode("REPEATED") + .setFields(subFields)); + @Test public void testConvertGenericRecordToTableRow() throws Exception { TableSchema tableSchema = new TableSchema(); - List subFields = Lists.newArrayList( - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); - /* - * Note that the quality and quantity fields do not have their mode set, so they should default - * to NULLABLE. This is an important test of BigQuery semantics. - * - * All the other fields we set in this function are required on the Schema response. - * - * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema - */ - List fields = - Lists.newArrayList( - new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), - new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, - new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, - new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), - new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), - new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDatetime") - .setType("DATETIME").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), - new TableFieldSchema().setName("scion").setType("RECORD").setMode("NULLABLE") - .setFields(subFields), - new TableFieldSchema().setName("associates").setType("RECORD").setMode("REPEATED") - .setFields(subFields)); tableSchema.setFields(fields); Schema avroSchema = AvroCoder.of(Bird.class).getSchema(); @@ -132,6 +137,77 @@ public void testConvertGenericRecordToTableRow() throws Exception { } } + @Test + public void testConvertBigQuerySchemaToAvroSchema() { + TableSchema tableSchema = new TableSchema(); + tableSchema.setFields(fields); + Schema avroSchema = + BigQueryAvroUtils.toGenericAvroSchema("testSchema", tableSchema.getFields()); + + assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Type.LONG))); + assertThat( + avroSchema.getField("species").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("quality").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.DOUBLE)))); + assertThat( + avroSchema.getField("quantity").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); + assertThat( + avroSchema.getField("birthday").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); + assertThat( + avroSchema.getField("flighted").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BOOLEAN)))); + assertThat( + avroSchema.getField("sound").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BYTES)))); + assertThat( + avroSchema.getField("anniversaryDate").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("anniversaryDatetime").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("anniversaryTime").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + + assertThat( + avroSchema.getField("scion").schema(), + equalTo( + Schema.createUnion( + Schema.create(Type.NULL), + Schema.createRecord( + "scion", + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for scion", + false, + ImmutableList.of( + new Field( + "species", + Schema.createUnion( + Schema.create(Type.NULL), Schema.create(Type.STRING)), + null, + (Object) null)))))); + assertThat( + avroSchema.getField("associates").schema(), + equalTo( + Schema.createArray( + Schema.createRecord( + "associates", + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for associates", + false, + ImmutableList.of( + new Field( + "species", + Schema.createUnion( + Schema.create(Type.NULL), Schema.create(Type.STRING)), + null, + (Object) null)))))); + } + /** * Pojo class used as the record type in tests. */ diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index 96bc45707bca..e00e80e257be 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -76,8 +76,6 @@ import java.util.Set; import javax.annotation.Nullable; import org.apache.avro.Schema; -import org.apache.avro.Schema.Field; -import org.apache.avro.Schema.Type; import org.apache.avro.file.DataFileWriter; import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; @@ -779,22 +777,11 @@ public void testReadFromTable() throws IOException { new SerializableFunction() { @Override public Schema apply(Void input) { - return Schema.createRecord( - "TestTableRow", - "Table Rows in BigQueryIOTest", - "org.apache.beam.sdk.io.gcp.bigquery", - false, + return BigQueryAvroUtils.toGenericAvroSchema( + "sometable", ImmutableList.of( - new Field( - "name", - Schema.createUnion(Schema.create(Type.STRING), Schema.create(Type.NULL)), - "The name field", - ""), - new Field( - "number", - Schema.createUnion(Schema.create(Type.LONG), Schema.create(Type.NULL)), - "The number field", - 0))); + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER"))); } }; Collection> records = From fe6ffd981ef9f902e1cb65ff4f471e789a84fb70 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Wed, 12 Oct 2016 09:32:52 -0700 Subject: [PATCH 6/9] fixup! fixup! Perform initial splitting in the DirectRunner --- .../apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index e00e80e257be..08141a627334 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -484,6 +484,7 @@ private Table deserTable(String projectId, String datasetId, String tableId) @Override public void deleteTable(String projectId, String datasetId, String tableId) throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); } @Override @@ -496,27 +497,27 @@ public boolean isTableEmpty(String projectId, String datasetId, String tableId) @Override public Dataset getDataset( String projectId, String datasetId) throws IOException, InterruptedException { - return null; + throw new UnsupportedOperationException("Unsupported"); } @Override public void createDataset( String projectId, String datasetId, String location, String description) throws IOException, InterruptedException { - + throw new UnsupportedOperationException("Unsupported"); } @Override public void deleteDataset(String projectId, String datasetId) throws IOException, InterruptedException { - + throw new UnsupportedOperationException("Unsupported"); } @Override public long insertAll( TableReference ref, List rowList, @Nullable List insertIdList) throws IOException, InterruptedException { - return 0; + throw new UnsupportedOperationException("Unsupported"); } } From 71e0010545141889a4686b2582338c1f6a5a5db0 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Thu, 13 Oct 2016 10:41:48 -0700 Subject: [PATCH 7/9] fixup! fixup! fixup! Perform initial splitting in the DirectRunner --- .../runners/direct/UnboundedReadEvaluatorFactoryTest.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java index a7a1f5f50562..76acb031f727 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java @@ -131,9 +131,11 @@ public UncommittedBundle answer(InvocationOnMock invocation) throws Throwable assertThat(shard.getTimestamp(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); assertThat(shard.getWindows(), Matchers.contains(GlobalWindow.INSTANCE)); UnboundedSource shardSource = shard.getValue().getSource(); - readItems.addAll(SourceTestUtils.readNItemsFromUnstartedReader(shardSource.createReader( - PipelineOptionsFactory.create(), - null), readPerSplit)); + readItems.addAll( + SourceTestUtils.readNItemsFromUnstartedReader( + shardSource.createReader( + PipelineOptionsFactory.create(), null /* No starting checkpoint */), + readPerSplit)); } assertThat(readItems, containsInAnyOrder(expectedOutputs.toArray(new Long[0]))); } From 0611a86b5b0e62b9377fb1711574ac255d403cd3 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Thu, 13 Oct 2016 13:38:09 -0700 Subject: [PATCH 8/9] fixup! fixup! fixup! fixup! Perform initial splitting in the DirectRunner --- .../sdk/io/gcp/bigquery/BigQueryIOTest.java | 155 +++++++++++++----- 1 file changed, 118 insertions(+), 37 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index 08141a627334..3c9b79cfe58d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.fromJsonString; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.toJsonString; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; @@ -57,16 +58,20 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import com.google.common.collect.Table.Cell; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileFilter; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamException; import java.io.Serializable; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -294,18 +299,6 @@ public FakeJobService getJobReturns(Object... getJobReturns) { */ public FakeJobService pollJobReturns(Object... pollJobReturns) { this.pollJobReturns = pollJobReturns; - for (int i = 0; i < pollJobReturns.length; i++) { - if (pollJobReturns[i] instanceof Job) { - try { - // Job is not serializable, so encode the job as a byte array. - pollJobReturns[i] = Transport.getJsonFactory().toByteArray(pollJobReturns[i]); - } catch (IOException e) { - throw new IllegalArgumentException( - String.format( - "Could not encode Job %s via available JSON factory", pollJobReturns[i])); - } - } - } return this; } @@ -353,14 +346,8 @@ public Job pollJob(JobReference jobRef, int maxAttempts) if (pollJobStatusCallsCount < pollJobReturns.length) { Object ret = pollJobReturns[pollJobStatusCallsCount++]; - if (ret instanceof byte[]) { - try { - return Transport.getJsonFactory() - .createJsonParser(new ByteArrayInputStream((byte[]) ret)) - .parse(Job.class); - } catch (IOException e) { - throw new RuntimeException("Couldn't parse encoded Job", e); - } + if (ret instanceof Job) { + return (Job) ret; } else if (ret instanceof Status) { return JOB_STATUS_MAP.get(ret); } else if (ret instanceof InterruptedException) { @@ -434,35 +421,81 @@ public Job getJob(JobReference jobRef) throws InterruptedException { "Exceeded expected number of calls: " + getJobReturns.length); } } + + ////////////////////////////////// SERIALIZATION METHODS //////////////////////////////////// + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeObject(replaceJobsWithBytes(startJobReturns)); + out.writeObject(replaceJobsWithBytes(pollJobReturns)); + out.writeObject(replaceJobsWithBytes(getJobReturns)); + out.writeObject(executingProject); + } + + private Object[] replaceJobsWithBytes(Object[] objs) { + Object[] copy = Arrays.copyOf(objs, objs.length); + for (int i = 0; i < copy.length; i++) { + checkArgument( + copy[i] == null || copy[i] instanceof Serializable || copy[i] instanceof Job, + "Only serializable elements and jobs can be added add to Job Returns"); + if (copy[i] instanceof Job) { + try { + // Job is not serializable, so encode the job as a byte array. + copy[i] = Transport.getJsonFactory().toByteArray(copy[i]); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format("Could not encode Job %s via available JSON factory", copy[i])); + } + } + } + return copy; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.startJobReturns = replaceBytesWithJobs(in.readObject()); + this.pollJobReturns = replaceBytesWithJobs(in.readObject()); + this.getJobReturns = replaceBytesWithJobs(in.readObject()); + this.executingProject = (String) in.readObject(); + } + + private Object[] replaceBytesWithJobs(Object obj) throws IOException { + checkState(obj instanceof Object[]); + Object[] objs = (Object[]) obj; + Object[] copy = Arrays.copyOf(objs, objs.length); + for (int i = 0; i < copy.length; i++) { + if (copy[i] instanceof byte[]) { + Job job = Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream((byte[]) copy[i])) + .parse(Job.class); + copy[i] = job; + } + } + return copy; + } + + private void readObjectNoData() throws ObjectStreamException { + } + } /** A fake dataset service that can be serialized, for use in testReadFromTable. */ private static class FakeDatasetService implements DatasetService, Serializable { - private final com.google.common.collect.Table> tables = + private com.google.common.collect.Table> tables = HashBasedTable.create(); public FakeDatasetService withTable( String projectId, String datasetId, String tableId, Table table) throws IOException { - Map dataset = tables.get(projectId, datasetId); + Map dataset = tables.get(projectId, datasetId); if (dataset == null) { dataset = new HashMap<>(); tables.put(projectId, datasetId, dataset); } - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - dataset.put(tableId, Transport.getJsonFactory().toByteArray(table)); + dataset.put(tableId, table); return this; } @Override public Table getTable(String projectId, String datasetId, String tableId) throws InterruptedException, IOException { - Table table = deserTable(projectId, datasetId, tableId); - return table; - } - - private Table deserTable(String projectId, String datasetId, String tableId) - throws IOException { - Map dataset = + Map dataset = checkNotNull( tables.get(projectId, datasetId), "Tried to get a table %s:%s.%s from %s, but no such table was set", @@ -470,15 +503,12 @@ private Table deserTable(String projectId, String datasetId, String tableId) datasetId, tableId, FakeDatasetService.class.getSimpleName()); - byte[] tableBytes = checkNotNull(dataset.get(tableId), + return checkNotNull(dataset.get(tableId), "Tried to get a table %s:%s.%s from %s, but no such table was set", projectId, datasetId, tableId, FakeDatasetService.class.getSimpleName()); - return Transport.getJsonFactory() - .createJsonParser(new ByteArrayInputStream(tableBytes)) - .parse(Table.class); } @Override @@ -490,7 +520,7 @@ public void deleteTable(String projectId, String datasetId, String tableId) @Override public boolean isTableEmpty(String projectId, String datasetId, String tableId) throws IOException, InterruptedException { - Long numBytes = deserTable(projectId, datasetId, tableId).getNumBytes(); + Long numBytes = getTable(projectId, datasetId, tableId).getNumBytes(); return numBytes == null || numBytes == 0L; } @@ -519,6 +549,57 @@ public long insertAll( throws IOException, InterruptedException { throw new UnsupportedOperationException("Unsupported"); } + + ////////////////////////////////// SERIALIZATION METHODS //////////////////////////////////// + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeObject(replaceTablesWithBytes(this.tables)); + } + + private com.google.common.collect.Table> + replaceTablesWithBytes( + com.google.common.collect.Table> toCopy) + throws IOException { + com.google.common.collect.Table> copy = + HashBasedTable.create(); + for (Cell> cell : toCopy.cellSet()) { + HashMap dataset = new HashMap<>(); + copy.put(cell.getRowKey(), cell.getColumnKey(), dataset); + for (Map.Entry dsTables : cell.getValue().entrySet()) { + dataset.put( + dsTables.getKey(), Transport.getJsonFactory().toByteArray(dsTables.getValue())); + } + } + return copy; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + com.google.common.collect.Table> tablesTable = + (com.google.common.collect.Table>) in.readObject(); + this.tables = replaceBytesWithTables(tablesTable); + } + + private com.google.common.collect.Table> + replaceBytesWithTables( + com.google.common.collect.Table> tablesTable) + throws IOException { + com.google.common.collect.Table> copy = + HashBasedTable.create(); + for (Cell> cell : tablesTable.cellSet()) { + HashMap dataset = new HashMap<>(); + copy.put(cell.getRowKey(), cell.getColumnKey(), dataset); + for (Map.Entry dsTables : cell.getValue().entrySet()) { + Table table = + Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream(dsTables.getValue())) + .parse(Table.class); + dataset.put(dsTables.getKey(), table); + } + } + return copy; + } + + private void readObjectNoData() throws ObjectStreamException {} + } @Rule public transient ExpectedException thrown = ExpectedException.none(); From 73b3bb36e5bfd5c9f0132a6b5f8196fa82ce316b Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Fri, 14 Oct 2016 13:02:03 -0700 Subject: [PATCH 9/9] fixup! fixup! fixup! fixup! fixup! Perform initial splitting in the DirectRunner --- .../apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index 3c9b79cfe58d..9d636114bc4a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -65,7 +65,6 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.io.ObjectStreamException; import java.io.Serializable; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; @@ -470,10 +469,6 @@ private Object[] replaceBytesWithJobs(Object obj) throws IOException { } return copy; } - - private void readObjectNoData() throws ObjectStreamException { - } - } /** A fake dataset service that can be serialized, for use in testReadFromTable. */ @@ -597,9 +592,6 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE } return copy; } - - private void readObjectNoData() throws ObjectStreamException {} - } @Rule public transient ExpectedException thrown = ExpectedException.none();