-
Notifications
You must be signed in to change notification settings - Fork 28.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Flume unit tests and also add tests for Python API
- Loading branch information
Showing
6 changed files
with
546 additions
and
218 deletions.
There are no files selected for viewing
116 changes: 116 additions & 0 deletions
116
external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.streaming.flume | ||
|
||
import java.net.{InetSocketAddress, ServerSocket} | ||
import java.nio.ByteBuffer | ||
import java.util.{List => JList} | ||
|
||
import scala.collection.JavaConversions._ | ||
|
||
import com.google.common.base.Charsets.UTF_8 | ||
import org.apache.avro.ipc.NettyTransceiver | ||
import org.apache.avro.ipc.specific.SpecificRequestor | ||
import org.apache.commons.lang3.RandomUtils | ||
import org.apache.flume.source.avro | ||
import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} | ||
import org.jboss.netty.channel.ChannelPipeline | ||
import org.jboss.netty.channel.socket.SocketChannel | ||
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory | ||
import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} | ||
|
||
import org.apache.spark.util.Utils | ||
import org.apache.spark.SparkConf | ||
|
||
/** | ||
* Share codes for Scala and Python unit tests | ||
*/ | ||
private[flume] class FlumeTestUtils { | ||
|
||
private var transceiver: NettyTransceiver = null | ||
|
||
private val testPort: Int = findFreePort() | ||
|
||
def getTestPort(): Int = testPort | ||
|
||
/** Find a free port */ | ||
private def findFreePort(): Int = { | ||
val candidatePort = RandomUtils.nextInt(1024, 65536) | ||
Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { | ||
val socket = new ServerSocket(trialPort) | ||
socket.close() | ||
(null, trialPort) | ||
}, new SparkConf())._2 | ||
} | ||
|
||
/** Send data to the flume receiver */ | ||
def writeInput(input: JList[String], enableCompression: Boolean): Unit = { | ||
val testAddress = new InetSocketAddress("localhost", testPort) | ||
|
||
val inputEvents = input.map { item => | ||
val event = new AvroFlumeEvent | ||
event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) | ||
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) | ||
event | ||
} | ||
|
||
// if last attempted transceiver had succeeded, close it | ||
close() | ||
|
||
// Create transceiver | ||
transceiver = { | ||
if (enableCompression) { | ||
new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) | ||
} else { | ||
new NettyTransceiver(testAddress) | ||
} | ||
} | ||
|
||
// Create Avro client with the transceiver | ||
val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) | ||
if (client == null) { | ||
throw new AssertionError("Cannot create client") | ||
} | ||
|
||
// Send data | ||
val status = client.appendBatch(inputEvents.toList) | ||
if (status != avro.Status.OK) { | ||
throw new AssertionError("Sent events unsuccessfully") | ||
} | ||
} | ||
|
||
def close(): Unit = { | ||
if (transceiver != null) { | ||
transceiver.close() | ||
transceiver = null | ||
} | ||
} | ||
|
||
/** Class to create socket channel with compression */ | ||
private class CompressionChannelFactory(compressionLevel: Int) | ||
extends NioClientSocketChannelFactory { | ||
|
||
override def newChannel(pipeline: ChannelPipeline): SocketChannel = { | ||
val encoder = new ZlibEncoder(compressionLevel) | ||
pipeline.addFirst("deflater", encoder) | ||
pipeline.addFirst("inflater", new ZlibDecoder()) | ||
super.newChannel(pipeline) | ||
} | ||
} | ||
|
||
} |
211 changes: 211 additions & 0 deletions
211
external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.streaming.flume | ||
|
||
import java.util.concurrent._ | ||
import java.util.{List => JList, Map => JMap} | ||
|
||
import scala.collection.JavaConversions._ | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import com.google.common.base.Charsets.UTF_8 | ||
import org.apache.flume.event.EventBuilder | ||
import org.apache.flume.Context | ||
import org.apache.flume.channel.MemoryChannel | ||
import org.apache.flume.conf.Configurables | ||
|
||
import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} | ||
|
||
/** | ||
* Share codes for Scala and Python unit tests | ||
*/ | ||
private[flume] class PollingFlumeTestUtils { | ||
|
||
private val batchCount = 5 | ||
private val eventsPerBatch = 100 | ||
private val totalEventsPerChannel = batchCount * eventsPerBatch | ||
private val channelCapacity = 5000 | ||
|
||
def getEventsPerBatch: Int = eventsPerBatch | ||
|
||
def getTotalEvents: Int = totalEventsPerChannel * channels.size | ||
|
||
private val channels = new ArrayBuffer[MemoryChannel] | ||
private val sinks = new ArrayBuffer[SparkSink] | ||
|
||
/** | ||
* Start a sink and return the port of this sink | ||
*/ | ||
def startSingleSink(): Int = { | ||
channels.clear() | ||
sinks.clear() | ||
|
||
// Start the channel and sink. | ||
val context = new Context() | ||
context.put("capacity", channelCapacity.toString) | ||
context.put("transactionCapacity", "1000") | ||
context.put("keep-alive", "0") | ||
val channel = new MemoryChannel() | ||
Configurables.configure(channel, context) | ||
|
||
val sink = new SparkSink() | ||
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") | ||
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) | ||
Configurables.configure(sink, context) | ||
sink.setChannel(channel) | ||
sink.start() | ||
|
||
channels += (channel) | ||
sinks += sink | ||
|
||
sink.getPort() | ||
} | ||
|
||
/** | ||
* Start 2 sinks and return the ports | ||
*/ | ||
def startMultipleSinks(): JList[Int] = { | ||
channels.clear() | ||
sinks.clear() | ||
|
||
// Start the channel and sink. | ||
val context = new Context() | ||
context.put("capacity", channelCapacity.toString) | ||
context.put("transactionCapacity", "1000") | ||
context.put("keep-alive", "0") | ||
val channel = new MemoryChannel() | ||
Configurables.configure(channel, context) | ||
|
||
val channel2 = new MemoryChannel() | ||
Configurables.configure(channel2, context) | ||
|
||
val sink = new SparkSink() | ||
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") | ||
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) | ||
Configurables.configure(sink, context) | ||
sink.setChannel(channel) | ||
sink.start() | ||
|
||
val sink2 = new SparkSink() | ||
context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") | ||
context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) | ||
Configurables.configure(sink2, context) | ||
sink2.setChannel(channel2) | ||
sink2.start() | ||
|
||
sinks += sink | ||
sinks += sink2 | ||
channels += channel | ||
channels += channel2 | ||
|
||
sinks.map(_.getPort()) | ||
} | ||
|
||
/** | ||
* Send data and wait until all data has been received | ||
*/ | ||
def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { | ||
val executor = Executors.newCachedThreadPool() | ||
val executorCompletion = new ExecutorCompletionService[Void](executor) | ||
|
||
val latch = new CountDownLatch(batchCount * channels.size) | ||
sinks.foreach(_.countdownWhenBatchReceived(latch)) | ||
|
||
channels.foreach(channel => { | ||
executorCompletion.submit(new TxnSubmitter(channel)) | ||
}) | ||
|
||
for (i <- 0 until channels.size) { | ||
executorCompletion.take() | ||
} | ||
|
||
latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. | ||
} | ||
|
||
/** | ||
* A Python-friendly method to assert the output | ||
*/ | ||
def assertOutput( | ||
outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { | ||
require(outputHeaders.size == outputBodies.size) | ||
val eventSize = outputHeaders.size | ||
if (eventSize != totalEventsPerChannel * channels.size) { | ||
throw new AssertionError( | ||
s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") | ||
} | ||
var counter = 0 | ||
for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { | ||
val eventBodyToVerify = s"${channels(k).getName}-$i" | ||
val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") | ||
var found = false | ||
var j = 0 | ||
while (j < eventSize && !found) { | ||
if (eventBodyToVerify == outputBodies.get(j) && | ||
eventHeaderToVerify == outputHeaders.get(j)) { | ||
found = true | ||
counter += 1 | ||
} | ||
j += 1 | ||
} | ||
} | ||
if (counter != totalEventsPerChannel * channels.size) { | ||
throw new AssertionError( | ||
s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") | ||
} | ||
} | ||
|
||
def assertChannelsAreEmpty(): Unit = { | ||
channels.foreach(assertChannelIsEmpty) | ||
} | ||
|
||
private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { | ||
val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") | ||
queueRemaining.setAccessible(true) | ||
val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") | ||
if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { | ||
throw new AssertionError(s"Channel ${channel.getName} is not empty") | ||
} | ||
} | ||
|
||
def close(): Unit = { | ||
sinks.foreach(_.stop()) | ||
sinks.clear() | ||
channels.foreach(_.stop()) | ||
channels.clear() | ||
} | ||
|
||
private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { | ||
override def call(): Void = { | ||
var t = 0 | ||
for (i <- 0 until batchCount) { | ||
val tx = channel.getTransaction | ||
tx.begin() | ||
for (j <- 0 until eventsPerBatch) { | ||
channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), | ||
Map[String, String](s"test-$t" -> "header"))) | ||
t += 1 | ||
} | ||
tx.commit() | ||
tx.close() | ||
Thread.sleep(500) // Allow some time for the events to reach | ||
} | ||
null | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.