Skip to content

Commit

Permalink
[Spark] fix databricks environment
Browse files Browse the repository at this point in the history
Signed-off-by: Pawel Leszczynski <leszczynski.pawel@gmail.com>
  • Loading branch information
pawel-big-lebowski committed Mar 28, 2024
1 parent 2d1ee64 commit 350dde6
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static io.openlineage.spark.agent.DatabricksUtils.DBFS_EVENTS_FILE;
import static io.openlineage.spark.agent.DatabricksUtils.init;
import static io.openlineage.spark.agent.DatabricksUtils.platformVersion;
import static io.openlineage.spark.agent.DatabricksUtils.runScript;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -17,6 +18,7 @@
import io.openlineage.client.OpenLineage.OutputDataset;
import io.openlineage.client.OpenLineage.RunEvent;
import io.openlineage.client.OpenLineage.RunEvent.EventType;
import io.openlineage.client.OpenLineage.RunFacet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -77,10 +79,45 @@ public void testCreateTableAsSelect() {
InputDataset inputDataset = lastEvent.getInputs().get(0);

assertThat(outputDataset.getNamespace()).isEqualTo("dbfs");
assertThat(outputDataset.getName()).isEqualTo("/user/hive/warehouse/ctas");
assertThat(outputDataset.getName()).isEqualTo("/user/hive/warehouse/ctas_" + platformVersion());

assertThat(inputDataset.getNamespace()).isEqualTo("dbfs");
assertThat(inputDataset.getName()).isEqualTo("/user/hive/warehouse/temp");
assertThat(inputDataset.getName()).isEqualTo("/user/hive/warehouse/temp_" + platformVersion());

// test DatabricksEnvironmentFacetBuilder handler
RunEvent eventWithDatabricksProperties =
runEvents.stream()
.filter(
r ->
r.getRun()
.getFacets()
.getAdditionalProperties()
.containsKey("environment-properties"))
.findFirst()
.get();

RunFacet environmentFacet =
eventWithDatabricksProperties
.getRun()
.getFacets()
.getAdditionalProperties()
.get("environment-properties");

Map<String, Object> properties =
(Map<String, Object>)
environmentFacet.getAdditionalProperties().get("environment-properties");

assertThat(properties.get("spark.databricks.job.type")).isEqualTo("python");

List<Object> mounts = (List<Object>) properties.get("mountPoints");

assertThat(mounts).isNotEmpty();
Map<String, String> mountInfo = (Map<String, String>) mounts.get(0);

assertThat(mountInfo).containsKeys("mountPoint", "source");

assertThat(mountInfo.get("mountPoint")).startsWith("/databricks");
assertThat(mountInfo.get("source")).startsWith("databricks");
}

@Test
Expand Down Expand Up @@ -117,7 +154,7 @@ public void testNarrowTransformation() {

assertThat(completeEvent).isPresent();
assertThat(completeEvent.get().getOutputs().get(0).getName())
.isEqualTo("/data/path/to/output/narrow_transformation");
.isEqualTo("/data/path/to/output/narrow_transformation_" + platformVersion());
}

@Test
Expand Down Expand Up @@ -148,7 +185,7 @@ public void testWideTransformation() {

assertThat(completeEvent).isPresent();
assertThat(completeEvent.get().getOutputs().get(0).getName())
.isEqualTo("/data/output/wide_transformation/result");
.isEqualTo("/data/output/wide_transformation/result_" + platformVersion());
}

@Test
Expand All @@ -174,8 +211,9 @@ public void testWriteReadFromTableWithLocation() {
.get();

// assert input and output are the same
assertThat(outputDataset.getNamespace()).isEqualTo(outputDataset.getNamespace());
assertThat(inputDataset.getName()).isEqualTo(inputDataset.getName());
// TODO: this assertions are not working - ISSUE
// assertThat(inputDataset.getNamespace()).isEqualTo(outputDataset.getNamespace());
// assertThat(inputDataset.getName()).isEqualTo(outputDataset.getName());
assertThat(runEvents.size()).isLessThan(20);
}

Expand All @@ -202,37 +240,41 @@ void testMergeInto() {
.getAdditionalProperties();

assertThat(event.getOutputs()).hasSize(1);
assertThat(event.getOutputs().get(0).getName()).endsWith("events");
assertThat(event.getOutputs().get(0).getName()).endsWith("events_" + platformVersion());

assertThat(event.getInputs()).hasSize(2);
assertThat(event.getInputs().stream().map(d -> d.getName()).collect(Collectors.toList()))
.containsExactlyInAnyOrder(
"/user/hive/warehouse/test_db.db/updates", "/user/hive/warehouse/test_db.db/events");
"/user/hive/warehouse/test_db.db/updates_" + platformVersion(),
"/user/hive/warehouse/test_db.db/events_" + platformVersion());

assertThat(fields).hasSize(2);
assertThat(fields.get("last_updated_at").getInputFields()).hasSize(1);
assertThat(fields.get("last_updated_at").getInputFields().get(0))
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/updates")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/updates_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "updated_at");

assertThat(fields.get("event_id").getInputFields()).hasSize(2);
assertThat(
fields.get("event_id").getInputFields().stream()
.filter(e -> e.getName().endsWith("updates"))
.filter(e -> e.getName().endsWith("updates_" + platformVersion()))
.findFirst()
.get())
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/updates")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/updates_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "event_id");

assertThat(
fields.get("event_id").getInputFields().stream()
.filter(e -> e.getName().endsWith("events"))
.filter(e -> e.getName().endsWith("events_" + platformVersion()))
.findFirst()
.get())
.hasFieldOrPropertyWithValue("namespace", "dbfs")
.hasFieldOrPropertyWithValue("name", "/user/hive/warehouse/test_db.db/events")
.hasFieldOrPropertyWithValue(
"name", "/user/hive/warehouse/test_db.db/events_" + platformVersion())
.hasFieldOrPropertyWithValue("field", "event_id");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,26 @@
public class DatabricksUtils {

public static final String CLUSTER_NAME = "openlineage-test-cluster";
public static final Map<String, String> PLATFORM_VERSIONS =
public static final Map<String, String> PLATFORM_VERSIONS_NAMES =
Stream.of(
new AbstractMap.SimpleEntry<>("3.4.1", "13.3.x-scala2.12"),
new AbstractMap.SimpleEntry<>("3.4.2", "13.3.x-scala2.12"),
new AbstractMap.SimpleEntry<>("3.5.0", "14.2.x-scala2.12"))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final Map<String, String> PLATFORM_VERSIONS =
Stream.of(
new AbstractMap.SimpleEntry<>("3.4.2", "13.3"),
new AbstractMap.SimpleEntry<>("3.5.0", "14.2"))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final String NODE_TYPE = "Standard_DS3_v2";
public static final String DBFS_EVENTS_FILE = "dbfs:/databricks/openlineage/events.log";
public static final String INIT_SCRIPT_FILE = "/Shared/open-lineage-init-script.sh";
public static final String DBFS_CLUSTER_LOGS = "dbfs:/databricks/openlineage/cluster-logs";
private static final String SPARK_VERSION = "spark.version";
public static final String DBFS_EVENTS_FILE =
"dbfs:/databricks/openlineage/events_" + platformVersion() + ".log";

public static String platformVersion() {
return PLATFORM_VERSIONS.get(System.getProperty(SPARK_VERSION)).replace(".", "_");
}

@SneakyThrows
static String init(WorkspaceClient workspace) {
Expand Down Expand Up @@ -151,13 +161,15 @@ static List<RunEvent> runScript(WorkspaceClient workspace, String clusterId, Str
runSubmitTaskSettings.setTaskKey(taskName);
runSubmitTaskSettings.setExistingClusterId(clusterId);
runSubmitTaskSettings.setSparkPythonTask(task);
runSubmitTaskSettings.setTimeoutSeconds(60L);

SubmitRun submitRun = new SubmitRun();
submitRun.setRunName(taskName);
submitRun.setTimeoutSeconds(60L);
submitRun.setTasks(Collections.singletonList(runSubmitTaskSettings));

// trigger one time job
workspace.jobs().submit(submitRun).get();
workspace.jobs().submit(submitRun).get(Duration.ofMinutes(3));

return fetchEventsEmitted(workspace);
}
Expand Down Expand Up @@ -203,18 +215,20 @@ private static String getClusterName() {
}

private static String getSparkPlatformVersion() {
if (!PLATFORM_VERSIONS.containsKey(System.getProperty(SPARK_VERSION))) {
log.error("Unsupported spark_version for databricks test");
if (!PLATFORM_VERSIONS_NAMES.containsKey(System.getProperty(SPARK_VERSION))) {
log.error("Unsupported spark_version for databricks test {}", SPARK_VERSION);
}

return PLATFORM_VERSIONS.get(System.getProperty(SPARK_VERSION));
log.info(
"Databricks version {}", PLATFORM_VERSIONS_NAMES.get(System.getProperty(SPARK_VERSION)));
return PLATFORM_VERSIONS_NAMES.get(System.getProperty(SPARK_VERSION));
}

@SneakyThrows
private static void uploadOpenlineageJar(WorkspaceClient workspace) {
Path jarFile =
Files.list(Paths.get("../build/libs/"))
.filter(p -> p.getFileName().toString().startsWith("openlineage-spark-"))
.filter(p -> p.getFileName().toString().startsWith("openlineage-spark_"))
.filter(p -> p.getFileName().toString().endsWith("jar"))
.findAny()
.orElseThrow(() -> new RuntimeException("openlineage-spark jar not found"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,23 @@
if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

spark.sql("DROP TABLE IF EXISTS default.temp")
spark.sql("DROP TABLE IF EXISTS default.ctas")
runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

spark.sql("DROP TABLE IF EXISTS default.temp_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS default.ctas_{}".format(runtime_version))

spark.createDataFrame([{"a": 1, "b": 2}, {"a": 3, "b": 4}]).repartition(1).write.mode(
"overwrite"
).saveAsTable("default.temp")
).saveAsTable("default.temp_{}".format(runtime_version))

spark.sql("CREATE TABLE default.ctas AS SELECT a, b FROM default.temp")
spark.sql(
"CREATE TABLE default.ctas_{} AS SELECT a, b FROM default.temp_{}".format(
runtime_version, runtime_version
)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,26 @@
if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

spark.sql("CREATE DATABASE IF NOT EXISTS test_db")
spark.sql("DROP TABLE IF EXISTS test_db.t1")
spark.sql("DROP TABLE IF EXISTS test_db.t2")
spark.sql("CREATE TABLE test_db.t1 (id long, value string) LOCATION '/mnt/openlineage/test_db/t1'")
spark.sql("DROP TABLE IF EXISTS test_db.t1_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS test_db.t2_{}".format(runtime_version))
spark.sql(
"CREATE TABLE test_db.t1_{} (id long, value string) LOCATION '/mnt/openlineage/test_db/t1_{}'".format(
runtime_version, runtime_version
)
)

df = spark.sparkContext.parallelize([(1, "a"), (2, "b"), (3, "c")]).toDF(["id", "value"])
df.write.mode("overwrite").saveAsTable("test_db.t1")
df.write.mode("overwrite").saveAsTable("test_db.t1_{}".format(runtime_version))

spark.sql("CREATE TABLE test_db.t2 AS select * from test_db.t1")
spark.sql(
"CREATE TABLE test_db.t2_{} AS select * from test_db.t1_{}".format(runtime_version, runtime_version)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,44 @@
import os
import time

dbutils.fs.rm("dbfs:/mnt/openlineage/test/updates", True)
dbutils.fs.rm("dbfs:/mnt/openlineage/test/events", True)
runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test/updates")
dbutils.fs.mkdirs("dbfs:/mnt/openlineage/test/events")
dbutils.fs.rm("dbfs:/user/hive/warehouse/test_db.db/updates_{}".format(runtime_version), True)
dbutils.fs.rm("dbfs:/user/hive/warehouse/test_db.db/events_{}".format(runtime_version), True)

dbutils.fs.mkdirs("dbfs:/user/hive/warehouse/test_db.db/updates_{}".format(runtime_version))
dbutils.fs.mkdirs("dbfs:/user/hive/warehouse/test_db.db/events_{}".format(runtime_version))

if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

spark.sql("CREATE DATABASE IF NOT EXISTS test_db")
spark.sql("DROP TABLE IF EXISTS test_db.events")
spark.sql("DROP TABLE IF EXISTS test_db.updates")
spark.sql("DROP TABLE IF EXISTS test_db.events_{}".format(runtime_version))
spark.sql("DROP TABLE IF EXISTS test_db.updates_{}".format(runtime_version))

spark.sql("CREATE TABLE test_db.events (event_id long, last_updated_at long) USING delta")
spark.sql("CREATE TABLE test_db.updates (event_id long, updated_at long) USING delta")
spark.sql(
"CREATE TABLE test_db.events_{} (event_id long, last_updated_at long) USING delta".format(runtime_version)
)
spark.sql(
"CREATE TABLE test_db.updates_{} (event_id long, updated_at long) USING delta".format(runtime_version)
)

spark.sql("INSERT INTO test_db.events VALUES (1, 1641290276);")
spark.sql("INSERT INTO test_db.updates VALUES (1, 1641290277);")
spark.sql("INSERT INTO test_db.updates VALUES (2, 1641290277);")
spark.sql("INSERT INTO test_db.events_{} VALUES (1, 1641290276);".format(runtime_version))
spark.sql("INSERT INTO test_db.updates_{} VALUES (1, 1641290277);".format(runtime_version))
spark.sql("INSERT INTO test_db.updates_{} VALUES (2, 1641290277);".format(runtime_version))

spark.sql(
"MERGE INTO test_db.events target USING test_db.updates "
+ " ON target.event_id = test_db.updates.event_id"
+ " WHEN MATCHED THEN UPDATE SET target.last_updated_at = test_db.updates.updated_at"
+ " WHEN NOT MATCHED THEN INSERT (event_id, last_updated_at) "
+ "VALUES (event_id, updated_at)"
(
"MERGE INTO test_db.events_{} target USING test_db.updates_{} "
+ " ON target.event_id = test_db.updates_{}.event_id"
+ " WHEN MATCHED THEN UPDATE SET target.last_updated_at = test_db.updates_{}.updated_at"
+ " WHEN NOT MATCHED THEN INSERT (event_id, last_updated_at) "
+ "VALUES (event_id, updated_at)"
).format(runtime_version, runtime_version, runtime_version, runtime_version)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
if os.path.exists("/tmp/events.log"):
os.remove("/tmp/events.log")

runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

spark = SparkSession.builder.appName("narrow_transformation").getOrCreate()

data = [(1, "a"), (2, "b"), (3, "c")]
Expand All @@ -17,8 +19,10 @@

df = df.withColumn("id_plus_one", df["id"] + 1)

(df.write.mode("overwrite").parquet("data/path/to/output/narrow_transformation/"))
df.write.mode("overwrite").parquet("data/path/to/output/narrow_transformation_{}/".format(runtime_version))

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

spark = SparkSession.builder.appName("wide_transformation").getOrCreate()

runtime_version = os.environ.get("DATABRICKS_RUNTIME_VERSION", None).replace(".", "_")

data = [(1, "a"), (2, "b"), (3, "c")]
rdd = spark.sparkContext.parallelize(data)
df = rdd.toDF(["id", "value"])

(df.groupBy("id").count().write.mode("overwrite").parquet("data/output/wide_transformation/result/"))
df.groupBy("id").count().write.mode("overwrite").parquet(
"data/output/wide_transformation/result_{}/".format(runtime_version)
)

time.sleep(3)

dbutils.fs.cp("file:/tmp/events.log", "dbfs:/databricks/openlineage/events.log")
event_file = "dbfs:/databricks/openlineage/events_{}.log".format(runtime_version)
dbutils.fs.rm(event_file, True)
dbutils.fs.cp("file:/tmp/events.log", event_file)
2 changes: 1 addition & 1 deletion integration/spark/databricks/open-lineage-init-script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
STAGE_DIR="/dbfs/databricks/openlineage"

echo "BEGIN: Upload Spark Listener JARs"
cp -f $STAGE_DIR/openlineage-spark-*.jar /mnt/driver-daemon/jars || { echo "Error copying Spark Listener library file"; exit 1;}
cp -f $STAGE_DIR/openlineage-spark_*.jar /mnt/driver-daemon/jars || { echo "Error copying Spark Listener library file"; exit 1;}
echo "END: Upload Spark Listener JARs"

echo "BEGIN: Modify Spark config settings"
Expand Down

0 comments on commit 350dde6

Please sign in to comment.