Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds implementation for supporting columnar batch reads from Spark. #198

Merged
merged 2 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.google.cloud.bigquery.connector.common;

import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.bigquery.storage.v1.BigQueryReadClient;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
Expand All @@ -29,6 +30,7 @@ public class ReadRowsHelper {
private ReadRowsRequest.Builder request;
private int maxReadRowsRetries;
private BigQueryReadClient client;
private ServerStream<ReadRowsResponse> incomingStream;

public ReadRowsHelper(
BigQueryReadClientFactory bigQueryReadClientFactory,
Expand All @@ -51,7 +53,13 @@ public Iterator<ReadRowsResponse> readRows() {

// In order to enable testing
protected Iterator<ReadRowsResponse> fetchResponses(ReadRowsRequest.Builder readRowsRequest) {
return client.readRowsCallable().call(readRowsRequest.build()).iterator();
incomingStream = client.readRowsCallable().call(readRowsRequest.build());
return incomingStream.iterator();
}

@Override
public String toString() {
return request.toString();
}

// Ported from https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/150
Expand Down Expand Up @@ -89,7 +97,7 @@ public ReadRowsResponse next() {
serverResponses = helper.fetchResponses(helper.request.setOffset(readRowsCount));
retries++;
} else {
helper.client.close();
helper.close();
throw e;
}
}
Expand All @@ -100,6 +108,10 @@ public ReadRowsResponse next() {
}

public void close() {
if (incomingStream != null) {
incomingStream.cancel();
incomingStream = null;
}
if (!client.isShutdown()) {
client.close();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;

import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.cloud.spark.bigquery.ArrowSchemaConverter;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;

class ArrowColumnBatchPartitionColumnBatchReader implements InputPartitionReader<ColumnarBatch> {
private static final long maxAllocation = 500 * 1024 * 1024;

private final ReadRowsHelper readRowsHelper;
private final ArrowStreamReader reader;
private final BufferAllocator allocator;
private final List<String> namesInOrder;
private ColumnarBatch currentBatch;
private boolean closed = false;

static class ReadRowsResponseInputStreamEnumeration
implements java.util.Enumeration<InputStream> {
private Iterator<ReadRowsResponse> responses;
private ReadRowsResponse currentResponse;

ReadRowsResponseInputStreamEnumeration(Iterator<ReadRowsResponse> responses) {
this.responses = responses;
loadNextResponse();
}

public boolean hasMoreElements() {
return currentResponse != null;
}

public InputStream nextElement() {
if (!hasMoreElements()) {
throw new NoSuchElementException("No more responses");
}
ReadRowsResponse ret = currentResponse;
loadNextResponse();
return ret.getArrowRecordBatch().getSerializedRecordBatch().newInput();
}

void loadNextResponse() {
if (responses.hasNext()) {
currentResponse = responses.next();
} else {
currentResponse = null;
}
}
}

ArrowColumnBatchPartitionColumnBatchReader(
Iterator<ReadRowsResponse> readRowsResponses,
ByteString schema,
ReadRowsHelper readRowsHelper,
List<String> namesInOrder) {
this.allocator =
(new RootAllocator(maxAllocation))
.newChildAllocator("ArrowBinaryIterator", 0, maxAllocation);
this.readRowsHelper = readRowsHelper;
this.namesInOrder = namesInOrder;

InputStream batchStream =
new SequenceInputStream(new ReadRowsResponseInputStreamEnumeration(readRowsResponses));
InputStream fullStream = new SequenceInputStream(schema.newInput(), batchStream);

reader = new ArrowStreamReader(fullStream, allocator);
}

@Override
public boolean next() throws IOException {
if (closed) {
return false;
}
closed = !reader.loadNextBatch();
if (closed) {
return false;
}
VectorSchemaRoot root = reader.getVectorSchemaRoot();
if (currentBatch == null) {
// trying to verify from dev@spark but this object
// should only need to get created once. The underlying
// vectors should stay the same.
ColumnVector[] columns =
namesInOrder.stream()
.map(root::getVector)
.map(ArrowSchemaConverter::new)
.toArray(ColumnVector[]::new);

currentBatch = new ColumnarBatch(columns);
}
currentBatch.setNumRows(root.getRowCount());
return true;
}

@Override
public ColumnarBatch get() {
return currentBatch;
}

@Override
public void close() throws IOException {
closed = true;
try {
readRowsHelper.close();
} catch (Exception e) {
throw new IOException("Failure closing stream: " + readRowsHelper, e);
} finally {
try {
AutoCloseables.close(reader, allocator);
} catch (Exception e) {
throw new IOException("Failure closing arrow components. stream: " + readRowsHelper, e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;
davidrabinowitz marked this conversation as resolved.
Show resolved Hide resolved

import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory;
import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.connector.common.ReadSessionResponse;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.util.Iterator;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
import org.apache.spark.sql.vectorized.ColumnarBatch;

public class ArrowInputPartition implements InputPartition<ColumnarBatch> {

private final BigQueryReadClientFactory bigQueryReadClientFactory;
private final String streamName;
private final int maxReadRowsRetries;
private final ImmutableList<String> selectedFields;
private final ByteString serializedArrowSchema;

public ArrowInputPartition(
BigQueryReadClientFactory bigQueryReadClientFactory,
String name,
int maxReadRowsRetries,
ImmutableList<String> selectedFields,
ReadSessionResponse readSessionResponse) {
this.bigQueryReadClientFactory = bigQueryReadClientFactory;
this.streamName = name;
this.maxReadRowsRetries = maxReadRowsRetries;
this.selectedFields = selectedFields;
this.serializedArrowSchema =
readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema();
}

@Override
public InputPartitionReader<ColumnarBatch> createPartitionReader() {
ReadRowsRequest.Builder readRowsRequest =
ReadRowsRequest.newBuilder().setReadStream(streamName);
ReadRowsHelper readRowsHelper =
new ReadRowsHelper(bigQueryReadClientFactory, readRowsRequest, maxReadRowsRetries);
Iterator<ReadRowsResponse> readRowsResponses = readRowsHelper.readRows();
return new ArrowColumnBatchPartitionColumnBatchReader(
readRowsResponses, serializedArrowSchema, readRowsHelper, selectedFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.sql.vectorized.ColumnarBatch;

public class BigQueryDataSourceReader
implements DataSourceReader,
SupportsPushDownRequiredColumns,
SupportsPushDownFilters,
SupportsReportStatistics {
SupportsReportStatistics,
SupportsScanColumnarBatch {

private static Statistics UNKNOWN_STATISTICS =
new Statistics() {
Expand Down Expand Up @@ -87,9 +89,14 @@ public StructType readSchema() {
return schema.orElse(SchemaConverters.toSpark(table.getDefinition().getSchema()));
}

@Override
public boolean enableBatchRead() {
return readSessionCreatorConfig.getReadDataFormat() == DataFormat.ARROW && !isEmptySchema();
}

@Override
public List<InputPartition<InternalRow>> planInputPartitions() {
if (schema.map(StructType::isEmpty).orElse(false)) {
if (isEmptySchema()) {
// create empty projection
return createEmptyProjectionPartitions();
}
Expand Down Expand Up @@ -117,10 +124,44 @@ public List<InputPartition<InternalRow>> planInputPartitions() {
.collect(Collectors.toList());
}

@Override
public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
if (!enableBatchRead()) {
throw new IllegalStateException("Batch reads should not be enabled");
}
ImmutableList<String> selectedFields =
schema
.map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames()))
.orElse(ImmutableList.of());
Optional<String> filter =
emptyIfNeeded(
SparkFilterUtils.getCompiledFilter(
readSessionCreatorConfig.getReadDataFormat(), globalFilter, pushedFilters));
ReadSessionResponse readSessionResponse =
readSessionCreator.create(
tableId, selectedFields, filter, readSessionCreatorConfig.getMaxParallelism());
ReadSession readSession = readSessionResponse.getReadSession();
return readSession.getStreamsList().stream()
.map(
stream ->
new ArrowInputPartition(
bigQueryReadClientFactory,
stream.getName(),
readSessionCreatorConfig.getMaxReadRowsRetries(),
selectedFields,
readSessionResponse))
.collect(Collectors.toList());
}

private boolean isEmptySchema() {
return schema.map(StructType::isEmpty).orElse(false);
}

private ReadRowsResponseToInternalRowIteratorConverter createConverter(
ImmutableList<String> selectedFields, ReadSessionResponse readSessionResponse) {
ReadRowsResponseToInternalRowIteratorConverter converter;
if (readSessionCreatorConfig.getReadDataFormat() == DataFormat.AVRO) {
DataFormat format = readSessionCreatorConfig.getReadDataFormat();
if (format == DataFormat.AVRO) {
Schema schema = readSessionResponse.getReadTableInfo().getDefinition().getSchema();
if (selectedFields.isEmpty()) {
// means select *
Expand All @@ -138,11 +179,9 @@ private ReadRowsResponseToInternalRowIteratorConverter createConverter(
}
return ReadRowsResponseToInternalRowIteratorConverter.avro(
schema, selectedFields, readSessionResponse.getReadSession().getAvroSchema().getSchema());
} else {
return ReadRowsResponseToInternalRowIteratorConverter.arrow(
selectedFields,
readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema());
}
throw new IllegalArgumentException(
"No known converted for " + readSessionCreatorConfig.getReadDataFormat());
}

List<InputPartition<InternalRow>> createEmptyProjectionPartitions() {
Expand Down