From 483d3641befba7ea654bf92267478776f845199f Mon Sep 17 00:00:00 2001 From: Pawel Leszczynski Date: Fri, 13 Oct 2023 08:29:27 +0200 Subject: [PATCH] [SPARK] write scala integration test Signed-off-by: Pawel Leszczynski --- CHANGELOG.md | 4 + .../spark/app/integrations/sparkrdd/1.json | 6 +- .../spark/app/integrations/sparkrdd/2.json | 6 +- .../agent/lifecycle/RddExecutionContext.java | 25 ++- .../spark/agent/SparkContainerUtils.java | 4 +- .../spark/agent/SparkScalaContainerTest.java | 173 ++++++++++++++++++ .../spark/agent/lifecycle/LibraryTest.java | 6 +- .../lifecycle/SparkReadWriteIntegTest.java | 14 +- .../app/src/test/resources/log4j.properties | 3 +- .../spark_scala_scripts/rdd_union.scala | 33 ++++ .../lifecycle/plan/ExternalRDDVisitor.java | 5 + .../spark/agent/util/PlanUtils.java | 37 +--- .../spark/agent/util/RddPathUtils.java | 157 ++++++++++++++++ .../spark/agent/util/RddPathUtilsTest.java | 109 +++++++++++ .../plan/column/InputFieldsCollector.java | 5 +- 15 files changed, 531 insertions(+), 56 deletions(-) create mode 100644 integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkScalaContainerTest.java create mode 100644 integration/spark/app/src/test/resources/spark_scala_scripts/rdd_union.scala create mode 100644 integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/RddPathUtils.java create mode 100644 integration/spark/shared/src/test/java/io/openlineage/spark/agent/util/RddPathUtilsTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b5d65d64e4..4ff4398699 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased](https://github.com/OpenLineage/OpenLineage/compare/1.4.1...HEAD) +### Added +* **Spark: support `rdd` and `toDF` operations available in Spark Scala API.** [`#2188`](https://github.com/OpenLineage/OpenLineage/pull/2188) [@pawel-big-lebowski](https://github.com/pawel-big-lebowski) + *This PR includes the first Scala integration test, fixes `ExternalRddVisitor` and adds support for extracting inputs from `MapPartitionsRDD` and `ParallelCollectionRDD` plan nodes.* + ### Fixed * **Spark: unify dataset naming for RDD jobs and Spark SQL.** [`2181`](https://github.com/OpenLineage/OpenLineage/pull/2181) [@pawel-big-lebowski](https://github.com/pawel-big-lebowski) *Use the same mechanism for RDD jobs to extract dataset identifier as used for Spark SQL.* diff --git a/integration/spark/app/integrations/sparkrdd/1.json b/integration/spark/app/integrations/sparkrdd/1.json index 72ee0d7466..331fc2fc51 100644 --- a/integration/spark/app/integrations/sparkrdd/1.json +++ b/integration/spark/app/integrations/sparkrdd/1.json @@ -17,11 +17,13 @@ }, "job" : { "namespace" : "ns_name", - "name" : "test_rdd.shuffled_map_partitions_hadoop" + "name" : "test_rdd.map_partitions_shuffled_map_partitions_hadoop" }, "inputs" : [ { "namespace" : "gs.bucket", "name" : "gs://bucket/data.txt" } ], - "outputs" : [ ] + "outputs" : [ { + "namespace" : "file" + }] } \ No newline at end of file diff --git a/integration/spark/app/integrations/sparkrdd/2.json b/integration/spark/app/integrations/sparkrdd/2.json index 4e42272178..186cec18a2 100644 --- a/integration/spark/app/integrations/sparkrdd/2.json +++ b/integration/spark/app/integrations/sparkrdd/2.json @@ -18,11 +18,13 @@ }, "job" : { "namespace" : "ns_name", - "name" : "test_rdd.shuffled_map_partitions_hadoop" + "name" : "test_rdd.map_partitions_shuffled_map_partitions_hadoop" }, "inputs" : [ { "namespace" : "gs.bucket", "name" : "gs://bucket/data.txt" } ], - "outputs" : [ ] + "outputs" : [ { + "namespace" : "file" + }] } \ No newline at end of file diff --git a/integration/spark/app/src/main/java/io/openlineage/spark/agent/lifecycle/RddExecutionContext.java b/integration/spark/app/src/main/java/io/openlineage/spark/agent/lifecycle/RddExecutionContext.java index aeac9dd6d4..5e2ab03505 100644 --- a/integration/spark/app/src/main/java/io/openlineage/spark/agent/lifecycle/RddExecutionContext.java +++ b/integration/spark/app/src/main/java/io/openlineage/spark/agent/lifecycle/RddExecutionContext.java @@ -99,9 +99,11 @@ public void end(SparkListenerStageCompleted stageCompleted) {} @Override @SuppressWarnings("PMD") // f.setAccessible(true); public void setActiveJob(ActiveJob activeJob) { + log.debug("setActiveJob within RddExecutionContext {}", activeJob); RDD finalRDD = activeJob.finalStage().rdd(); this.jobSuffix = nameRDD(finalRDD); Set> rdds = Rdds.flattenRDDs(finalRDD); + log.debug("flattenRDDs {}", rdds); this.inputs = findInputs(rdds); Configuration jc = new JobConf(); if (activeJob.finalStage() instanceof ResultStage) { @@ -197,17 +199,23 @@ static String nameRDD(RDD rdd) { @Override public void start(SparkListenerSQLExecutionStart sqlStart) { // do nothing + log.debug("start SparkListenerSQLExecutionStart {}", sqlStart); } @Override public void end(SparkListenerSQLExecutionEnd sqlEnd) { // do nothing + log.debug("start SparkListenerSQLExecutionEnd {}", sqlEnd); } @Override public void start(SparkListenerJobStart jobStart) { - if (inputs.isEmpty() && outputs.isEmpty()) { - log.info("RDDs are empty: skipping sending OpenLineage event"); + log.debug("start SparkListenerJobStart {}", jobStart); + if (outputs.isEmpty()) { + // Oftentimes SparkListener is triggered for actions which do not contain any meaningful + // lineage data and are useless in the context of lineage graph. We assume this occurs + // for RDD operations which have no output dataset + log.info("Output RDDs are empty: skipping sending OpenLineage event"); return; } OpenLineage ol = new OpenLineage(Versions.OPEN_LINEAGE_PRODUCER_URI); @@ -227,8 +235,12 @@ public void start(SparkListenerJobStart jobStart) { @Override public void end(SparkListenerJobEnd jobEnd) { - if (inputs.isEmpty() && outputs.isEmpty() && !(jobEnd.jobResult() instanceof JobFailed)) { - log.info("RDDs are empty: skipping sending OpenLineage event"); + log.debug("end SparkListenerJobEnd {}", jobEnd); + if (outputs.isEmpty() && !(jobEnd.jobResult() instanceof JobFailed)) { + // Oftentimes SparkListener is triggered for actions which do not contain any meaningful + // lineage data and are useless in the context of lineage graph. We assume this occurs + // for RDD operations which have no output dataset + log.info("Output RDDs are empty: skipping sending OpenLineage event"); return; } OpenLineage ol = new OpenLineage(Versions.OPEN_LINEAGE_PRODUCER_URI); @@ -346,11 +358,12 @@ protected List findOutputs(RDD rdd, Configuration config) { if (outputPath != null) { return Collections.singletonList(outputPath.toUri()); } + log.debug("Output path is null"); return Collections.emptyList(); } protected List findInputs(Set> rdds) { - log.debug("findInputs within RddExecutionContext"); + log.debug("find Inputs within RddExecutionContext {}", rdds); return PlanUtils.findRDDPaths(rdds.stream().collect(Collectors.toList())).stream() .map(path -> path.toUri()) .collect(Collectors.toList()); @@ -373,10 +386,12 @@ protected static Path getOutputPath(RDD rdd, Configuration config) { } else { jc = new JobConf(config); } + log.debug("JobConf {}", jc); path = org.apache.hadoop.mapred.FileOutputFormat.getOutputPath(jc); if (path == null) { try { // old fashioned mapreduce api + log.debug("Path is null, trying to use old fashioned mapreduce api"); path = org.apache.hadoop.mapreduce.lib.output.FileOutputFormat.getOutputPath(new Job(jc)); } catch (IOException exception) { exception.printStackTrace(System.out); diff --git a/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkContainerUtils.java b/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkContainerUtils.java index 64e8cd6839..99e3a54c75 100644 --- a/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkContainerUtils.java +++ b/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkContainerUtils.java @@ -164,7 +164,7 @@ static GenericContainer makePysparkContainerWithDefaultConf( network, waitMessage, mockServerContainer, sparkSubmit.toArray(new String[0])); } - static void addSparkConfig(List command, String value) { + public static void addSparkConfig(List command, String value) { command.add("--conf"); command.add(value); } @@ -201,7 +201,7 @@ static void runPysparkContainerWithDefaultConf( } @SuppressWarnings("PMD") - private static void consumeOutput(org.testcontainers.containers.output.OutputFrame of) { + static void consumeOutput(org.testcontainers.containers.output.OutputFrame of) { try { switch (of.getType()) { case STDOUT: diff --git a/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkScalaContainerTest.java b/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkScalaContainerTest.java new file mode 100644 index 0000000000..da335f4fb5 --- /dev/null +++ b/integration/spark/app/src/test/java/io/openlineage/spark/agent/SparkScalaContainerTest.java @@ -0,0 +1,173 @@ +/* +/* Copyright 2018-2023 contributors to the OpenLineage project +/* SPDX-License-Identifier: Apache-2.0 +*/ + +package io.openlineage.spark.agent; + +import static io.openlineage.spark.agent.SparkContainerUtils.addSparkConfig; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockserver.model.HttpRequest.request; + +import io.openlineage.client.OpenLineage; +import io.openlineage.client.OpenLineage.RunEvent; +import io.openlineage.client.OpenLineageClientUtils; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.mockserver.client.MockServerClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.MockServerContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +/** + * This class runs integration test for a Spark job written in scala. All the other tests + * run python spark scripts instead. Having a Scala job allows us to test `toDF`/`rdd` + * methods which are slightly different for Spark jobs written in Scala. + * + * The integration test relies on bitnami/spark docker image. It requires `spark.version` to + * specify which Spark version should be tested. It also requires `openlineage.spark.jar` + * system property which is set in `build.gradle`. + * + * @See https://hub.docker.com/r/bitnami/spark/ + */ +@Tag("integration-test") +@Testcontainers +@Slf4j +public class SparkScalaContainerTest { + + private static final Network network = Network.newNetwork(); + + @Container + private static final MockServerContainer openLineageClientMockContainer = + SparkContainerUtils.makeMockServerContainer(network); + + private static GenericContainer spark; + private static MockServerClient mockServerClient; + private static final Logger logger = LoggerFactory.getLogger(SparkContainerIntegrationTest.class); + + @BeforeAll + public static void setup() { + mockServerClient = + new MockServerClient( + openLineageClientMockContainer.getHost(), + openLineageClientMockContainer.getServerPort()); + mockServerClient + .when(request("/api/v1/lineage")) + .respond(org.mockserver.model.HttpResponse.response().withStatusCode(201)); + + Awaitility.await().until(openLineageClientMockContainer::isRunning); + } + + @AfterEach + public void cleanupSpark() { + mockServerClient.reset(); + try { + if (spark != null) spark.stop(); + } catch (Exception e) { + logger.error("Unable to shut down pyspark container", e); + } + } + + @AfterAll + public static void tearDown() { + try { + openLineageClientMockContainer.stop(); + } catch (Exception e) { + logger.error("Unable to shut down openlineage client container", e); + } + network.close(); + } + + private GenericContainer createSparkContainer(String script) { + return new GenericContainer<>( + DockerImageName.parse("bitnami/spark:" + System.getProperty("spark.version"))) + .withNetwork(network) + .withNetworkAliases("spark") + .withFileSystemBind("src/test/resources/spark_scala_scripts", "/opt/spark_scala_scripts") + .withFileSystemBind("src/test/resources/log4j.properties", "/opt/log4j.properties") + .withFileSystemBind("build/libs", "/opt/libs") + .withLogConsumer(SparkContainerUtils::consumeOutput) + .waitingFor(Wait.forLogMessage(".*scala> :quit.*", 1)) + .withStartupTimeout(Duration.of(10, ChronoUnit.MINUTES)) + .dependsOn(openLineageClientMockContainer) + .withReuse(true) + .withCommand( + sparkShellCommandForScript("/opt/spark_scala_scripts/" + script) + .toArray(new String[] {})); + } + + private List sparkShellCommandForScript(String script) { + List command = new ArrayList<>(); + addSparkConfig(command, "spark.openlineage.transport.type=http"); + addSparkConfig( + command, + "spark.openlineage.transport.url=http://openlineageclient:1080/api/v1/namespaces/scala-test"); + addSparkConfig(command, "spark.openlineage.debugFacet=enabled"); + addSparkConfig(command, "spark.extraListeners=" + OpenLineageSparkListener.class.getName()); + addSparkConfig(command, "spark.sql.warehouse.dir=/tmp/warehouse"); + addSparkConfig(command, "spark.sql.shuffle.partitions=1"); + addSparkConfig(command, "spark.driver.extraJavaOptions=-Dderby.system.home=/tmp/derby"); + addSparkConfig(command, "spark.sql.warehouse.dir=/tmp/warehouse"); + addSparkConfig(command, "spark.jars.ivy=/tmp/.ivy2/"); + addSparkConfig(command, "spark.openlineage.facets.disabled="); + addSparkConfig( + command, "spark.driver.extraJavaOptions=-Dlog4j.configuration=/opt/log4j.properties"); + + List sparkShell = + new ArrayList(Arrays.asList("./bin/spark-shell", "--master", "local", "-i", script)); + sparkShell.addAll(command); + sparkShell.addAll( + Arrays.asList("--jars", "/opt/libs/" + System.getProperty("openlineage.spark.jar"))); + + log.info("Running spark-shell command: ", String.join(" ", sparkShell)); + + return sparkShell; + } + + @Test + void testScalaUnionRddToParquet() { + spark = createSparkContainer("rdd_union.scala"); + spark.start(); + + await() + .atMost(Duration.ofSeconds(10)) + .pollInterval(Duration.ofMillis(500)) + .untilAsserted( + () -> { + List events = + Arrays.stream( + mockServerClient.retrieveRecordedRequests( + request().withPath("/api/v1/lineage"))) + .map(r -> r.getBodyAsString()) + .map(event -> OpenLineageClientUtils.runEventFromJson(event)) + .collect(Collectors.toList()); + RunEvent lastEvent = events.get(events.size() - 1); + + assertThat(events).isNotEmpty(); + assertThat(lastEvent.getOutputs().get(0)) + .hasFieldOrPropertyWithValue("namespace", "file") + .hasFieldOrPropertyWithValue("name", "/tmp/scala-test/rdd_output"); + + assertThat(lastEvent.getInputs().stream().map(d -> d.getName())) + .contains("/tmp/scala-test/rdd_input1", "/tmp/scala-test/rdd_input2"); + }); + } +} diff --git a/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/LibraryTest.java b/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/LibraryTest.java index 368c9354eb..59a826241f 100644 --- a/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/LibraryTest.java +++ b/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/LibraryTest.java @@ -20,6 +20,7 @@ import io.openlineage.spark.agent.SparkAgentTestExtension; import java.io.IOException; import java.net.URL; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +37,7 @@ import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import scala.Tuple2; @@ -95,7 +97,7 @@ class LibraryTest { // } @Test - void testRdd(SparkSession spark) throws IOException { + void testRdd(@TempDir Path tmpDir, SparkSession spark) throws IOException { when(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT.getJobNamespace()) .thenReturn("ns_name"); when(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT.getParentJobName()) @@ -111,7 +113,7 @@ void testRdd(SparkSession spark) throws IOException { .flatMap(s -> Arrays.asList(s.split(" ")).iterator()) .mapToPair(word -> new Tuple2<>(word, 1)) .reduceByKey(Integer::sum) - .count(); + .saveAsTextFile(tmpDir.toString() + "/output"); sc.stop(); diff --git a/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/SparkReadWriteIntegTest.java b/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/SparkReadWriteIntegTest.java index 2bf0a83f25..6805e7ceda 100644 --- a/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/SparkReadWriteIntegTest.java +++ b/integration/spark/app/src/test/java/io/openlineage/spark/agent/lifecycle/SparkReadWriteIntegTest.java @@ -364,7 +364,7 @@ void testInsertIntoDataSourceDirVisitor(@TempDir Path tempDir, SparkSession spar List inputs = event.getInputs(); assertEquals(1, inputs.size()); assertEquals(FILE, inputs.get(0).getNamespace()); - assertEquals(testFile.toAbsolutePath().getParent().toString(), inputs.get(0).getName()); + assertEquals(testFile.toAbsolutePath().toString(), inputs.get(0).getName()); } @Test @@ -383,9 +383,9 @@ void testWithExternalRdd(@TempDir Path tmpDir, SparkSession spark) ArgumentCaptor lineageEvent = ArgumentCaptor.forClass(OpenLineage.RunEvent.class); - Mockito.verify(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT, times(6)) + Mockito.verify(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT, times(4)) .emit(lineageEvent.capture()); - OpenLineage.RunEvent completeEvent = lineageEvent.getAllValues().get(5); + OpenLineage.RunEvent completeEvent = lineageEvent.getAllValues().get(3); assertThat(completeEvent).hasFieldOrPropertyWithValue(EVENT_TYPE, RunEvent.EventType.COMPLETE); assertThat(completeEvent.getInputs()) .first() @@ -481,9 +481,13 @@ void testCreateDataSourceTableAsSelect(@TempDir Path tmpDir, SparkSession spark) ArgumentCaptor lineageEvent = ArgumentCaptor.forClass(OpenLineage.RunEvent.class); - Mockito.verify(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT, atLeast(6)) + Mockito.verify(SparkAgentTestExtension.OPEN_LINEAGE_SPARK_CONTEXT, atLeast(4)) .emit(lineageEvent.capture()); - OpenLineage.RunEvent event = lineageEvent.getAllValues().get(5); + OpenLineage.RunEvent event = + lineageEvent.getAllValues().stream() + .filter(ev -> ev.getInputs() != null && !ev.getInputs().isEmpty()) + .findFirst() + .get(); assertThat(lineageEvent.getAllValues().get(lineageEvent.getAllValues().size() - 1)) .hasFieldOrPropertyWithValue(EVENT_TYPE, RunEvent.EventType.COMPLETE); diff --git a/integration/spark/app/src/test/resources/log4j.properties b/integration/spark/app/src/test/resources/log4j.properties index c9c602b0b2..46e54ab22f 100644 --- a/integration/spark/app/src/test/resources/log4j.properties +++ b/integration/spark/app/src/test/resources/log4j.properties @@ -9,4 +9,5 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.io.openlineage=DEBUG log4j.logger.io.openlineage.spark.shaded=WARN log4j.logger.org.apache.spark.storage=WARN -log4j.logger.org.apache.spark.scheduler=WARN \ No newline at end of file +log4j.logger.org.apache.spark.scheduler=WARN + diff --git a/integration/spark/app/src/test/resources/spark_scala_scripts/rdd_union.scala b/integration/spark/app/src/test/resources/spark_scala_scripts/rdd_union.scala new file mode 100644 index 0000000000..387d73e3eb --- /dev/null +++ b/integration/spark/app/src/test/resources/spark_scala_scripts/rdd_union.scala @@ -0,0 +1,33 @@ +{ + import spark.implicits._ + import org.apache.spark.sql.SaveMode + + sc + .parallelize(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .map(a => a.toString()) + .toDF() + .write + .mode(SaveMode.Overwrite) + .parquet("/tmp/scala-test/rdd_input1") + + sc + .parallelize(Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .map(a => a.toString()) + .toDF() + .write + .mode(SaveMode.Overwrite) + .parquet("/tmp/scala-test/rdd_input2") + + val rdd1 = spark.read.parquet("/tmp/scala-test/rdd_input1").rdd + val rdd2 = spark.read.parquet("/tmp/scala-test/rdd_input2").rdd + + rdd1 + .union(rdd2) + .map(i => OLC(i.toString)) + .toDF() + .write + .mode(SaveMode.Overwrite) + .parquet("/tmp/scala-test/rdd_output") +} + +case class OLC(payload: String) \ No newline at end of file diff --git a/integration/spark/shared/src/main/java/io/openlineage/spark/agent/lifecycle/plan/ExternalRDDVisitor.java b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/lifecycle/plan/ExternalRDDVisitor.java index d0a92a72cb..0dc3a56805 100644 --- a/integration/spark/shared/src/main/java/io/openlineage/spark/agent/lifecycle/plan/ExternalRDDVisitor.java +++ b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/lifecycle/plan/ExternalRDDVisitor.java @@ -21,6 +21,11 @@ public ExternalRDDVisitor(OpenLineageContext context) { super(context, DatasetFactory.input(context)); } + @Override + public boolean isDefinedAt(LogicalPlan x) { + return x instanceof ExternalRDD; + } + @Override public List apply(LogicalPlan x) { ExternalRDD externalRDD = (ExternalRDD) x; diff --git a/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/PlanUtils.java b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/PlanUtils.java index 43048cb06b..0223b088cd 100644 --- a/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/PlanUtils.java +++ b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/PlanUtils.java @@ -10,7 +10,6 @@ import java.io.IOException; import java.net.URI; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -18,16 +17,11 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; -import java.util.stream.Stream; import lombok.extern.slf4j.Slf4j; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.mapred.FileInputFormat; -import org.apache.spark.package$; -import org.apache.spark.rdd.HadoopRDD; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.catalyst.expressions.Attribute; -import org.apache.spark.sql.execution.datasources.FileScanRDD; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.PartialFunction; @@ -227,36 +221,7 @@ public static Path getDirectoryPath(Path p, Configuration hadoopConf) { */ public static List findRDDPaths(List> fileRdds) { return fileRdds.stream() - .flatMap( - rdd -> { - if (rdd instanceof HadoopRDD) { - HadoopRDD hadoopRDD = (HadoopRDD) rdd; - Path[] inputPaths = FileInputFormat.getInputPaths(hadoopRDD.getJobConf()); - Configuration hadoopConf = hadoopRDD.getConf(); - return Arrays.stream(inputPaths) - .map(p -> PlanUtils.getDirectoryPath(p, hadoopConf)); - } else if (rdd instanceof FileScanRDD) { - FileScanRDD fileScanRDD = (FileScanRDD) rdd; - return ScalaConversionUtils.fromSeq(fileScanRDD.filePartitions()).stream() - .flatMap(fp -> Arrays.stream(fp.files())) - .map( - f -> { - if (package$.MODULE$.SPARK_VERSION().compareTo("3.4") > 0) { - // filePath returns SparkPath for Spark 3.4 - return ReflectionUtils.tryExecuteMethod(f, "filePath") - .map(o -> ReflectionUtils.tryExecuteMethod(o, "toPath")) - .map(o -> (Path) o.get()) - .get() - .getParent(); - } else { - return new Path(f.filePath()).getParent(); - } - }); - } else { - log.warn("Unknown RDD class {}", rdd.getClass().getCanonicalName()); - return Stream.empty(); - } - }) + .flatMap(RddPathUtils::findRDDPaths) .distinct() .collect(Collectors.toList()); } diff --git a/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/RddPathUtils.java b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/RddPathUtils.java new file mode 100644 index 0000000000..5ec3d5dbf7 --- /dev/null +++ b/integration/spark/shared/src/main/java/io/openlineage/spark/agent/util/RddPathUtils.java @@ -0,0 +1,157 @@ +/* +/* Copyright 2018-2023 contributors to the OpenLineage project +/* SPDX-License-Identifier: Apache-2.0 +*/ + +package io.openlineage.spark.agent.util; + +import java.util.Arrays; +import java.util.Objects; +import java.util.stream.Stream; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.reflect.FieldUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapred.FileInputFormat; +import org.apache.spark.package$; +import org.apache.spark.rdd.HadoopRDD; +import org.apache.spark.rdd.MapPartitionsRDD; +import org.apache.spark.rdd.ParallelCollectionRDD; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.execution.datasources.FileScanRDD; +import scala.Tuple2; +import scala.collection.immutable.Seq; + +/** Utility class to extract paths from RDD nodes. */ +@Slf4j +public class RddPathUtils { + + public static Stream findRDDPaths(RDD rdd) { + return Stream.of( + new HadoopRDDExtractor(), + new FileScanRDDExtractor(), + new MapPartitionsRDDExtractor(), + new ParallelCollectionRDDExtractor()) + .filter(e -> e.isDefinedAt(rdd)) + .findFirst() + .orElse(new UnknownRDDExtractor()) + .extract(rdd) + .filter(p -> p != null); + } + + static class UnknownRDDExtractor implements RddPathExtractor { + @Override + public boolean isDefinedAt(Object rdd) { + return true; + } + + @Override + public Stream extract(RDD rdd) { + log.warn("Unknown RDD class {}", rdd); + return Stream.empty(); + } + } + + static class HadoopRDDExtractor implements RddPathExtractor { + @Override + public boolean isDefinedAt(Object rdd) { + return rdd instanceof HadoopRDD; + } + + @Override + public Stream extract(HadoopRDD rdd) { + org.apache.hadoop.fs.Path[] inputPaths = FileInputFormat.getInputPaths(rdd.getJobConf()); + Configuration hadoopConf = rdd.getConf(); + return Arrays.stream(inputPaths).map(p -> PlanUtils.getDirectoryPath(p, hadoopConf)); + } + } + + static class MapPartitionsRDDExtractor implements RddPathExtractor { + + @Override + public boolean isDefinedAt(Object rdd) { + return rdd instanceof MapPartitionsRDD; + } + + @Override + public Stream extract(MapPartitionsRDD rdd) { + return findRDDPaths(rdd.prev()); + } + } + + static class FileScanRDDExtractor implements RddPathExtractor { + @Override + public boolean isDefinedAt(Object rdd) { + return rdd instanceof FileScanRDD; + } + + @Override + public Stream extract(FileScanRDD rdd) { + return ScalaConversionUtils.fromSeq(rdd.filePartitions()).stream() + .flatMap(fp -> Arrays.stream(fp.files())) + .map( + f -> { + if (package$.MODULE$.SPARK_VERSION().compareTo("3.4") > 0) { + // filePath returns SparkPath for Spark 3.4 + return ReflectionUtils.tryExecuteMethod(f, "filePath") + .map(o -> ReflectionUtils.tryExecuteMethod(o, "toPath")) + .map(o -> (Path) o.get()) + .get() + .getParent(); + } else { + return parentOf(f.filePath()); + } + }); + } + } + + static class ParallelCollectionRDDExtractor implements RddPathExtractor { + @Override + public boolean isDefinedAt(Object rdd) { + return rdd instanceof ParallelCollectionRDD; + } + + @Override + public Stream extract(ParallelCollectionRDD rdd) { + try { + Object data = FieldUtils.readField(rdd, "data", true); + log.debug("ParallelCollectionRDD data: {} {}", data); + if (data instanceof Seq) { + return ScalaConversionUtils.fromSeq((Seq) data).stream() + .map( + el -> { + Path path = null; + if (el instanceof Tuple2) { + // we're able to extract path + path = parentOf(((Tuple2) el)._1.toString()); + log.debug("Found input {}", path); + } else { + log.warn("unable to extract Path from {}", el.getClass().getCanonicalName()); + } + return path; + }) + .filter(Objects::nonNull); + } else { + log.warn("Cannot extract path from ParallelCollectionRDD {}", data); + } + } catch (IllegalAccessException | IllegalArgumentException e) { + log.warn("Cannot read data field from ParallelCollectionRDD {}", rdd); + } + return Stream.empty(); + } + } + + private static Path parentOf(String path) { + try { + return new Path(path).getParent(); + } catch (Exception e) { + return null; + } + } + + interface RddPathExtractor { + boolean isDefinedAt(Object rdd); + + Stream extract(T rdd); + } +} diff --git a/integration/spark/shared/src/test/java/io/openlineage/spark/agent/util/RddPathUtilsTest.java b/integration/spark/shared/src/test/java/io/openlineage/spark/agent/util/RddPathUtilsTest.java new file mode 100644 index 0000000000..9ec56f5a4f --- /dev/null +++ b/integration/spark/shared/src/test/java/io/openlineage/spark/agent/util/RddPathUtilsTest.java @@ -0,0 +1,109 @@ +/* +/* Copyright 2018-2023 contributors to the OpenLineage project +/* SPDX-License-Identifier: Apache-2.0 +*/ + +package io.openlineage.spark.agent.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.spark.rdd.MapPartitionsRDD; +import org.apache.spark.rdd.ParallelCollectionRDD; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.execution.datasources.FilePartition; +import org.apache.spark.sql.execution.datasources.FileScanRDD; +import org.apache.spark.sql.execution.datasources.PartitionedFile; +import org.junit.jupiter.api.Test; +import org.testcontainers.shaded.org.apache.commons.lang3.reflect.FieldUtils; +import scala.Tuple2; +import scala.collection.JavaConversions; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class RddPathUtilsTest { + + @Test + void testFindRDDPathsForMapPartitionsRDD() { + FilePartition filePartition = mock(FilePartition.class); + PartitionedFile partitionedFile = mock(PartitionedFile.class); + FileScanRDD fileScanRDD = mock(FileScanRDD.class); + MapPartitionsRDD mapPartitions = mock(MapPartitionsRDD.class); + + when(mapPartitions.prev()).thenReturn(fileScanRDD); + when(fileScanRDD.filePartitions()) + .thenReturn(JavaConversions.asScalaBuffer(Collections.singletonList(filePartition))); + when(filePartition.files()).thenReturn(new PartitionedFile[] {partitionedFile}); + when(partitionedFile.filePath()).thenReturn("/some-path/sub-path"); + + List rddPaths = PlanUtils.findRDDPaths(Collections.singletonList(mapPartitions)); + + assertThat(rddPaths).hasSize(1); + assertThat(rddPaths.get(0).toString()).isEqualTo("/some-path"); + } + + @Test + void testFindRDDPathsForParallelCollectionRDD() throws IllegalAccessException { + ParallelCollectionRDD parallelCollectionRDD = mock(ParallelCollectionRDD.class); + Seq> data = + JavaConverters.asScalaIteratorConverter( + Arrays.asList( + new Tuple2<>("/some-path1/data-file-325342.snappy.parquet", 345), + new Tuple2<>("/some-path2/data-file-654342.snappy.parquet", 345)) + .iterator()) + .asScala() + .toSeq(); + + FieldUtils.writeDeclaredField(parallelCollectionRDD, "data", data, true); + + List rddPaths = PlanUtils.findRDDPaths(Collections.singletonList(parallelCollectionRDD)); + + assertThat(rddPaths).hasSize(2); + assertThat(rddPaths.get(0).toString()).isEqualTo("/some-path1"); + assertThat(rddPaths.get(1).toString()).isEqualTo("/some-path2"); + } + + @Test + void testFindRDDPathsForParallelCollectionRDDWhenNoDataField() throws IllegalAccessException { + ParallelCollectionRDD parallelCollectionRDD = mock(ParallelCollectionRDD.class); + + FieldUtils.writeDeclaredField(parallelCollectionRDD, "data", null, true); + assertThat(PlanUtils.findRDDPaths(Collections.singletonList(parallelCollectionRDD))).hasSize(0); + } + + @Test + void testFindRDDPathsForParallelCollectionRDDWhenDataFieldNotSeqOfTuples() + throws IllegalAccessException { + ParallelCollectionRDD parallelCollectionRDD = mock(ParallelCollectionRDD.class); + Seq data = + JavaConverters.asScalaIteratorConverter(Arrays.asList(333).iterator()).asScala().toSeq(); + + FieldUtils.writeDeclaredField(parallelCollectionRDD, "data", data, true); + assertThat(PlanUtils.findRDDPaths(Collections.singletonList(parallelCollectionRDD))).hasSize(0); + } + + @Test + void testFindRDDPathsEmptyStringPath() { + FilePartition filePartition = mock(FilePartition.class); + FileScanRDD fileScanRDD = mock(FileScanRDD.class); + PartitionedFile partitionedFile = mock(PartitionedFile.class); + + when(filePartition.files()).thenReturn(new PartitionedFile[] {partitionedFile}); + when(partitionedFile.filePath()).thenReturn(""); + when(fileScanRDD.filePartitions()) + .thenReturn(JavaConversions.asScalaBuffer(Collections.singletonList(filePartition))); + + List rddPaths = PlanUtils.findRDDPaths(Collections.singletonList(fileScanRDD)); + + assertThat(rddPaths).hasSize(0); + } + + @Test + void testFindRDDPathsUnknownRdd() { + assertThat(PlanUtils.findRDDPaths(Collections.singletonList(mock(RDD.class)))).isEmpty(); + } +} diff --git a/integration/spark/spark3/src/main/java/io/openlineage/spark3/agent/lifecycle/plan/column/InputFieldsCollector.java b/integration/spark/spark3/src/main/java/io/openlineage/spark3/agent/lifecycle/plan/column/InputFieldsCollector.java index dc8d016f08..12fd623f6d 100644 --- a/integration/spark/spark3/src/main/java/io/openlineage/spark3/agent/lifecycle/plan/column/InputFieldsCollector.java +++ b/integration/spark/spark3/src/main/java/io/openlineage/spark3/agent/lifecycle/plan/column/InputFieldsCollector.java @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation; import org.apache.spark.sql.catalyst.plans.logical.UnaryNode; +import org.apache.spark.sql.execution.ExternalRDD; import org.apache.spark.sql.execution.LogicalRDD; import org.apache.spark.sql.execution.columnar.InMemoryRelation; import org.apache.spark.sql.execution.datasources.HadoopFsRelation; @@ -124,7 +125,9 @@ private static List extractDatasetIdentifier( // implemented in // io.openlineage.spark3.agent.lifecycle.plan.column.ColumnLevelLineageUtils.collectInputsAndExpressionDependencies // requires merging multiple LogicalPlans - } else if (node instanceof OneRowRelation || node instanceof LocalRelation) { + } else if (node instanceof OneRowRelation + || node instanceof LocalRelation + || node instanceof ExternalRDD) { // skip without warning } else if (node instanceof LeafNode) { log.warn("Could not extract dataset identifier from {}", node.getClass().getCanonicalName());