Skip to content

Commit

Permalink
[SPARK-24991][SQL] use InternalRow in DataSourceWriter
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

A follow up of #21118

Since we use `InternalRow` in the read API of data source v2, we should do the same thing for the write API.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #21948 from cloud-fan/row-write.
  • Loading branch information
cloud-fan committed Aug 6, 2018
1 parent 327bb30 commit ac527b5
Show file tree
Hide file tree
Showing 17 changed files with 73 additions and 230 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
*/
class KafkaStreamWriter(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends StreamWriter with SupportsWriteInternalRow {
extends StreamWriter {

validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)

override def createInternalRowWriterFactory(): KafkaStreamWriterFactory =
override def createWriterFactory(): KafkaStreamWriterFactory =
KafkaStreamWriterFactory(topic, producerParams, schema)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.sources.v2.writer;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.StreamWriteSupport;
import org.apache.spark.sql.sources.v2.WriteSupport;
Expand Down Expand Up @@ -61,7 +61,7 @@ public interface DataSourceWriter {
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
DataWriterFactory<Row> createWriterFactory();
DataWriterFactory<InternalRow> createWriterFactory();

/**
* Returns whether Spark should use the commit coordinator to ensure that at most one task for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
* successfully, and have a way to revert committed data writers without the commit message, because
* Spark only accepts the commit message that arrives first and ignore others.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
* source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
* that mix in {@link SupportsWriteInternalRow}.
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
*/
@InterfaceStability.Evolving
public interface DataWriter<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
public interface DataWriterFactory<T> extends Serializable {

/**
* Returns a data writer to do the actual writing work.
* Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data
* object instance when sending data to the data writer, for better performance. Data writers
* are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a
* list.
*
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
override def output: Seq[Attribute] = Nil

override protected def doExecute(): RDD[InternalRow] = {
val writeTask = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val writeTask = writer.createWriterFactory()
val useCommitCoordinator = writer.useCommitCoordinator
val rdd = query.execute()
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
Expand Down Expand Up @@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging {
})
}
}

class InternalRowDataWriterFactory(
rowWriterFactory: DataWriterFactory[Row],
schema: StructType) extends DataWriterFactory[InternalRow] {

override def createDataWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new InternalRowDataWriter(
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
RowEncoder.apply(schema).resolveAndBind())
}
}

class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
extends DataWriter[InternalRow] {

override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))

override def commit(): WriterCommitMessage = rowWriter.commit()

override def abort(): Unit = rowWriter.abort()
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.{Clock, Utils}

Expand Down Expand Up @@ -498,12 +497,7 @@ class MicroBatchExecution(
newAttributePlan.schema,
outputMode,
new DataSourceOptions(extraOptions.asJava))
if (writer.isInstanceOf[SupportsWriteInternalRow]) {
WriteToDataSourceV2(
new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan)
} else {
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
}
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@

package org.apache.spark.sql.execution.streaming.continuous

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
import org.apache.spark.util.Utils

/**
Expand All @@ -47,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
SparkEnv.get)
EpochTracker.initializeCurrentEpoch(
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)

while (!context.isInterrupted() && !context.isCompleted()) {
var dataWriter: DataWriter[InternalRow] = null
// write the data and commit this writer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous

import scala.util.control.NonFatal

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.util.Utils

/**
* The physical plan for writing data into a continuous processing [[StreamWriter]].
Expand All @@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
override def output: Seq[Attribute] = Nil

override protected def doExecute(): RDD[InternalRow] = {
val writerFactory = writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}

val writerFactory = writer.createWriterFactory()
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)

logInfo(s"Start processing data source writer: $writer. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.execution.streaming.sources

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
Expand All @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
assert(SparkSession.getActiveSession.isDefined)
protected val spark = SparkSession.getActiveSession.get

def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
Expand All @@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
println(printMessage)
println("-------------------------------------------")
// scalastyle:off println
spark
.createDataFrame(rows.toList.asJava, schema)
Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
.show(numRowsToShow, isTruncated)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
import org.apache.spark.sql.{ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.python.PythonForeachWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
Expand All @@ -46,11 +46,11 @@ case class ForeachWriterProvider[T](
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
new StreamWriter with SupportsWriteInternalRow {
new StreamWriter {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
override def createWriterFactory(): DataWriterFactory[InternalRow] = {
val rowConverter: InternalRow => T = converter match {
case Left(enc) =>
val boundEnc = enc.resolveAndBind(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter

/**
Expand All @@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr

override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
}

class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)
extends DataSourceWriter with SupportsWriteInternalRow {
override def commit(messages: Array[WriterCommitMessage]): Unit = {
writer.commit(batchId, messages)
}

override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages)

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] =
writer match {
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
case _ => throw new IllegalStateException(
"InternalRowMicroBatchWriter should only be created with base writer support")
}
override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory()
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}

/**
Expand All @@ -30,11 +30,11 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
* Note that, because it sends all rows to the driver, this factory will generally be unsuitable
* for production-quality sinks. It's intended for use in tests.
*/
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] {
override def createDataWriter(
partitionId: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
epochId: Long): DataWriter[InternalRow] = {
new PackedRowDataWriter()
}
}
Expand All @@ -43,15 +43,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] {
* Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most
* recent interval.
*/
case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage
case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage

/**
* A simple [[DataWriter]] that just sends all the rows it's received as a commit message.
*/
class PackedRowDataWriter() extends DataWriter[Row] with Logging {
private val data = mutable.Buffer[Row]()
class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging {
private val data = mutable.Buffer[InternalRow]()

override def write(row: Row): Unit = data.append(row)
// Spark reuses the same `InternalRow` instance, here we copy it before buffer it.
override def write(row: InternalRow): Unit = data.append(row.copy())

override def commit(): PackedRowCommitMessage = {
val msg = PackedRowCommitMessage(data.toArray)
Expand Down
Loading

0 comments on commit ac527b5

Please sign in to comment.