Skip to content

Commit

Permalink
[SPARK-3054][STREAMING] Add unit tests for Spark Sink.
Browse files Browse the repository at this point in the history
This patch adds unit tests for Spark Sink.

It also removes the private[flume] for Spark Sink,
since the sink is instantiated from Flume configuration (looks like this is ignored by reflection which is used by
Flume, but we should still remove it anyway).
  • Loading branch information
harishreedharan committed Aug 15, 2014
1 parent eaeb0f7 commit c86d615
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 2 deletions.
7 changes: 7 additions & 0 deletions external/flume-sink/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope> <!-- Need it only for tests, don't package it -->
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import org.apache.flume.sink.AbstractSink
*
*/

private[flume]
class SparkSink extends AbstractSink with Logging with Configurable {

// Size of the pool to use for holding transaction processors.
Expand Down Expand Up @@ -131,6 +130,14 @@ class SparkSink extends AbstractSink with Logging with Configurable {
blockingLatch.await()
Status.BACKOFF
}

private[flume] def getPort(): Int = {
serverOpt
.map(_.getPort)
.getOrElse(
throw new RuntimeException("Server was not started!")
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package org.apache.spark.streaming.flume.sink

import java.net.InetSocketAddress
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{CountDownLatch, Executors}

import scala.collection.JavaConversions._
import scala.concurrent.{Promise, Future}
import scala.util.{Failure, Success, Try}

import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.event.EventBuilder
import org.apache.spark.streaming.TestSuiteBase
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory


/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
class SparkSinkSuite extends TestSuiteBase {
val batchCount = 5
val eventsPerBatch = 100
val totalEventsPerChannel = batchCount * eventsPerBatch
val channelCapacity = 5000
val maxAttempts = 5

test("Success test") {
val (channel, sink) = initializeChannelAndSink(None)
channel.start()
sink.start()

putEvents(channel, 1000)

val port = sink.getPort
val address = new InetSocketAddress("0.0.0.0", port)

val (transceiver, client) = getTransceiverAndClient(address, 1)(0)
val events = client.getEventBatch(1000)
client.ack(events.getSequenceNumber)
assert(events.getEvents.size() === 1000)
assertChannelIsEmpty(channel)
sink.stop()
channel.stop()
transceiver.close()
}

test("Nack") {
val (channel, sink) = initializeChannelAndSink(None)
channel.start()
sink.start()
putEvents(channel, 1000)

val port = sink.getPort
val address = new InetSocketAddress("0.0.0.0", port)

val (transceiver, client) = getTransceiverAndClient(address, 1)(0)
val events = client.getEventBatch(1000)
assert(events.getEvents.size() === 1000)
client.nack(events.getSequenceNumber)
assert(availableChannelSlots(channel) === 4000)
sink.stop()
channel.stop()
transceiver.close()
}

test("Timeout") {
val (channel, sink) = initializeChannelAndSink(Option(Map(SparkSinkConfig
.CONF_TRANSACTION_TIMEOUT -> 1.toString)))
channel.start()
sink.start()
putEvents(channel, 1000)
val port = sink.getPort
val address = new InetSocketAddress("0.0.0.0", port)

val (transceiver, client) = getTransceiverAndClient(address, 1)(0)
val events = client.getEventBatch(1000)
assert(events.getEvents.size() === 1000)
Thread.sleep(1000)
assert(availableChannelSlots(channel) === 4000)
sink.stop()
channel.stop()
transceiver.close()
}

test("Multiple consumers") {
multipleClients(failSome = false)
}

test("Multiple consumers With Some Failures") {
multipleClients(failSome = true)
}

def multipleClients(failSome: Boolean): Unit = {
import scala.concurrent.ExecutionContext.Implicits.global
val (channel, sink) = initializeChannelAndSink(None)
channel.start()
sink.start()
(1 to 5).map(_ =>putEvents(channel, 1000))
val port = sink.getPort
val address = new InetSocketAddress("0.0.0.0", port)

val transAndClient = getTransceiverAndClient(address, 5)
val batchCounter = new CountDownLatch(5)
val counter = new AtomicInteger(0)
transAndClient.foreach(x => {
val promise = Promise[EventBatch]()
val future = promise.future
Future {
val client = x._2
var events: EventBatch = null
Try {
events = client.getEventBatch(1000)
if(!failSome || counter.incrementAndGet() % 2 == 0) {
client.ack(events.getSequenceNumber)
} else {
client.nack(events.getSequenceNumber)
}
}.map(_ => promise.success(events)).recover({
case e => promise.failure(e)
})
}
future.onComplete {
case Success(events) => assert(events.getEvents.size() === 1000)
batchCounter.countDown()
case Failure(t) =>
batchCounter.countDown()
throw t
}
})
batchCounter.await()
if(failSome) {
assert(availableChannelSlots(channel) === 2000)
} else {
assertChannelIsEmpty(channel)
}
sink.stop()
channel.stop()
transAndClient.foreach(x => x._1.close())
}

def initializeChannelAndSink(overrides: Option[Map[String, String]]):
(MemoryChannel, SparkSink) = {
val channel = new MemoryChannel()
val channelContext = new Context()

channelContext.put("capacity", channelCapacity.toString)
channelContext.put("transactionCapacity", 1000.toString)
channelContext.put("keep-alive", 0.toString)
overrides.foreach(channelContext.putAll(_))
channel.configure(channelContext)

val sink = new SparkSink()
val sinkContext = new Context()
sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0")
sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString)
sink.setChannel(channel)
(channel, sink)
}

private def putEvents(ch: MemoryChannel, count: Int): Unit = {
val tx = ch.getTransaction
tx.begin()
(1 to count).map(x => ch.put(EventBuilder.withBody(x.toString.getBytes)))
tx.commit()
tx.close()
}

private def getTransceiverAndClient(address: InetSocketAddress, count: Int):
Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = {

(1 to count).map(_ => {
lazy val channelFactoryExecutor =
Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).
setNameFormat("Flume Receiver Channel Thread - %d").build())
lazy val channelFactory =
new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor)
val transceiver = new NettyTransceiver(address, channelFactory)
val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver)
(transceiver, client)
})
}

private def assertChannelIsEmpty(channel: MemoryChannel) = {
assert(availableChannelSlots(channel) === 5000)
}

private def availableChannelSlots(channel: MemoryChannel): Int = {
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining")
queueRemaining.setAccessible(true)
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits")
m.invoke(queueRemaining.get(channel)).asInstanceOf[Int]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class FlumePollingStreamSuite extends TestSuiteBase {
}

def assertChannelIsEmpty(channel: MemoryChannel) = {
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining");
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining")
queueRemaining.setAccessible(true)
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits")
assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000)
Expand Down

0 comments on commit c86d615

Please sign in to comment.