diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 007bba5c6cdf..d3d33a296c53 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -23,6 +23,7 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Predicate; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; @@ -50,6 +52,7 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hashing; import org.joda.time.Duration; /** This {@link PTransform} manages loads into BigQuery using the Storage API. */ @@ -379,12 +382,21 @@ public WriteResult expandUntriggered( PCollection> successfulConvertedRows = convertMessagesResult.get(successfulConvertedRowsTag); - if (numShards > 0) { + if (numShards > 0 && input.isBounded() == PCollection.IsBounded.UNBOUNDED) { successfulConvertedRows = successfulConvertedRows.apply( "ResdistibuteNumShards", Redistribute.>arbitrarily() .withNumBuckets(numShards)); + } else if (numShards > 0 && input.isBounded() == PCollection.IsBounded.BOUNDED) { + successfulConvertedRows = + successfulConvertedRows + .apply( + "AddKeyWithSideInputs", + ParDo.of(new AddShardKeyFn<>(dynamicDestinations, numShards)) + .withSideInputs(dynamicDestinations.getSideInputs())) + .apply("RedistributeNumShards", Redistribute.byKey()) + .apply("Remove shard", Values.create()); } PCollectionTuple writeRecordsResult = @@ -457,6 +469,37 @@ private void addErrorCollections( } } + private static class AddShardKeyFn + extends DoFn< + KV, KV>> { + + private final StorageApiDynamicDestinations dynamicDestinations; + private final int numShards; + + public AddShardKeyFn( + StorageApiDynamicDestinations dynamicDestinations, int numShards) { + this.dynamicDestinations = dynamicDestinations; + this.numShards = numShards; + } + + @ProcessElement + public void processElement( + ProcessContext c, + @Element KV element, + OutputReceiver>> outputReceiver) { + dynamicDestinations.setSideInputAccessorFromProcessContext(c); + + String tableUrn = dynamicDestinations.getTable(element.getKey()).getShortTableUrn(); + + int hash = Hashing.murmur3_32_fixed().hashString(tableUrn, StandardCharsets.UTF_8).asInt(); + + int shardKey = + Math.floorMod(hash + ThreadLocalRandom.current().nextInt(numShards), numShards); + + outputReceiver.output(KV.of(shardKey, element)); + } + } + private static class ConvertInsertErrorToBadRecord extends DoFn {