Skip to content

Commit

Permalink
Fixed Spark 3 compatibility (#242)
Browse files Browse the repository at this point in the history
PR #240 has unified the use of `SparkBigQueryConfig` also for the original DataSource v1 implementation, used in Spark 3. This implementation relies on `org.apache.spark.sql.sources.v2.DataSourceOptions` which does not exist in Spark 3 due to the massive DataSource API change it had.

This PR addresses that by:
* Removing `DataSourceOptions` from `SparkBigQueryConfig`
* Adding acceptance test for Dataproc image 2.0 (currently in preview), with Spark 3
  • Loading branch information
davidrabinowitz committed Sep 11, 2020
1 parent ea5a202 commit d467538
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.sql.execution.datasources.DataSource;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.types.StructType;

import java.io.Serializable;
Expand All @@ -51,6 +50,7 @@
public class SparkBigQueryConfig implements BigQueryConfig, Serializable {

public static final String VIEWS_ENABLED_OPTION = "viewsEnabled";
public static final String DATE_PARTITION_PARAM = "datePartition";
@VisibleForTesting static final DataFormat DEFAULT_READ_DATA_FORMAT = DataFormat.ARROW;

@VisibleForTesting
Expand All @@ -63,13 +63,9 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable {
private static final String READ_DATA_FORMAT_OPTION = "readDataFormat";
private static final ImmutableList<String> PERMITTED_READ_DATA_FORMATS =
ImmutableList.of(DataFormat.ARROW.toString(), DataFormat.AVRO.toString());

private static final Supplier<com.google.common.base.Optional<String>> DEFAULT_FALLBACK =
() -> empty();

private static final String CONF_PREFIX = "spark.datasource.bigquery.";
public static final String DATE_PARTITION_PARAM = "datePartition";

TableId tableId;
String parentProjectId;
com.google.common.base.Optional<String> credentialsKey;
Expand Down Expand Up @@ -106,26 +102,8 @@ public class SparkBigQueryConfig implements BigQueryConfig, Serializable {
}

@VisibleForTesting
public static SparkBigQueryConfig fromV1(
Map<String, String> parameters,
Map<String, String> globalOptions,
Configuration hadoopConfiguration,
int defaultParallelism,
SQLConf sqlConf,
String sparkVersion,
Optional<StructType> schema) {
return from(
new DataSourceOptions(parameters),
ImmutableMap.copyOf(globalOptions),
hadoopConfiguration,
defaultParallelism,
sqlConf,
sparkVersion,
schema);
}

public static SparkBigQueryConfig from(
DataSourceOptions options,
Map<String, String> options,
ImmutableMap<String, String> globalOptions,
Configuration hadoopConfiguration,
int defaultParallelism,
Expand Down Expand Up @@ -245,41 +223,37 @@ private static void validateDateFormat(String date, String optionName) {
}
}

public Credentials createCredentials() {
return new BigQueryCredentialsSupplier(
accessToken.toJavaUtil(), credentialsKey.toJavaUtil(), credentialsFile.toJavaUtil())
.getCredentials();
}

private static com.google.common.base.Supplier<String> defaultBilledProject() {
return () -> BigQueryOptions.getDefaultInstance().getProjectId();
}

private static String getRequiredOption(DataSourceOptions options, String name) {
private static String getRequiredOption(Map<String, String> options, String name) {
return getOption(options, name, DEFAULT_FALLBACK)
.toJavaUtil()
.orElseThrow(() -> new IllegalArgumentException(format("Option %s required.", name)));
}

private static String getRequiredOption(
DataSourceOptions options, String name, com.google.common.base.Supplier<String> fallback) {
Map<String, String> options, String name, com.google.common.base.Supplier<String> fallback) {
return getOption(options, name, DEFAULT_FALLBACK).or(fallback);
}

private static com.google.common.base.Optional<String> getOption(
DataSourceOptions options, String name) {
Map<String, String> options, String name) {
return getOption(options, name, DEFAULT_FALLBACK);
}

private static com.google.common.base.Optional<String> getOption(
DataSourceOptions options,
Map<String, String> options,
String name,
Supplier<com.google.common.base.Optional<String>> fallback) {
return fromJavaUtil(firstPresent(options.get(name), fallback.get().toJavaUtil()));
return fromJavaUtil(
firstPresent(
Optional.ofNullable(options.get(name.toLowerCase())), fallback.get().toJavaUtil()));
}

private static com.google.common.base.Optional<String> getOptionFromMultipleParams(
DataSourceOptions options,
Map<String, String> options,
Collection<String> names,
Supplier<com.google.common.base.Optional<String>> fallback) {
return names.stream()
Expand All @@ -290,16 +264,16 @@ private static com.google.common.base.Optional<String> getOptionFromMultiplePara
}

private static com.google.common.base.Optional<String> getAnyOption(
ImmutableMap<String, String> globalOptions, DataSourceOptions options, String name) {
return com.google.common.base.Optional.fromNullable(
options.get(name).orElse(globalOptions.get(name)));
ImmutableMap<String, String> globalOptions, Map<String, String> options, String name) {
return com.google.common.base.Optional.fromNullable(options.get(name.toLowerCase()))
.or(com.google.common.base.Optional.fromNullable(globalOptions.get(name)));
}

// gives the option to support old configurations as fallback
// Used to provide backward compatibility
private static com.google.common.base.Optional<String> getAnyOption(
ImmutableMap<String, String> globalOptions,
DataSourceOptions options,
Map<String, String> options,
Collection<String> names) {
return names.stream()
.map(name -> getAnyOption(globalOptions, options, name))
Expand All @@ -310,12 +284,38 @@ private static com.google.common.base.Optional<String> getAnyOption(

private static boolean getAnyBooleanOption(
ImmutableMap<String, String> globalOptions,
DataSourceOptions options,
Map<String, String> options,
String name,
boolean defaultValue) {
return getAnyOption(globalOptions, options, name).transform(Boolean::valueOf).or(defaultValue);
}

static Map<String, String> normalizeConf(Map<String, String> conf) {
Map<String, String> normalizeConf =
conf.entrySet().stream()
.filter(e -> e.getKey().startsWith(CONF_PREFIX))
.collect(
Collectors.toMap(
e -> e.getKey().substring(CONF_PREFIX.length()), e -> e.getValue()));
Map<String, String> result = new HashMap<>(conf);
result.putAll(normalizeConf);
return ImmutableMap.copyOf(result);
}

private static com.google.common.base.Optional empty() {
return com.google.common.base.Optional.absent();
}

private static com.google.common.base.Optional fromJavaUtil(Optional o) {
return com.google.common.base.Optional.fromJavaUtil(o);
}

public Credentials createCredentials() {
return new BigQueryCredentialsSupplier(
accessToken.toJavaUtil(), credentialsKey.toJavaUtil(), credentialsFile.toJavaUtil())
.getCredentials();
}

public TableId getTableId() {
return tableId;
}
Expand Down Expand Up @@ -449,18 +449,6 @@ public ReadSessionCreatorConfig toReadSessionCreatorConfig() {
defaultParallelism);
}

static Map<String, String> normalizeConf(Map<String, String> conf) {
Map<String, String> normalizeConf =
conf.entrySet().stream()
.filter(e -> e.getKey().startsWith(CONF_PREFIX))
.collect(
Collectors.toMap(
e -> e.getKey().substring(CONF_PREFIX.length()), e -> e.getValue()));
Map<String, String> result = new HashMap<>(conf);
result.putAll(normalizeConf);
return ImmutableMap.copyOf(result);
}

enum IntermediateFormat {
AVRO("avro", FormatOptions.avro()),
AVRO_2_3("com.databricks.spark.avro", FormatOptions.avro()),
Expand All @@ -481,14 +469,6 @@ enum IntermediateFormat {
this.formatOptions = formatOptions;
}

public String getDataSource() {
return dataSource;
}

public FormatOptions getFormatOptions() {
return formatOptions;
}

public static IntermediateFormat from(String format, String sparkVersion, SQLConf sqlConf) {
Preconditions.checkArgument(
PERMITTED_DATA_SOURCES.contains(format.toLowerCase()),
Expand Down Expand Up @@ -540,13 +520,13 @@ private static IllegalStateException missingAvroException(

return new IllegalStateException(message, cause);
}
}

private static com.google.common.base.Optional empty() {
return com.google.common.base.Optional.absent();
}
public String getDataSource() {
return dataSource;
}

private static com.google.common.base.Optional fromJavaUtil(Optional o) {
return com.google.common.base.Optional.fromJavaUtil(o);
public FormatOptions getFormatOptions() {
return formatOptions;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void configure(Binder binder) {
@Provides
public SparkBigQueryConfig provideSparkBigQueryConfig() {
return SparkBigQueryConfig.from(
options,
options.asMap(),
ImmutableMap.copyOf(mapAsJavaMap(spark.conf().getAll())),
spark.sparkContext().hadoopConfiguration(),
spark.sparkContext().defaultParallelism(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.cloud.bigquery.TableDefinition.Type.{MATERIALIZED_VIEW, TABLE,
import com.google.cloud.bigquery.connector.common.BigQueryUtil
import com.google.cloud.bigquery.{BigQuery, TableDefinition}
import com.google.cloud.spark.bigquery.direct.DirectBigQueryRelation
import com.google.common.collect.ImmutableMap
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -129,8 +130,8 @@ class BigQueryRelationProvider(
def createSparkBigQueryConfig(sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType] = None): SparkBigQueryConfig = {
SparkBigQueryConfig.fromV1(parameters.asJava,
sqlContext.getAllConfs.asJava,
SparkBigQueryConfig.from(parameters.asJava,
ImmutableMap.copyOf(sqlContext.getAllConfs.asJava),
sqlContext.sparkContext.hadoopConfiguration,
sqlContext.sparkContext.defaultParallelism,
sqlContext.sparkSession.sessionState.conf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void testDefaults() {
DataSourceOptions options = new DataSourceOptions(defaultOptions);
SparkBigQueryConfig config =
SparkBigQueryConfig.from(
options,
options.asMap(),
ImmutableMap.of(),
hadoopConfiguration,
DEFAULT_PARALLELISM,
Expand Down Expand Up @@ -102,7 +102,7 @@ public void testConfigFromOptions() {
.build());
SparkBigQueryConfig config =
SparkBigQueryConfig.from(
options,
options.asMap(),
ImmutableMap.of(),
hadoopConfiguration,
DEFAULT_PARALLELISM,
Expand Down
Loading

0 comments on commit d467538

Please sign in to comment.