Skip to content

Commit

Permalink
spark, bigquery: support query option on table read (#2556)
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Apr 3, 2024
1 parent 0a35048 commit 9cb16d0
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 82 deletions.
2 changes: 1 addition & 1 deletion integration/spark/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ext {
activeRuntimeElementsConfiguration = "scala" + scala.replace(".", "") + "RuntimeElements"

assertjVersion = '3.25.1'
bigqueryVersion = '0.29.0'
bigqueryVersion = '0.35.1'
junit5Version = '5.10.1'
mockitoVersion = '4.11.0'
postgresqlVersion = '42.7.1'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"eventType": "COMPLETE",
"job": {
"namespace": "{NAMESPACE}"
},
"inputs": [{
"namespace": "bigquery",
"name": "{PROJECT_ID}.{DATASET_ID}.{SPARK_VERSION}_{SCALA_VERSION}_source_query_test"
}],
"outputs": [{
"namespace": "bigquery",
"name": "{PROJECT_ID}.{DATASET_ID}.{SPARK_VERSION}_{SCALA_VERSION}_target_query_test"
}]
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package com.google.cloud.bigquery.connector.common;

import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.BigQuery;
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.QueryJobConfiguration;
import com.google.cloud.spark.bigquery.repackaged.com.google.inject.Binder;
import com.google.cloud.spark.bigquery.repackaged.com.google.inject.Module;
import com.google.cloud.spark.bigquery.repackaged.com.google.inject.Provides;
Expand All @@ -32,6 +33,12 @@ public BigQueryCredentialsSupplier provideBigQueryCredentialsSupplier(BigQueryCo
Optional.empty(),
Optional.empty(),
Optional.empty(),
"",
Collections.emptySet(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
}
Expand All @@ -44,6 +51,8 @@ public BigQueryClient provideBigQueryClient() {
Optional.of("materializationProject"),
Optional.of("materializationDataset"),
null,
Collections.emptyMap());
Collections.emptyMap(),
QueryJobConfiguration.Priority.BATCH,
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package io.openlineage.spark.agent;

import static io.openlineage.spark.agent.MockServerUtils.getEventsEmitted;
import static io.openlineage.spark.agent.MockServerUtils.verifyEvents;
import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -58,8 +59,16 @@ public class GoogleCloudIntegrationTest {
private static final String LOCAL_IP = "127.0.0.1";
private static final String SPARK_3 = "(3.*)";
private static final String SPARK_3_3 = "(3\\.[3-9].*)";
private static final String SPARK_VERSION = "spark.version";
private static final String SPARK_VERSION_PROPERTY = "spark.version";
private static final String CREDENTIALS_FILE = "build/gcloud/gcloud-service-key.json";

private static final String DATASET_ID = "airflow_integration";
private static final String SPARK_VERSION =
String.format("spark_%s", SparkContainerProperties.SPARK_VERSION).replace(".", "_");
private static final String SCALA_VERSION =
String.format("scala_%s", SparkContainerProperties.SCALA_BINARY_VERSION).replace(".", "_");
private static final String VERSION_NAME = String.format("%s_%s", SPARK_VERSION, SCALA_VERSION);

private static SparkSession spark;
private static ClientAndServer mockServer;

Expand Down Expand Up @@ -114,30 +123,16 @@ public void beforeEach() {
}

@Test
@EnabledIfSystemProperty(named = SPARK_VERSION, matches = SPARK_3_3) // Spark version >= 3.*
@EnabledIfSystemProperty(
named = SPARK_VERSION_PROPERTY,
matches = SPARK_3_3) // Spark version >= 3.*
void testReadAndWriteFromBigquery() {
String DATASET_ID = "airflow_integration";
String sparkVersion =
String.format("spark_%s", SparkContainerProperties.SPARK_VERSION).replace(".", "_");
String scalaVersion =
String.format("scala_%s", SparkContainerProperties.SCALA_BINARY_VERSION).replace(".", "_");
String versionName = String.format("%s_%s", sparkVersion, scalaVersion);
String source_table = String.format("%s.%s.%s_source", PROJECT_ID, DATASET_ID, versionName);
String target_table = String.format("%s.%s.%s_target", PROJECT_ID, DATASET_ID, versionName);
String source_table = String.format("%s.%s.%s_source", PROJECT_ID, DATASET_ID, VERSION_NAME);
String target_table = String.format("%s.%s.%s_target", PROJECT_ID, DATASET_ID, VERSION_NAME);
log.info("Source Table: {}", source_table);
log.info("Target Table: {}", target_table);

Dataset<Row> dataset =
spark
.createDataFrame(
ImmutableList.of(RowFactory.create(1L, 2L), RowFactory.create(3L, 4L)),
new StructType(
new StructField[] {
new StructField("a", LongType$.MODULE$, false, Metadata.empty()),
new StructField("b", LongType$.MODULE$, false, Metadata.empty())
}))
.repartition(1);

Dataset<Row> dataset = getTestDataset();
dataset.write().format("bigquery").option("table", source_table).mode("overwrite").save();

Dataset<Row> first = spark.read().format("bigquery").option("table", source_table).load();
Expand All @@ -149,8 +144,8 @@ void testReadAndWriteFromBigquery() {
replacements.put("{PROJECT_ID}", PROJECT_ID);
replacements.put("{DATASET_ID}", DATASET_ID);
replacements.put("{BUCKET_NAME}", BUCKET_NAME);
replacements.put("{SPARK_VERSION}", sparkVersion);
replacements.put("{SCALA_VERSION}", scalaVersion);
replacements.put("{SPARK_VERSION}", SPARK_VERSION);
replacements.put("{SCALA_VERSION}", SCALA_VERSION);

if (log.isDebugEnabled()) {
logRunEvents();
Expand All @@ -165,6 +160,51 @@ void testReadAndWriteFromBigquery() {
"pysparkBigquerySaveEnd.json");
}

@Test
@EnabledIfSystemProperty(
named = SPARK_VERSION_PROPERTY,
matches = SPARK_3_3) // Spark version == 3.*
void testReadAndWriteFromBigqueryUsingQuery() {
String source_table =
String.format("%s.%s.%s_source_query_test", PROJECT_ID, DATASET_ID, VERSION_NAME);
String target_table =
String.format("%s.%s.%s_target_query_test", PROJECT_ID, DATASET_ID, VERSION_NAME);
String source_query = String.format("SELECT * FROM %s", source_table);
log.info("Source Query: {}", source_query);
log.info("Target Table: {}", target_table);

Dataset<Row> dataset = getTestDataset();
dataset.write().format("bigquery").option("table", source_table).mode("overwrite").save();

Dataset<Row> first =
spark
.read()
.format("bigquery")
.option("viewMaterializationProject", PROJECT_ID)
.option("viewMaterializationDataset", DATASET_ID)
.option("viewsEnabled", "true")
.option("query", source_query)
.load();

first.write().format("bigquery").option("table", target_table).mode("overwrite").save();

HashMap<String, String> replacements = new HashMap<>();
replacements.put("{NAMESPACE}", NAMESPACE);
replacements.put("{PROJECT_ID}", PROJECT_ID);
replacements.put("{DATASET_ID}", DATASET_ID);
replacements.put("{BUCKET_NAME}", BUCKET_NAME);
replacements.put("{SPARK_VERSION}", SPARK_VERSION);
replacements.put("{SCALA_VERSION}", SCALA_VERSION);

if (log.isDebugEnabled()) {
logRunEvents();
}

List<RunEvent> events = getEventsEmitted(mockServer);

verifyEvents(mockServer, replacements, "pysparkBigqueryQueryEnd.json");
}

private static void logRunEvents() {
List<RunEvent> eventsEmitted = MockServerUtils.getEventsEmitted(mockServer);
ObjectMapper om = new ObjectMapper().findAndRegisterModules();
Expand All @@ -181,7 +221,9 @@ private static void logRunEvents() {
}

@Test
@EnabledIfSystemProperty(named = SPARK_VERSION, matches = SPARK_3) // Spark version >= 3.*
@EnabledIfSystemProperty(
named = SPARK_VERSION_PROPERTY,
matches = SPARK_3) // Spark version >= 3.*
void testRddWriteToBucket() throws IOException {
String sparkVersion = String.format("spark-%s", SparkContainerProperties.SPARK_VERSION);
String scalaVersion = String.format("scala-%s", SparkContainerProperties.SCALA_BINARY_VERSION);
Expand Down Expand Up @@ -233,4 +275,16 @@ void testRddWriteToBucket() throws IOException {
.hasFieldOrPropertyWithValue(
"namespace", BUCKET_URI.getScheme() + "://" + BUCKET_URI.getHost());
}

private static Dataset<Row> getTestDataset() {
return spark
.createDataFrame(
ImmutableList.of(RowFactory.create(1L, 2L), RowFactory.create(3L, 4L)),
new StructType(
new StructField[] {
new StructField("a", LongType$.MODULE$, false, Metadata.empty()),
new StructField("b", LongType$.MODULE$, false, Metadata.empty())
}))
.repartition(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static void verifyEvents(
Path eventFolder = Paths.get("integrations/container/");

await()
.atMost(Duration.ofSeconds(20))
.atMost(Duration.ofSeconds(30))
.untilAsserted(
() ->
mockServerClient.verify(
Expand Down
2 changes: 1 addition & 1 deletion integration/spark/shared/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ idea {
ext {
assertjVersion = "3.25.1"
awaitilityVersion = "4.2.0"
bigqueryVersion = "0.29.0"
bigqueryVersion = "0.35.1"
databricksVersion = "0.1.4"
junit5Version = "5.10.1"
kafkaClientsVersion = "3.6.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import com.google.cloud.spark.bigquery.BigQueryRelation;
import com.google.cloud.spark.bigquery.BigQueryRelationProvider;
import com.google.cloud.spark.bigquery.SparkBigQueryConfig;
import com.google.cloud.spark.bigquery.direct.DirectBigQueryRelation;
import io.openlineage.client.OpenLineage;
import io.openlineage.spark.agent.util.SqlUtils;
import io.openlineage.spark.api.DatasetFactory;
import io.openlineage.spark.api.OpenLineageContext;
import io.openlineage.spark.api.QueryPlanVisitor;
Expand All @@ -18,6 +21,7 @@
import java.util.Optional;
import java.util.function.Supplier;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.reflect.FieldUtils;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
Expand Down Expand Up @@ -63,15 +67,33 @@ private Optional<Supplier<BigQueryRelation>> bigQuerySupplier(LogicalPlan plan)
public List<OpenLineage.InputDataset> apply(LogicalPlan plan) {
return bigQuerySupplier(plan)
.map(s -> s.get())
.filter(relation -> getBigQueryTableName(relation).isPresent())
.map(
relation ->
Collections.singletonList(
factory.getDataset(
getBigQueryTableName(relation).get(),
BIGQUERY_NAMESPACE,
relation.schema())))
.orElse(null);
relation -> {
if (relation instanceof DirectBigQueryRelation) {
List<OpenLineage.InputDataset> datasets =
tryGetFromQuery((DirectBigQueryRelation) relation);
if (!datasets.isEmpty()) {
return datasets;
}
}
return Collections.singletonList(
factory.getDataset(
getBigQueryTableName(relation).get(), BIGQUERY_NAMESPACE, relation.schema()));
})
.orElse(Collections.emptyList());
}

private List<OpenLineage.InputDataset> tryGetFromQuery(DirectBigQueryRelation relation) {
try {
SparkBigQueryConfig config =
(SparkBigQueryConfig) FieldUtils.readField(relation, "options", true);
if (config.getQuery().isPresent()) {
return SqlUtils.getDatasets(factory, config.getQuery().get(), "bigquery", "bigquery");
}
} catch (IllegalAccessException | IllegalArgumentException | NullPointerException e) {
log.error("Could not invoke method", e);
}
return Collections.emptyList();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,15 @@
package io.openlineage.spark.agent.lifecycle.plan.handlers;

import io.openlineage.client.OpenLineage;
import io.openlineage.client.utils.DatasetIdentifier;
import io.openlineage.client.utils.JdbcUtils;
import io.openlineage.spark.agent.util.JdbcSparkUtils;
import io.openlineage.spark.api.DatasetFactory;
import io.openlineage.sql.ColumnMeta;
import io.openlineage.sql.DbTableMeta;
import io.openlineage.sql.SqlMeta;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

@Slf4j
public class JdbcRelationHandler<D extends OpenLineage.Dataset> {
Expand Down Expand Up @@ -50,43 +43,6 @@ public List<D> getDatasets(JDBCRelation relation, String url) {
if (!sqlMeta.isPresent()) {
return Collections.emptyList();
}
if (sqlMeta.get().columnLineage().isEmpty()) {
DatasetIdentifier di =
JdbcUtils.getDatasetIdentifierFromJdbcUrl(
url, sqlMeta.get().inTables().get(0).qualifiedName());
return Collections.singletonList(
datasetFactory.getDataset(di.getName(), di.getNamespace(), relation.schema()));
}
return sqlMeta.get().inTables().stream()
.map(
dbtm -> {
DatasetIdentifier di =
JdbcUtils.getDatasetIdentifierFromJdbcUrl(url, dbtm.qualifiedName());
return datasetFactory.getDataset(
di.getName(),
di.getNamespace(),
generateJDBCSchema(dbtm, relation.schema(), sqlMeta.get()));
})
.collect(Collectors.toList());
}

private static StructType generateJDBCSchema(
DbTableMeta origin, StructType schema, SqlMeta sqlMeta) {
StructType originSchema = new StructType();
for (StructField f : schema.fields()) {
List<ColumnMeta> fields =
sqlMeta.columnLineage().stream()
.filter(cl -> cl.descendant().name().equals(f.name()))
.flatMap(
cl ->
cl.lineage().stream()
.filter(
cm -> cm.origin().isPresent() && cm.origin().get().equals(origin)))
.collect(Collectors.toList());
for (ColumnMeta cm : fields) {
originSchema = originSchema.add(cm.name(), f.dataType());
}
}
return originSchema;
return JdbcSparkUtils.getDatasets(datasetFactory, sqlMeta.get(), relation.schema(), url);
}
}

0 comments on commit 9cb16d0

Please sign in to comment.