Skip to content

Commit

Permalink
Refactored code, fixed bugs, added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Nov 2, 2015
1 parent bd9cd94 commit be8cffc
Show file tree
Hide file tree
Showing 9 changed files with 690 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ private[streaming] class StateImpl[S] extends State[S] {

def remove(): Unit = {
require(!timingOut, "Cannot remove the state that is timing out")
require(!removed, "Cannot remove the state that has already been removed")
removed = true
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.dstream

import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}



abstract class EmittedRecordsDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
ssc: StreamingContext) extends DStream[T](ssc) {

def stateSnapshots(): DStream[(K, S)]
}


private[streaming] class EmittedRecordsDStreamImpl[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
trackStateDStream: TrackStateDStream[K, V, S, T])
extends EmittedRecordsDStream[K, V, S, T](trackStateDStream.context) {

override def slideDuration: Duration = trackStateDStream.slideDuration

override def dependencies: List[DStream[_]] = List(trackStateDStream)

override def compute(validTime: Time): Option[RDD[T]] = {
trackStateDStream.getOrCompute(validTime).map { _.flatMap[T] { _.emittedRecords } }
}

def stateSnapshots(): DStream[(K, S)] = {
trackStateDStream.flatMap[(K, S)] { _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
}
}

/**
* A DStream that allows per-key state to be maintains, and arbitrary records to be generated
* based on updates to the state.
*
* @param parent Parent (key, value) stream that is the source
* @param spec Specifications of the trackStateByKey operation
* @tparam K Key type
* @tparam V Value type
* @tparam S Type of the state maintained
* @tparam T Type of the eiitted records
*/
private[streaming] class TrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T])
extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) {

persist(StorageLevel.MEMORY_ONLY)

private val partitioner = spec.getPartitioner().getOrElse(
new HashPartitioner(ssc.sc.defaultParallelism))

private val trackingFunction = spec.getFunction()

override def slideDuration: Duration = parent.slideDuration

override def dependencies: List[DStream[_]] = List(parent)

override val mustCheckpoint = true

/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = {
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, V, S, T](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime.milliseconds
)
}
val newDataRDD = parent.getOrCompute(validTime).get
val partitionedDataRDD = newDataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}

Some(new TrackStateRDD(prevStateRDD, partitionedDataRDD,
trackingFunction, validTime.milliseconds, timeoutThresholdTime))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,13 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
)
}

def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): DStream[T] = {
new TrackeStateDStream[K, V, S, T](
self,
spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]]
).mapPartitions { partitionIter =>
partitionIter.flatMap { _.emittedRecords }
}
def trackStateByKey[S: ClassTag, T: ClassTag](spec: TrackStateSpec[K, V, S, T]): EmittedRecordsDStream[K, V, S, T] = {
new EmittedRecordsDStreamImpl[K, V, S, T](
new TrackStateDStream[K, V, S, T](
self,
spec.asInstanceOf[TrackStateSpecImpl[K, V, S, T]]
)
)
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,45 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.dstream
package org.apache.spark.streaming.rdd

import java.io.{IOException, ObjectOutputStream}
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.util.StateMap
import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
import org.apache.spark.streaming.{StateImpl, State}
import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
import org.apache.spark.util.Utils
import org.apache.spark._




private[streaming] case class TrackStateRDDRecord[K, S, T](
var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) {
/*
private def writeObject(outputStream: ObjectOutputStream): Unit = {
outputStream.writeObject(stateMap)
outputStream.writeInt(emittedRecords.size)
val iterator = emittedRecords.iterator
while(iterator.hasNext) {
outputStream.writeObject(iterator.next)
}
}
private[streaming] case class TrackStateRDDRecord[K: ClassTag, S: ClassTag, T: ClassTag](
stateMap: StateMap[K, S], emittedRecords: Seq[T])
private def readObject(inputStream: ObjectInputStream): Unit = {
stateMap = inputStream.readObject().asInstanceOf[StateMap[K, S]]
val numEmittedRecords = inputStream.readInt()
val array = new Array[T](numEmittedRecords)
var i = 0
while(i < numEmittedRecords) {
array(i) = inputStream.readObject().asInstanceOf[T]
}
emittedRecords = array.toSeq
}*/
}


private[streaming] class TrackStateRDDPartition(
idx: Int,
@transient private var prevStateRDD: RDD[_],
@transient private var partitionedDataRDD: RDD[_]) extends Partition {

private[dstream] var previousSessionRDDPartition: Partition = null
private[dstream] var partitionedDataRDDPartition: Partition = null
private[rdd] var previousSessionRDDPartition: Partition = null
private[rdd] var partitionedDataRDDPartition: Partition = null

override def index: Int = idx
override def hashCode(): Int = idx
Expand All @@ -53,22 +59,28 @@ private[streaming] class TrackStateRDDPartition(
}
}




/**
* RDD storing the keyed-state of trackStateByKey and corresponding emitted records.
* Each partition of this RDD has a single record that contains a StateMap storing
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
_sc: SparkContext,
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (K, Option[V], State[S]) => Option[T],
currentTime: Long, timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
_sc,
partitionedDataRDD.sparkContext,
List(
new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {

@volatile private var doFullScan = false

require(partitionedDataRDD.partitioner.nonEmpty)
require(prevStateRDD.partitioner.nonEmpty)
require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)

override val partitioner = prevStateRDD.partitioner
Expand All @@ -86,11 +98,13 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
if (!prevStateRDDIterator.hasNext) {
throw new SparkException(s"Could not find state map in previous state RDD")

val newStateMap = if (prevStateRDDIterator.hasNext) {
prevStateRDDIterator.next().stateMap.copy()
} else {
new EmptyStateMap[K, S]()
}

val newStateMap = prevStateRDDIterator.next().stateMap.copy()
val emittedRecords = new ArrayBuffer[T]

val wrappedState = new StateImpl[S]()
Expand Down Expand Up @@ -132,61 +146,32 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
}

private[streaming] object TrackStateRDD {
def createFromPairRDD[K: ClassTag, S: ClassTag, T: ClassTag](

def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
updateTime: Long): RDD[TrackStateRDDRecord[K, S, T]] = {
updateTime: Long): TrackStateRDD[K, V, S, T] = {

val createRecord = (iterator: Iterator[(K, S)]) => {
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime) }
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
}
pairRDD.partitionBy(partitioner).mapPartitions[TrackStateRDDRecord[K, S, T]](
createRecord, true)
}
}


// -----------------------------------------------
// ---------------- SessionDStream ---------------
// -----------------------------------------------


private[streaming] class TrackeStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
parent: DStream[(K, V)], spec: TrackStateSpecImpl[K, V, S, T])
extends DStream[TrackStateRDDRecord[K, S, T]](parent.context) {
}, preservesPartitioning = true)

persist(StorageLevel.MEMORY_ONLY)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

private val partitioner = spec.getPartitioner().getOrElse(
new HashPartitioner(ssc.sc.defaultParallelism))
val noOpFunc = (key: K, value: Option[V], state: State[S]) => None

private val trackingFunction = spec.getFunction()

override def slideDuration: Duration = parent.slideDuration

override def dependencies: List[DStream[_]] = List(parent)

override val mustCheckpoint = true

/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, T]]] = {
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, S, T](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime.milliseconds
)
}
val newDataRDD = parent.getOrCompute(validTime).get
val partitionedDataRDD = newDataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}

Some(new TrackStateRDD(
ssc.sparkContext, prevStateRDD, partitionedDataRDD,
trackingFunction, validTime.milliseconds, timeoutThresholdTime))
private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
override protected def getPartitions: Array[Partition] = parent.partitions
override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
parent.compute(partition, context).flatMap { _.emittedRecords }
}
}

private[streaming] class StateSnapshotRDD[K: ClassTag, V: ClassTag]
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa

/** Implementation of StateMap based on Spark's OpenHashMap */
private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
@transient @volatile private var parentStateMap: StateMap[K, S],
@transient @volatile var parentStateMap: StateMap[K, S],
initialCapacity: Int = 64,
deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
) extends StateMap[K, S] { self =>
Expand All @@ -99,8 +99,12 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
/** Get the session data if it exists */
override def get(key: K): Option[S] = {
val stateInfo = deltaMap(key)
if (stateInfo != null && !stateInfo.deleted) {
Some(stateInfo.data)
if (stateInfo != null) {
if (!stateInfo.deleted) {
Some(stateInfo.data)
} else {
None
}
} else {
parentStateMap.get(key)
}
Expand Down Expand Up @@ -185,6 +189,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "")
}

override def toString(): String = {
s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]"
}

private def writeObject(outputStream: ObjectOutputStream): Unit = {

outputStream.defaultWriteObject()
Expand Down
Loading

0 comments on commit be8cffc

Please sign in to comment.