diff --git a/src/main/scala/kafka4m/consumer/ConsumerCommand.scala b/src/main/scala/kafka4m/consumer/ConsumerCommand.scala index 64f0b49..9395c4f 100644 --- a/src/main/scala/kafka4m/consumer/ConsumerCommand.scala +++ b/src/main/scala/kafka4m/consumer/ConsumerCommand.scala @@ -2,6 +2,15 @@ package kafka4m.consumer import scala.concurrent.Promise import scala.util.Try +/** + * A wrapper for an operation which needs to be performed on the kafka consumer thread + * + * @param f the operation to perform on a KafkaConsumer + * @param promise the result promise to complete once this task completes + * @tparam K the kafka consumer key type + * @tparam V the kafka consumer value type + * @tparam A the task result type + */ private[consumer] final case class ExecOnConsumer[K, V, A](f: RichKafkaConsumer[K, V] => A, promise: Promise[A] = Promise[A]()) { def run(inst: RichKafkaConsumer[K, V]) = { promise.tryComplete(Try(f(inst))) diff --git a/src/main/scala/kafka4m/consumer/RichKafkaConsumer.scala b/src/main/scala/kafka4m/consumer/RichKafkaConsumer.scala index 511bfee..6279adc 100644 --- a/src/main/scala/kafka4m/consumer/RichKafkaConsumer.scala +++ b/src/main/scala/kafka4m/consumer/RichKafkaConsumer.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal * A means of driving a kafka-stream using the consumer (not kafka streaming) API */ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], - val topics: Set[String], + val defaultTopics: Set[String], val defaultPollTimeout: Duration, commandQueue: ConcurrentQueue[Task, ExecOnConsumer[K, V, _]], kafkaScheduler: Scheduler, @@ -41,28 +41,27 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], @volatile private var closed = false - require(topics.nonEmpty, "empty topic set for consumer") - require(topics.forall(_.nonEmpty), "blank topic set for consumer") - private val javaPollDuration: time.Duration = RichKafkaConsumer.asJavaDuration(defaultPollTimeout) - def partitionsByTopic(limitToOurTopic: Boolean = true): Map[String, List[KafkaPartitionInfo]] = { + def partitionsByTopic(limitToOurTopic: Boolean = false): Map[String, List[KafkaPartitionInfo]] = { val view = consumer.listTopics().asScala.view.mapValues(_.asScala.map(KafkaPartitionInfo.apply).toList) if (limitToOurTopic) { - view.filterKeys(topics.contains).toMap + view.filterKeys(defaultTopics.contains).toMap } else { view.toMap } } - def subscribe(topic: String, listener: ConsumerRebalanceListener = RebalanceListener): Unit = { - logger.info(s"Subscribing to $topic") - consumer.subscribe(java.util.Collections.singletonList(topic), listener) + def subscribe(topic: String): Unit = subscribe(Set(topic)) + + def subscribe(topics: Set[String], listener: ConsumerRebalanceListener = RebalanceListener): Unit = { + logger.info(s"Subscribing to $topics") + consumer.subscribe(topics.asJava, listener) } def partitions: List[KafkaPartitionInfo] = { val byTopic = partitionsByTopic(true) - topics.toList.flatMap(byTopic.getOrElse(_, Nil)) + byTopic.valuesIterator.flatten.toList } /** @@ -74,9 +73,7 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], try { val records: ConsumerRecords[K, V] = consumer.poll(timeout) logger.debug(s"Got ${records.count()} records from ${records.partitions().asScala.mkString(s"[", ",", "]")}") - val forTopic: Iterable[ConsumerRecord[K, V]] = records.asScala - logger.trace(s"Got ${forTopic.size} of ${records.count()} for topic '$topics' records from ${records.partitions().asScala.mkString(s"[", ",", "]")}") - forTopic + records.asScala } catch { case NonFatal(e) => logger.warn(s"Poll threw $e") @@ -118,10 +115,10 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], /** * @return a task which will run any exec commands on our kafka scheduler */ - def execNext() = { + private def execNext() = { require(!closed, "RickKafkaConsumer is already closed") commandQueue.tryPoll.flatMap { - case Some(exec: ExecOnConsumer[K, V, _]) => + case Some(exec) => Task(exec.run(self)).executeOn(kafkaScheduler).map(_ => NoResults).void case _ => Task.unit } @@ -154,34 +151,34 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], Try(thunk).map(_ => true) } - def seekToBeginning(partition: Int) = swallow { - logger.info(s"seekToBeginning(${partition})") + def seekToBeginningOnPartition(partition: Int, topics: Set[String] = defaultTopics) = swallow { + logger.info(s"seekToBeginning(${partition}, $topics)") topics.foreach { topic => val tp = new TopicPartition(topic, partition) consumer.seekToBeginning(java.util.Collections.singletonList(tp)) } } - def seekToBeginning() = swallow { - logger.info(s"seekToBeginning") + def seekToBeginning(topics: Set[String] = defaultTopics) = swallow { + logger.info(s"seekToBeginning($topics)") topics.foreach { topic => - val topicPartitions = assignmentPartitions.map { partition => + val topicPartitions = assignmentPartitions(topics).map { partition => new TopicPartition(topic, partition) } consumer.seekToBeginning(topicPartitions.asJava) } } - def seekToEnd() = swallow { + def seekToEnd(topics: Set[String] = defaultTopics) = swallow { logger.info("seekToEndUnsafe") topics.foreach { topic => - val topicPartitions = assignmentPartitions.map { partition => + val topicPartitions = assignmentPartitions(topics).map { partition => new TopicPartition(topic, partition) } consumer.seekToEnd(topicPartitions.asJava) } } - def assignToTopics(): Try[Set[TopicPartition]] = { + def assignToTopics(topics: Set[String] = defaultTopics): Try[Set[TopicPartition]] = { val pbt = partitionsByTopic() val allTopicPartitions = topics.flatMap { topic => val topicPartitions = pbt.get(topic).map { partitions: List[KafkaPartitionInfo] => @@ -208,7 +205,7 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], } } - def seekTo(topicPartitionState: PartitionOffsetState) = swallow { + def seekTo(topicPartitionState: PartitionOffsetState, topics: Set[String] = defaultTopics) = swallow { logger.info(s"seekToUnsafe(${topicPartitionState})") for { topic <- topics @@ -219,21 +216,21 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], } } - def positionsFor(partition: Int) = { + def positionsFor(partition: Int, topics: Set[String] = defaultTopics) = { val byTopic = topics.map { topic => topic -> consumer.position(new TopicPartition(topic, partition)) } byTopic.toMap } - def committed(partition: Int): Map[String, OffsetAndMetadata] = { + def committed(partition: Int, topics: Set[String] = defaultTopics): Map[String, OffsetAndMetadata] = { val byTopic = topics.map { topic => topic -> consumer.committed(new TopicPartition(topic, partition)) } byTopic.toMap } - def assignmentPartitions: List[Int] = { + def assignmentPartitions(topics: Set[String] = defaultTopics): List[Int] = { assignments().map { tp => require(topics.contains(tp.topic()), s"consumer for topics $topics has assignment on ${tp.topic()}") tp.partition() @@ -241,15 +238,15 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], } def assignments() = consumer.assignment().asScala.toList - def status(verbose: Boolean): String = { + def status(verbose: Boolean, topics: Set[String] = defaultTopics): String = { val byTopic = partitionsByTopic() val topicStatuses = topics.map { topic => byTopic.get(topic).fold(s"topic '${topic}' doesn't exist") { partitions => val ourAssignments = { - val all: List[Int] = assignmentPartitions + val all: List[Int] = assignmentPartitions(topics) val detail = if (verbose) { - val committedStatus: Seq[Map[String, OffsetAndMetadata]] = all.map(committed) + val committedStatus: Seq[Map[String, OffsetAndMetadata]] = all.map(i => committed(i, topics)) committedStatus.mkString("\n\tCommit status:\n\t", "\n\t", "\n") } else { "" @@ -266,8 +263,8 @@ final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V], /** * @return a scala-friendly data structure containing the commit status of the kafka cluster */ - def committedStatus(): List[CommittedStatus] = { - val all: List[Int] = assignmentPartitions + def committedStatus(topics: Set[String] = defaultTopics): List[CommittedStatus] = { + val all: List[Int] = assignmentPartitions(topics) partitionsByTopic().collect { case (topic, kafkaPartitions) => val weAreSubscribed: Boolean = topics.contains(topic) diff --git a/src/test/scala/kafka4m/consumer/RichKafkaConsumerTest.scala b/src/test/scala/kafka4m/consumer/RichKafkaConsumerTest.scala index 69f42f1..566ed06 100644 --- a/src/test/scala/kafka4m/consumer/RichKafkaConsumerTest.scala +++ b/src/test/scala/kafka4m/consumer/RichKafkaConsumerTest.scala @@ -37,7 +37,7 @@ class RichKafkaConsumerTest extends BaseKafka4mDockerSpec { val (topic, config) = Kafka4mTestConfig.next() Using(RichKafkaConsumer.byteArrayValues(config, FixedScheduler().scheduler, sched)) { consumer => Using(RichKafkaAdmin(config))(_.createTopicSync(topic, testTimeout)) - consumer.assignmentPartitions shouldBe empty + consumer.assignmentPartitions() shouldBe empty } } } @@ -54,7 +54,7 @@ class RichKafkaConsumerTest extends BaseKafka4mDockerSpec { val third = producer.sendAsync(topic, "third", "value".getBytes(), partition = 0).futureValue When("We subscribe and consume to the end") - consumer.subscribe(topic, RebalanceListener) + consumer.subscribe(topic) eventually { consumer.unsafePoll().toList.size shouldBe 3 @@ -62,7 +62,7 @@ class RichKafkaConsumerTest extends BaseKafka4mDockerSpec { And("seek to the beginning") eventually { - consumer.seekToBeginning(0) shouldBe Success(true) + consumer.seekToBeginningOnPartition(0) shouldBe Success(true) } Then("we should see that offset as the first message")