Skip to content

Commit

Permalink
Adding Datasource v2 writing support (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidrabinowitz committed Jan 28, 2021
1 parent 63321ef commit a3a7bf0
Show file tree
Hide file tree
Showing 24 changed files with 2,008 additions and 415 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ lazy val connector = (project in file("connector"))
"aopalliance" % "aopalliance" % "1.0" % "provided",
"org.codehaus.jackson" % "jackson-core-asl" % "1.9.13" % "provided",
"org.codehaus.jackson" % "jackson-mapper-asl" % "1.9.13" % "provided",
"org.apache.arrow" % "arrow-vector" % "0.16.0" exclude("org.slf4j", "slf4j-api"),
"com.google.inject" % "guice" % "4.2.3",
"org.apache.arrow" % "arrow-vector" % "0.16.0" exclude("org.slf4j", "slf4j-api"),
"org.apache.parquet" % "parquet-protobuf" % "1.10.0"
exclude("com.hadoop.gplcompression", "hadoop-lzo"),

// Keep com.google.cloud dependencies in sync
"com.google.cloud" % "google-cloud-bigquery" % "1.123.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import com.google.cloud.http.BaseHttpServiceException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
Expand All @@ -40,6 +39,8 @@
// presto converts the dataset and table names to lower case, while BigQuery is case sensitive
// the mappings here keep the mappings
public class BigQueryClient {
private static final Logger logger = LoggerFactory.getLogger(BigQueryClient.class);

private final BigQuery bigQuery;
private final Optional<String> materializationProject;
private final Optional<String> materializationDataset;
Expand Down Expand Up @@ -125,10 +126,29 @@ TableId createDestinationTable(TableId tableId) {
return TableId.of(datasetId.getProject(), datasetId.getDataset(), name);
}

Table update(TableInfo table) {
public Table update(TableInfo table) {
return bigQuery.update(table);
}

public Job createAndWaitFor(JobConfiguration.Builder jobConfiguration) {
return createAndWaitFor(jobConfiguration.build());
}

public Job createAndWaitFor(JobConfiguration jobConfiguration) {
JobInfo jobInfo = JobInfo.of(jobConfiguration);
Job job = bigQuery.create(jobInfo);

logger.info("Submitted job {}. jobId: {}", jobConfiguration, job.getJobId());
// TODO(davidrab): add retry options
try {
return job.waitFor();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new BigQueryException(
BaseHttpServiceException.UNKNOWN_CODE, format("Failed to run the job [%s]", job), e);
}
}

Job create(JobInfo jobInfo) {
return bigQuery.create(jobInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
import com.google.cloud.bigquery.LegacySQLTypeName;
import com.google.cloud.bigquery.Schema;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.spark.ml.linalg.SQLDataTypes;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.GenericArrayData;
Expand All @@ -40,13 +38,13 @@
public class SchemaConverters {
// Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale.
// See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type
private static final int BQ_NUMERIC_PRECISION = 38;
private static final int BQ_NUMERIC_SCALE = 9;
static final int BQ_NUMERIC_PRECISION = 38;
static final int BQ_NUMERIC_SCALE = 9;
private static final DecimalType NUMERIC_SPARK_TYPE =
DataTypes.createDecimalType(BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE);
// The maximum nesting depth of a BigQuery RECORD:
private static final int MAX_BIGQUERY_NESTED_DEPTH = 15;
private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported.";
static final int MAX_BIGQUERY_NESTED_DEPTH = 15;
static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported.";

/** Convert a BigQuery schema to a Spark schema */
public static StructType toSpark(Schema schema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.auth.Credentials;
import com.google.cloud.bigquery.*;
import com.google.cloud.bigquery.BigQueryOptions;
import com.google.cloud.bigquery.FormatOptions;
import com.google.cloud.bigquery.JobInfo;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TimePartitioning;
import com.google.cloud.bigquery.connector.common.BigQueryConfig;
import com.google.cloud.bigquery.connector.common.BigQueryCredentialsSupplier;
import com.google.cloud.bigquery.connector.common.ReadSessionCreatorConfig;
Expand All @@ -35,7 +39,13 @@
import java.io.Serializable;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.util.*;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -51,6 +61,7 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable {
public static final String VIEWS_ENABLED_OPTION = "viewsEnabled";
public static final String USE_AVRO_LOGICAL_TYPES_OPTION = "useAvroLogicalTypes";
public static final String DATE_PARTITION_PARAM = "datePartition";
public static final String VALIDATE_SPARK_AVRO_PARAM = "validateSparkAvroInternalParam";
@VisibleForTesting static final DataFormat DEFAULT_READ_DATA_FORMAT = DataFormat.ARROW;

@VisibleForTesting
Expand All @@ -59,7 +70,7 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable {
static final String GCS_CONFIG_CREDENTIALS_FILE_PROPERTY =
"google.cloud.auth.service.account.json.keyfile";
static final String GCS_CONFIG_PROJECT_ID_PROPERTY = "fs.gs.project.id";
private static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat";
public static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat";
private static final String READ_DATA_FORMAT_OPTION = "readDataFormat";
private static final ImmutableList<String> PERMITTED_READ_DATA_FORMATS =
ImmutableList.of(DataFormat.ARROW.toString(), DataFormat.AVRO.toString());
Expand Down Expand Up @@ -159,10 +170,13 @@ public static SparkBigQueryConfig from(
config.temporaryGcsBucket = getAnyOption(globalOptions, options, "temporaryGcsBucket");
config.persistentGcsBucket = getAnyOption(globalOptions, options, "persistentGcsBucket");
config.persistentGcsPath = getOption(options, "persistentGcsPath");
boolean validateSparkAvro =
Boolean.valueOf(getRequiredOption(options, VALIDATE_SPARK_AVRO_PARAM, () -> "true"));
config.intermediateFormat =
getAnyOption(globalOptions, options, INTERMEDIATE_FORMAT_OPTION)
.transform(String::toLowerCase)
.transform(format -> IntermediateFormat.from(format, sparkVersion, sqlConf))
.transform(
format -> IntermediateFormat.from(format, sparkVersion, sqlConf, validateSparkAvro))
.or(DEFAULT_INTERMEDIATE_FORMAT);
String readDataFormatParam =
getAnyOption(globalOptions, options, READ_DATA_FORMAT_OPTION)
Expand Down Expand Up @@ -520,14 +534,15 @@ public enum IntermediateFormat {
this.formatOptions = formatOptions;
}

public static IntermediateFormat from(String format, String sparkVersion, SQLConf sqlConf) {
public static IntermediateFormat from(
String format, String sparkVersion, SQLConf sqlConf, boolean validateSparkAvro) {
Preconditions.checkArgument(
PERMITTED_DATA_SOURCES.contains(format.toLowerCase()),
"Data read format '%s' is not supported. Supported formats are %s",
"Data write format '%s' is not supported. Supported formats are %s",
format,
PERMITTED_DATA_SOURCES);

if (format.equalsIgnoreCase("avro")) {
if (validateSparkAvro && format.equalsIgnoreCase("avro")) {
IntermediateFormat intermediateFormat = isSpark24OrAbove(sparkVersion) ? AVRO : AVRO_2_3;

try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;

import org.apache.avro.Schema;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.EncoderFactory;

import java.io.IOException;
import java.io.OutputStream;

public class AvroIntermediateRecordWriter implements IntermediateRecordWriter {

private final OutputStream outputStream;
private final DatumWriter<GenericRecord> writer;
private final DataFileWriter<GenericRecord> dataFileWriter;

AvroIntermediateRecordWriter(Schema schema, OutputStream outputStream) throws IOException {
this.outputStream = outputStream;
this.writer = new GenericDatumWriter<>(schema);
this.dataFileWriter = new DataFileWriter<>(writer);
this.dataFileWriter.create(schema, outputStream);
}

@Override
public void write(GenericRecord record) throws IOException {
dataFileWriter.append(record);
}

@Override
public void close() throws IOException {
try {
dataFileWriter.flush();
} finally {
dataFileWriter.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@
*/
package com.google.cloud.spark.bigquery.v2;

import com.google.cloud.bigquery.*;
import com.google.cloud.bigquery.connector.common.*;
import com.google.cloud.bigquery.Field;
import com.google.cloud.bigquery.Schema;
import com.google.cloud.bigquery.StandardTableDefinition;
import com.google.cloud.bigquery.TableDefinition;
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.connector.common.BigQueryClient;
import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory;
import com.google.cloud.bigquery.connector.common.ReadSessionCreator;
import com.google.cloud.bigquery.connector.common.ReadSessionCreatorConfig;
import com.google.cloud.bigquery.connector.common.ReadSessionResponse;
import com.google.cloud.bigquery.storage.v1.DataFormat;
import com.google.cloud.bigquery.storage.v1.ReadSession;
import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter;
Expand All @@ -26,14 +35,25 @@
import com.google.common.collect.ImmutableSet;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.Statistics;
import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters;
import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns;
import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics;
import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import scala.collection.JavaConversions;

import java.util.*;
import java.util.function.Function;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -85,11 +105,13 @@ public BigQueryDataSourceReader(
new ReadSessionCreator(readSessionCreatorConfig, bigQueryClient, bigQueryReadClientFactory);
this.globalFilter = globalFilter;
this.schema = schema;
this.fields =
JavaConversions.asJavaCollection(
SchemaConverters.toSpark(table.getDefinition().getSchema()))
.stream()
.collect(Collectors.toMap(field -> field.name(), Function.identity()));
// We want to keep the key order
this.fields = new LinkedHashMap<>();
for (StructField field :
JavaConversions.seqAsJavaList(
SchemaConverters.toSpark(table.getDefinition().getSchema()))) {
fields.put(field.name(), field);
}
}

@Override
Expand Down Expand Up @@ -141,7 +163,7 @@ public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
ImmutableList<String> selectedFields =
schema
.map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames()))
.orElse(ImmutableList.of());
.orElse(ImmutableList.copyOf(fields.keySet()));
Optional<String> filter =
emptyIfNeeded(
SparkFilterUtils.getCompiledFilter(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;

import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.connector.common.BigQueryClient;
import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory;
import com.google.cloud.spark.bigquery.SparkBigQueryConfig;
import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import org.apache.spark.sql.types.StructType;

import java.util.Optional;

public class BigQueryDataSourceReaderModule implements Module {
@Override
public void configure(Binder binder) {
// empty
}

@Singleton
@Provides
public BigQueryDataSourceReader provideDataSourceReader(
BigQueryClient bigQueryClient,
BigQueryReadClientFactory bigQueryReadClientFactory,
SparkBigQueryConfig config) {
TableInfo tableInfo =
bigQueryClient.getSupportedTable(
config.getTableId(), config.isViewsEnabled(), SparkBigQueryConfig.VIEWS_ENABLED_OPTION);
return new BigQueryDataSourceReader(
tableInfo,
bigQueryClient,
bigQueryReadClientFactory,
config.toReadSessionCreatorConfig(),
config.getFilter(),
config.getSchema());
}
}
Loading

0 comments on commit a3a7bf0

Please sign in to comment.