Skip to content

Commit

Permalink
[Spark] fix databricks environment
Browse files Browse the repository at this point in the history
Signed-off-by: Pawel Leszczynski <leszczynski.pawel@gmail.com>
  • Loading branch information
pawel-big-lebowski committed Apr 2, 2024
1 parent 2d1ee64 commit 89dca31
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 95 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
*Fixes the issue flink module generated contains the internal libs that are not published*
* **Spark: fix access to active Spark session ** [`#2535`](https://github.com/OpenLineage/OpenLineage/pull/2535) [@pawel-big-lebowski](https://github.com/pawel-big-lebowski)
*Always catch `IllegalStateException` when accessing `SparkSession`.*
* **Spark: fix databricks environment.** [`#2537`](https://github.com/OpenLineage/OpenLineage/pull/2537) [@pawel-big-lebowski](https://github.com/pawel-big-lebowski)
*Fix `ClassNotFoundError` occurring on databricks runtime, extend integration test to verify `DatabricksEnvironmentFacet`.*

## [1.10.2](https://github.com/OpenLineage/OpenLineage/compare/1.9.1...1.10.2) - 2024-03-15

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static io.openlineage.spark.agent.DatabricksUtils.DBFS_EVENTS_FILE;
import static io.openlineage.spark.agent.DatabricksUtils.init;
import static io.openlineage.spark.agent.DatabricksUtils.platformVersion;
import static io.openlineage.spark.agent.DatabricksUtils.runScript;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -17,6 +18,7 @@
import io.openlineage.client.OpenLineage.OutputDataset;
import io.openlineage.client.OpenLineage.RunEvent;
import io.openlineage.client.OpenLineage.RunEvent.EventType;
import io.openlineage.client.OpenLineage.RunFacet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -77,10 +79,45 @@ public void testCreateTableAsSelect() {
InputDataset inputDataset = lastEvent.getInputs().get(0);

assertThat(outputDataset.getNamespace()).isEqualTo("dbfs");
assertThat(outputDataset.getName()).isEqualTo("/user/hive/warehouse/ctas");
assertThat(outputDataset.getName()).isEqualTo("/user/hive/warehouse/ctas_" + platformVersion());

assertThat(inputDataset.getNamespace()).isEqualTo("dbfs");
assertThat(inputDataset.getName()).isEqualTo("/user/hive/warehouse/temp");
assertThat(inputDataset.getName()).isEqualTo("/user/hive/warehouse/temp_" + platformVersion());

// test DatabricksEnvironmentFacetBuilder handler
RunEvent eventWithDatabricksProperties =
runEvents.stream()
.filter(
r ->
r.getRun()
.getFacets()
.getAdditionalProperties()
.containsKey("environment-properties"))
.findFirst()
.get();

RunFacet environmentFacet =
eventWithDatabricksProperties
.getRun()
.getFacets()
.getAdditionalProperties()
.get("environment-properties");

Map<String, Object> properties =
(Map<String, Object>)
environmentFacet.getAdditionalProperties().get("environment-properties");

assertThat(properties.get("spark.databricks.job.type")).isEqualTo("python");

List<Object> mounts = (List<Object>) properties.get("mountPoints");

assertThat(mounts).isNotEmpty();
Map<String, String> mountInfo = (Map<String, String>) mounts.get(0);

assertThat(mountInfo).containsKeys("mountPoint", "source");

assertThat(mountInfo.get("mountPoint")).startsWith("/databricks");
assertThat(mountInfo.get("source")).startsWith("databricks");
}

@Test
Expand Down Expand Up @@ -117,7 +154,7 @@ public void testNarrowTransformation() {

assertThat(completeEvent).isPresent();
assertThat(completeEvent.get().getOutputs().get(0).getName())
.isEqualTo("/data/path/to/output/narrow_transformation");
.isEqualTo("/data/path/to/output/narrow_transformation_" + platformVersion());
}

@Test
Expand Down Expand Up @@ -148,7 +185,7 @@ public void testWideTransformation() {

assertThat(completeEvent).isPresent();
assertThat(completeEvent.get().getOutputs().get(0).getName())
.isEqualTo("/data/output/wide_transformation/result");
.isEqualTo("/data/output/wide_transformation/result_" + platformVersion());
}

@Test
Expand All @@ -174,8 +211,10 @@ public void testWriteReadFromTableWithLocation() {
.get();

// assert input and output are the same
assertThat(outputDataset.getNamespace()).isEqualTo(outputDataset.getNamespace());
assertThat(inputDataset.getName()).isEqualTo(inputDataset.getName());
// TODO: this assertions are not working
// https://github.com/OpenLineage/OpenLineage/issues/2543
// assertThat(inputDataset.getNamespace()).isEqualTo(outputDataset.getNamespace());
// assertThat(inputDataset.getName()).isEqualTo(outputDataset.getName());
assertThat(runEvents.size()).isLessThan(20);
}

Expand All @@ -202,37 +241,41 @@ void testMergeInto() {
.getAdditionalProperties();

assertThat(event.getOutputs()).hasSize(1);
assertThat(event.getOutputs().get(0).getName()).endsWith("events");
assertThat(event.getOutputs().get(0).getName()).endsWith("events_" + platformVersion());

assertThat(event.getInputs()).hasSize(2);
assertThat(event.getInputs().stream().map(d -> d.getName()).collect(Collectors.toList()))
.containsExactlyInAnyOrder(
"/user/hive/warehouse/test_db.db/updates", "/user/hive/warehouse/test_db.db/events");
"/user/hive/warehouse/test_db.db/updates_" + platformVersion(),
"/user/hive/warehouse/test_db.db/events_" + platformVersion());

assertThat(fields).hasSize(2);
assertThat(fields.get("last_updated_at").getInputFields()).hasSize(1);
assertThat(fields.get("last_updated_at").getInputFields().get(0))
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/updates")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/updates_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "updated_at");

assertThat(fields.get("event_id").getInputFields()).hasSize(2);
assertThat(
fields.get("event_id").getInputFields().stream()
.filter(e -> e.getName().endsWith("updates"))
.filter(e -> e.getName().endsWith("updates_" + platformVersion()))
.findFirst()
.get())
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/updates")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/updates_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "event_id");

assertThat(
fields.get("event_id").getInputFields().stream()
.filter(e -> e.getName().endsWith("events"))
.filter(e -> e.getName().endsWith("events_" + platformVersion()))
.findFirst()
.get())
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/events")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/events_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "event_id");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,26 @@
public class DatabricksUtils {

public static final String CLUSTER_NAME = "openlineage-test-cluster";
public static final Map<String, String> PLATFORM_VERSIONS =
public static final Map<String, String> PLATFORM_VERSIONS_NAMES =
Stream.of(
new AbstractMap.SimpleEntry<>("3.4.1", "13.3.x-scala2.12"),
new AbstractMap.SimpleEntry<>("3.4.2", "13.3.x-scala2.12"),
new AbstractMap.SimpleEntry<>("3.5.0", "14.2.x-scala2.12"))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final Map<String, String> PLATFORM_VERSIONS =
Stream.of(
new AbstractMap.SimpleEntry<>("3.4.2", "13.3"),
new AbstractMap.SimpleEntry<>("3.5.0", "14.2"))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final String NODE_TYPE = "Standard_DS3_v2";
public static final String DBFS_EVENTS_FILE = "dbfs:/databricks/openlineage/events.log";
public static final String INIT_SCRIPT_FILE = "/Shared/open-lineage-init-script.sh";
public static final String DBFS_CLUSTER_LOGS = "dbfs:/databricks/openlineage/cluster-logs";
private static final String SPARK_VERSION = "spark.version";
public static final String DBFS_EVENTS_FILE =
"dbfs:/databricks/openlineage/events_" + platformVersion() + ".log";

public static String platformVersion() {
return PLATFORM_VERSIONS.get(System.getProperty(SPARK_VERSION)).replace(".", "_");
}

@SneakyThrows
static String init(WorkspaceClient workspace) {
Expand All @@ -72,18 +82,18 @@ static String init(WorkspaceClient workspace) {
workspace.dbfs().delete(deleteClusterLogs);

// check if cluster is available
String clusterId;
Iterable<ClusterDetails> clusterDetails = workspace.clusters().list(new ListClustersRequest());
if (clusterDetails != null) {
clusterId =
StreamSupport.stream(clusterDetails.spliterator(), false)
.filter(cl -> cl.getClusterName().equals(getClusterName()))
.map(cl -> cl.getClusterId())
.findAny()
.orElseGet(() -> createCluster(workspace));
} else {
clusterId = createCluster(workspace);
log.info("Encountered clusters to delete.");
StreamSupport.stream(clusterDetails.spliterator(), false)
.filter(cl -> cl.getClusterName().equals(getClusterName()))
.forEach(
cl -> {
log.info("Deleting a cluster {}-{}.", cl.getClusterName(), cl.getClusterId());
workspace.clusters().permanentDelete(cl.getClusterId());
});
}
String clusterId = createCluster(workspace);

log.info("Ensuring cluster is running");
workspace.clusters().ensureClusterIsRunning(clusterId);
Expand Down Expand Up @@ -203,18 +213,20 @@ private static String getClusterName() {
}

private static String getSparkPlatformVersion() {
if (!PLATFORM_VERSIONS.containsKey(System.getProperty(SPARK_VERSION))) {
log.error("Unsupported spark_version for databricks test");
if (!PLATFORM_VERSIONS_NAMES.containsKey(System.getProperty(SPARK_VERSION))) {
log.error("Unsupported spark_version for databricks test {}", SPARK_VERSION);
}

return PLATFORM_VERSIONS.get(System.getProperty(SPARK_VERSION));
log.info(
"Databricks version {}", PLATFORM_VERSIONS_NAMES.get(System.getProperty(SPARK_VERSION)));
return PLATFORM_VERSIONS_NAMES.get(System.getProperty(SPARK_VERSION));
}

@SneakyThrows
private static void uploadOpenlineageJar(WorkspaceClient workspace) {
Path jarFile =
Files.list(Paths.get("../build/libs/"))
.filter(p -> p.getFileName().toString().startsWith("openlineage-spark-"))
.filter(p -> p.getFileName().toString().startsWith("openlineage-spark_"))
.filter(p -> p.getFileName().toString().endsWith("jar"))
.findAny()
.orElseThrow(() -> new RuntimeException("openlineage-spark jar not found"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,23 @@
if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

spark.sql("DROP TABLE IF EXISTS default.temp")
spark.sql("DROP TABLE IF EXISTS default.ctas")
runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

spark.sql("DROP TABLE IF EXISTS default.temp_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS default.ctas_{}".format(runtime_version))

spark.createDataFrame([{"a": 1, "b": 2}, {"a": 3, "b": 4}]).repartition(1).write.mode(
"overwrite"
).saveAsTable("default.temp")
).saveAsTable("default.temp_{}".format(runtime_version))

spark.sql("CREATE TABLE default.ctas AS SELECT a, b FROM default.temp")
spark.sql(
"CREATE TABLE default.ctas_{} AS SELECT a, b FROM default.temp_{}".format(
runtime_version, runtime_version
)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
import os
import time

dbutils.fs.rm("dbfs:/mnt/openlineage/test/t1", True)
dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test/t1")
runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

dbutils.fs.rm("dbfs:/mnt/openlineage/test_db/t1_{}'".format(runtime_version), True)
dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test_db/t1_{}'".format(runtime_version))

if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

spark.sql("CREATE DATABASE IF NOT EXISTS test_db")
spark.sql("DROP TABLE IF EXISTS test_db.t1")
spark.sql("DROP TABLE IF EXISTS test_db.t2")
spark.sql("CREATE TABLE test_db.t1 (id long, value string) LOCATION '/mnt/openlineage/test_db/t1'")
spark.sql("DROP TABLE IF EXISTS test_db.t1_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS test_db.t2_{}".format(runtime_version))
spark.sql(
"CREATE TABLE test_db.t1_{} (id long, value string) LOCATION '/mnt/openlineage/test_db/t1_{}'".format(
runtime_version, runtime_version
)
)

df = spark.sparkContext.parallelize([(1, "a"), (2, "b"), (3, "c")]).toDF(["id", "value"])
df.write.mode("overwrite").saveAsTable("test_db.t1")
df.write.mode("overwrite").saveAsTable("test_db.t1_{}".format(runtime_version))

spark.sql("CREATE TABLE test_db.t2 AS select * from test_db.t1")
spark.sql(
"CREATE TABLE test_db.t2_{} AS select * from test_db.t1_{}".format(runtime_version, runtime_version)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,44 @@
import os
import time

dbutils.fs.rm("dbfs:/mnt/openlineage/test/updates", True)
dbutils.fs.rm("dbfs:/mnt/openlineage/test/events", True)
runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test/updates")
dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test/events")
dbutils.fs.rm("dbfs:/user/hive/warehouse/test_db.db/updates_{}".format(runtime_version), True)
dbutils.fs.rm("dbfs:/user/hive/warehouse/test_db.db/events_{}".format(runtime_version), True)

dbutils.fs.mkdirs("dbfs:/user/hive/warehouse/test_db.db/updates_{}".format(runtime_version))
dbutils.fs.mkdirs("dbfs:/user/hive/warehouse/test_db.db/events_{}".format(runtime_version))

if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

spark.sql("CREATE DATABASE IF NOT EXISTS test_db")
spark.sql("DROP TABLE IF EXISTS test_db.events")
spark.sql("DROP TABLE IF EXISTS test_db.updates")
spark.sql("DROP TABLE IF EXISTS test_db.events_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS test_db.updates_{}".format(runtime_version))

spark.sql("CREATE TABLE test_db.events (event_id long, last_updated_at long) USING delta")
spark.sql("CREATE TABLE test_db.updates (event_id long, updated_at long) USING delta")
spark.sql(
"CREATE TABLE test_db.events_{} (event_id long, last_updated_at long) USING delta".format(runtime_version)
)
spark.sql(
"CREATE TABLE test_db.updates_{} (event_id long, updated_at long) USING delta".format(runtime_version)
)

spark.sql("INSERT INTO test_db.events VALUES (1, 1641290276);")
spark.sql("INSERT INTO test_db.updates VALUES (1, 1641290277);")
spark.sql("INSERT INTO test_db.updates VALUES (2, 1641290277);")
spark.sql("INSERT INTO test_db.events_{} VALUES (1, 1641290276);".format(runtime_version))
spark.sql("INSERT INTO test_db.updates_{} VALUES (1, 1641290277);".format(runtime_version))
spark.sql("INSERT INTO test_db.updates_{} VALUES (2, 1641290277);".format(runtime_version))

spark.sql(
"MERGE INTO test_db.events target USING test_db.updates "
+ " ON target.event_id = test_db.updates.event_id"
+ " WHEN MATCHED THEN UPDATE SET target.last_updated_at = test_db.updates.updated_at"
+ " WHEN NOT MATCHED THEN INSERT (event_id, last_updated_at) "
+ "VALUES (event_id, updated_at)"
(
"MERGE INTO test_db.events_{} target USING test_db.updates_{} "
+ " ON target.event_id = test_db.updates_{}.event_id"
+ " WHEN MATCHED THEN UPDATE SET target.last_updated_at = test_db.updates_{}.updated_at"
+ " WHEN NOT MATCHED THEN INSERT (event_id, last_updated_at) "
+ "VALUES (event_id, updated_at)"
).format(runtime_version, runtime_version, runtime_version, runtime_version)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

spark = SparkSession.builder.appName("narrow_transformation").getOrCreate()

data = [(1, "a"), (2, "b"), (3, "c")]
Expand All @@ -17,8 +19,10 @@

df = df.withColumn("id_plus_one", df["id"] + 1)

(df.write.mode("overwrite").parquet("data/path/to/output/narrow_transformation/"))
df.write.mode("overwrite").parquet("data/path/to/output/narrow_transformation_{}/".format(runtime_version))

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)

0 comments on commit 89dca31

Please sign in to comment.