diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 30f5fced5a8bf..c8dcf8b7cf92f 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -36,6 +36,15 @@ private[spark] object RpcUtils { rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } + def makeDriverRef( + name: String, + driverHost: String, + driverPort: Int, + rpcEnv: RpcEnv): RpcEndpointRef = { + Utils.checkHost(driverHost) + rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) + } + /** Returns the default Spark timeout to use for RPC ask operations. */ def askRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq(RPC_ASK_TIMEOUT.key, NETWORK_TIMEOUT.key), "120s") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeMode.java new file mode 100644 index 0000000000000..da2127d44b4e7 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeMode.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.read.InputPartition; + +/** + * A {@link MicroBatchStream} for streaming queries with real time mode. + * + */ +@Evolving +public interface SupportsRealTimeMode { + /** + * Returns a list of {@link InputPartition input partitions} given the start offset. Each + * {@link InputPartition} represents a data split that can be processed by one Spark task. The + * number of input partitions returned here is the same as the number of RDD partitions + * this scan outputs. + */ + InputPartition[] planInputPartitions(Offset start); + + /** + * Merge partitioned offsets coming from {@link SupportsRealTimeMode} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * Called during logical planning to inform the source if it's in real time mode + */ + default void prepareForRealTimeMode() {} +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeRead.java new file mode 100644 index 0000000000000..5bed945432c95 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsRealTimeRead.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming; + +import java.io.IOException; +import java.util.Optional; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.read.PartitionReader; + +/** + * A variation on {@link PartitionReader} for use with low latency streaming processing. + * + */ +@Evolving +public interface SupportsRealTimeRead extends PartitionReader { + + /** + * A class to represent the status of a record to be read as the return type of nextWithTimeout. + * It contains whether the next record is available and the ingestion time of the record + * if the source connector provided relevant info. A list of source connector that has ingestion + * time is listed below: + * - Kafka when the record timestamp type is LogAppendTime + * - Kinesis has ApproximateArrivalTimestamp + */ + class RecordStatus { + private final boolean hasRecord; + private final Optional recArrivalTime; + + private RecordStatus(boolean hasRecord, Optional recArrivalTime) { + this.hasRecord = hasRecord; + this.recArrivalTime = recArrivalTime; + } + + // Public factory methods to control instance creation + public static RecordStatus newStatusWithoutArrivalTime(boolean hasRecord) { + return new RecordStatus(hasRecord, Optional.empty()); + } + + public static RecordStatus newStatusWithArrivalTimeMs(Long recArrivalTime) { + return new RecordStatus(true, Optional.of(recArrivalTime)); + } + + public boolean hasRecord() { + return hasRecord; + } + + public Optional recArrivalTime() { + return recArrivalTime; + } + } + + /** + * Get the offset of the next record, or the start offset if no records have been read. + *

+ * The execution engine will call this method along with get() to keep track of the current + * offset. When a task ends, the offset in each partition will be passed back to the driver. + * They will be used as the start offsets of the next batch. + */ + PartitionOffset getOffset(); + + /** + * Alternative function to be called than next(), that proceed to the next record. The different + * from next() is that, if there is no more records, the call needs to keep waiting until + * the timeout. + * @param timeout if no result is available after this timeout (milliseconds), return + * @return {@link RecordStatus} describing whether a record is available and its arrival time + * @throws IOException + */ + RecordStatus nextWithTimeout(Long timeout) throws IOException; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionOffsetWithIndex.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionOffsetWithIndex.scala new file mode 100644 index 0000000000000..97f8e698cb9e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionOffsetWithIndex.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal.connector; + +import java.io.Serializable; + +import org.apache.spark.sql.connector.read.streaming.PartitionOffset; + +/** + * Internal class for real time mode to pass partition offset from executors to the driver. + */ +private[sql] case class PartitionOffsetWithIndex(index: Long, partitionOffset: PartitionOffset) + extends Serializable; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala new file mode 100644 index 0000000000000..32f3275d9f253 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import scala.jdk.OptionConverters._ + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.connector.read.{ + InputPartition, + PartitionReader, + PartitionReaderFactory, + Scan, + SupportsReportStatistics +} +import org.apache.spark.sql.connector.read.streaming.{ + MicroBatchStream, + Offset, + PartitionOffset, + SupportsRealTimeMode, + SupportsRealTimeRead +} +import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus +import org.apache.spark.sql.internal.connector.PartitionOffsetWithIndex +import org.apache.spark.util.{Clock, CollectionAccumulator, ManualClock, SystemClock} +import org.apache.spark.util.ArrayImplicits._ + +/* The singleton object to control the time in testing */ +object LowLatencyClock { + private var clock: Clock = new SystemClock + + def getClock: Clock = clock + + def getTimeMillis(): Long = { + clock.getTimeMillis() + } + + def waitTillTime(targetTime: Long): Unit = { + clock.waitTillTime(targetTime) + } + + /* Below methods are only for testing. */ + def setClock(inputClock: Clock): Unit = { + clock = inputClock + } +} + +/** + * A wrap reader that turns a Partition Reader extending SupportsRealTimeRead to a + * normal PartitionReader and follow the task termination time `lowLatencyEndTime`, and + * report end offsets in the end to `endOffsets`. + */ +case class LowLatencyReaderWrap( + reader: SupportsRealTimeRead[InternalRow], + lowLatencyEndTime: Long, + endOffsets: CollectionAccumulator[PartitionOffsetWithIndex]) + extends PartitionReader[InternalRow] { + + override def next(): Boolean = { + val curTime = LowLatencyClock.getTimeMillis() + val ret = if (curTime >= lowLatencyEndTime) { + RecordStatus.newStatusWithoutArrivalTime(false) + } else { + reader.nextWithTimeout(lowLatencyEndTime - curTime) + } + + if (!ret.hasRecord) { + // The way of using TaskContext.get().partitionId() to map to a partition + // may be fragile. + endOffsets.add( + new PartitionOffsetWithIndex(TaskContext.get().partitionId(), reader.getOffset) + ) + } + ret.hasRecord + } + + override def get(): InternalRow = { + reader.get() + } + + override def close(): Unit = {} +} + +/** + * Wrapper factory that creates LowLatencyReaderWrap from reader as SupportsRealTimeRead + */ +case class LowLatencyReaderFactoryWrap( + partitionReaderFactory: PartitionReaderFactory, + lowLatencyEndTime: Long, + endOffsets: CollectionAccumulator[PartitionOffsetWithIndex]) + extends PartitionReaderFactory + with Logging { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val rowReader = partitionReaderFactory.createReader(partition) + assert(rowReader.isInstanceOf[SupportsRealTimeRead[InternalRow]]) + logInfo( + log"Creating low latency PartitionReader, stopping at " + + log"${MDC(LogKeys.TO_TIME, lowLatencyEndTime)}" + ) + LowLatencyReaderWrap( + rowReader.asInstanceOf[SupportsRealTimeRead[InternalRow]], + lowLatencyEndTime, + endOffsets + ) + } +} + +/** + * Physical plan node for scanning a micro-batch of data from a data source. + */ +case class RealTimeStreamScanExec( + output: Seq[Attribute], + @transient scan: Scan, + @transient stream: MicroBatchStream, + @transient start: Offset, + batchDurationMs: Long) + extends DataSourceV2ScanExecBase { + + override def keyGroupedPartitioning: Option[Seq[Expression]] = None + + override def ordering: Option[Seq[SortOrder]] = None + + val endOffsetsAccumulator: CollectionAccumulator[PartitionOffsetWithIndex] = { + assert(stream.isInstanceOf[SupportsRealTimeMode]) + SparkContext.getActive.map(_.collectionAccumulator[PartitionOffsetWithIndex]).get + } + + // There is a rule for the case of TransformWithState + Initial state in realtime mode + // that we overwrite the batch duration to 0 for the first batch. We include + // batchDurationMs in the equals/hashCode methods for the rule to take effect, since + // rule executor will determine the effectiveness of the rule through fast equal. + override def equals(other: Any): Boolean = other match { + case other: RealTimeStreamScanExec => + this.stream == other.stream && + this.batchDurationMs == other.batchDurationMs + case _ => false + } + + override def hashCode(): Int = Objects.hashCode(stream, batchDurationMs) + + override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() + + override lazy val inputPartitions: Seq[InputPartition] = { + val lls = stream.asInstanceOf[SupportsRealTimeMode] + assert(lls != null) + lls.planInputPartitions(start).toImmutableArraySeq + } + + override def simpleString(maxFields: Int): String = + s"${super.simpleString(maxFields)} [batchDurationMs=${batchDurationMs}ms]" + + override lazy val inputRDD: RDD[InternalRow] = { + // For RTM task monitoring + sparkContext.setLocalProperty("rtmBatchDurationMs", batchDurationMs.toString) + + val inputRDD = new DataSourceRDD( + sparkContext, + partitions, + LowLatencyReaderFactoryWrap( + readerFactory, + LowLatencyClock.getTimeMillis() + batchDurationMs, + endOffsetsAccumulator + ), + supportsColumnar, + customMetrics + ) + postDriverMetrics() + inputRDD + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index 68eb3cc7688d2..5631444145c6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -158,6 +158,20 @@ case class MemoryStream[A : Encoder]( id: Int, sqlContext: SQLContext, numPartitions: Option[Int] = None) + extends MemoryStreamBaseClass[A]( + id, sqlContext, numPartitions = numPartitions) + +/** + * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] + * is intended for use in unit tests as it can only replay data when the object is still + * available. + * + * If numPartitions is provided, the rows will be redistributed to the given number of partitions. + */ +abstract class MemoryStreamBaseClass[A: Encoder]( + id: Int, + sqlContext: SQLContext, + numPartitions: Option[Int] = None) extends MemoryStreamBase[A](sqlContext) with MicroBatchStream with SupportsTriggerAvailableNow @@ -298,7 +312,6 @@ case class MemoryStream[A : Encoder]( } } - class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition object MemoryStreamReaderFactory extends PartitionReaderFactory { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala new file mode 100644 index 0000000000000..260695b0f79ce --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ListBuffer + +import org.json4s.{Formats, NoTypeHints} +import org.json4s.jackson.Serialization + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.connector.read.streaming.{ + Offset => OffsetV2, + PartitionOffset, + ReadLimit, + SupportsRealTimeMode, + SupportsRealTimeRead +} +import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus +import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock +import org.apache.spark.sql.execution.streaming.runtime._ +import org.apache.spark.util.{Clock, RpcUtils} + +/** + * A low latency memory source from memory, only for unit test purpose. + * This class is very similar to ContinuousMemoryStream, except that it implements the + * interface of SupportsRealTimeMode, rather than ContinuousStream + * The overall strategy here is: + * * LowLatencyMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * LowLatencyMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. + */ +class LowLatencyMemoryStream[A: Encoder]( + id: Int, + sqlContext: SQLContext, + numPartitions: Int = 2, + clock: Clock = LowLatencyClock.getClock) + extends MemoryStreamBaseClass[A](0, sqlContext) + with SupportsRealTimeMode { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + // LowLatencyReader implementation + + @GuardedBy("this") + private val records = Seq.fill(numPartitions)(new ListBuffer[UnsafeRow]) + + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) + @volatile private var endpointRef: RpcEndpointRef = _ + + override def addData(data: IterableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.iterator.to(Seq).zipWithIndex.map { + case (item, index) => + records(index % numPartitions) += toRow(item).copy().asInstanceOf[UnsafeRow] + } + + // The new target offset is the offset where all records in all partitions have been processed. + LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) + } + + def addData(partitionId: Int, data: IterableOnce[A]): Offset = synchronized { + require( + partitionId >= 0 && partitionId < numPartitions, + s"Partition ID $partitionId is out of bounds for $numPartitions partitions." + ) + + // Add data to the specified partition. + records(partitionId) ++= data.iterator.map(item => toRow(item).copy().asInstanceOf[UnsafeRow]) + + // The new target offset is the offset where all records in all partitions have been processed. + LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) + } + + override def initialOffset(): OffsetV2 = { + LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) + } + + override def latestOffset(startOffset: OffsetV2, limit: ReadLimit): OffsetV2 = { + LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) + } + + override def deserializeOffset(json: String): LowLatencyMemoryStreamOffset = { + LowLatencyMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): LowLatencyMemoryStreamOffset = { + LowLatencyMemoryStreamOffset( + offsets.map { + case ContinuousRecordPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def planInputPartitions(start: OffsetV2): Array[InputPartition] = { + val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset] + synchronized { + val endpointName = s"ContinuousRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + LowLatencyMemoryStreamInputPartition( + endpointName, + endpointRef.address, + part, + index, + Int.MaxValue + ) + }.toArray + } + } + + override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = { + val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset] + val endOffset = end.asInstanceOf[LowLatencyMemoryStreamOffset] + synchronized { + val endpointName = s"ContinuousRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + LowLatencyMemoryStreamInputPartition( + endpointName, + endpointRef.address, + part, + index, + endOffset.partitionNums(part) + ) + }.toArray + } + } + + override def createReaderFactory(): PartitionReaderFactory = { + new LowLatencyMemoryStreamReaderFactory(clock) + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: OffsetV2): Unit = {} + + override def reset(): Unit = { + super.reset() + records.foreach(_.clear()) + } +} + +object LowLatencyMemoryStream { + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def apply[A: Encoder](numPartitions: Int)( + implicit + sqlContext: SQLContext): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A]( + memoryStreamId.getAndIncrement(), + sqlContext, + numPartitions = numPartitions + ) + + def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) +} + +/** + * An input partition for LowLatency memory stream. + */ +case class LowLatencyMemoryStreamInputPartition( + driverEndpointName: String, + driverEndpointAddress: RpcAddress, + partition: Int, + startOffset: Int, + endOffset: Int) + extends InputPartition + +class LowLatencyMemoryStreamReaderFactory(clock: Clock) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[LowLatencyMemoryStreamInputPartition] + new LowLatencyMemoryStreamPartitionReader( + p.driverEndpointName, + p.driverEndpointAddress, + p.partition, + p.startOffset, + p.endOffset, + clock + ) + } +} + +/** + * An input partition reader for LowLatency memory stream. + * + * Polls the driver endpoint for new records. + */ +class LowLatencyMemoryStreamPartitionReader( + driverEndpointName: String, + driverEndpointAddress: RpcAddress, + partition: Int, + startOffset: Int, + endOffset: Int, + clock: Clock) + extends SupportsRealTimeRead[InternalRow] { + // ES-1365239: Avoid tracking the ref, given that we create a new one for each partition reader + // because a new driver endpoint is created for each LowLatencyMemoryStream. If we track the ref, + // we can end up with a lot of refs (1000s) if a test suite has so many test cases and can lead to + // issues with the tracking array. Causing the test suite to be flaky. + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + driverEndpointAddress.host, + driverEndpointAddress.port, + SparkEnv.get.rpcEnv + ) + + private var currentOffset = startOffset + private var current: Option[InternalRow] = None + + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by LowLatency + // processing. We hope that some unit test will end up instantiating a LowLatency memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def nextWithTimeout(timeout: java.lang.Long): RecordStatus = { + val startReadTime = clock.nanoTime() + var elapsedTimeMs = 0L + current = getRecord + while (current.isEmpty) { + val POLL_TIME = 10L + if (elapsedTimeMs >= timeout) { + return RecordStatus.newStatusWithoutArrivalTime(false) + } + Thread.sleep(POLL_TIME) + current = getRecord + elapsedTimeMs = (clock.nanoTime() - startReadTime) / 1000 / 1000 + } + currentOffset += 1 + RecordStatus.newStatusWithoutArrivalTime(true) + } + + override def next(): Boolean = { + current = getRecord + if (current.isDefined) { + currentOffset += 1 + true + } else { + false + } + } + + override def get(): InternalRow = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousRecordPartitionOffset = + ContinuousRecordPartitionOffset(partition, currentOffset) + + private def getRecord: Option[InternalRow] = { + if (currentOffset >= endOffset) { + return None + } + endpoint.askSync[Option[InternalRow]]( + GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset)) + ) + } +} + +case class LowLatencyMemoryStreamOffset(partitionNums: Map[Int, Int]) extends Offset { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +}