From a3a7bf01e9f184a18f3f1db03f8d044ecc64900b Mon Sep 17 00:00:00 2001 From: David Rabinowitz Date: Thu, 28 Jan 2021 09:40:41 -0800 Subject: [PATCH] Adding Datasource v2 writing support (#283) --- build.sbt | 4 +- .../connector/common/BigQueryClient.java | 28 +- .../spark/bigquery/SchemaConverters.java | 10 +- .../spark/bigquery/SparkBigQueryConfig.java | 29 +- .../v2/AvroIntermediateRecordWriter.java | 55 +++ .../bigquery/v2/BigQueryDataSourceReader.java | 44 +- .../v2/BigQueryDataSourceReaderModule.java | 53 +++ .../bigquery/v2/BigQueryDataSourceV2.java | 83 +++- .../v2/BigQueryDataSourceWriterModule.java | 116 +++++ .../v2/BigQueryIndirectDataSourceWriter.java | 281 +++++++++++ .../v2/BigQueryIndirectDataWriter.java | 78 +++ .../v2/BigQueryIndirectDataWriterFactory.java | 65 +++ .../BigQueryIndirectWriterCommitMessage.java | 31 ++ .../bigquery/v2/IntermediateDataCleaner.java | 64 +++ .../bigquery/v2/IntermediateRecordWriter.java | 26 + .../v2/SparkBigQueryConnectorModule.java | 39 +- .../bigquery/AvroSchemaConverterTest.java | 289 ++++++++++++ .../spark/bigquery/SchemaConverterTest.java | 12 - .../cloud/spark/bigquery/TestUtils.scala | 2 +- .../bigquery/it/IntegrationTestUtils.scala | 6 + ...=> SparkBigQueryEndToEndReadITSuite.scala} | 359 +------------- .../SparkBigQueryEndToEndWriteITSuite.scala | 443 ++++++++++++++++++ .../spark/bigquery/ArrowSchemaConverter.java | 1 - .../spark/bigquery/AvroSchemaConverter.java | 305 ++++++++++++ 24 files changed, 2008 insertions(+), 415 deletions(-) create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/AvroIntermediateRecordWriter.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReaderModule.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceWriterModule.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataSourceWriter.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriter.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriterFactory.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectWriterCommitMessage.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateDataCleaner.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateRecordWriter.java create mode 100644 connector/src/test/java/com/google/cloud/spark/bigquery/AvroSchemaConverterTest.java rename connector/src/test/scala/com/google/cloud/spark/bigquery/it/{SparkBigQueryEndToEndITSuite.scala => SparkBigQueryEndToEndReadITSuite.scala} (58%) create mode 100644 connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndWriteITSuite.scala create mode 100644 connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/AvroSchemaConverter.java diff --git a/build.sbt b/build.sbt index c42231b28..5b8b568e4 100644 --- a/build.sbt +++ b/build.sbt @@ -79,8 +79,10 @@ lazy val connector = (project in file("connector")) "aopalliance" % "aopalliance" % "1.0" % "provided", "org.codehaus.jackson" % "jackson-core-asl" % "1.9.13" % "provided", "org.codehaus.jackson" % "jackson-mapper-asl" % "1.9.13" % "provided", - "org.apache.arrow" % "arrow-vector" % "0.16.0" exclude("org.slf4j", "slf4j-api"), "com.google.inject" % "guice" % "4.2.3", + "org.apache.arrow" % "arrow-vector" % "0.16.0" exclude("org.slf4j", "slf4j-api"), + "org.apache.parquet" % "parquet-protobuf" % "1.10.0" + exclude("com.hadoop.gplcompression", "hadoop-lzo"), // Keep com.google.cloud dependencies in sync "com.google.cloud" % "google-cloud-bigquery" % "1.123.2", diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java index 716bcec55..0643cd0cf 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java @@ -19,12 +19,11 @@ import com.google.cloud.http.BaseHttpServiceException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Optional; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -40,6 +39,8 @@ // presto converts the dataset and table names to lower case, while BigQuery is case sensitive // the mappings here keep the mappings public class BigQueryClient { + private static final Logger logger = LoggerFactory.getLogger(BigQueryClient.class); + private final BigQuery bigQuery; private final Optional materializationProject; private final Optional materializationDataset; @@ -125,10 +126,29 @@ TableId createDestinationTable(TableId tableId) { return TableId.of(datasetId.getProject(), datasetId.getDataset(), name); } - Table update(TableInfo table) { + public Table update(TableInfo table) { return bigQuery.update(table); } + public Job createAndWaitFor(JobConfiguration.Builder jobConfiguration) { + return createAndWaitFor(jobConfiguration.build()); + } + + public Job createAndWaitFor(JobConfiguration jobConfiguration) { + JobInfo jobInfo = JobInfo.of(jobConfiguration); + Job job = bigQuery.create(jobInfo); + + logger.info("Submitted job {}. jobId: {}", jobConfiguration, job.getJobId()); + // TODO(davidrab): add retry options + try { + return job.waitFor(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BigQueryException( + BaseHttpServiceException.UNKNOWN_CODE, format("Failed to run the job [%s]", job), e); + } + } + Job create(JobInfo jobInfo) { return bigQuery.create(jobInfo); } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index 8f4791eb5..952461924 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -21,10 +21,8 @@ import com.google.cloud.bigquery.LegacySQLTypeName; import com.google.cloud.bigquery.Schema; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableSet; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; -import org.apache.spark.ml.linalg.SQLDataTypes; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.GenericArrayData; @@ -40,13 +38,13 @@ public class SchemaConverters { // Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale. // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type - private static final int BQ_NUMERIC_PRECISION = 38; - private static final int BQ_NUMERIC_SCALE = 9; + static final int BQ_NUMERIC_PRECISION = 38; + static final int BQ_NUMERIC_SCALE = 9; private static final DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType(BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); // The maximum nesting depth of a BigQuery RECORD: - private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; - private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; + static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; /** Convert a BigQuery schema to a Spark schema */ public static StructType toSpark(Schema schema) { diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java index 0e71cfb87..a036bba2a 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java @@ -17,7 +17,11 @@ import com.google.api.gax.retrying.RetrySettings; import com.google.auth.Credentials; -import com.google.cloud.bigquery.*; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.FormatOptions; +import com.google.cloud.bigquery.JobInfo; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TimePartitioning; import com.google.cloud.bigquery.connector.common.BigQueryConfig; import com.google.cloud.bigquery.connector.common.BigQueryCredentialsSupplier; import com.google.cloud.bigquery.connector.common.ReadSessionCreatorConfig; @@ -35,7 +39,13 @@ import java.io.Serializable; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; -import java.util.*; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -51,6 +61,7 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable { public static final String VIEWS_ENABLED_OPTION = "viewsEnabled"; public static final String USE_AVRO_LOGICAL_TYPES_OPTION = "useAvroLogicalTypes"; public static final String DATE_PARTITION_PARAM = "datePartition"; + public static final String VALIDATE_SPARK_AVRO_PARAM = "validateSparkAvroInternalParam"; @VisibleForTesting static final DataFormat DEFAULT_READ_DATA_FORMAT = DataFormat.ARROW; @VisibleForTesting @@ -59,7 +70,7 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable { static final String GCS_CONFIG_CREDENTIALS_FILE_PROPERTY = "google.cloud.auth.service.account.json.keyfile"; static final String GCS_CONFIG_PROJECT_ID_PROPERTY = "fs.gs.project.id"; - private static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat"; + public static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat"; private static final String READ_DATA_FORMAT_OPTION = "readDataFormat"; private static final ImmutableList PERMITTED_READ_DATA_FORMATS = ImmutableList.of(DataFormat.ARROW.toString(), DataFormat.AVRO.toString()); @@ -159,10 +170,13 @@ public static SparkBigQueryConfig from( config.temporaryGcsBucket = getAnyOption(globalOptions, options, "temporaryGcsBucket"); config.persistentGcsBucket = getAnyOption(globalOptions, options, "persistentGcsBucket"); config.persistentGcsPath = getOption(options, "persistentGcsPath"); + boolean validateSparkAvro = + Boolean.valueOf(getRequiredOption(options, VALIDATE_SPARK_AVRO_PARAM, () -> "true")); config.intermediateFormat = getAnyOption(globalOptions, options, INTERMEDIATE_FORMAT_OPTION) .transform(String::toLowerCase) - .transform(format -> IntermediateFormat.from(format, sparkVersion, sqlConf)) + .transform( + format -> IntermediateFormat.from(format, sparkVersion, sqlConf, validateSparkAvro)) .or(DEFAULT_INTERMEDIATE_FORMAT); String readDataFormatParam = getAnyOption(globalOptions, options, READ_DATA_FORMAT_OPTION) @@ -520,14 +534,15 @@ public enum IntermediateFormat { this.formatOptions = formatOptions; } - public static IntermediateFormat from(String format, String sparkVersion, SQLConf sqlConf) { + public static IntermediateFormat from( + String format, String sparkVersion, SQLConf sqlConf, boolean validateSparkAvro) { Preconditions.checkArgument( PERMITTED_DATA_SOURCES.contains(format.toLowerCase()), - "Data read format '%s' is not supported. Supported formats are %s", + "Data write format '%s' is not supported. Supported formats are %s", format, PERMITTED_DATA_SOURCES); - if (format.equalsIgnoreCase("avro")) { + if (validateSparkAvro && format.equalsIgnoreCase("avro")) { IntermediateFormat intermediateFormat = isSpark24OrAbove(sparkVersion) ? AVRO : AVRO_2_3; try { diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/AvroIntermediateRecordWriter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/AvroIntermediateRecordWriter.java new file mode 100644 index 000000000..7a89c81fb --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/AvroIntermediateRecordWriter.java @@ -0,0 +1,55 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.EncoderFactory; + +import java.io.IOException; +import java.io.OutputStream; + +public class AvroIntermediateRecordWriter implements IntermediateRecordWriter { + + private final OutputStream outputStream; + private final DatumWriter writer; + private final DataFileWriter dataFileWriter; + + AvroIntermediateRecordWriter(Schema schema, OutputStream outputStream) throws IOException { + this.outputStream = outputStream; + this.writer = new GenericDatumWriter<>(schema); + this.dataFileWriter = new DataFileWriter<>(writer); + this.dataFileWriter.create(schema, outputStream); + } + + @Override + public void write(GenericRecord record) throws IOException { + dataFileWriter.append(record); + } + + @Override + public void close() throws IOException { + try { + dataFileWriter.flush(); + } finally { + dataFileWriter.close(); + } + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java index 90e25ddc4..c45f9add9 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java @@ -15,8 +15,17 @@ */ package com.google.cloud.spark.bigquery.v2; -import com.google.cloud.bigquery.*; -import com.google.cloud.bigquery.connector.common.*; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.TableDefinition; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.connector.common.BigQueryClient; +import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory; +import com.google.cloud.bigquery.connector.common.ReadSessionCreator; +import com.google.cloud.bigquery.connector.common.ReadSessionCreatorConfig; +import com.google.cloud.bigquery.connector.common.ReadSessionResponse; import com.google.cloud.bigquery.storage.v1.DataFormat; import com.google.cloud.bigquery.storage.v1.ReadSession; import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter; @@ -26,14 +35,25 @@ import com.google.common.collect.ImmutableSet; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.Statistics; +import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; +import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics; +import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; import scala.collection.JavaConversions; -import java.util.*; -import java.util.function.Function; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -85,11 +105,13 @@ public BigQueryDataSourceReader( new ReadSessionCreator(readSessionCreatorConfig, bigQueryClient, bigQueryReadClientFactory); this.globalFilter = globalFilter; this.schema = schema; - this.fields = - JavaConversions.asJavaCollection( - SchemaConverters.toSpark(table.getDefinition().getSchema())) - .stream() - .collect(Collectors.toMap(field -> field.name(), Function.identity())); + // We want to keep the key order + this.fields = new LinkedHashMap<>(); + for (StructField field : + JavaConversions.seqAsJavaList( + SchemaConverters.toSpark(table.getDefinition().getSchema()))) { + fields.put(field.name(), field); + } } @Override @@ -141,7 +163,7 @@ public List> planBatchInputPartitions() { ImmutableList selectedFields = schema .map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames())) - .orElse(ImmutableList.of()); + .orElse(ImmutableList.copyOf(fields.keySet())); Optional filter = emptyIfNeeded( SparkFilterUtils.getCompiledFilter( diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReaderModule.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReaderModule.java new file mode 100644 index 000000000..79a30c6b4 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReaderModule.java @@ -0,0 +1,53 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.connector.common.BigQueryClient; +import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory; +import com.google.cloud.spark.bigquery.SparkBigQueryConfig; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import org.apache.spark.sql.types.StructType; + +import java.util.Optional; + +public class BigQueryDataSourceReaderModule implements Module { + @Override + public void configure(Binder binder) { + // empty + } + + @Singleton + @Provides + public BigQueryDataSourceReader provideDataSourceReader( + BigQueryClient bigQueryClient, + BigQueryReadClientFactory bigQueryReadClientFactory, + SparkBigQueryConfig config) { + TableInfo tableInfo = + bigQueryClient.getSupportedTable( + config.getTableId(), config.isViewsEnabled(), SparkBigQueryConfig.VIEWS_ENABLED_OPTION); + return new BigQueryDataSourceReader( + tableInfo, + bigQueryClient, + bigQueryReadClientFactory, + config.toReadSessionCreatorConfig(), + config.getFilter(), + config.getSchema()); + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java index 130bfb3b6..f1304df80 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java @@ -15,35 +15,50 @@ */ package com.google.cloud.spark.bigquery.v2; +import com.google.cloud.bigquery.JobInfo; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.connector.common.BigQueryClient; import com.google.cloud.bigquery.connector.common.BigQueryClientModule; +import com.google.cloud.bigquery.connector.common.BigQueryUtil; +import com.google.cloud.spark.bigquery.SparkBigQueryConfig; import com.google.inject.Guice; import com.google.inject.Injector; +import com.google.inject.Module; +import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; import java.util.Optional; -public class BigQueryDataSourceV2 implements DataSourceV2, ReadSupport { +/** + * A DataSourceV2 implementation, providing efficient reader and writer for the Google Cloud + * Platform BigQuery. + */ +public class BigQueryDataSourceV2 implements DataSourceV2, ReadSupport, WriteSupport { @Override public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - SparkSession spark = getDefaultSparkSessionOrCreate(); - - Injector injector = - Guice.createInjector( - new BigQueryClientModule(), - new SparkBigQueryConnectorModule(spark, options, Optional.ofNullable(schema))); - + Injector injector = createInjector(schema, options, new BigQueryDataSourceReaderModule()); BigQueryDataSourceReader reader = injector.getInstance(BigQueryDataSourceReader.class); return reader; } + private Injector createInjector(StructType schema, DataSourceOptions options, Module module) { + SparkSession spark = getDefaultSparkSessionOrCreate(); + return Guice.createInjector( + new BigQueryClientModule(), + new SparkBigQueryConnectorModule(spark, options, Optional.ofNullable(schema)), + module); + } + private SparkSession getDefaultSparkSessionOrCreate() { - scala.Option defaultSpareSession = SparkSession.getDefaultSession(); + scala.Option defaultSpareSession = SparkSession.getActiveSession(); if (defaultSpareSession.isDefined()) { return defaultSpareSession.get(); } @@ -54,4 +69,54 @@ private SparkSession getDefaultSparkSessionOrCreate() { public DataSourceReader createReader(DataSourceOptions options) { return createReader(null, options); } + + /** + * Returning a DataSourceWriter for the specified parameters. In case the table already exist and + * the SaveMode is "Ignore", an Optional.empty() is returned. + */ + @Override + public Optional createWriter( + String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options) { + Injector injector = + createInjector( + schema, options, new BigQueryDataSourceWriterModule(writeUUID, schema, mode)); + // first verify if we need to do anything at all, based on the table existence and the save + // mode. + BigQueryClient bigQueryClient = injector.getInstance(BigQueryClient.class); + SparkBigQueryConfig config = injector.getInstance(SparkBigQueryConfig.class); + TableInfo table = bigQueryClient.getTable(config.getTableId()); + if (table != null) { + // table already exists + if (mode == SaveMode.Ignore) { + return Optional.empty(); + } + if (mode == SaveMode.ErrorIfExists) { + throw new IllegalArgumentException( + String.format( + "SaveMode is set to ErrorIfExists and table '%s' already exists. Did you want " + + "to add data to the table by setting the SaveMode to Append? Example: " + + "df.write.format.options.mode(\"append\").save()", + BigQueryUtil.friendlyTableName(table.getTableId()))); + } + } else { + // table does not exist + // If the CreateDisposition is CREATE_NEVER, and the table does not exist, + // there's no point in writing the data to GCS in the first place as it going + // to fail on the BigQuery side. + boolean createNever = + config + .getCreateDisposition() + .map(createDisposition -> createDisposition == JobInfo.CreateDisposition.CREATE_NEVER) + .orElse(false); + if (createNever) { + throw new IllegalArgumentException( + String.format( + "For table %s Create Disposition is CREATE_NEVER and the table does not exists. Aborting the insert", + BigQueryUtil.friendlyTableName(config.getTableId()))); + } + } + BigQueryIndirectDataSourceWriter writer = + injector.getInstance(BigQueryIndirectDataSourceWriter.class); + return Optional.of(writer); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceWriterModule.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceWriterModule.java new file mode 100644 index 000000000..498c339d6 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceWriterModule.java @@ -0,0 +1,116 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.bigquery.connector.common.BigQueryClient; +import com.google.cloud.spark.bigquery.SparkBigQueryConfig; +import com.google.common.base.Preconditions; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructType; + +import java.io.IOException; +import java.util.Optional; +import java.util.UUID; + +class BigQueryDataSourceWriterModule implements Module { + + private final String writeUUID; + private final StructType sparkSchema; + private final SaveMode mode; + + BigQueryDataSourceWriterModule(String writeUUID, StructType sparkSchema, SaveMode mode) { + this.writeUUID = writeUUID; + this.sparkSchema = sparkSchema; + this.mode = mode; + } + + @Override + public void configure(Binder binder) { + // empty + } + + @Singleton + @Provides + public BigQueryIndirectDataSourceWriter provideDataSourceWriter( + BigQueryClient bigQueryClient, SparkBigQueryConfig config, SparkSession spark) + throws IOException { + Path gcsPath = + createGcsPath( + config, + spark.sparkContext().hadoopConfiguration(), + spark.sparkContext().applicationId()); + Optional intermediateDataCleaner = + config + .getTemporaryGcsBucket() + .map( + ignored -> + new IntermediateDataCleaner( + gcsPath, spark.sparkContext().hadoopConfiguration())); + // based on pmkc's suggestion at https://git.io/JeWRt + intermediateDataCleaner.ifPresent(cleaner -> Runtime.getRuntime().addShutdownHook(cleaner)); + return new BigQueryIndirectDataSourceWriter( + bigQueryClient, + config, + spark.sparkContext().hadoopConfiguration(), + sparkSchema, + writeUUID, + mode, + gcsPath, + intermediateDataCleaner); + } + + Path createGcsPath(SparkBigQueryConfig config, Configuration conf, String applicationId) + throws IOException { + Preconditions.checkArgument( + config.getTemporaryGcsBucket().isPresent() || config.getPersistentGcsBucket().isPresent(), + "Temporary or persistent GCS bucket must be informed."); + boolean needNewPath = true; + Path gcsPath = null; + while (needNewPath) { + String gcsPathOption = + config + .getTemporaryGcsBucket() + .map( + bucket -> + String.format( + "gs://%s/.spark-bigquery-%s-%s", + bucket, applicationId, UUID.randomUUID())) + .orElseGet( + () -> { + // if we are here it means that the PersistentGcsBucket is set + String path = + config + .getPersistentGcsPath() + .orElse( + String.format( + ".spark-bigquery-%s-%s", applicationId, UUID.randomUUID())); + return String.format("gs://%s/%s", config.getPersistentGcsBucket().get(), path); + }); + gcsPath = new Path(gcsPathOption); + FileSystem fs = gcsPath.getFileSystem(conf); + needNewPath = fs.exists(gcsPath); // if the path exists for some reason, then retry + } + return gcsPath; + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataSourceWriter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataSourceWriter.java new file mode 100644 index 000000000..8f18ab886 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataSourceWriter.java @@ -0,0 +1,281 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.Job; +import com.google.cloud.bigquery.JobInfo; +import com.google.cloud.bigquery.LoadJobConfiguration; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.TableDefinition; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.TimePartitioning; +import com.google.cloud.bigquery.connector.common.BigQueryClient; +import com.google.cloud.bigquery.connector.common.BigQueryUtil; +import com.google.cloud.http.BaseHttpServiceException; +import com.google.cloud.spark.bigquery.AvroSchemaConverter; +import com.google.cloud.spark.bigquery.SparkBigQueryConfig; +import com.google.cloud.spark.bigquery.SparkBigQueryUtil; +import com.google.cloud.spark.bigquery.SupportedCustomDataType; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.RemoteIterator; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.DataWriterFactory; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.AbstractMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A DataSourceWriter implemented by first writing the DataFrame's data into GCS in an intermediate + * format, and then triggering a BigQuery load job on this data. Hence the "indirect" - the data + * goes through an intermediate storage. + */ +public class BigQueryIndirectDataSourceWriter implements DataSourceWriter { + + private static final Logger logger = + LoggerFactory.getLogger(BigQueryIndirectDataSourceWriter.class); + + private final BigQueryClient bigQueryClient; + private final SparkBigQueryConfig config; + private final Configuration hadoopConfiguration; + private final StructType sparkSchema; + private final String writeUUID; + private final SaveMode saveMode; + private final Path gcsPath; + private final Optional intermediateDataCleaner; + + public BigQueryIndirectDataSourceWriter( + BigQueryClient bigQueryClient, + SparkBigQueryConfig config, + Configuration hadoopConfiguration, + StructType sparkSchema, + String writeUUID, + SaveMode saveMode, + Path gcsPath, + Optional intermediateDataCleaner) { + this.bigQueryClient = bigQueryClient; + this.config = config; + this.hadoopConfiguration = hadoopConfiguration; + this.sparkSchema = sparkSchema; + this.writeUUID = writeUUID; + this.saveMode = saveMode; + this.gcsPath = gcsPath; + this.intermediateDataCleaner = intermediateDataCleaner; + } + + static Iterable wrap(final RemoteIterator remoteIterator) { + return () -> + new Iterator() { + @Override + public boolean hasNext() { + try { + return remoteIterator.hasNext(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public T next() { + try { + return remoteIterator.next(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + }; + } + + @Override + public DataWriterFactory createWriterFactory() { + org.apache.avro.Schema avroSchema = AvroSchemaConverter.sparkSchemaToAvroSchema(sparkSchema); + return new BigQueryIndirectDataWriterFactory( + new SerializableConfiguration(hadoopConfiguration), + gcsPath.toString(), + sparkSchema, + avroSchema.toString()); + } + + @Override + public void commit(WriterCommitMessage[] messages) { + logger.info( + "Data has been successfully written to GCS. Going to load {} files to BigQuery", + messages.length); + try { + List sourceUris = + Stream.of(messages) + .map(msg -> ((BigQueryIndirectWriterCommitMessage) msg).getUri()) + .collect(Collectors.toList()); + loadDataToBigQuery(sourceUris); + updateMetadataIfNeeded(); + logger.info("Data has been successfully loaded to BigQuery"); + } catch (IOException e) { + throw new UncheckedIOException(e); + } finally { + cleanTemporaryGcsPathIfNeeded(); + } + } + + @Override + public void abort(WriterCommitMessage[] messages) { + try { + logger.warn( + "Aborting write {} for table {}", + writeUUID, + BigQueryUtil.friendlyTableName(config.getTableId())); + } finally { + cleanTemporaryGcsPathIfNeeded(); + } + } + + void loadDataToBigQuery(List sourceUris) throws IOException { + // Solving Issue #248 + List optimizedSourceUris = SparkBigQueryUtil.optimizeLoadUriListForSpark(sourceUris); + + LoadJobConfiguration.Builder jobConfiguration = + LoadJobConfiguration.newBuilder( + config.getTableId(), + optimizedSourceUris, + config.getIntermediateFormat().getFormatOptions()) + .setCreateDisposition(JobInfo.CreateDisposition.CREATE_IF_NEEDED) + .setWriteDisposition(saveModeToWriteDisposition(saveMode)) + .setAutodetect(true); + + config.getCreateDisposition().ifPresent(jobConfiguration::setCreateDisposition); + + if (config.getPartitionField().isPresent() || config.getPartitionType().isPresent()) { + TimePartitioning.Builder timePartitionBuilder = + TimePartitioning.newBuilder(config.getPartitionTypeOrDefault()); + config.getPartitionExpirationMs().ifPresent(timePartitionBuilder::setExpirationMs); + config.getPartitionRequireFilter().ifPresent(timePartitionBuilder::setRequirePartitionFilter); + config.getPartitionField().ifPresent(timePartitionBuilder::setField); + jobConfiguration.setTimePartitioning(timePartitionBuilder.build()); + config + .getClusteredFields() + .ifPresent( + clusteredFields -> { + Clustering clustering = Clustering.newBuilder().setFields(clusteredFields).build(); + jobConfiguration.setClustering(clustering); + }); + } + + if (!config.getLoadSchemaUpdateOptions().isEmpty()) { + jobConfiguration.setSchemaUpdateOptions(config.getLoadSchemaUpdateOptions()); + } + + Job finishedJob = bigQueryClient.createAndWaitFor(jobConfiguration); + + if (finishedJob.getStatus().getError() != null) { + throw new BigQueryException( + BaseHttpServiceException.UNKNOWN_CODE, + String.format( + "Failed to load to %s in job %s. BigQuery error was '%s'", + BigQueryUtil.friendlyTableName(config.getTableId()), + finishedJob.getJobId(), + finishedJob.getStatus().getError().getMessage()), + finishedJob.getStatus().getError()); + } else { + logger.info( + "Done loading to {}. jobId: {}", + BigQueryUtil.friendlyTableName(config.getTableId()), + finishedJob.getJobId()); + } + } + + JobInfo.WriteDisposition saveModeToWriteDisposition(SaveMode saveMode) { + if (saveMode == SaveMode.ErrorIfExists) { + return JobInfo.WriteDisposition.WRITE_EMPTY; + } + // SaveMode.Ignore is handled in the data source level. If it has arrived here it means tha + // table does not exist + if (saveMode == SaveMode.Append || saveMode == SaveMode.Ignore) { + return JobInfo.WriteDisposition.WRITE_APPEND; + } + if (saveMode == SaveMode.Overwrite) { + return JobInfo.WriteDisposition.WRITE_TRUNCATE; + } + throw new UnsupportedOperationException( + "SaveMode " + saveMode + " is currently not supported."); + } + + void updateMetadataIfNeeded() { + // TODO: Issue #190 should be solved here + Map> fieldsToUpdate = + Stream.of(sparkSchema.fields()) + .map( + field -> + new AbstractMap.SimpleImmutableEntry>( + field.name(), SupportedCustomDataType.of(field.dataType()))) + .filter(nameAndType -> nameAndType.getValue().isPresent()) + .collect( + Collectors.toMap( + AbstractMap.SimpleImmutableEntry::getKey, + AbstractMap.SimpleImmutableEntry::getValue)); + if (!fieldsToUpdate.isEmpty()) { + logger.debug("updating schema, found fields to update: {}", fieldsToUpdate.keySet()); + TableInfo originalTableInfo = bigQueryClient.getTable(config.getTableId()); + TableDefinition originalTableDefinition = originalTableInfo.getDefinition(); + Schema originalSchema = originalTableDefinition.getSchema(); + Schema updatedSchema = + Schema.of( + originalSchema.getFields().stream() + .map( + field -> + Optional.ofNullable(fieldsToUpdate.get(field.getName())) + .map(dataType -> updatedField(field, dataType.get().getTypeMarker())) + .orElse(field)) + .collect(Collectors.toList())); + TableInfo.Builder updatedTableInfo = + originalTableInfo + .toBuilder() + .setDefinition(originalTableDefinition.toBuilder().setSchema(updatedSchema).build()); + + bigQueryClient.update(updatedTableInfo.build()); + } + } + + Field updatedField(Field field, String marker) { + Field.Builder newField = field.toBuilder(); + String description = field.getDescription(); + if (description == null) { + newField.setDescription(marker); + } else if (!description.endsWith(marker)) { + newField.setDescription(description + " " + marker); + } + return newField.build(); + } + + void cleanTemporaryGcsPathIfNeeded() { + intermediateDataCleaner.ifPresent(cleaner -> cleaner.deletePath()); + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriter.java new file mode 100644 index 000000000..136fa7918 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.spark.bigquery.AvroSchemaConverter; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericRecord; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +class BigQueryIndirectDataWriter implements DataWriter { + + private static final Logger logger = LoggerFactory.getLogger(BigQueryIndirectDataWriter.class); + Path path; + FileSystem fs; + FSDataOutputStream outputStream; + StructType sparkSchema; + Schema avroSchema; + IntermediateRecordWriter intermediateRecordWriter; + private int partitionId; + + protected BigQueryIndirectDataWriter( + int partitionId, + Path path, + FileSystem fs, + StructType sparkSchema, + Schema avroSchema, + IntermediateRecordWriter intermediateRecordWriter) { + this.partitionId = partitionId; + this.path = path; + this.fs = fs; + this.sparkSchema = sparkSchema; + this.avroSchema = avroSchema; + this.intermediateRecordWriter = intermediateRecordWriter; + } + + @Override + public void write(InternalRow record) throws IOException { + GenericRecord avroRecord = + AvroSchemaConverter.sparkRowToAvroGenericData(record, sparkSchema, avroSchema); + intermediateRecordWriter.write(avroRecord); + } + + @Override + public WriterCommitMessage commit() throws IOException { + intermediateRecordWriter.close(); + return new BigQueryIndirectWriterCommitMessage(path.toString()); + } + + @Override + public void abort() throws IOException { + logger.warn( + "Writing of partition {} has been aborted, attempting to delete {}", partitionId, path); + fs.delete(path, false); + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriterFactory.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriterFactory.java new file mode 100644 index 000000000..823c1fe38 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectDataWriterFactory.java @@ -0,0 +1,65 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import org.apache.avro.Schema; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; +import org.apache.spark.sql.sources.v2.writer.DataWriterFactory; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.SerializableConfiguration; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.UUID; + +class BigQueryIndirectDataWriterFactory implements DataWriterFactory { + + SerializableConfiguration conf; + String gcsDirPath; + StructType sparkSchema; + String avroSchemaJson; + + public BigQueryIndirectDataWriterFactory( + SerializableConfiguration conf, + String gcsDirPath, + StructType sparkSchema, + String avroSchemaJson) { + this.conf = conf; + this.gcsDirPath = gcsDirPath; + this.sparkSchema = sparkSchema; + this.avroSchemaJson = avroSchemaJson; + } + + @Override + public DataWriter createDataWriter(int partitionId, long taskId, long epochId) { + try { + Schema avroSchema = new Schema.Parser().parse(avroSchemaJson); + UUID uuid = new UUID(taskId, epochId); + String uri = String.format("%s/part-%06d-%s.avro", gcsDirPath, partitionId, uuid); + Path path = new Path(uri); + FileSystem fs = path.getFileSystem(conf.value()); + IntermediateRecordWriter intermediateRecordWriter = + new AvroIntermediateRecordWriter(avroSchema, fs.create(path)); + return new BigQueryIndirectDataWriter( + partitionId, path, fs, sparkSchema, avroSchema, intermediateRecordWriter); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectWriterCommitMessage.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectWriterCommitMessage.java new file mode 100644 index 000000000..6aef64f5d --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryIndirectWriterCommitMessage.java @@ -0,0 +1,31 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; + +class BigQueryIndirectWriterCommitMessage implements WriterCommitMessage { + + private final String uri; + + public BigQueryIndirectWriterCommitMessage(String uri) { + this.uri = uri; + } + + public String getUri() { + return uri; + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateDataCleaner.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateDataCleaner.java new file mode 100644 index 000000000..3ffa02c81 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateDataCleaner.java @@ -0,0 +1,64 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Responsible for recursively deleting the intermediate path. Implementing Runnable in order to act + * as shutdown hook. + */ +class IntermediateDataCleaner extends Thread { + private static final Logger logger = LoggerFactory.getLogger(IntermediateDataCleaner.class); + + /** the path to delete */ + private final Path path; + /** the hadoop configuration */ + private final Configuration conf; + + IntermediateDataCleaner(Path path, Configuration conf) { + this.path = path; + this.conf = conf; + } + + @Override + public void run() { + deletePath(); + } + + void deletePath() { + try { + FileSystem fs = path.getFileSystem(conf); + if (pathExists(fs, path)) { + fs.delete(path, true); + } + } catch (Exception e) { + logger.error("Failed to delete path " + path, e); + } + } + // fs.exists can throw exception on missing path + private boolean pathExists(FileSystem fs, Path path) { + try { + return fs.exists(path); + } catch (Exception e) { + return false; + } + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateRecordWriter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateRecordWriter.java new file mode 100644 index 000000000..5af4242c3 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/IntermediateRecordWriter.java @@ -0,0 +1,26 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import org.apache.avro.generic.GenericRecord; + +import java.io.Closeable; +import java.io.IOException; + +public interface IntermediateRecordWriter extends Closeable { + + void write(GenericRecord avroRecord) throws IOException; +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java index d426d25b8..da98684cf 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java @@ -15,10 +15,7 @@ */ package com.google.cloud.spark.bigquery.v2; -import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.connector.common.BigQueryClient; import com.google.cloud.bigquery.connector.common.BigQueryConfig; -import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory; import com.google.cloud.bigquery.connector.common.UserAgentProvider; import com.google.cloud.spark.bigquery.SparkBigQueryConfig; import com.google.cloud.spark.bigquery.SparkBigQueryConnectorUserAgentProvider; @@ -31,6 +28,8 @@ import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.types.StructType; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import static scala.collection.JavaConversions.mapAsJavaMap; @@ -53,34 +52,30 @@ public void configure(Binder binder) { binder.bind(BigQueryConfig.class).toProvider(this::provideSparkBigQueryConfig); } + @Singleton + @Provides + public SparkSession provideSparkSession() { + return spark; + } + @Singleton @Provides public SparkBigQueryConfig provideSparkBigQueryConfig() { + Map optionsMap = new HashMap<>(options.asMap()); + // no need for the spar-avro module, we have an internal copy of avro + optionsMap.put(SparkBigQueryConfig.VALIDATE_SPARK_AVRO_PARAM.toLowerCase(), "false"); + // DataSource V2 implementation uses Java only + optionsMap.put( + SparkBigQueryConfig.INTERMEDIATE_FORMAT_OPTION.toLowerCase(), + SparkBigQueryConfig.IntermediateFormat.AVRO.toString()); + return SparkBigQueryConfig.from( - options.asMap(), + ImmutableMap.copyOf(optionsMap), ImmutableMap.copyOf(mapAsJavaMap(spark.conf().getAll())), spark.sparkContext().hadoopConfiguration(), spark.sparkContext().defaultParallelism(), spark.sqlContext().conf(), spark.version(), - Optional.empty()); - } - - @Singleton - @Provides - public BigQueryDataSourceReader provideDataSourceReader( - BigQueryClient bigQueryClient, - BigQueryReadClientFactory bigQueryReadClientFactory, - SparkBigQueryConfig config) { - TableInfo tableInfo = - bigQueryClient.getSupportedTable( - config.getTableId(), config.isViewsEnabled(), SparkBigQueryConfig.VIEWS_ENABLED_OPTION); - return new BigQueryDataSourceReader( - tableInfo, - bigQueryClient, - bigQueryReadClientFactory, - config.toReadSessionCreatorConfig(), - config.getFilter(), schema); } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/AvroSchemaConverterTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/AvroSchemaConverterTest.java new file mode 100644 index 000000000..108d612a5 --- /dev/null +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/AvroSchemaConverterTest.java @@ -0,0 +1,289 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery; + +import com.google.cloud.spark.bigquery.it.TestConstants; +import com.google.common.collect.ImmutableList; +import org.apache.avro.Conversions; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; + +import static com.google.common.truth.Truth.assertThat; + +public class AvroSchemaConverterTest { + + @Test + public void testSchemaConversion() { + StructType sparkSchema = TestConstants.ALL_TYPES_TABLE_SCHEMA(); + Schema avroSchema = AvroSchemaConverter.sparkSchemaToAvroSchema(sparkSchema); + Schema.Field[] fields = + avroSchema.getFields().toArray(new Schema.Field[avroSchema.getFields().size()]); + checkField(fields[0], "int_req", Schema.create(Schema.Type.LONG)); + checkField(fields[1], "int_null", nullable(Schema.Type.LONG)); + checkField(fields[2], "bl", nullable(Schema.Type.BOOLEAN)); + checkField(fields[3], "str", nullable(Schema.Type.STRING)); + checkField( + fields[4], + "day", + nullable(LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()))); + checkField( + fields[5], + "ts", + nullable(LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder().longType()))); + checkField(fields[6], "dt", nullable(Schema.Type.STRING)); + checkField(fields[7], "tm", nullable(Schema.Type.LONG)); + checkField(fields[8], "binary", nullable(Schema.Type.BYTES)); + checkField(fields[9], "float", nullable(Schema.Type.DOUBLE)); + checkField( + fields[10], + "nums", + nullable( + Schema.createRecord( + "nums", + null, + null, + false, + ImmutableList.of( + new Schema.Field("min", nullable(decimal("min")), null, (Object) null), + new Schema.Field("max", nullable(decimal("max")), null, (Object) null), + new Schema.Field("pi", nullable(decimal("pi")), null, (Object) null), + new Schema.Field( + "big_pi", nullable(decimal("big_pi")), null, (Object) null))))); + checkField(fields[11], "int_arr", nullable(Schema.createArray(nullable(Schema.Type.LONG)))); + checkField( + fields[12], + "int_struct_arr", + nullable( + Schema.createArray( + nullable( + Schema.createRecord( + "int_struct_arr", + null, + null, + false, + ImmutableList.of( + new Schema.Field( + "i", nullable(Schema.Type.LONG), null, (Object) null))))))); + } + + @Test + public void testConvertIntegers() { + InternalRow row = + new GenericInternalRow( + new Object[] { + Byte.valueOf("0"), Short.valueOf("1"), Integer.valueOf(2), Long.valueOf(3) + }); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of( + DataTypes.createStructField("byte_f", DataTypes.ByteType, false), + DataTypes.createStructField("short_f", DataTypes.ShortType, false), + DataTypes.createStructField("int_f", DataTypes.IntegerType, false), + DataTypes.createStructField("long_f", DataTypes.LongType, false))); + + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("byte_f") + .type(SchemaBuilder.builder().longType()) + .noDefault() // + .name("short_f") + .type(SchemaBuilder.builder().longType()) + .noDefault() // + .name("int_f") + .type(SchemaBuilder.builder().longType()) + .noDefault() // + .name("long_f") + .type(SchemaBuilder.builder().longType()) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + assertThat(result.get(0)).isEqualTo(Long.valueOf(0)); + assertThat(result.get(1)).isEqualTo(Long.valueOf(1)); + assertThat(result.get(2)).isEqualTo(Long.valueOf(2)); + assertThat(result.get(3)).isEqualTo(Long.valueOf(3)); + } + + @Test + public void testConvertNull() { + InternalRow row = new GenericInternalRow(new Object[] {null}); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of(DataTypes.createStructField("null_f", DataTypes.LongType, true))); + + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("long_f") + .type( + SchemaBuilder.unionOf() + .type(SchemaBuilder.builder().longType()) + .and() + .nullType() + .endUnion()) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + assertThat(result.get(0)).isNull(); + } + + @Test + public void testConvertNullable() { + InternalRow row = new GenericInternalRow(new Object[] {Long.valueOf(0)}); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of(DataTypes.createStructField("null_f", DataTypes.LongType, true))); + + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("long_f") + .type( + SchemaBuilder.unionOf() + .type(SchemaBuilder.builder().longType()) + .and() + .nullType() + .endUnion()) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + assertThat(result.get(0)).isEqualTo(Long.valueOf(0)); + } + + @Test + public void testConvertDecimal() { + InternalRow row = + new GenericInternalRow( + new Object[] { + Decimal.apply(BigDecimal.valueOf(123.456), SchemaConverters.BQ_NUMERIC_PRECISION, 3) + }); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of( + DataTypes.createStructField( + "decimal_f", + DataTypes.createDecimalType(SchemaConverters.BQ_NUMERIC_PRECISION, 3), + false))); + + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("decimal_f") + .type(decimal("decimal_f")) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + Conversions.DecimalConversion decimalConversion = new Conversions.DecimalConversion(); + assertThat( + decimalConversion.fromBytes( + (ByteBuffer) result.get(0), + avroSchema.getField("decimal_f").schema(), + LogicalTypes.decimal(SchemaConverters.BQ_NUMERIC_PRECISION, 3))) + .isEqualTo(BigDecimal.valueOf(123.456)); + } + + @Test + public void testConvertDoubles() { + InternalRow row = + new GenericInternalRow(new Object[] {Float.valueOf("0.0"), Double.valueOf("1.1")}); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of( + DataTypes.createStructField("float_f", DataTypes.FloatType, false), + DataTypes.createStructField("double_f", DataTypes.DoubleType, false))); + + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("float_f") + .type(SchemaBuilder.builder().doubleType()) + .noDefault() // + .name("double_f") + .type(SchemaBuilder.builder().doubleType()) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + assertThat(result.get(0)).isEqualTo(Double.valueOf(0.0)); + assertThat(result.get(1)).isEqualTo(Double.valueOf(1.1)); + } + + @Test + public void testConvertDateTime() { + InternalRow row = + new GenericInternalRow(new Object[] {Integer.valueOf(15261), Long.valueOf(1318608914000L)}); + StructType sparkSchema = + DataTypes.createStructType( + ImmutableList.of( + DataTypes.createStructField("date_f", DataTypes.DateType, false), + DataTypes.createStructField("ts_f", DataTypes.TimestampType, false))); + Schema avroSchema = + SchemaBuilder.record("root") + .fields() // + .name("date_f") + .type(LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType())) + .noDefault() // + .name("ts_f") + .type(LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder().longType())) + .noDefault() // + .endRecord(); + GenericData.Record result = + AvroSchemaConverter.sparkRowToAvroGenericData(row, sparkSchema, avroSchema); + assertThat(result.getSchema()).isEqualTo(avroSchema); + assertThat(result.get(0)).isEqualTo(15261); + assertThat(result.get(1)).isEqualTo(1318608914000L); + } + + @Test + public void testComparisonToSparkAvro() {} + + private void checkField(Schema.Field field, String name, Schema schema) { + assertThat(field.name()).isEqualTo(name); + assertThat(field.schema()).isEqualTo(schema); + } + + private Schema decimal(String name) { + return LogicalTypes.decimal(38, 9).addToSchema(SchemaBuilder.builder().bytesType()); + } + + Schema nullable(Schema schema) { + return Schema.createUnion(schema, Schema.create(Schema.Type.NULL)); + } + + Schema nullable(Schema.Type type) { + return nullable(Schema.create(type)); + } +} diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java index 4b891c0b0..da36630c7 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java @@ -101,8 +101,6 @@ public void testFieldHasDescriptionBigQueryToSpark() throws Exception { */ @Test public void testSparkToBQSchema() throws Exception { - logger.setLevel(Level.DEBUG); - StructType schema = BIG_SPARK_SCHEMA; Schema expected = BIG_BIGQUERY_SCHEMA; @@ -116,8 +114,6 @@ public void testSparkToBQSchema() throws Exception { @Test public void testSparkMapException() throws Exception { - logger.setLevel(Level.DEBUG); - try { createBigQueryColumn(SPARK_MAP_FIELD, 0); fail("Did not throw an error for an unsupported map-type"); @@ -127,8 +123,6 @@ public void testSparkMapException() throws Exception { @Test public void testDecimalTypeConversion() throws Exception { - logger.setLevel(Level.DEBUG); - assertThat(toBigQueryType(NUMERIC_SPARK_TYPE)).isEqualTo(LegacySQLTypeName.NUMERIC); try { @@ -141,8 +135,6 @@ public void testDecimalTypeConversion() throws Exception { @Test public void testTimeTypesConversions() throws Exception { - logger.setLevel(Level.DEBUG); - // FIXME: restore this check when the Vortex team adds microsecond precision, and Timestamp // conversion can be fixed. // assertThat(toBigQueryType(DataTypes.TimestampType)).isEqualTo(LegacySQLTypeName.TIMESTAMP); @@ -151,8 +143,6 @@ public void testTimeTypesConversions() throws Exception { @Test public void testDescriptionConversion() throws Exception { - logger.setLevel(Level.DEBUG); - String description = "I love bananas"; Field result = createBigQueryColumn( @@ -168,8 +158,6 @@ public void testDescriptionConversion() throws Exception { @Test public void testMaximumNestingDepthError() throws Exception { - logger.setLevel(Level.DEBUG); - StructType inner = new StructType(); StructType superRecursiveSchema = inner; for (int i = 0; i < MAX_BIGQUERY_NESTED_DEPTH + 1; i++) { diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/TestUtils.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/TestUtils.scala index ff2e075be..097ce46ff 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/TestUtils.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/TestUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.SparkSession import org.mockito.Mockito._ object TestUtils { - def getOrCreateSparkSession(applicationName : String): SparkSession = { + def getOrCreateSparkSession(applicationName: String): SparkSession = { SparkSession.builder() .appName(applicationName) .master("local") diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/IntegrationTestUtils.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/IntegrationTestUtils.scala index acaf847f7..e485b1dc3 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/IntegrationTestUtils.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/IntegrationTestUtils.scala @@ -41,3 +41,9 @@ object IntegrationTestUtils extends Logging { bq.delete(DatasetId.of(dataset), DatasetDeleteOption.deleteContents()) } } + +case class Person(name: String, friends: Seq[Friend]) + +case class Friend(age: Int, links: Seq[Link]) + +case class Link(uri: String) diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala similarity index 58% rename from connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala rename to connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala index 3709e1b9d..e96ce22b3 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala @@ -19,20 +19,15 @@ import com.google.cloud.bigquery._ import com.google.cloud.spark.bigquery.TestUtils import com.google.cloud.spark.bigquery.direct.DirectBigQueryRelation import com.google.cloud.spark.bigquery.it.TestConstants._ -import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.scalatest.concurrent.TimeLimits import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.time.SpanSugar._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite, Matchers} -import scala.collection.JavaConverters._ -class SparkBigQueryEndToEndITSuite extends FunSuite +class SparkBigQueryEndToEndReadITSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll with Matchers @@ -176,22 +171,22 @@ class SparkBigQueryEndToEndITSuite extends FunSuite assert(row(3).isInstanceOf[Long]) } - test("cache data frame in DataSource %s. Data Format %s".format(dataSourceFormat, dataFormat)) { - val allTypesTable = readAllTypesTable("bigquery") - writeToBigQuery(allTypesTable, SaveMode.Overwrite, "avro") - - val df = spark.read.format("bigquery") - .option("dataset", testDataset) - .option("table", testTable) - .option("readDataFormat", "arrow") - .load().cache() - - assert(df.head() == allTypesTable.head()) - - // read from cache - assert(df.head() == allTypesTable.head()) - assert(df.schema == allTypesTable.schema) - } +// test("cache data frame in DataSource %s. Data Format %s".format(dataSourceFormat, dataFormat)) { +// val allTypesTable = readAllTypesTable("bigquery") +// writeToBigQuery(allTypesTable, SaveMode.Overwrite, "avro") +// +// val df = spark.read.format("bigquery") +// .option("dataset", testDataset) +// .option("table", testTable) +// .option("readDataFormat", "arrow") +// .load().cache() +// +// assert(df.head() == allTypesTable.head()) +// +// // read from cache +// assert(df.head() == allTypesTable.head()) +// assert(df.schema == allTypesTable.schema) +// } test("number of partitions. DataSource %s. Data Format %s" .format(dataSourceFormat, dataFormat)) { @@ -433,7 +428,6 @@ class SparkBigQueryEndToEndITSuite extends FunSuite override def afterAll: Unit = { IntegrationTestUtils.deleteDatasetAndTables(testDataset) - spark.stop() } /** Generate a test to verify that the given DataFrame is equal to a known result. */ @@ -451,278 +445,6 @@ class SparkBigQueryEndToEndITSuite extends FunSuite } } - private def initialData = spark.createDataFrame(spark.sparkContext.parallelize( - Seq(Person("Abc", Seq(Friend(10, Seq(Link("www.abc.com"))))), - Person("Def", Seq(Friend(12, Seq(Link("www.def.com")))))))) - - private def additonalData = spark.createDataFrame(spark.sparkContext.parallelize( - Seq(Person("Xyz", Seq(Friend(10, Seq(Link("www.xyz.com"))))), - Person("Pqr", Seq(Friend(12, Seq(Link("www.pqr.com")))))))) - - // getNumRows returns BigInteger, and it messes up the matchers - private def testTableNumberOfRows = bq.getTable(testDataset, testTable).getNumRows.intValue - - private def testPartitionedTableDefinition = bq.getTable(testDataset, testTable + "_partitioned") - .getDefinition[StandardTableDefinition]() - - private def writeToBigQuery(df: DataFrame, mode: SaveMode, format: String = "parquet") = - df.write.format("bigquery") - .mode(mode) - .option("table", fullTableName) - .option("temporaryGcsBucket", temporaryGcsBucket) - .option("intermediateFormat", format) - .save() - - private def initialDataValuesExist = numberOfRowsWith("Abc") == 1 - - private def numberOfRowsWith(name: String) = - bq.query(QueryJobConfiguration.of(s"select name from $fullTableName where name='$name'")) - .getTotalRows - - private def fullTableName = s"$testDataset.$testTable" - - private def fullTableNamePartitioned = s"$testDataset.${ - testTable - }_partitioned" - - private def additionalDataValuesExist = numberOfRowsWith("Xyz") == 1 - - test("write to bq - append save mode") { - // initial write - writeToBigQuery(initialData, SaveMode.Append) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - // second write - writeToBigQuery(additonalData, SaveMode.Append) - testTableNumberOfRows shouldBe 4 - additionalDataValuesExist shouldBe true - } - - test("write to bq - error if exists save mode") { - // initial write - writeToBigQuery(initialData, SaveMode.ErrorIfExists) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - // second write - assertThrows[IllegalArgumentException] { - writeToBigQuery(additonalData, SaveMode.ErrorIfExists) - } - } - - test("write to bq - ignore save mode") { - // initial write - writeToBigQuery(initialData, SaveMode.Ignore) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - // second write - writeToBigQuery(additonalData, SaveMode.Ignore) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - additionalDataValuesExist shouldBe false - } - - test("write to bq - overwrite save mode") { - // initial write - writeToBigQuery(initialData, SaveMode.Overwrite) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - // second write - writeToBigQuery(additonalData, SaveMode.Overwrite) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe false - additionalDataValuesExist shouldBe true - } - - test("write to bq - orc format") { - // required by ORC - spark.conf.set("spark.sql.orc.impl", "native") - writeToBigQuery(initialData, SaveMode.ErrorIfExists, "orc") - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - } - - test("write to bq - avro format") { - writeToBigQuery(initialData, SaveMode.ErrorIfExists, "avro") - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - } - - test("write to bq - parquet format") { - writeToBigQuery(initialData, SaveMode.ErrorIfExists, "parquet") - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - } - - test("write to bq - simplified api") { - initialData.write.format("bigquery") - .option("temporaryGcsBucket", temporaryGcsBucket) - .save(fullTableName) - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - } - - test("write to bq - unsupported format") { - assertThrows[IllegalArgumentException] { - writeToBigQuery(initialData, SaveMode.ErrorIfExists, "something else") - } - } - - test("write all types to bq - avro format") { - val allTypesTable = readAllTypesTable("bigquery") - writeToBigQuery(allTypesTable, SaveMode.Overwrite, "avro") - - val df = spark.read.format("bigquery") - .option("dataset", testDataset) - .option("table", testTable) - .load() - - assert(df.head() == allTypesTable.head()) - assert(df.schema == allTypesTable.schema) - } - - test("streaming bq write append") { - failAfter(120 seconds) { - val schema = initialData.schema - val expressionEncoder: ExpressionEncoder[Row] = - RowEncoder(schema).resolveAndBind() - val stream = MemoryStream[Row](expressionEncoder, spark.sqlContext) - var lastBatchId: Long = 0 - val streamingDF = stream.toDF() - val cpLoc: String = "/tmp/%s-%d". - format(fullTableName, System.nanoTime()) - // Start write stream - val writeStream = streamingDF.writeStream. - format("bigquery"). - outputMode(OutputMode.Append()). - option("checkpointLocation", cpLoc). - option("table", fullTableName). - option("temporaryGcsBucket", temporaryGcsBucket). - start - - // Write to stream - stream.addData(initialData.collect()) - while (writeStream.lastProgress.batchId <= lastBatchId) { - Thread.sleep(1000L) - } - lastBatchId = writeStream.lastProgress.batchId - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - // Write to stream - stream.addData(additonalData.collect()) - while (writeStream.lastProgress.batchId <= lastBatchId) { - Thread.sleep(1000L) - } - writeStream.stop() - testTableNumberOfRows shouldBe 4 - additionalDataValuesExist shouldBe true - } - } - - test("query materialized view") { - var df = spark.read.format("bigquery") - .option("table", "bigquery-public-data:ethereum_blockchain.live_logs") - .option("viewsEnabled", "true") - .option("viewMaterializationProject", System.getenv("GOOGLE_CLOUD_PROJECT")) - .option("viewMaterializationDataset", testDataset) - .load() - } - - test("write to bq - adding the settings to spark.conf") { - spark.conf.set("temporaryGcsBucket", temporaryGcsBucket) - val df = initialData - df.write.format("bigquery") - .option("table", fullTableName) - .save() - testTableNumberOfRows shouldBe 2 - initialDataValuesExist shouldBe true - } - - test("write to bq - partitioned and clustered table") { - val df = spark.read.format("com.google.cloud.spark.bigquery") - .option("table", LIBRARIES_PROJECTS_TABLE) - .load() - .where("platform = 'Sublime'") - - df.write.format("bigquery") - .option("table", fullTableNamePartitioned) - .option("temporaryGcsBucket", temporaryGcsBucket) - .option("partitionField", "created_timestamp") - .option("clusteredFields", "platform") - .mode(SaveMode.Overwrite) - .save() - - val tableDefinition = testPartitionedTableDefinition - tableDefinition.getTimePartitioning.getField shouldBe "created_timestamp" - tableDefinition.getClustering.getFields should contain("platform") - } - - test("overwrite single partition") { - // create partitioned table - val tableName = "partitioned_table" - val fullTableName = s"$testDataset.$tableName" - bq.create(TableInfo.of( - TableId.of(testDataset, tableName), - StandardTableDefinition.newBuilder() - .setSchema(Schema.of( - Field.of("the_date", LegacySQLTypeName.DATE), - Field.of("some_text", LegacySQLTypeName.STRING) - )) - .setTimePartitioning(TimePartitioning.newBuilder(TimePartitioning.Type.DAY) - .setField("the_date").build()).build())) - // entering the data - bq.query(QueryJobConfiguration.of( - s""" - |insert into `$fullTableName` (the_date, some_text) values - |('2020-07-01', 'foo'), - |('2020-07-02', 'bar') - |""".stripMargin.replace('\n', ' '))) - - // overrding a single partition - val newDataDF = spark.createDataFrame( - List(Row(java.sql.Date.valueOf("2020-07-01"), "baz")).asJava, - StructType(Array( - StructField("the_date", DateType), - StructField("some_text", StringType)))) - - newDataDF.write.format("bigquery") - .option("temporaryGcsBucket", temporaryGcsBucket) - .option("datePartition", "20200701") - .mode("overwrite") - .save(fullTableName) - - val result = spark.read.format("bigquery").load(fullTableName).collect() - - result.size shouldBe 2 - result.filter(row => row(1).equals("bar")).size shouldBe 1 - result.filter(row => row(1).equals("baz")).size shouldBe 1 - - } - - test("support custom data types") { - val table = s"$testDataset.$testTable" - - val originalVectorDF = spark.createDataFrame( - List(Row("row1", 1, Vectors.dense(1, 2, 3))).asJava, - StructType(Seq( - StructField("name", DataTypes.StringType), - StructField("num", DataTypes.IntegerType), - StructField("vector", SQLDataTypes.VectorType)))) - - originalVectorDF.write.format("bigquery") - // must use avro or orc - .option("intermediateFormat", "orc") - .option("temporaryGcsBucket", temporaryGcsBucket) - .save(table) - - val readVectorDF = spark.read.format("bigquery") - .load(table) - - val orig = originalVectorDF.head - val read = readVectorDF.head - - read should equal(orig) - } - def extractWords(df: DataFrame): Set[String] = { df.select("word") .where("corpus_date = 0") @@ -730,50 +452,5 @@ class SparkBigQueryEndToEndITSuite extends FunSuite .map(_.getString(0)) .toSet } - - test("hourly partition") { - testPartition("HOUR") - } - - test("daily partition") { - testPartition("DAY") - } - - test("monthly partition") { - testPartition("MONTH") - } - - test("yearly partition") { - testPartition("YEAR") - } - - def testPartition(partitionType: String): Unit = { - val s = spark // cannot import from a var - import s.implicits._ - val df = spark.createDataset(Seq( - Data("a", java.sql.Timestamp.valueOf("2020-01-01 01:01:01")), - Data("b", java.sql.Timestamp.valueOf("2020-01-02 02:02:02")), - Data("c", java.sql.Timestamp.valueOf("2020-01-03 03:03:03")) - )).toDF() - - val table = s"${testDataset}.${testTable}_${partitionType}" - df.write.format("bigquery") - .option("temporaryGcsBucket", temporaryGcsBucket) - .option("partitionField", "t") - .option("partitionType", partitionType) - .option("partitionRequireFilter", "true") - .option("table", table) - .save() - - val readDF = spark.read.format("bigquery").load(table) - assert(readDF.count == 3) - } } -case class Person(name: String, friends: Seq[Friend]) - -case class Friend(age: Int, links: Seq[Link]) - -case class Link(uri: String) - -case class Data(str: String, t: java.sql.Timestamp) diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndWriteITSuite.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndWriteITSuite.scala new file mode 100644 index 000000000..2148c0e4d --- /dev/null +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndWriteITSuite.scala @@ -0,0 +1,443 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.it + +import java.util.UUID + +import com.google.cloud.bigquery._ +import com.google.cloud.spark.bigquery.TestUtils +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.scalatest.concurrent.TimeLimits +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.time.SpanSugar._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite, Matchers} + +import scala.collection.JavaConverters._ + +class SparkBigQueryEndToEndWriteITSuite extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Matchers + with TimeLimits + with TableDrivenPropertyChecks { + + val temporaryGcsBucket = "davidrab-sandbox" + val bq = BigQueryOptions.getDefaultInstance.getService + private val LIBRARIES_PROJECTS_TABLE = "bigquery-public-data.libraries_io.projects" + private val ALL_TYPES_TABLE_NAME = "all_types" + private var spark: SparkSession = _ + private var testDataset: String = _ + + private def metadata(key: String, value: String): Metadata = metadata(Map(key -> value)) + + private def metadata(map: Map[String, String]): Metadata = { + val metadata = new MetadataBuilder() + for ((key, value) <- map) { + metadata.putString(key, value) + } + metadata.build() + } + + before { + // have a fresh table for each test + testTable = s"test_${System.nanoTime()}" + } + private var testTable: String = _ + + override def beforeAll: Unit = { + spark = TestUtils.getOrCreateSparkSession("SparkBigQueryEndToEndWriteITSuite") + // spark.conf.set("spark.sql.codegen.factoryMode", "NO_CODEGEN") + // System.setProperty("spark.testing", "true") + testDataset = s"spark_bigquery_it_${System.currentTimeMillis()}" + IntegrationTestUtils.createDataset(testDataset) + IntegrationTestUtils.runQuery( + TestConstants.ALL_TYPES_TABLE_QUERY_TEMPLATE.format(s"$testDataset.$ALL_TYPES_TABLE_NAME")) + } + + + // Write tests. We have four save modes: Append, ErrorIfExists, Ignore and + // Overwrite. For each there are two behaviours - the table exists or not. + // See more at http://spark.apache.org/docs/2.3.2/api/java/org/apache/spark/sql/SaveMode.html + + override def afterAll: Unit = { + IntegrationTestUtils.deleteDatasetAndTables(testDataset) + } + + private def initialData = spark.createDataFrame(spark.sparkContext.parallelize( + Seq(Person("Abc", Seq(Friend(10, Seq(Link("www.abc.com"))))), + Person("Def", Seq(Friend(12, Seq(Link("www.def.com")))))))) + + private def additonalData = spark.createDataFrame(spark.sparkContext.parallelize( + Seq(Person("Xyz", Seq(Friend(10, Seq(Link("www.xyz.com"))))), + Person("Pqr", Seq(Friend(12, Seq(Link("www.pqr.com")))))))) + + // getNumRows returns BigInteger, and it messes up the matchers + private def testTableNumberOfRows = bq.getTable(testDataset, testTable).getNumRows.intValue + + private def testPartitionedTableDefinition = bq.getTable(testDataset, testTable + "_partitioned") + .getDefinition[StandardTableDefinition]() + + private def writeToBigQuery( + dataSource: String, + df: DataFrame, + mode: SaveMode, + format: String = "parquet") = + df.write.format(dataSource) + .mode(mode) + .option("table", fullTableName) + .option("temporaryGcsBucket", temporaryGcsBucket) + .option("intermediateFormat", format) + .save() + + private def initialDataValuesExist = numberOfRowsWith("Abc") == 1 + + private def numberOfRowsWith(name: String) = + bq.query(QueryJobConfiguration.of(s"select name from $fullTableName where name='$name'")) + .getTotalRows + + private def fullTableName = s"$testDataset.$testTable" + + private def fullTableNamePartitioned = s"$testDataset.${testTable}_partitioned" + + private def additionalDataValuesExist = numberOfRowsWith("Xyz") == 1 + + def readAllTypesTable(dataSourceFormat: String): DataFrame = + spark.read.format(dataSourceFormat) + .option("dataset", testDataset) + .option("table", ALL_TYPES_TABLE_NAME) + .load() + + + Seq("bigquery", "com.google.cloud.spark.bigquery.v2.BigQueryDataSourceV2") + .foreach(testsWithDataSource) + + def testsWithDataSource(dataSourceFormat: String) { + + test("write to bq - append save mode. DataSource %s".format(dataSourceFormat)) { + // initial write + writeToBigQuery(dataSourceFormat, initialData, SaveMode.Append) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + // second write + writeToBigQuery(dataSourceFormat, additonalData, SaveMode.Append) + testTableNumberOfRows shouldBe 4 + additionalDataValuesExist shouldBe true + } + + test("write to bq - error if exists save mode. DataSource %s".format(dataSourceFormat)) { + // initial write + writeToBigQuery(dataSourceFormat, initialData, SaveMode.ErrorIfExists) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + // second write + assertThrows[IllegalArgumentException] { + writeToBigQuery(dataSourceFormat, additonalData, SaveMode.ErrorIfExists) + } + } + + test("write to bq - ignore save mode. DataSource %s".format(dataSourceFormat)) { + // initial write + writeToBigQuery(dataSourceFormat, initialData, SaveMode.Ignore) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + // second write + writeToBigQuery(dataSourceFormat, additonalData, SaveMode.Ignore) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + additionalDataValuesExist shouldBe false + } + + test("write to bq - overwrite save mode. DataSource %s".format(dataSourceFormat)) { + // initial write + writeToBigQuery(dataSourceFormat, initialData, SaveMode.Overwrite) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + // second write + writeToBigQuery(dataSourceFormat, additonalData, SaveMode.Overwrite) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe false + additionalDataValuesExist shouldBe true + } + + test("write to bq - orc format. DataSource %s".format(dataSourceFormat)) { + // v2 does not support ORC + if (dataSourceFormat.equals("bigquery")) { + // required by ORC + spark.conf.set("spark.sql.orc.impl", "native") + writeToBigQuery(dataSourceFormat, initialData, SaveMode.ErrorIfExists, "orc") + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + } + } + + test("write to bq - avro format. DataSource %s".format(dataSourceFormat)) { + writeToBigQuery(dataSourceFormat, initialData, SaveMode.ErrorIfExists, "avro") + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + } + + test("write to bq - parquet format. DataSource %s".format(dataSourceFormat)) { + // v2 does not support parquet + if (dataSourceFormat.equals("bigquery")) { + writeToBigQuery(dataSourceFormat, initialData, SaveMode.ErrorIfExists, "parquet") + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + } + } + + test("write to bq - simplified api. DataSource %s".format(dataSourceFormat)) { + initialData.write.format(dataSourceFormat) + .option("temporaryGcsBucket", temporaryGcsBucket) + .save(fullTableName) + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + } + + test("write to bq - unsupported format. DataSource %s".format(dataSourceFormat)) { + if (dataSourceFormat.equals("bigquery")) { + assertThrows[Exception] { + writeToBigQuery(dataSourceFormat, initialData, SaveMode.ErrorIfExists, "something else") + } + } + } + + test("write all types to bq - avro format. DataSource %s".format(dataSourceFormat)) { + val allTypesTable = readAllTypesTable(dataSourceFormat) + writeToBigQuery(dataSourceFormat, allTypesTable, SaveMode.Overwrite, "avro") + + val df = spark.read.format(dataSourceFormat) + .option("dataset", testDataset) + .option("table", testTable) + .load() + + assert(df.head() == allTypesTable.head()) + assert(df.schema == allTypesTable.schema) + } + + test("query materialized view. DataSource %s".format(dataSourceFormat)) { + var df = spark.read.format(dataSourceFormat) + .option("table", "bigquery-public-data:ethereum_blockchain.live_logs") + .option("viewsEnabled", "true") + .option("viewMaterializationProject", System.getenv("GOOGLE_CLOUD_PROJECT")) + .option("viewMaterializationDataset", testDataset) + .load() + } + + test("write to bq - adding the settings to spark.conf. DataSource %s" + .format(dataSourceFormat)) { + spark.conf.set("temporaryGcsBucket", temporaryGcsBucket) + val df = initialData + df.write.format(dataSourceFormat) + .option("table", fullTableName) + .save() + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + } + + test("write to bq - partitioned and clustered table. DataSource %s".format(dataSourceFormat)) { + val df = spark.read.format("com.google.cloud.spark.bigquery") + .option("table", LIBRARIES_PROJECTS_TABLE) + .load() + .where("platform = 'Sublime'") + + df.write.format(dataSourceFormat) + .option("table", fullTableNamePartitioned) + .option("temporaryGcsBucket", temporaryGcsBucket) + .option("partitionField", "created_timestamp") + .option("clusteredFields", "platform") + .mode(SaveMode.Overwrite) + .save() + + val tableDefinition = testPartitionedTableDefinition + tableDefinition.getTimePartitioning.getField shouldBe "created_timestamp" + tableDefinition.getClustering.getFields should contain("platform") + } + + test("overwrite single partition. DataSource %s".format(dataSourceFormat)) { + // create partitioned table + val tableName = s"partitioned_table_$randomSuffix" + val fullTableName = s"$testDataset.$tableName" + bq.create(TableInfo.of( + TableId.of(testDataset, tableName), + StandardTableDefinition.newBuilder() + .setSchema(Schema.of( + Field.of("the_date", LegacySQLTypeName.DATE), + Field.of("some_text", LegacySQLTypeName.STRING) + )) + .setTimePartitioning(TimePartitioning.newBuilder(TimePartitioning.Type.DAY) + .setField("the_date").build()).build())) + // entering the data + bq.query(QueryJobConfiguration.of( + s""" + |insert into `$fullTableName` (the_date, some_text) values + |('2020-07-01', 'foo'), + |('2020-07-02', 'bar') + |""".stripMargin.replace('\n', ' '))) + + // overrding a single partition + val newDataDF = spark.createDataFrame( + List(Row(java.sql.Date.valueOf("2020-07-01"), "baz")).asJava, + StructType(Array( + StructField("the_date", DateType), + StructField("some_text", StringType)))) + + newDataDF.write.format(dataSourceFormat) + .option("temporaryGcsBucket", temporaryGcsBucket) + .option("datePartition", "20200701") + .mode("overwrite") + .save(fullTableName) + + val result = spark.read.format(dataSourceFormat).load(fullTableName).collect() + + result.size shouldBe 2 + result.filter(row => row(1).equals("bar")).size shouldBe 1 + result.filter(row => row(1).equals("baz")).size shouldBe 1 + + } + + // test("support custom data types. DataSource %s".format(dataSourceFormat)) { + // val table = s"$testDataset.$testTable" + // + // val originalVectorDF = spark.createDataFrame( + // List(Row("row1", 1, Vectors.dense(1, 2, 3))).asJava, + // StructType(Seq( + // StructField("name", DataTypes.StringType), + // StructField("num", DataTypes.IntegerType), + // StructField("vector", SQLDataTypes.VectorType)))) + // + // originalVectorDF.write.format(dataSourceFormat) + // // must use avro or orc + // .option("intermediateFormat", "avro") + // .option("temporaryGcsBucket", temporaryGcsBucket) + // .save(table) + // + // val readVectorDF = spark.read.format(dataSourceFormat) + // .load(table) + // + // val orig = originalVectorDF.head + // val read = readVectorDF.head + // + // read should equal(orig) + // } + + test("compare read formats DataSource %s".format(dataSourceFormat)) { + val allTypesTable = readAllTypesTable(dataSourceFormat) + writeToBigQuery(dataSourceFormat, allTypesTable, SaveMode.Overwrite, "avro") + + val df = spark.read.format(dataSourceFormat) + .option("dataset", testDataset) + .option("table", testTable) + .option("readDataFormat", "arrow") + .load().cache() + + assert(df.head() == allTypesTable.head()) + + // read from cache + assert(df.head() == allTypesTable.head()) + assert(df.schema == allTypesTable.schema) + } + + } + + private def randomSuffix: String = { + val uuid = UUID.randomUUID() + java.lang.Long.toHexString(uuid.getMostSignificantBits) + + java.lang.Long.toHexString(uuid.getLeastSignificantBits) + } + + test("streaming bq write append") { + failAfter(120 seconds) { + val schema = initialData.schema + val expressionEncoder: ExpressionEncoder[Row] = + RowEncoder(schema).resolveAndBind() + val stream = MemoryStream[Row](expressionEncoder, spark.sqlContext) + var lastBatchId: Long = 0 + val streamingDF = stream.toDF() + val cpLoc: String = "/tmp/%s-%d". + format(fullTableName, System.nanoTime()) + // Start write stream + val writeStream = streamingDF.writeStream. + format("bigquery"). + outputMode(OutputMode.Append()). + option("checkpointLocation", cpLoc). + option("table", fullTableName). + option("temporaryGcsBucket", temporaryGcsBucket). + start + + // Write to stream + stream.addData(initialData.collect()) + while (writeStream.lastProgress.batchId <= lastBatchId) { + Thread.sleep(1000L) + } + lastBatchId = writeStream.lastProgress.batchId + testTableNumberOfRows shouldBe 2 + initialDataValuesExist shouldBe true + // Write to stream + stream.addData(additonalData.collect()) + while (writeStream.lastProgress.batchId <= lastBatchId) { + Thread.sleep(1000L) + } + writeStream.stop() + testTableNumberOfRows shouldBe 4 + additionalDataValuesExist shouldBe true + } + } + + test("hourly partition") { + testPartition("HOUR") + } + + test("daily partition") { + testPartition("DAY") + } + + test("monthly partition") { + testPartition("MONTH") + } + + test("yearly partition") { + testPartition("YEAR") + } + + def testPartition(partitionType: String): Unit = { + val s = spark // cannot import from a var + import s.implicits._ + val df = spark.createDataset(Seq( + Data("a", java.sql.Timestamp.valueOf("2020-01-01 01:01:01")), + Data("b", java.sql.Timestamp.valueOf("2020-01-02 02:02:02")), + Data("c", java.sql.Timestamp.valueOf("2020-01-03 03:03:03")) + )).toDF() + + val table = s"${testDataset}.${testTable}_${partitionType}" + df.write.format("bigquery") + .option("temporaryGcsBucket", temporaryGcsBucket) + .option("partitionField", "t") + .option("partitionType", partitionType) + .option("partitionRequireFilter", "true") + .option("table", table) + .save() + + val readDF = spark.read.format("bigquery").load(table) + assert(readDF.count == 3) + } +} + +case class Data(str: String, t: java.sql.Timestamp) + diff --git a/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java index 6967c0d72..8a937810d 100644 --- a/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java +++ b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.google.cloud.spark.bigquery; import io.netty.buffer.ArrowBuf; diff --git a/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/AvroSchemaConverter.java b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/AvroSchemaConverter.java new file mode 100644 index 000000000..5e113eef7 --- /dev/null +++ b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/AvroSchemaConverter.java @@ -0,0 +1,305 @@ +/* + * Copyright 2020 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery; + +import com.google.common.base.Preconditions; +import org.apache.avro.Conversions; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.util.Utf8; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Optional; + +public class AvroSchemaConverter { + + private static final Schema NULL = Schema.create(Schema.Type.NULL); + private static final Conversions.DecimalConversion DECIMAL_CONVERSIONS = + new Conversions.DecimalConversion(); + + public static Schema sparkSchemaToAvroSchema(StructType sparkSchema) { + return sparkTypeToRawAvroType(sparkSchema, false, "root"); + } + + static Schema sparkTypeToRawAvroType(DataType dataType, boolean nullable, String recordName) { + SchemaBuilder.TypeBuilder builder = SchemaBuilder.builder(); + Schema avroType = sparkTypeToRawAvroType(dataType, recordName, builder); + + if (nullable) { + avroType = Schema.createUnion(avroType, NULL); + } + + return avroType; + } + + static Schema sparkTypeToRawAvroType( + DataType dataType, String recordName, SchemaBuilder.TypeBuilder builder) { + if (dataType instanceof BinaryType) { + return builder.bytesType(); + } + if (dataType instanceof ByteType + || dataType instanceof ShortType + || dataType instanceof IntegerType + || dataType instanceof LongType) { + return builder.longType(); + } + if (dataType instanceof BooleanType) { + return builder.booleanType(); + } + if (dataType instanceof FloatType || dataType instanceof DoubleType) { + return builder.doubleType(); + } + if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + if (decimalType.precision() <= SchemaConverters.BQ_NUMERIC_PRECISION + && decimalType.scale() <= SchemaConverters.BQ_NUMERIC_SCALE) { + return LogicalTypes.decimal(decimalType.precision(), decimalType.scale()) + .addToSchema(builder.bytesType()); + } else { + throw new IllegalArgumentException( + "Decimal type is too wide to fit in BigQuery Numeric format"); + } + } + if (dataType instanceof StringType) { + return builder.stringType(); + } + if (dataType instanceof TimestampType) { + // return builder.TIMESTAMP; FIXME: Restore this correct conversion when the Vortex + // team adds microsecond support to their backend + return LogicalTypes.timestampMicros().addToSchema(builder.longType()); + } + if (dataType instanceof DateType) { + return LogicalTypes.date().addToSchema(builder.intType()); + } + if (dataType instanceof ArrayType) { + return builder + .array() + .items( + sparkTypeToRawAvroType( + ((ArrayType) dataType).elementType(), + ((ArrayType) dataType).containsNull(), + recordName)); + } + if (dataType instanceof StructType) { + SchemaBuilder.FieldAssembler fieldsAssembler = builder.record(recordName).fields(); + for (StructField field : ((StructType) dataType).fields()) { + Schema avroType = sparkTypeToRawAvroType(field.dataType(), field.nullable(), field.name()); + fieldsAssembler.name(field.name()).type(avroType).noDefault(); + } + return fieldsAssembler.endRecord(); + } + if (dataType instanceof UserDefinedType) { + DataType userDefinedType = ((UserDefinedType) dataType).sqlType(); + return sparkTypeToRawAvroType(userDefinedType, recordName, builder); + } + if (dataType instanceof MapType) { + throw new IllegalArgumentException(SchemaConverters.MAPTYPE_ERROR_MESSAGE); + } else { + throw new IllegalArgumentException("Data type not supported: " + dataType.simpleString()); + } + } + + public static GenericData.Record sparkRowToAvroGenericData( + InternalRow row, StructType sparkSchema, Schema avroSchema) { + StructConverter structConverter = new StructConverter(sparkSchema, avroSchema); + return structConverter.convert(row); + } + + static Schema resolveNullableType(Schema avroType, boolean nullable) { + if (nullable && avroType.getType() != Schema.Type.NULL) { + // avro uses union to represent nullable type. + List fields = avroType.getTypes(); + Preconditions.checkArgument( + fields.size() == 2, "Avro nullable filed should be represented by a union of size 2"); + Optional actualType = + fields.stream().filter(field -> field.getType() != Schema.Type.NULL).findFirst(); + return actualType.orElseThrow( + () -> new IllegalArgumentException("No actual type has been found in " + avroType)); + } else { + return avroType; + } + } + + static Converter createConverterFor(DataType sparkType, Schema avroType) { + if (sparkType instanceof NullType && avroType.getType() == Schema.Type.NULL) { + return (getter, ordinal) -> null; + } + if (sparkType instanceof BooleanType && avroType.getType() == Schema.Type.BOOLEAN) { + return (getter, ordinal) -> getter.getBoolean(ordinal); + } + if (sparkType instanceof ByteType && avroType.getType() == Schema.Type.LONG) { + return (getter, ordinal) -> Long.valueOf(getter.getByte(ordinal)); + } + if (sparkType instanceof ShortType && avroType.getType() == Schema.Type.LONG) { + return (getter, ordinal) -> Long.valueOf(getter.getShort(ordinal)); + } + if (sparkType instanceof IntegerType && avroType.getType() == Schema.Type.LONG) { + return (getter, ordinal) -> Long.valueOf(getter.getInt(ordinal)); + } + if (sparkType instanceof LongType && avroType.getType() == Schema.Type.LONG) { + return (getter, ordinal) -> getter.getLong(ordinal); + } + if (sparkType instanceof FloatType && avroType.getType() == Schema.Type.DOUBLE) { + return (getter, ordinal) -> Double.valueOf(getter.getFloat(ordinal)); + } + if (sparkType instanceof DoubleType && avroType.getType() == Schema.Type.DOUBLE) { + return (getter, ordinal) -> getter.getDouble(ordinal); + } + if (sparkType instanceof DecimalType && avroType.getType() == Schema.Type.BYTES) { + DecimalType decimalType = (DecimalType) sparkType; + return (getter, ordinal) -> { + Decimal decimal = getter.getDecimal(ordinal, decimalType.precision(), decimalType.scale()); + return DECIMAL_CONVERSIONS.toBytes( + decimal.toJavaBigDecimal(), + avroType, + LogicalTypes.decimal(decimalType.precision(), decimalType.scale())); + }; + } + if (sparkType instanceof StringType && avroType.getType() == Schema.Type.STRING) { + return (getter, ordinal) -> new Utf8(getter.getUTF8String(ordinal).getBytes()); + } + if (sparkType instanceof BinaryType && avroType.getType() == Schema.Type.FIXED) { + int size = avroType.getFixedSize(); + return (getter, ordinal) -> { + byte[] data = getter.getBinary(ordinal); + if (data.length != size) { + throw new IllegalArgumentException( + String.format( + "Cannot write %s bytes of binary data into FIXED Type with size of %s bytes", + data.length, size)); + } + return new GenericData.Fixed(avroType, data); + }; + } + if (sparkType instanceof BinaryType && avroType.getType() == Schema.Type.BYTES) { + return (getter, ordinal) -> ByteBuffer.wrap(getter.getBinary(ordinal)); + } + + if (sparkType instanceof DateType && avroType.getType() == Schema.Type.INT) { + return (getter, ordinal) -> getter.getInt(ordinal); + } + + if (sparkType instanceof TimestampType && avroType.getType() == Schema.Type.LONG) { + return (getter, ordinal) -> getter.getLong(ordinal); + } + + if (sparkType instanceof ArrayType && avroType.getType() == Schema.Type.ARRAY) { + DataType et = ((ArrayType) sparkType).elementType(); + boolean containsNull = ((ArrayType) sparkType).containsNull(); + + Converter elementConverter = + createConverterFor(et, resolveNullableType(avroType.getElementType(), containsNull)); + return (getter, ordinal) -> { + ArrayData arrayData = getter.getArray(ordinal); + int len = arrayData.numElements(); + Object[] result = new Object[len]; + for (int i = 0; i < len; i++) { + if (containsNull && arrayData.isNullAt(i)) { + result[i] = null; + } else { + result[i] = elementConverter.convert(arrayData, i); + } + } + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + return java.util.Arrays.asList(result); + }; + } + if (sparkType instanceof StructType && avroType.getType() == Schema.Type.RECORD) { + StructType sparkStruct = (StructType) sparkType; + + StructConverter structConverter = new StructConverter(sparkStruct, avroType); + int numFields = sparkStruct.length(); + return (getter, ordinal) -> structConverter.convert(getter.getStruct(ordinal, numFields)); + } + if (sparkType instanceof UserDefinedType) { + UserDefinedType userDefinedType = (UserDefinedType) sparkType; + return createConverterFor(userDefinedType.sqlType(), avroType); + } + throw new IllegalArgumentException( + String.format("Cannot convert Catalyst type %s to Avro type %s", sparkType, avroType)); + } + + @FunctionalInterface + interface Converter { + Object convert(SpecializedGetters getters, int ordinal); + } + + static class StructConverter { + private final StructType sparkStruct; + private final Schema avroStruct; + + StructConverter(StructType sparkStruct, Schema avroStruct) { + this.sparkStruct = sparkStruct; + this.avroStruct = avroStruct; + Preconditions.checkArgument( + avroStruct.getType() == Schema.Type.RECORD + && avroStruct.getFields().size() == sparkStruct.length(), + "Cannot convert Catalyst type %s to Avro type %s.", + sparkStruct, + avroStruct); + } + + GenericData.Record convert(InternalRow row) { + int numFields = sparkStruct.length(); + Converter[] fieldConverters = new Converter[numFields]; + StructField[] sparkFields = sparkStruct.fields(); + Schema.Field[] avroFields = avroStruct.getFields().toArray(new Schema.Field[numFields]); + + GenericData.Record result = new GenericData.Record(avroStruct); + + for (int i = 0; i < numFields; i++) { + if (row.isNullAt(i)) { + result.put(i, null); + } else { + Converter fieldConverter = + AvroSchemaConverter.createConverterFor( + sparkFields[i].dataType(), + AvroSchemaConverter.resolveNullableType( + avroFields[i].schema(), sparkFields[i].nullable())); + result.put(i, fieldConverter.convert(row, i)); + } + } + return result; + } + } +}