/
memory.scala
297 lines (237 loc) · 10.2 KB
/
memory.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming
import java.util
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
}
/**
* A base class for memory stream implementations. Supports adding data and resetting.
*/
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream {
val encoder = encoderFor[A]
protected val attributes = encoder.schema.toAttributes
protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer()
def toDS(): Dataset[A] = {
Dataset[A](sqlContext.sparkSession, logicalPlan)
}
def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}
def addData(data: A*): OffsetV2 = {
addData(data.toTraversable)
}
def addData(data: TraversableOnce[A]): OffsetV2
def fullSchema(): StructType = encoder.schema
protected val logicalPlan: LogicalPlan = {
StreamingRelationV2(
Some(MemoryStreamTableProvider),
"memory",
new MemoryStreamTable(this),
CaseInsensitiveStringMap.empty(),
attributes,
None)
}
override def initialOffset(): OffsetV2 = {
throw new IllegalStateException("should not be called.")
}
override def deserializeOffset(json: String): OffsetV2 = {
throw new IllegalStateException("should not be called.")
}
override def commit(end: OffsetV2): Unit = {
throw new IllegalStateException("should not be called.")
}
}
// This class is used to indicate the memory stream data source. We don't actually use it, as
// memory stream is for test only and we never look it up by name.
object MemoryStreamTableProvider extends SimpleTableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new IllegalStateException("MemoryStreamTableProvider should not be used.")
}
}
class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead {
override def name(): String = "MemoryStreamDataSource"
override def schema(): StructType = stream.fullSchema()
override def capabilities(): util.Set[TableCapability] = {
Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MemoryStreamScanBuilder(stream)
}
}
class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder with Scan {
override def build(): Scan = this
override def description(): String = "MemoryStreamDataSource"
override def readSchema(): StructType = stream.fullSchema()
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
stream.asInstanceOf[MicroBatchStream]
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
stream.asInstanceOf[ContinuousStream]
}
}
/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
* is intended for use in unit tests as it can only replay data when the object is still
* available.
*
* If numPartitions is provided, the rows will be redistributed to the given number of partitions.
*/
case class MemoryStream[A : Encoder](
id: Int,
sqlContext: SQLContext,
numPartitions: Option[Int] = None)
extends MemoryStreamBase[A](sqlContext) with MicroBatchStream with Logging {
protected val output = logicalPlan.output
/**
* All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[Array[UnsafeRow]]
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
@GuardedBy("this")
protected var startOffset = new LongOffset(-1)
@GuardedBy("this")
private var endOffset = new LongOffset(-1)
/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
*/
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
def addData(data: TraversableOnce[A]): Offset = {
val objects = data.toSeq
val rows = objects.iterator.map(d => toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
batches += rows
currentOffset
}
}
override def toString: String = {
s"MemoryStream[${truncatedString(output, ",", SQLConf.get.maxToStringFields)}]"
}
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
override def initialOffset: OffsetV2 = LongOffset(-1)
override def latestOffset(): OffsetV2 = {
if (currentOffset.offset == -1) null else currentOffset
}
override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = {
val startOffset = start.asInstanceOf[LongOffset]
val endOffset = end.asInstanceOf[LongOffset]
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1
// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}
logDebug(generateDebugString(newBlocks.flatten.toSeq, startOrdinal, endOrdinal))
numPartitions match {
case Some(numParts) =>
// When the number of partition is provided, we redistribute the rows into
// the given number of partition, via round-robin manner.
val inputRows = newBlocks.flatten.toArray
(0 until numParts).map { newPartIdx =>
val records = inputRows.zipWithIndex.filter { case (_, idx) =>
idx % numParts == newPartIdx
}.map(_._1)
new MemoryStreamInputPartition(records)
}.toArray
case _ =>
newBlocks.map { block =>
new MemoryStreamInputPartition(block)
}.toArray
}
}
}
override def createReaderFactory(): PartitionReaderFactory = {
MemoryStreamReaderFactory
}
private def generateDebugString(
rows: Seq[UnsafeRow],
startOrdinal: Int,
endOrdinal: Int): String = {
val fromRow = encoder.resolveAndBind().createDeserializer()
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${rows.map(row => fromRow(row)).mkString(", ")}"
}
override def commit(end: OffsetV2): Unit = synchronized {
val newOffset = end.asInstanceOf[LongOffset]
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
if (offsetDiff < 0) {
sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
}
batches.trimStart(offsetDiff)
lastOffsetCommitted = newOffset
}
override def stop(): Unit = {}
def reset(): Unit = synchronized {
batches.clear()
startOffset = LongOffset(-1)
endOffset = LongOffset(-1)
currentOffset = new LongOffset(-1)
lastOffsetCommitted = new LongOffset(-1)
}
}
class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition
object MemoryStreamReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val records = partition.asInstanceOf[MemoryStreamInputPartition].records
new PartitionReader[InternalRow] {
private var currentIndex = -1
override def next(): Boolean = {
// Return true as long as the new index is in the array.
currentIndex += 1
currentIndex < records.length
}
override def get(): UnsafeRow = records(currentIndex)
override def close(): Unit = {}
}
}
}