diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 4fe5844eac0e..85419c659857 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -131,9 +131,15 @@ public void testRead() throws Exception { } @Test - public void testReadWithQuery() throws Exception { + public void testReadWithQueryString() throws Exception { elasticsearchIOTestCommon.setPipeline(pipeline); - elasticsearchIOTestCommon.testReadWithQuery(); + elasticsearchIOTestCommon.testReadWithQueryString(); + } + + @Test + public void testReadWithQueryValueProvider() throws Exception { + elasticsearchIOTestCommon.setPipeline(pipeline); + elasticsearchIOTestCommon.testReadWithQueryValueProvider(); } @Test diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 05686cd605c5..d809cfd96604 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -122,12 +122,21 @@ public void testRead() throws Exception { } @Test - public void testReadWithQuery() throws Exception { + public void testReadWithQueryString() throws Exception { // need to create the index using the helper method (not create it at first insertion) // for the indexSettings() to be run createIndex(getEsIndex()); elasticsearchIOTestCommon.setPipeline(pipeline); - elasticsearchIOTestCommon.testReadWithQuery(); + elasticsearchIOTestCommon.testReadWithQueryString(); + } + + @Test + public void testReadWithQueryValueProvider() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(getEsIndex()); + elasticsearchIOTestCommon.setPipeline(pipeline); + elasticsearchIOTestCommon.testReadWithQueryValueProvider(); } @Test diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java index 6638b7d894e9..84696e549dd3 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTest.java @@ -122,12 +122,21 @@ public void testRead() throws Exception { } @Test - public void testReadWithQuery() throws Exception { + public void testReadWithQueryString() throws Exception { // need to create the index using the helper method (not create it at first insertion) // for the indexSettings() to be run createIndex(getEsIndex()); elasticsearchIOTestCommon.setPipeline(pipeline); - elasticsearchIOTestCommon.testReadWithQuery(); + elasticsearchIOTestCommon.testReadWithQueryString(); + } + + @Test + public void testReadWithQueryValueProvider() throws Exception { + // need to create the index using the helper method (not create it at first insertion) + // for the indexSettings() to be run + createIndex(getEsIndex()); + elasticsearchIOTestCommon.setPipeline(pipeline); + elasticsearchIOTestCommon.testReadWithQueryValueProvider(); } @Test diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestCommon.java b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestCommon.java index 90ca521323bc..386a5184bc43 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestCommon.java +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/src/test/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIOTestCommon.java @@ -45,11 +45,13 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.function.BiFunction; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.elasticsearch.ElasticsearchIO.RetryConfiguration.DefaultRetryPredicate; import org.apache.beam.sdk.io.elasticsearch.ElasticsearchIO.RetryConfiguration.RetryPredicate; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; @@ -196,7 +198,17 @@ void testRead() throws Exception { pipeline.run(); } - void testReadWithQuery() throws Exception { + void testReadWithQueryString() throws Exception { + testReadWithQueryInternal(Read::withQuery); + } + + void testReadWithQueryValueProvider() throws Exception { + testReadWithQueryInternal( + (read, query) -> read.withQuery(ValueProvider.StaticValueProvider.of(query))); + } + + private void testReadWithQueryInternal(BiFunction queryConfigurer) + throws IOException { if (!useAsITests) { ElasticsearchIOTestUtils.insertTestDocuments(connectionConfiguration, numDocs, restClient); } @@ -212,11 +224,12 @@ void testReadWithQuery() throws Exception { + " }\n" + "}"; - PCollection output = - pipeline.apply( - ElasticsearchIO.read() - .withConnectionConfiguration(connectionConfiguration) - .withQuery(query)); + Read read = ElasticsearchIO.read().withConnectionConfiguration(connectionConfiguration); + + read = queryConfigurer.apply(read, query); + + PCollection output = pipeline.apply(read); + PAssert.thatSingleton(output.apply("Count", Count.globally())) .isEqualTo(numDocs / NUM_SCIENTISTS); pipeline.run(); diff --git a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java index 5ced40bb368e..13073a1f8ec0 100644 --- a/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java +++ b/sdks/java/io/elasticsearch/src/main/java/org/apache/beam/sdk/io/elasticsearch/ElasticsearchIO.java @@ -44,6 +44,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.function.Predicate; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.net.ssl.SSLContext; import org.apache.beam.sdk.annotations.Experimental; @@ -51,6 +52,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -77,6 +79,7 @@ import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy; import org.apache.http.nio.entity.NStringEntity; import org.apache.http.ssl.SSLContexts; +import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClientBuilder; @@ -454,7 +457,7 @@ public abstract static class Read extends PTransform abstract ConnectionConfiguration getConnectionConfiguration(); @Nullable - abstract String getQuery(); + abstract ValueProvider getQuery(); abstract boolean isWithMetadata(); @@ -468,7 +471,7 @@ public abstract static class Read extends PTransform abstract static class Builder { abstract Builder setConnectionConfiguration(ConnectionConfiguration connectionConfiguration); - abstract Builder setQuery(String query); + abstract Builder setQuery(ValueProvider query); abstract Builder setWithMetadata(boolean withMetadata); @@ -502,6 +505,20 @@ public Read withConnectionConfiguration(ConnectionConfiguration connectionConfig public Read withQuery(String query) { checkArgument(query != null, "query can not be null"); checkArgument(!query.isEmpty(), "query can not be empty"); + return withQuery(ValueProvider.StaticValueProvider.of(query)); + } + + /** + * Provide a {@link ValueProvider} that provides the query used while reading from + * Elasticsearch. This is useful for cases when the query must be dynamic. + * + * @param query the query. See Query + * DSL + * @return a {@link PTransform} reading data from Elasticsearch. + */ + public Read withQuery(ValueProvider query) { + checkArgument(query != null, "query can not be null"); return builder().setQuery(query).build(); } @@ -577,6 +594,7 @@ public static class BoundedElasticsearchSource extends BoundedSource { @Nullable private final String shardPreference; @Nullable private final Integer numSlices; @Nullable private final Integer sliceId; + @Nullable private Long estimatedByteSize; // constructor used in split() when we know the backend version private BoundedElasticsearchSource( @@ -584,11 +602,13 @@ private BoundedElasticsearchSource( @Nullable String shardPreference, @Nullable Integer numSlices, @Nullable Integer sliceId, + @Nullable Long estimatedByteSize, int backendVersion) { this.backendVersion = backendVersion; this.spec = spec; this.shardPreference = shardPreference; this.numSlices = numSlices; + this.estimatedByteSize = estimatedByteSize; this.sliceId = sliceId; } @@ -627,11 +647,12 @@ public List> split( while (shards.hasNext()) { Map.Entry shardJson = shards.next(); String shardId = shardJson.getKey(); - sources.add(new BoundedElasticsearchSource(spec, shardId, null, null, backendVersion)); + sources.add( + new BoundedElasticsearchSource(spec, shardId, null, null, null, backendVersion)); } checkArgument(!sources.isEmpty(), "No shard found"); } else if (backendVersion == 5 || backendVersion == 6) { - long indexSize = BoundedElasticsearchSource.estimateIndexSize(connectionConfiguration); + long indexSize = getEstimatedSizeBytes(options); float nbBundlesFloat = (float) indexSize / desiredBundleSizeBytes; int nbBundles = (int) Math.ceil(nbBundlesFloat); // ES slice api imposes that the number of slices is <= 1024 even if it can be overloaded @@ -644,7 +665,10 @@ public List> split( // the slice API allows to split the ES shards // to have bundles closer to desiredBundleSizeBytes for (int i = 0; i < nbBundles; i++) { - sources.add(new BoundedElasticsearchSource(spec, null, nbBundles, i, backendVersion)); + long estimatedByteSizeForBundle = getEstimatedSizeBytes(options) / nbBundles; + sources.add( + new BoundedElasticsearchSource( + spec, null, nbBundles, i, estimatedByteSizeForBundle, backendVersion)); } } return sources; @@ -652,7 +676,54 @@ public List> split( @Override public long getEstimatedSizeBytes(PipelineOptions options) throws IOException { - return estimateIndexSize(spec.getConnectionConfiguration()); + if (estimatedByteSize != null) { + return estimatedByteSize; + } + final ConnectionConfiguration connectionConfiguration = spec.getConnectionConfiguration(); + JsonNode statsJson = getStats(connectionConfiguration, false); + JsonNode indexStats = + statsJson.path("indices").path(connectionConfiguration.getIndex()).path("primaries"); + long indexSize = indexStats.path("store").path("size_in_bytes").asLong(); + LOG.debug("estimate source byte size: total index size " + indexSize); + + String query = spec.getQuery() != null ? spec.getQuery().get() : null; + if (query == null || query.isEmpty()) { // return index size if no query + estimatedByteSize = indexSize; + return estimatedByteSize; + } + + long totalCount = indexStats.path("docs").path("count").asLong(); + LOG.debug("estimate source byte size: total document count " + totalCount); + if (totalCount == 0) { // The min size is 1, because DirectRunner does not like 0 + estimatedByteSize = 1L; + return estimatedByteSize; + } + + String endPoint = + String.format( + "/%s/%s/_count", + connectionConfiguration.getIndex(), connectionConfiguration.getType()); + try (RestClient restClient = connectionConfiguration.createClient()) { + long count = queryCount(restClient, endPoint, query); + LOG.debug("estimate source byte size: query document count " + count); + if (count == 0) { + estimatedByteSize = 1L; + } else { + // We estimate the average byte size for each document is (index/totalCount) + // and then multiply the document count in the index + estimatedByteSize = (indexSize / totalCount) * count; + } + } + return estimatedByteSize; + } + + private long queryCount( + @Nonnull RestClient restClient, @Nonnull String endPoint, @Nonnull String query) + throws IOException { + Request request = new Request("GET", endPoint); + request.setEntity(new NStringEntity(query, ContentType.APPLICATION_JSON)); + JsonNode searchResult = parseResponse(restClient.performRequest(request).getEntity()); + return searchResult.path("count").asLong(); } @VisibleForTesting @@ -726,7 +797,7 @@ private BoundedElasticsearchReader(BoundedElasticsearchSource source) { public boolean start() throws IOException { restClient = source.spec.getConnectionConfiguration().createClient(); - String query = source.spec.getQuery(); + String query = source.spec.getQuery() != null ? source.spec.getQuery().get() : null; if (query == null) { query = "{\"query\": { \"match_all\": {} }}"; }