From 7a3a5498d28a573ebe545c2213281915e264faef Mon Sep 17 00:00:00 2001 From: "lukas.nalezenec" Date: Wed, 13 Mar 2019 14:40:02 +0100 Subject: [PATCH] spline-132 JDBC write support --- .../spline/harvester/DataLineageBuilder.scala | 33 +++-- .../spline/harvester/operationBuilders.scala | 25 +++- .../harvester/DataLineageBuilderTest.scala | 4 + integration-tests/pom.xml | 17 +++ .../spline/fixture/DerbyDatabaseFixture.scala | 72 +++++++++++ .../absa/spline/fixture/SplineFixture.scala | 24 +++- .../absa/spline/BasicIntegrationTests.scala | 106 --------------- .../za/co/absa/spline/JDBCWriteTest.scala | 63 +++++++++ .../spline/SaveAsTableIntegrationTests.scala | 121 ++++++++++++++++++ .../WriteCommandParserImpl.scala | 28 +++- .../WriteCommandParserImpl.scala | 27 +++- .../WriteCommandParserImpl.scala | 30 ++++- .../sparkadapterapi/WriteCommandParser.scala | 6 +- .../co/absa/spline/fixture/SparkFixture.scala | 2 +- 14 files changed, 415 insertions(+), 143 deletions(-) create mode 100644 integration-tests/src/main/scala/za/co/absa/spline/fixture/DerbyDatabaseFixture.scala delete mode 100644 integration-tests/src/test/scala/za/co/absa/spline/BasicIntegrationTests.scala create mode 100644 integration-tests/src/test/scala/za/co/absa/spline/JDBCWriteTest.scala create mode 100644 integration-tests/src/test/scala/za/co/absa/spline/SaveAsTableIntegrationTests.scala diff --git a/harvester/src/main/scala/za/co/absa/spline/harvester/DataLineageBuilder.scala b/harvester/src/main/scala/za/co/absa/spline/harvester/DataLineageBuilder.scala index 7d9f41c5f..6a5cba13d 100644 --- a/harvester/src/main/scala/za/co/absa/spline/harvester/DataLineageBuilder.scala +++ b/harvester/src/main/scala/za/co/absa/spline/harvester/DataLineageBuilder.scala @@ -36,6 +36,7 @@ class DataLineageBuilder(logicalPlan: LogicalPlan, executedPlanOpt: Option[Spark private val writeCommandParser = writeCommandParserFactory.writeParser() private val clusterUrl: Option[String] = sparkContext.getConf.getOption("spark.master") private val tableCommandParser = writeCommandParserFactory.saveAsTableParser(clusterUrl) + private val jdbcCommandParser = writeCommandParserFactory.jdbcParser() def buildLineage(): Option[DataLineage] = { val builders = getOperations(logicalPlan) @@ -82,19 +83,19 @@ class DataLineageBuilder(logicalPlan: LogicalPlan, executedPlanOpt: Option[Spark if (maybeExistingBuilder.isEmpty) { - //try to find all possible commands for traversing- save to filesystem, saveAsTable, JDBC - val writes = writeCommandParser. - asWriteCommandIfPossible(curOpNode). - map(wc => Seq(wc.query)). - getOrElse(Nil) + val parsers = Array(jdbcCommandParser, writeCommandParser, tableCommandParser) - val tables = tableCommandParser. - asWriteCommandIfPossible(curOpNode). - map(wc => Seq(wc.query)). - getOrElse(Nil) + val maybePlan: Option[LogicalPlan] = parsers. + map(_.asWriteCommandIfPossible(curOpNode)). + collectFirst { + case Some(wc) => wc.query + } - var newNodesToProcess: Seq[LogicalPlan] = writes ++ tables - if (newNodesToProcess.isEmpty) newNodesToProcess = curOpNode.children + val newNodesToProcess: Seq[LogicalPlan] = + maybePlan match { + case Some(q) => Seq(q) + case None => curOpNode.children + } traverseAndCollect( curBuilder +: accBuilders, @@ -122,19 +123,23 @@ class DataLineageBuilder(logicalPlan: LogicalPlan, executedPlanOpt: Option[Spark case a: SubqueryAlias => new AliasNodeBuilder(a) case lr: LogicalRelation => new BatchReadNodeBuilder(lr) with HDFSAwareBuilder case StreamingRelationVersionAgnostic(dataSourceInfo) => new StreamReadNodeBuilder(op) + case wc if jdbcCommandParser.matches(op) => + val (readMetrics: Metrics, writeMetrics: Metrics) = getMetrics() + val tableCmd = jdbcCommandParser.asWriteCommand(wc).asInstanceOf[SaveJDBCCommand] + new SaveJDBCCommandNodeBuilder(tableCmd, writeMetrics, readMetrics) case wc if writeCommandParser.matches(op) => - val (readMetrics: Metrics, writeMetrics: Metrics) = makeMetrics() + val (readMetrics: Metrics, writeMetrics: Metrics) = getMetrics() val writeCmd = writeCommandParser.asWriteCommand(wc).asInstanceOf[WriteCommand] new BatchWriteNodeBuilder(writeCmd, writeMetrics, readMetrics) with HDFSAwareBuilder case wc if tableCommandParser.matches(op) => - val (readMetrics: Metrics, writeMetrics: Metrics) = makeMetrics() + val (readMetrics: Metrics, writeMetrics: Metrics) = getMetrics() val tableCmd = tableCommandParser.asWriteCommand(wc).asInstanceOf[SaveAsTableCommand] new SaveAsTableNodeBuilder(tableCmd, writeMetrics, readMetrics) case x => new GenericNodeBuilder(x) } } - private def makeMetrics(): (Metrics, Metrics) = { + private def getMetrics(): (Metrics, Metrics) = { executedPlanOpt. map(getExecutedReadWriteMetrics). getOrElse((Map.empty, Map.empty)) diff --git a/harvester/src/main/scala/za/co/absa/spline/harvester/operationBuilders.scala b/harvester/src/main/scala/za/co/absa/spline/harvester/operationBuilders.scala index 6141d75d9..51ededc12 100644 --- a/harvester/src/main/scala/za/co/absa/spline/harvester/operationBuilders.scala +++ b/harvester/src/main/scala/za/co/absa/spline/harvester/operationBuilders.scala @@ -24,13 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.{DataSource, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.{JDBCRelation, SaveMode} -import za.co.absa.spline.sparkadapterapi.WriteCommand +import za.co.absa.spline.sparkadapterapi.{DataSourceInfo, SaveAsTableCommand, SaveJDBCCommand, WriteCommand} import za.co.absa.spline.model.endpoint._ import za.co.absa.spline.model.{op, _} -import za.co.absa.spline.sparkadapterapi.{DataSourceInfo, SaveAsTableCommand} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import za.co.absa.spline.sparkadapterapi.StreamingRelationAdapter.instance.extractDataSourceInfo -import za.co.absa.spline.sparkadapterapi.{DataSourceInfo, WriteCommand} sealed trait OperationNodeBuilder { @@ -186,6 +184,27 @@ class SaveAsTableNodeBuilder } } +class SaveJDBCCommandNodeBuilder +(val operation: SaveJDBCCommand, val writeMetrics: Map[String, Long], val readMetrics: Map[String, Long]) +(implicit val componentCreatorFactory: ComponentCreatorFactory) + extends OperationNodeBuilder with RootNode { + + override val output: AttrGroup = new AttrGroup(operation.query.output) + + override def build() = op.BatchWrite( + operationProps, + operation.format, + operation.tableName, + append = operation.mode == SaveMode.Append, + writeMetrics = writeMetrics, + readMetrics = readMetrics + ) + + override def ignoreLineageWrite:Boolean = { + false + } +} + trait RootNode { def ignoreLineageWrite:Boolean } diff --git a/harvester/src/test/scala/za/co/absa/spline/harvester/DataLineageBuilderTest.scala b/harvester/src/test/scala/za/co/absa/spline/harvester/DataLineageBuilderTest.scala index cc95aefb4..c46e0a441 100644 --- a/harvester/src/test/scala/za/co/absa/spline/harvester/DataLineageBuilderTest.scala +++ b/harvester/src/test/scala/za/co/absa/spline/harvester/DataLineageBuilderTest.scala @@ -58,12 +58,16 @@ object DataLineageBuilderTest extends MockitoSugar { private def lineageBuilderFor(df: DataFrame)(implicit sparkContext: SparkContext): DataLineageBuilder = { val plan = df.queryExecution.analyzed val mockWriteCommandParser = mock[WriteCommandParser[LogicalPlan]] + val mockJdbcCommandParser = mock[WriteCommandParser[LogicalPlan]] val factory = mock[WriteCommandParserFactory] when(mockWriteCommandParser asWriteCommandIfPossible any()) thenReturn None + when(mockJdbcCommandParser asWriteCommandIfPossible any()) thenReturn None + when(factory writeParser()) thenReturn mockWriteCommandParser when(factory saveAsTableParser(any())) thenReturn mockWriteCommandParser + when(factory jdbcParser()) thenReturn mockJdbcCommandParser new DataLineageBuilder(plan, None, sparkContext)(mock[Configuration], factory) } diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 1cdfaa0c6..3e1a7abe0 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -64,6 +64,16 @@ migrator-tool ${project.version} + + za.co.absa.spline.shadow + persistence + ${project.version} + + + za.co.absa.spline.shadow + migrator-tool + ${project.version} + org.scalatest scalatest_${scala.compat.version} @@ -72,6 +82,13 @@ org.mockito mockito-core + + + org.apache.derby + derby + 10.14.2.0 + + diff --git a/integration-tests/src/main/scala/za/co/absa/spline/fixture/DerbyDatabaseFixture.scala b/integration-tests/src/main/scala/za/co/absa/spline/fixture/DerbyDatabaseFixture.scala new file mode 100644 index 000000000..2a160e3ed --- /dev/null +++ b/integration-tests/src/main/scala/za/co/absa/spline/fixture/DerbyDatabaseFixture.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline.fixture + +import java.sql.{Connection, DriverManager, ResultSet} + +import org.scalactic.source.Position +import org.scalatest.{BeforeAndAfter, Suite} +import za.co.absa.spline.common.TempDirectory + +/** + * Runs and wraps embedded Apache Derby DB. + **/ +trait DerbyDatabaseFixture extends BeforeAndAfter{ + + this: Suite => + + private val dbName = "splineTestDb" + val connectionString = s"jdbc:derby:memory:$dbName ;create=true" + + var connection : Connection = null + + private def execute(sql: String): ResultSet = { + val statement = connection.createStatement + statement.execute(sql) + statement.getResultSet + } + + private def closeDatabase() : Unit = { + def closeCommand(cmd: String) = util.Try({DriverManager.getConnection(cmd)}) + + val connectionString = "jdbc:derby:memory:" + dbName + closeCommand(connectionString + ";drop=true") + closeCommand(connectionString + ";shutdown=true") + } + + private def createTable(table: String): ResultSet = { + execute("Create table " + table + " (id int, name varchar(30))") + } + + private def dropTable(table: String): ResultSet = { + execute("drop table " + table) + } + + override protected def after(fun: => Any)(implicit pos: Position): Unit = try super.after(fun) finally closeDatabase() + + override protected def before(fun: => Any)(implicit pos: Position): Unit = { + val tempPath = TempDirectory("derbyUnitTest", "database").deleteOnExit().path + System.setProperty("derby.system.home", tempPath.toString) + DriverManager.registerDriver(new org.apache.derby.jdbc.EmbeddedDriver) + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + connection = DriverManager.getConnection(connectionString) + } +} + + + diff --git a/integration-tests/src/main/scala/za/co/absa/spline/fixture/SplineFixture.scala b/integration-tests/src/main/scala/za/co/absa/spline/fixture/SplineFixture.scala index aa8c9fa5d..d5e1088f7 100644 --- a/integration-tests/src/main/scala/za/co/absa/spline/fixture/SplineFixture.scala +++ b/integration-tests/src/main/scala/za/co/absa/spline/fixture/SplineFixture.scala @@ -16,12 +16,13 @@ package za.co.absa.spline.fixture +import java.util.Properties import java.{util => ju} import com.mongodb.casbah.MongoDB import com.mongodb.{DBCollection, DBObject} import org.apache.commons.configuration.Configuration -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} import org.bson.BSON import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -51,6 +52,10 @@ trait AbstractSplineFixture AbstractSplineFixture.touch() +// def withNewSession[T >: AnyRef](testBody: SparkSession => T): T = { +// testBody(spark.newSession) +// } + abstract override protected def beforeAll(): Unit = { import za.co.absa.spline.harvester.SparkLineageInitializer._ spark.enableLineageTracking() @@ -79,13 +84,19 @@ trait AsyncSplineFixture extends AbstractSplineFixture with AsyncTestSuiteMixin abstract override def withFixture(test: NoArgAsyncTest): FutureOutcome = exec { super.withFixture(test) } + + def withSplineEnabled[T](session:SparkSession)(testBody: => T) = { + import za.co.absa.spline.harvester.SparkLineageInitializer._ + session.enableLineageTracking() + testBody + } } object AbstractSplineFixture { import scala.concurrent.{ExecutionContext, Future} -// System.getProperties.setProperty(PERSISTENCE_FACTORY, classOf[TestPersistenceFactory].getName) + //System.getProperties.setProperty(PERSISTENCE_FACTORY, classOf[TestPersistenceFactory].getName) private var justCapturedLineage: DataLineage = _ @@ -121,6 +132,15 @@ object AbstractSplineFixture { df.write.mode(mode).saveAsTable(tableName) AbstractSplineFixture.justCapturedLineage } + + /** Writes dataframe to table and returns captured lineage*/ + def jdbcLineage(connectionString:String, + tableName:String, + properties:Properties = new Properties(), + mode: SaveMode = SaveMode.ErrorIfExists): DataLineage = { + df.write.mode(mode).jdbc(connectionString, tableName, properties) + AbstractSplineFixture.justCapturedLineage + } } implicit class LineageComponentSizeVerifier(lineage: DataLineage)(implicit ec: ExecutionContext) diff --git a/integration-tests/src/test/scala/za/co/absa/spline/BasicIntegrationTests.scala b/integration-tests/src/test/scala/za/co/absa/spline/BasicIntegrationTests.scala deleted file mode 100644 index 41befdd41..000000000 --- a/integration-tests/src/test/scala/za/co/absa/spline/BasicIntegrationTests.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2017 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spline - -import org.apache.hadoop.fs.Path -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.sql.{Row, SaveMode} -import org.scalatest._ -import za.co.absa.spline.common.TempDirectory -import za.co.absa.spline.fixture.{AbstractSplineFixture, AsyncSparkFixture, AsyncSplineFixture} -import za.co.absa.spline.model.DataLineage -import za.co.absa.spline.model.op.{Write} - -/** Contains smoke tests for basic operations.*/ -//Ignored because we cannot have two AsyncSplineFixture based tests in -// one project. This will work in release 4 -@Ignore class BasicIntegrationTests - extends AsyncFlatSpec - with Matchers - with AsyncSparkFixture - with AsyncSplineFixture { - - import spark.implicits._ - - "saveAsTable" should "process all operations" in { - - val df = Seq((1, 2), (3, 4)).toDF().agg(concat(sum('_1), min('_2)) as "forty_two") - val saveAsTable: DataLineage = df.saveAsTableLineage("someTable") - - spark.sql("drop table someTable") - saveAsTable.operations.length shouldBe 3 - } - - "save_to_fs" should "process all operations" in { - - val df = Seq((1, 2), (3, 4)).toDF().agg(concat(sum('_1), min('_2)) as "forty_two") - val saveToFS: DataLineage = df.writtenLineage() - - saveToFS.operations.length shouldBe 3 - } - - "saveAsTable" should "use URIS compatible with filesystem write" in { - - //When I write something to table and then read it again, Spline have to use matching URI. - - val tableName = "externalTable" - val dir = TempDirectory ("sparkunit", "table", false).deleteOnExit() - val path = dir.path.toString.replace("\\", "/") - val sql = "create table " + tableName + " (num int) using parquet location '" + path + "' " - spark.sql(sql) - - val schema: StructType = StructType(List(StructField("num", IntegerType, true))) - val data = spark.sparkContext.parallelize(Seq(Row(1), Row(3))) - val inputDf = spark.sqlContext.createDataFrame(data, schema) - - val writeToTable: DataLineage = inputDf.saveAsTableLineage(tableName, SaveMode.Append) - - val write1: Write = writeToTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] - val saveAsTablePath = write1.path - - AbstractSplineFixture.resetCapturedLineage - spark.sql("drop table " + tableName) - - val readFromTable: DataLineage = inputDf.writtenLineage(path, SaveMode.Overwrite) - - val writeOperation = readFromTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] - val write2 = writeOperation.path - - saveAsTablePath shouldBe write2 - } - - "saveAsTable" should "use table path as identifier when writing to external table" in { - val dir = TempDirectory ("sparkunit", "table", true).deleteOnExit() - val expectedPath = dir.path.toUri.toURL - val path = dir.path.toString.replace("\\", "/") - val sql = "create table e_table(num int) using parquet location '" + path + "' " - spark.sql(sql) - - val schema: StructType = StructType(List(StructField("num", IntegerType, true))) - val data = spark.sparkContext.parallelize(Seq(Row(1), Row(3))) - val df = spark.sqlContext.createDataFrame(data, schema) - - val writeToTable: DataLineage = df.saveAsTableLineage("e_table", SaveMode.Append) - - val writeOperation: Write = writeToTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] - - new Path(writeOperation.path).toUri.toURL shouldBe expectedPath - } - - -} diff --git a/integration-tests/src/test/scala/za/co/absa/spline/JDBCWriteTest.scala b/integration-tests/src/test/scala/za/co/absa/spline/JDBCWriteTest.scala new file mode 100644 index 000000000..1448c9a40 --- /dev/null +++ b/integration-tests/src/test/scala/za/co/absa/spline/JDBCWriteTest.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline + +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} +import org.scalatest._ +import za.co.absa.spline.fixture.{AsyncSparkFixture, AsyncSplineFixture, DerbyDatabaseFixture} +import za.co.absa.spline.model.DataLineage +import za.co.absa.spline.model.op.{BatchWrite, Write} + + +@Ignore class JDBCWriteTest extends AsyncFlatSpec + with Matchers + with AsyncSparkFixture + with AsyncSplineFixture + with DerbyDatabaseFixture { + + val tableName = "testTable" + + val testData: DataFrame = + withNewSession( session => { + withSplineEnabled(session) { + val schema = StructType(StructField("ID", IntegerType, false) :: StructField("NAME", StringType, false) :: Nil) + val rdd = spark.sparkContext.parallelize(Row(1014, "Warsaw") :: Row(1002, "Corte") :: Nil) + spark.sqlContext.createDataFrame(rdd, schema) + } + }) + + + "save_to_fs" should "process all operations" in + withNewSession( session => { + withSplineEnabled(session) { + val tableName = "someTable" + System.currentTimeMillis() + + val lineage: DataLineage = testData.jdbcLineage(connectionString, tableName, mode = SaveMode.Overwrite) + + val producedWrites = lineage.operations.filter(_.isInstanceOf[Write]).map(_.asInstanceOf[BatchWrite]) + producedWrites.size shouldBe 1 + val write = producedWrites.head + + write.path shouldBe "jdbc://" + connectionString + ":" + tableName + write.append shouldBe false + } + }) +} + + diff --git a/integration-tests/src/test/scala/za/co/absa/spline/SaveAsTableIntegrationTests.scala b/integration-tests/src/test/scala/za/co/absa/spline/SaveAsTableIntegrationTests.scala new file mode 100644 index 000000000..9f10c5858 --- /dev/null +++ b/integration-tests/src/test/scala/za/co/absa/spline/SaveAsTableIntegrationTests.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2017 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spline + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.{Row, SaveMode} +import org.scalatest._ +import za.co.absa.spline.common.TempDirectory +import za.co.absa.spline.fixture.{AbstractSplineFixture, AsyncSparkFixture, AsyncSplineFixture} +import za.co.absa.spline.model.DataLineage +import za.co.absa.spline.model.op.Write + +/** Contains smoke tests for basic operations. */ +//Ignored because we cannot have two AsyncSplineFixture based tests in +// one project. This will work in release 4 +@Ignore class SaveAsTableIntegrationTests + extends AsyncFlatSpec + with Matchers + with AsyncSparkFixture + with AsyncSplineFixture { + + import spark.implicits._ + + "saveAsTable" should "process all operations" in + withNewSession( session => { + withSplineEnabled(session) { + val df = Seq((1, 2), (3, 4)).toDF().agg(concat(sum('_1), min('_2)) as "forty_two") + val saveAsTable: DataLineage = df.saveAsTableLineage("someTable") + + session.sql("drop table someTable") + saveAsTable.operations.length shouldBe 3 + } + } + ) + + "save_to_fs" should "process all operations" in + withNewSession( session => { + withSplineEnabled(session) { + + val df = Seq((1, 2), (3, 4)).toDF().agg(concat(sum('_1), min('_2)) as "forty_two") + val saveToFS: DataLineage = df.writtenLineage() + + saveToFS.operations.length shouldBe 3 + } + } + ) + + "saveAsTable" should "use URIS compatible with filesystem write" in + withNewSession( session => { + withSplineEnabled(session) { + + //When I write something to table and then read it again, Spline have to use matching URI. + + val tableName = "externalTable" + val dir = TempDirectory("sparkunit", "table", false).deleteOnExit() + val path = dir.path.toString.replace("\\", "/") + val sql = "create table " + tableName + " (num int) using parquet location '" + path + "' " + spark.sql(sql) + + val schema: StructType = StructType(List(StructField("num", IntegerType, true))) + val data = spark.sparkContext.parallelize(Seq(Row(1), Row(3))) + val inputDf = spark.sqlContext.createDataFrame(data, schema) + + val writeToTable: DataLineage = inputDf.saveAsTableLineage(tableName, SaveMode.Append) + + val write1: Write = writeToTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] + val saveAsTablePath = write1.path + + AbstractSplineFixture.resetCapturedLineage + spark.sql("drop table " + tableName) + + val readFromTable: DataLineage = inputDf.writtenLineage(path, SaveMode.Overwrite) + + val writeOperation = readFromTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] + val write2 = writeOperation.path + + saveAsTablePath shouldBe write2 + } + } + ) + + "saveAsTable" should "use table path as identifier when writing to external table" in + withNewSession( session => { + withSplineEnabled(session) { + val dir = TempDirectory("sparkunit", "table", true).deleteOnExit() + val expectedPath = dir.path.toUri.toURL + val path = dir.path.toString.replace("\\", "/") + val sql = "create table e_table(num int) using parquet location '" + path + "' " + spark.sql(sql) + + val schema: StructType = StructType(List(StructField("num", IntegerType, true))) + val data = spark.sparkContext.parallelize(Seq(Row(1), Row(3))) + val df = spark.sqlContext.createDataFrame(data, schema) + + val writeToTable: DataLineage = df.saveAsTableLineage("e_table", SaveMode.Append) + + val writeOperation: Write = writeToTable.operations.filter(_.isInstanceOf[Write]).head.asInstanceOf[Write] + + new Path(writeOperation.path).toUri.toURL shouldBe expectedPath + } + } + ) + + +} diff --git a/spark-adapter-2.2/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala b/spark-adapter-2.2/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala index 2deb05140..3c5ed7c56 100644 --- a/spark-adapter-2.2/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala +++ b/spark-adapter-2.2/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala @@ -25,6 +25,8 @@ class WriteCommandParserFactoryImpl extends WriteCommandParserFactory { override def writeParser(): WriteCommandParser[LogicalPlan] = new WriteCommandParserImpl() override def saveAsTableParser(clusterUrl: Option[String]): WriteCommandParser[LogicalPlan] = new SaveAsTableCommandParserImpl(clusterUrl) + + override def jdbcParser(): WriteCommandParser[LogicalPlan] = new SaveJdbcCommandParserImpl() } class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteCommandParser[LogicalPlan] { @@ -35,18 +37,34 @@ class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteComm val identifier = op.table.storage.locationUri match { case Some(location) => location.toURL.toString - case _ => { + case _ => val codec = new URLCodec() - SaveAsTableCommand.protocolPrefix + - codec.encode(clusterUrl.getOrElse("default")) + ":" + - codec.encode(op.table.identifier.database.getOrElse("default")) +":" + + URIPrefixes.managedTablePrefix + + codec.encode(clusterUrl.getOrElse(throw new IllegalArgumentException("Unknown cluster name."))) + ":" + + codec.encode(op.table.identifier.database.getOrElse("default")) + ":" + codec.encode(op.table.identifier.table) - } } SaveAsTableCommand(identifier, op.mode, "table", op.query) } } +class SaveJdbcCommandParserImpl extends WriteCommandParser[LogicalPlan] { + override def matches(operation: LogicalPlan): Boolean = { + operation.isInstanceOf[SaveIntoDataSourceCommand] && + operation.asInstanceOf[SaveIntoDataSourceCommand].provider == "jdbc" + } + + override def asWriteCommand(operation: LogicalPlan): AbstractWriteCommand = { + operation match { + case op:SaveIntoDataSourceCommand => + val url = op.options.getOrElse("url", throw new NoSuchElementException("Cannot get name of JDBC connection string.")) + val table = op.options.getOrElse("dbtable", throw new NoSuchElementException("Cannot get name of JDBC table.")) + val identifier = s"${URIPrefixes.jdbcTablePrefix}$url:$table" + SaveJDBCCommand(identifier, op.mode, "jdbc", op.query) + } + } +} + class WriteCommandParserImpl extends WriteCommandParser[LogicalPlan] { override def matches(operation: LogicalPlan): Boolean = { diff --git a/spark-adapter-2.3/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala b/spark-adapter-2.3/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala index 7b2f5f66e..0cb824b6f 100644 --- a/spark-adapter-2.3/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala +++ b/spark-adapter-2.3/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala @@ -19,12 +19,14 @@ package za.co.absa.spline.sparkadapterapi import org.apache.commons.codec.net.URLCodec import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand -import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, SaveIntoDataSourceCommand} class WriteCommandParserFactoryImpl extends WriteCommandParserFactory { override def writeParser(): WriteCommandParser[LogicalPlan] = new WriteCommandParserImpl() override def saveAsTableParser(clusterUrl: Option[String]): WriteCommandParser[LogicalPlan] = new SaveAsTableCommandParserImpl(clusterUrl) + override def jdbcParser(): WriteCommandParser[LogicalPlan] = new SaveJdbcCommandParserImpl } class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteCommandParser[LogicalPlan] { @@ -34,17 +36,32 @@ class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteComm val op = operation.asInstanceOf[CreateDataSourceTableAsSelectCommand] val identifier = op.table.storage.locationUri match { case Some(location) => location.toURL.toString - case _ => { + case _ => val codec = new URLCodec() - SaveAsTableCommand.protocolPrefix + - codec.encode(clusterUrl.getOrElse("default")) + ":" + + URIPrefixes.managedTablePrefix + + codec.encode(clusterUrl.getOrElse(throw new IllegalArgumentException("Unknown cluster name."))) + ":" + codec.encode(op.table.identifier.database.getOrElse("default")) + ":" + codec.encode(op.table.identifier.table) - } } SaveAsTableCommand(identifier, op.mode, "table", op.query) } } +class SaveJdbcCommandParserImpl extends WriteCommandParser[LogicalPlan] { + override def matches(operation: LogicalPlan): Boolean = { + operation.isInstanceOf[SaveIntoDataSourceCommand] && + operation.asInstanceOf[SaveIntoDataSourceCommand].dataSource.isInstanceOf[JdbcRelationProvider] + } + + override def asWriteCommand(operation: LogicalPlan): AbstractWriteCommand = { + operation match { + case op:SaveIntoDataSourceCommand => + val url = op.options.getOrElse("url", throw new NoSuchElementException("Cannot get name of JDBC connection string.")) + val table = op.options.getOrElse("dbtable", throw new NoSuchElementException("Cannot get name of JDBC table.")) + val identifier = s"${URIPrefixes.jdbcTablePrefix}$url:$table" + SaveJDBCCommand(identifier, op.mode, "jdbc", op.query) + } + } +} class WriteCommandParserImpl extends WriteCommandParser[LogicalPlan] { diff --git a/spark-adapter-2.4/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala b/spark-adapter-2.4/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala index 91ad9487f..f676d11ef 100644 --- a/spark-adapter-2.4/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala +++ b/spark-adapter-2.4/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParserImpl.scala @@ -19,13 +19,16 @@ package za.co.absa.spline.sparkadapterapi import org.apache.commons.codec.net.URLCodec import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand -import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, SaveIntoDataSourceCommand} class WriteCommandParserFactoryImpl extends WriteCommandParserFactory { override def writeParser(): WriteCommandParser[LogicalPlan] = new WriteCommandParserImpl() override def saveAsTableParser(clusterUrl: Option[String]): WriteCommandParser[LogicalPlan] = new SaveAsTableCommandParserImpl(clusterUrl) + + override def jdbcParser(): WriteCommandParser[LogicalPlan] = new SaveJdbcCommandParserImpl() } class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteCommandParser[LogicalPlan] { @@ -36,19 +39,36 @@ class SaveAsTableCommandParserImpl(clusterUrl: Option[String]) extends WriteComm val identifier = op.table.storage.locationUri match { case Some(location) => location.toURL.toString - case _ => { + case _ => val codec = new URLCodec() - SaveAsTableCommand.protocolPrefix + - codec.encode(clusterUrl.getOrElse("default")) + ":" + + URIPrefixes.managedTablePrefix + + codec.encode(clusterUrl.getOrElse(throw new IllegalArgumentException("Unknown cluster name."))) + ":" + codec.encode(op.table.identifier.database.getOrElse("default")) + ":" + codec.encode(op.table.identifier.table) - } } SaveAsTableCommand(identifier, op.mode, "table", op.query) } } +class SaveJdbcCommandParserImpl extends WriteCommandParser[LogicalPlan] { + override def matches(operation: LogicalPlan): Boolean = { + operation.isInstanceOf[SaveIntoDataSourceCommand] && + operation.asInstanceOf[SaveIntoDataSourceCommand].dataSource.isInstanceOf[JdbcRelationProvider] + } + + override def asWriteCommand(operation: LogicalPlan): AbstractWriteCommand = { + operation match { + case op:SaveIntoDataSourceCommand => + val url = op.options.getOrElse("url", throw new NoSuchElementException("Cannot get name of JDBC connection string.")) + val table = op.options.getOrElse("dbtable", throw new NoSuchElementException("Cannot get name of JDBC table.")) + val identifier = s"${URIPrefixes.jdbcTablePrefix}$url:$table" + SaveJDBCCommand(identifier, op.mode, "jdbc", op.query) + } + } +} + + class WriteCommandParserImpl extends WriteCommandParser[LogicalPlan] { override def matches(operation: LogicalPlan): Boolean = { diff --git a/spark-adapter-api/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParser.scala b/spark-adapter-api/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParser.scala index 8acc4b315..16aa8ac82 100644 --- a/spark-adapter-api/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParser.scala +++ b/spark-adapter-api/src/main/scala/za/co/absa/spline/sparkadapterapi/WriteCommandParser.scala @@ -35,6 +35,7 @@ abstract class WriteCommandParser[T <: LogicalPlan](implicit tag: ClassTag[T]) { abstract class WriteCommandParserFactory { def writeParser(): WriteCommandParser[LogicalPlan] def saveAsTableParser(clusterUrl: Option[String]) : WriteCommandParser[LogicalPlan] + def jdbcParser(): WriteCommandParser[LogicalPlan] } object WriteCommandParserFactory extends AdapterFactory[WriteCommandParserFactory] @@ -47,9 +48,10 @@ case class WriteCommand(path:String, mode: SaveMode, format: String, query: Logi case class SaveAsTableCommand(tableName:String, mode: SaveMode, format: String, query: LogicalPlan) extends AbstractWriteCommand -object SaveAsTableCommand { +object URIPrefixes { //prefix used in identifiers for saveAsTable writes - val protocolPrefix = "table://" + val managedTablePrefix = "table://" + val jdbcTablePrefix = "jdbc://" } case class SaveJDBCCommand(tableName:String, mode: SaveMode, format: String, query: LogicalPlan) extends AbstractWriteCommand \ No newline at end of file diff --git a/test-commons/src/main/scala/za/co/absa/spline/fixture/SparkFixture.scala b/test-commons/src/main/scala/za/co/absa/spline/fixture/SparkFixture.scala index 720eb0666..16c5fb988 100644 --- a/test-commons/src/main/scala/za/co/absa/spline/fixture/SparkFixture.scala +++ b/test-commons/src/main/scala/za/co/absa/spline/fixture/SparkFixture.scala @@ -30,7 +30,7 @@ trait AbstractSparkFixture { protected implicit lazy val sparkContext: SparkContext = spark.sparkContext protected implicit lazy val sqlContext: SQLContext = spark.sqlContext - def withNewSession[T >: AnyRef](testBody: SparkSession => T): T = { + def withNewSession[T](testBody: SparkSession => T): T = { testBody(spark.newSession) } }