Skip to content

Commit

Permalink
KAFKA-13772: Partitions are not correctly re-partitioned when the fet…
Browse files Browse the repository at this point in the history
…cher thread pool is resized (#11953)

Partitions are assigned to fetcher threads based on their hash modulo the number of fetcher threads. When we resize the fetcher thread pool, we basically re-distribute all the partitions based on the new fetcher thread pool size. The issue is that the logic that resizes the fetcher thread pool updates the `fetcherThreadMap` while iterating over it. The `Map` does not give any guarantee in this case - especially when the underlying map is re-hashed - and that led to not iterating over all the fetcher threads during the process and thus in leaving some partitions in the wrong fetcher threads.

Reviewers: Luke Chen <showuon@gmail.com>, David Jacot <djacot@confluent.io>
  • Loading branch information
yufeiyan1220 committed Mar 31, 2022
1 parent ce7788a commit 430f9c9
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 9 deletions.
23 changes: 15 additions & 8 deletions core/src/main/scala/kafka/server/AbstractFetcherManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,22 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri

def resizeThreadPool(newSize: Int): Unit = {
def migratePartitions(newSize: Int): Unit = {
val allRemovedPartitionsMap = mutable.Map[TopicPartition, InitialFetchState]()
fetcherThreadMap.forKeyValue { (id, thread) =>
val partitionStates = removeFetcherForPartitions(thread.partitions)
val partitionStates = thread.removeAllPartitions()
if (id.fetcherId >= newSize)
thread.shutdown()
val fetchStates = partitionStates.map { case (topicPartition, currentFetchState) =>
val initialFetchState = InitialFetchState(currentFetchState.topicId, thread.sourceBroker,
currentLeaderEpoch = currentFetchState.currentLeaderEpoch,
initOffset = currentFetchState.fetchOffset)
topicPartition -> initialFetchState
partitionStates.forKeyValue { (topicPartition, currentFetchState) =>
val initialFetchState = InitialFetchState(currentFetchState.topicId, thread.sourceBroker,
currentLeaderEpoch = currentFetchState.currentLeaderEpoch,
initOffset = currentFetchState.fetchOffset)
allRemovedPartitionsMap += topicPartition -> initialFetchState
}
addFetcherForPartitions(fetchStates)
}
// failed partitions are removed when adding partitions to fetcher
addFetcherForPartitions(allRemovedPartitionsMap)
}

lock synchronized {
val currentSize = numFetchersPerBroker
info(s"Resizing fetcher thread pool size from $currentSize to $newSize")
Expand Down Expand Up @@ -145,7 +148,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
case None =>
addAndStartFetcherThread(brokerAndFetcherId, brokerIdAndFetcherId)
}

// failed partitions are removed when added partitions to thread
addPartitionsToFetcherThread(fetcherThread, initialFetchOffsets)
}
}
Expand Down Expand Up @@ -251,6 +254,10 @@ class FailedPartitions {
def contains(topicPartition: TopicPartition): Boolean = synchronized {
failedPartitionsSet.contains(topicPartition)
}

def partitions(): Set[TopicPartition] = synchronized {
failedPartitionsSet.toSet
}
}

case class BrokerAndFetcherId(broker: BrokerEndPoint, fetcherId: Int)
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/kafka/server/AbstractFetcherThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,18 @@ abstract class AbstractFetcherThread(name: String,
} finally partitionMapLock.unlock()
}

def removeAllPartitions(): Map[TopicPartition, PartitionFetchState] = {
partitionMapLock.lockInterruptibly()
try {
val allPartitionState = partitionStates.partitionStateMap.asScala.toMap
allPartitionState.keys.foreach { tp =>
partitionStates.remove(tp)
fetcherLagStats.unregister(tp)
}
allPartitionState
} finally partitionMapLock.unlock()
}

def partitionCount: Int = {
partitionMapLock.lockInterruptibly()
try partitionStates.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ package kafka.server

import com.yammer.metrics.core.Gauge
import kafka.cluster.BrokerEndPoint
import kafka.log.LogAppendInfo
import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
import kafka.utils.Implicits.MapExtensionMethods
import kafka.utils.TestUtils
import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
import org.apache.kafka.common.requests.FetchRequest
import org.apache.kafka.common.utils.Utils
import org.apache.kafka.common.{TopicPartition, Uuid}
import org.apache.kafka.server.metrics.KafkaYammerMetrics
import org.junit.jupiter.api.{BeforeEach, Test}
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.{BeforeEach, Test}
import org.mockito.Mockito.{mock, verify, when}

import scala.collection.{Map, Set, mutable}
import scala.jdk.CollectionConverters._

class AbstractFetcherManagerTest {
Expand Down Expand Up @@ -100,6 +107,7 @@ class AbstractFetcherManagerTest {
fetcherManager.removeFetcherForPartitions(Set(tp))
assertEquals(0, getMetricValue(metricName))
}

@Test
def testDeadThreadCountMetric(): Unit = {
val fetcher: AbstractFetcherThread = mock(classOf[AbstractFetcherThread])
Expand Down Expand Up @@ -210,4 +218,113 @@ class AbstractFetcherManagerTest {
verify(fetcher).maybeUpdateTopicIds(Set(tp1), topicIds)
verify(fetcher).maybeUpdateTopicIds(Set(tp2), topicIds)
}

@Test
def testExpandThreadPool(): Unit = {
testResizeThreadPool(10, 50)
}

@Test
def testShrinkThreadPool(): Unit = {
testResizeThreadPool(50, 10)
}

private def testResizeThreadPool(currentFetcherSize: Int, newFetcherSize: Int, brokerNum: Int = 6): Unit = {
val fetchingTopicPartitions = makeTopicPartition(10, 100)
val failedTopicPartitions = makeTopicPartition(2, 5, "topic_failed")
val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", currentFetcherSize) {
override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = {
new TestResizeFetcherThread(sourceBroker, failedPartitions)
}
}
try {
fetcherManager.addFetcherForPartitions(fetchingTopicPartitions.map { tp =>
val brokerId = getBrokerId(tp, brokerNum)
val brokerEndPoint = new BrokerEndPoint(brokerId, s"kafka-host-$brokerId", 9092)
tp -> InitialFetchState(None, brokerEndPoint, 0, 0)
}.toMap)

// Mark some of these partitions failed within resizing scope
fetchingTopicPartitions.take(20).foreach(fetcherManager.addFailedPartition)
// Mark failed partitions out of resizing scope
failedTopicPartitions.foreach(fetcherManager.addFailedPartition)

fetcherManager.resizeThreadPool(newFetcherSize)

val ownedPartitions = mutable.Set.empty[TopicPartition]
fetcherManager.fetcherThreadMap.forKeyValue { (brokerIdAndFetcherId, fetcherThread) =>
val fetcherId = brokerIdAndFetcherId.fetcherId
val brokerId = brokerIdAndFetcherId.brokerId

fetcherThread.partitions.foreach { tp =>
ownedPartitions += tp
assertEquals(fetcherManager.getFetcherId(tp), fetcherId)
assertEquals(getBrokerId(tp, brokerNum), brokerId)
}
}
// Verify that all partitions are owned by the fetcher threads.
assertEquals(fetchingTopicPartitions, ownedPartitions)

// Only failed partitions should still be kept after resizing
assertEquals(failedTopicPartitions, fetcherManager.failedPartitions.partitions())
} finally {
fetcherManager.closeAllFetchers()
}
}


private def makeTopicPartition(topicNum: Int, partitionNum: Int, topicPrefix: String = "topic_"): Set[TopicPartition] = {
val res = mutable.Set[TopicPartition]()
for (i <- 0 to topicNum - 1) {
val topic = topicPrefix + i
for (j <- 0 to partitionNum - 1) {
res += new TopicPartition(topic, j)
}
}
res.toSet
}

private def getBrokerId(tp: TopicPartition, brokerNum: Int): Int = {
Utils.abs(tp.hashCode) % brokerNum
}

private class TestResizeFetcherThread(sourceBroker: BrokerEndPoint, failedPartitions: FailedPartitions)
extends AbstractFetcherThread(
name = "test-resize-fetcher",
clientId = "mock-fetcher",
sourceBroker,
failedPartitions,
fetchBackOffMs = 0,
brokerTopicStats = new BrokerTopicStats) {

override protected def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = {
None
}

override protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {}

override protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = {}

override protected def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = ResultWithPartitions(None, Set.empty)

override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = Some(0)

override protected def logStartOffset(topicPartition: TopicPartition): Long = 1

override protected def logEndOffset(topicPartition: TopicPartition): Long = 1

override protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = Some(OffsetAndEpoch(1, 0))

override protected def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = Map.empty

override protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = Map.empty

override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1

override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = 1

override protected val isOffsetForLeaderEpochSupported: Boolean = false
override protected val isTruncationOnFetchSupported: Boolean = false
}

}

0 comments on commit 430f9c9

Please sign in to comment.