Skip to content

Commit

Permalink
Refactor Flume unit tests and also add tests for Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jun 19, 2015
1 parent 9f33873 commit 0336579
Show file tree
Hide file tree
Showing 6 changed files with 546 additions and 218 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.flume

import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
import java.util.{List => JList}

import scala.collection.JavaConversions._

import com.google.common.base.Charsets.UTF_8
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent}
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder}

import org.apache.spark.util.Utils
import org.apache.spark.SparkConf

/**
* Share codes for Scala and Python unit tests
*/
private[flume] class FlumeTestUtils {

private var transceiver: NettyTransceiver = null

private val testPort: Int = findFreePort()

def getTestPort(): Int = testPort

/** Find a free port */
private def findFreePort(): Int = {
val candidatePort = RandomUtils.nextInt(1024, 65536)
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
}, new SparkConf())._2
}

/** Send data to the flume receiver */
def writeInput(input: JList[String], enableCompression: Boolean): Unit = {
val testAddress = new InetSocketAddress("localhost", testPort)

val inputEvents = input.map { item =>
val event = new AvroFlumeEvent
event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8)))
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
event
}

// if last attempted transceiver had succeeded, close it
close()

// Create transceiver
transceiver = {
if (enableCompression) {
new NettyTransceiver(testAddress, new CompressionChannelFactory(6))
} else {
new NettyTransceiver(testAddress)
}
}

// Create Avro client with the transceiver
val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver)
if (client == null) {
throw new AssertionError("Cannot create client")
}

// Send data
val status = client.appendBatch(inputEvents.toList)
if (status != avro.Status.OK) {
throw new AssertionError("Sent events unsuccessfully")
}
}

def close(): Unit = {
if (transceiver != null) {
transceiver.close()
transceiver = null
}
}

/** Class to create socket channel with compression */
private class CompressionChannelFactory(compressionLevel: Int)
extends NioClientSocketChannelFactory {

override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
val encoder = new ZlibEncoder(compressionLevel)
pipeline.addFirst("deflater", encoder)
pipeline.addFirst("inflater", new ZlibDecoder())
super.newChannel(pipeline)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.flume

import java.util.concurrent._
import java.util.{List => JList, Map => JMap}

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

import com.google.common.base.Charsets.UTF_8
import org.apache.flume.event.EventBuilder
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.conf.Configurables

import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink}

/**
* Share codes for Scala and Python unit tests
*/
private[flume] class PollingFlumeTestUtils {

private val batchCount = 5
private val eventsPerBatch = 100
private val totalEventsPerChannel = batchCount * eventsPerBatch
private val channelCapacity = 5000

def getEventsPerBatch: Int = eventsPerBatch

def getTotalEvents: Int = totalEventsPerChannel * channels.size

private val channels = new ArrayBuffer[MemoryChannel]
private val sinks = new ArrayBuffer[SparkSink]

/**
* Start a sink and return the port of this sink
*/
def startSingleSink(): Int = {
channels.clear()
sinks.clear()

// Start the channel and sink.
val context = new Context()
context.put("capacity", channelCapacity.toString)
context.put("transactionCapacity", "1000")
context.put("keep-alive", "0")
val channel = new MemoryChannel()
Configurables.configure(channel, context)

val sink = new SparkSink()
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
Configurables.configure(sink, context)
sink.setChannel(channel)
sink.start()

channels += (channel)
sinks += sink

sink.getPort()
}

/**
* Start 2 sinks and return the ports
*/
def startMultipleSinks(): JList[Int] = {
channels.clear()
sinks.clear()

// Start the channel and sink.
val context = new Context()
context.put("capacity", channelCapacity.toString)
context.put("transactionCapacity", "1000")
context.put("keep-alive", "0")
val channel = new MemoryChannel()
Configurables.configure(channel, context)

val channel2 = new MemoryChannel()
Configurables.configure(channel2, context)

val sink = new SparkSink()
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
Configurables.configure(sink, context)
sink.setChannel(channel)
sink.start()

val sink2 = new SparkSink()
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost")
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0))
Configurables.configure(sink2, context)
sink2.setChannel(channel2)
sink2.start()

sinks += sink
sinks += sink2
channels += channel
channels += channel2

sinks.map(_.getPort())
}

/**
* Send data and wait until all data has been received
*/
def sendDatAndEnsureAllDataHasBeenReceived(): Unit = {
val executor = Executors.newCachedThreadPool()
val executorCompletion = new ExecutorCompletionService[Void](executor)

val latch = new CountDownLatch(batchCount * channels.size)
sinks.foreach(_.countdownWhenBatchReceived(latch))

channels.foreach(channel => {
executorCompletion.submit(new TxnSubmitter(channel))
})

for (i <- 0 until channels.size) {
executorCompletion.take()
}

latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received.
}

/**
* A Python-friendly method to assert the output
*/
def assertOutput(
outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = {
require(outputHeaders.size == outputBodies.size)
val eventSize = outputHeaders.size
if (eventSize != totalEventsPerChannel * channels.size) {
throw new AssertionError(
s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize")
}
var counter = 0
for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) {
val eventBodyToVerify = s"${channels(k).getName}-$i"
val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header")
var found = false
var j = 0
while (j < eventSize && !found) {
if (eventBodyToVerify == outputBodies.get(j) &&
eventHeaderToVerify == outputHeaders.get(j)) {
found = true
counter += 1
}
j += 1
}
}
if (counter != totalEventsPerChannel * channels.size) {
throw new AssertionError(
s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter")
}
}

def assertChannelsAreEmpty(): Unit = {
channels.foreach(assertChannelIsEmpty)
}

private def assertChannelIsEmpty(channel: MemoryChannel): Unit = {
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining")
queueRemaining.setAccessible(true)
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits")
if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) {
throw new AssertionError(s"Channel ${channel.getName} is not empty")
}
}

def close(): Unit = {
sinks.foreach(_.stop())
sinks.clear()
channels.foreach(_.stop())
channels.clear()
}

private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] {
override def call(): Void = {
var t = 0
for (i <- 0 until batchCount) {
val tx = channel.getTransaction
tx.begin()
for (j <- 0 until eventsPerBatch) {
channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8),
Map[String, String](s"test-$t" -> "header")))
t += 1
}
tx.commit()
tx.close()
Thread.sleep(500) // Allow some time for the events to reach
}
null
}
}

}
Loading

0 comments on commit 0336579

Please sign in to comment.