Skip to content

Commit

Permalink
Fix bugs and address the issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryshao committed Apr 10, 2015
1 parent 64d9877 commit 61a04f0
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import java.util.concurrent.TimeoutException
import scala.annotation.tailrec
import scala.language.postfixOps
import scala.util.Random
import scala.util.control.NonFatal

import kafka.admin.AdminUtils
import kafka.common.KafkaException
Expand All @@ -44,6 +45,8 @@ import org.apache.spark.util.Utils
/**
* This is a helper class for Kafka test suites. This has the functionality to set up
* and tear down local Kafka servers, and to push data using Kafka producers.
*
* The reason to put Kafka test utility class in src is to test Python related Kafka APIs.
*/
private class KafkaTestUtils extends Logging {

Expand All @@ -55,7 +58,7 @@ private class KafkaTestUtils extends Logging {

private var zookeeper: EmbeddedZookeeper = _

var zkClient: ZkClient = _
private var zkClient: ZkClient = _

// Kafka broker related configurations
private val brokerHost = "localhost"
Expand All @@ -82,18 +85,25 @@ private class KafkaTestUtils extends Logging {
s"$brokerHost:$brokerPort"
}

/** Set up the Embedded Zookeeper server and get the proper Zookeeper port */
def setupEmbeddedZookeeper(): Unit = {
def zookeeperClient: ZkClient = {
assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client")
Option(zkClient).getOrElse(
throw new IllegalStateException("Zookeeper client is not yet initialized"))
}

// Set up the Embedded Zookeeper server and get the proper Zookeeper port
private def setupEmbeddedZookeeper(): Unit = {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
// Get the actual zookeeper binding port
zkPort = zookeeper.actualPort
zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout,
ZKStringSerializer)
zkReady = true
zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
}

/** Set up the Embedded Kafka server */
def setupEmbeddedKafkaServer(): Unit = {
// Set up the Embedded Kafka server
private def setupEmbeddedKafkaServer(): Unit = {
assert(zkReady, "Zookeeper should be set up beforehand")
// Kafka broker startup
var bindSuccess: Boolean = false
Expand All @@ -116,8 +126,14 @@ private class KafkaTestUtils extends Logging {
brokerReady = true
}

/** setup thw whole embedded servers, including Zookeeper and Kafka brokers */
def setupEmbeddedServers(): Unit = {
setupEmbeddedZookeeper()
setupEmbeddedKafkaServer()
}

/** Tear down the whole servers, including Kafka broker and Zookeeper */
def tearDownEmbeddedServers(): Unit = {
def teardownEmbeddedServers(): Unit = {
brokerReady = false
zkReady = false

Expand Down Expand Up @@ -151,7 +167,7 @@ private class KafkaTestUtils extends Logging {
waitUntilMetadataIsPropagated(topic, 0)
}

/** Java function for sending messages to the Kafka broker */
/** Java-friendly function for sending messages to the Kafka broker */
def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
import scala.collection.JavaConversions._
sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*))
Expand Down Expand Up @@ -191,6 +207,37 @@ private class KafkaTestUtils extends Logging {
}

private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
// A simplified version of scalatest eventually, rewrite here is to avoid adding extra test
// dependency
def eventually[T](timeout: Time, interval: Time)(func: => T): T = {
def makeAttempt(): Either[Throwable, T] = {
try {
Right(func)
} catch {
case e if NonFatal(e) => Left(e)
}
}

val startTime = System.currentTimeMillis()
@tailrec
def tryAgain(attempt: Int): T = {
makeAttempt() match {
case Right(result) => result
case Left(e) =>
val duration = System.currentTimeMillis() - startTime
if (duration < timeout.milliseconds) {
Thread.sleep(interval.milliseconds)
} else {
throw new TimeoutException(e.getMessage)
}

tryAgain(attempt + 1)
}
}

tryAgain(1)
}

eventually(Time(10000), Time(100)) {
assert(
server.apis.metadataCache.containsTopicAndPartition(topic, partition),
Expand All @@ -199,38 +246,7 @@ private class KafkaTestUtils extends Logging {
}
}

// A simplified version of scalatest eventually, rewrite here is to avoid adding extra test
// dependency
private def eventually[T](timeout: Time, interval: Time)(func: => T): T = {
def makeAttempt(): Either[Throwable, T] = {
try {
Right(func)
} catch {
case e: Throwable => Left(e)
}
}

val startTime = System.currentTimeMillis()
@tailrec
def tryAgain(attempt: Int): T = {
makeAttempt() match {
case Right(result) => result
case Left(e) =>
val duration = System.currentTimeMillis() - startTime
if (duration < timeout.milliseconds) {
Thread.sleep(interval.milliseconds)
} else {
throw new TimeoutException(e.getMessage)
}

tryAgain(attempt + 1)
}
}

tryAgain(1)
}

class EmbeddedZookeeper(val zkConnect: String) {
private class EmbeddedZookeeper(val zkConnect: String) {
val random = new Random()
val snapshotDir = Utils.createTempDir()
val logDir = Utils.createTempDir()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ public class JavaDirectKafkaStreamSuite implements Serializable {
@Before
public void setUp() {
kafkaTestUtils = new KafkaTestUtils();
kafkaTestUtils.setupEmbeddedZookeeper();
kafkaTestUtils.setupEmbeddedKafkaServer();
kafkaTestUtils.setupEmbeddedServers();
System.clearProperty("spark.driver.port");
SparkConf sparkConf = new SparkConf()
.setMaster("local[4]").setAppName(this.getClass().getSimpleName());
Expand All @@ -64,7 +63,7 @@ public void tearDown() {
System.clearProperty("spark.driver.port");

if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers();
kafkaTestUtils.teardownEmbeddedServers();
kafkaTestUtils = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ public class JavaKafkaRDDSuite implements Serializable {
@Before
public void setUp() {
kafkaTestUtils = new KafkaTestUtils();
kafkaTestUtils.setupEmbeddedZookeeper();
kafkaTestUtils.setupEmbeddedKafkaServer();
kafkaTestUtils.setupEmbeddedServers();
System.clearProperty("spark.driver.port");
SparkConf sparkConf = new SparkConf()
.setMaster("local[4]").setAppName(this.getClass().getSimpleName());
Expand All @@ -59,7 +58,7 @@ public void tearDown() {
System.clearProperty("spark.driver.port");

if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers();
kafkaTestUtils.teardownEmbeddedServers();
kafkaTestUtils = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.List;
import java.util.Random;

import scala.Predef;
import scala.Tuple2;

import kafka.serializer.StringDecoder;
Expand All @@ -48,8 +47,7 @@ public class JavaKafkaStreamSuite implements Serializable {
@Before
public void setUp() {
kafkaTestUtils = new KafkaTestUtils();
kafkaTestUtils.setupEmbeddedZookeeper();
kafkaTestUtils.setupEmbeddedKafkaServer();
kafkaTestUtils.setupEmbeddedServers();
System.clearProperty("spark.driver.port");
SparkConf sparkConf = new SparkConf()
.setMaster("local[4]").setAppName(this.getClass().getSimpleName());
Expand All @@ -66,7 +64,7 @@ public void tearDown() {
System.clearProperty("spark.driver.port");

if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers();
kafkaTestUtils.teardownEmbeddedServers();
kafkaTestUtils = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ class DirectKafkaStreamSuite

override def beforeAll {
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setupEmbeddedZookeeper()
kafkaTestUtils.setupEmbeddedKafkaServer()
kafkaTestUtils.setupEmbeddedServers()
}

override def afterAll {
if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers()
kafkaTestUtils.teardownEmbeddedServers()
kafkaTestUtils = null
}
}
Expand Down Expand Up @@ -88,7 +87,7 @@ class DirectKafkaStreamSuite
}
val totalSent = data.values.sum * topics.size
val kafkaParams = Map(
"metadata.broker.list" -> s"${kafkaTestUtils.brokerAddress}",
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "smallest"
)

Expand Down Expand Up @@ -134,7 +133,7 @@ class DirectKafkaStreamSuite
val data = Map("a" -> 10)
kafkaTestUtils.createTopic(topic)
val kafkaParams = Map(
"metadata.broker.list" -> s"${kafkaTestUtils.brokerAddress}",
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "largest"
)
val kc = new KafkaCluster(kafkaParams)
Expand Down Expand Up @@ -179,7 +178,7 @@ class DirectKafkaStreamSuite
val data = Map("a" -> 10)
kafkaTestUtils.createTopic(topic)
val kafkaParams = Map(
"metadata.broker.list" -> s"${kafkaTestUtils.brokerAddress}",
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "largest"
)
val kc = new KafkaCluster(kafkaParams)
Expand Down Expand Up @@ -225,7 +224,7 @@ class DirectKafkaStreamSuite
testDir = Utils.createTempDir()

val kafkaParams = Map(
"metadata.broker.list" -> s"${kafkaTestUtils.brokerAddress}",
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"auto.offset.reset" -> "smallest"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka
import scala.util.Random

import kafka.common.TopicAndPartition
import org.scalatest.{FunSuite, BeforeAndAfterAll}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {
private val topic = "kcsuitetopic" + Random.nextInt(10000)
Expand All @@ -31,17 +31,16 @@ class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {

override def beforeAll() {
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setupEmbeddedZookeeper()
kafkaTestUtils.setupEmbeddedKafkaServer()
kafkaTestUtils.setupEmbeddedServers()

kafkaTestUtils.createTopic(topic)
kafkaTestUtils.sendMessages(topic, Map("a" -> 1))
kc = new KafkaCluster(Map("metadata.broker.list" -> s"${kafkaTestUtils.brokerAddress}"))
kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress))
}

override def afterAll() {
if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers()
kafkaTestUtils.teardownEmbeddedServers()
kafkaTestUtils = null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import org.scalatest.{FunSuite, BeforeAndAfterAll}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark._

Expand All @@ -37,8 +37,7 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
override def beforeAll {
sc = new SparkContext(sparkConf)
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setupEmbeddedZookeeper()
kafkaTestUtils.setupEmbeddedKafkaServer()
kafkaTestUtils.setupEmbeddedServers()
}

override def afterAll {
Expand All @@ -48,7 +47,7 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
}

if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers()
kafkaTestUtils.teardownEmbeddedServers()
kafkaTestUtils = null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,30 @@ import scala.language.postfixOps
import scala.util.Random

import kafka.serializer.StringDecoder
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.SparkConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
private var ssc: StreamingContext = _
private var kafkaTestUtils: KafkaTestUtils = _

before {
override def beforeAll(): Unit = {
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setupEmbeddedZookeeper()
kafkaTestUtils.setupEmbeddedKafkaServer()
kafkaTestUtils.setupEmbeddedServers()
}

after {
override def afterAll(): Unit = {
if (ssc != null) {
ssc.stop()
ssc = null
}

if (kafkaTestUtils != null) {
kafkaTestUtils.tearDownEmbeddedServers()
kafkaTestUtils.teardownEmbeddedServers()
kafkaTestUtils = null
}
}
Expand Down
Loading

0 comments on commit 61a04f0

Please sign in to comment.