From 56dcb837a2f1c1d8c016cfccf8268a910bb77a36 Mon Sep 17 00:00:00 2001 From: Justine Olshan Date: Wed, 12 Apr 2023 17:04:51 -0700 Subject: [PATCH] KAFKA-14561: Improve transactions experience for older clients by ensuring ongoing transaction (#13391) Added check for ongoing transaction Thread to send and receive verify only add partition to txn requests Code to send on request thread courtesy of @artemlivshits Reviewers: Artem Livshits , Jun Rao --- .../message/AddPartitionsToTxnRequest.json | 5 +- .../main/java/kafka/server/NetworkUtils.java | 90 +++++++ .../builders/ReplicaManagerBuilder.java | 12 +- .../main/scala/kafka/cluster/Partition.scala | 4 + .../src/main/scala/kafka/log/UnifiedLog.scala | 5 + .../scala/kafka/network/RequestChannel.scala | 40 ++- .../server/AddPartitionsToTxnManager.scala | 180 +++++++++++++ .../scala/kafka/server/BrokerServer.scala | 8 +- .../main/scala/kafka/server/KafkaApis.scala | 15 +- .../main/scala/kafka/server/KafkaConfig.scala | 10 + .../kafka/server/KafkaRequestHandler.scala | 74 ++++++ .../main/scala/kafka/server/KafkaServer.scala | 7 +- .../scala/kafka/server/ReplicaManager.scala | 206 +++++++++++---- .../server/KafkaRequestHandlerTest.scala | 80 ++++++ .../unit/kafka/cluster/PartitionTest.scala | 60 ++++- .../AbstractCoordinatorConcurrencyTest.scala | 4 +- .../group/GroupCoordinatorTest.scala | 8 + .../group/GroupMetadataManagerTest.scala | 20 ++ .../TransactionStateManagerTest.scala | 12 + .../unit/kafka/network/SocketServerTest.scala | 2 + .../AddPartitionsToTxnManagerTest.scala | 245 ++++++++++++++++++ .../AddPartitionsToTxnRequestServerTest.scala | 1 - .../unit/kafka/server/KafkaApisTest.scala | 66 ++++- .../kafka/server/ReplicaManagerTest.scala | 197 +++++++++++++- 24 files changed, 1275 insertions(+), 76 deletions(-) create mode 100644 core/src/main/java/kafka/server/NetworkUtils.java create mode 100644 core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala create mode 100644 core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala create mode 100644 core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala diff --git a/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json index 32bb9b8d1f76..1b89c54d8640 100644 --- a/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json +++ b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json @@ -26,10 +26,7 @@ // // Version 4 adds VerifyOnly field to check if partitions are already in transaction and adds support to batch multiple transactions. // Versions 3 and below will be exclusively used by clients and versions 4 and above will be used by brokers. - // The AddPartitionsToTxnRequest version 4 API is added as part of KIP-890 and is still - // under developement. Hence, the API is not exposed by default by brokers - // unless explicitely enabled. - "latestVersionUnstable": true, + "latestVersionUnstable": false, "validVersions": "0-4", "flexibleVersions": "3+", "fields": [ diff --git a/core/src/main/java/kafka/server/NetworkUtils.java b/core/src/main/java/kafka/server/NetworkUtils.java new file mode 100644 index 000000000000..87dc7e961074 --- /dev/null +++ b/core/src/main/java/kafka/server/NetworkUtils.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ManualMetadataUpdater; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.common.Reconfigurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.Selectable; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; + +import java.util.Collections; + +public class NetworkUtils { + + public static NetworkClient buildNetworkClient(String prefix, + KafkaConfig config, + Metrics metrics, + Time time, + LogContext logContext) { + ChannelBuilder channelBuilder = ChannelBuilders.clientChannelBuilder( + config.interBrokerSecurityProtocol(), + JaasContext.Type.SERVER, + config, + config.interBrokerListenerName(), + config.saslMechanismInterBrokerProtocol(), + time, + config.saslInterBrokerHandshakeRequestEnable(), + logContext + ); + + if (channelBuilder instanceof Reconfigurable) { + config.addReconfigurable((Reconfigurable) channelBuilder); + } + + String metricGroupPrefix = prefix + "-channel"; + + Selector selector = new Selector( + NetworkReceive.UNLIMITED, + config.connectionsMaxIdleMs(), + metrics, + time, + metricGroupPrefix, + Collections.emptyMap(), + false, + channelBuilder, + logContext + ); + + String clientId = prefix + "-client-" + config.nodeId(); + return new NetworkClient( + selector, + new ManualMetadataUpdater(), + clientId, + 1, + 50, + 50, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.socketReceiveBufferBytes(), + config.requestTimeoutMs(), + config.connectionSetupTimeoutMs(), + config.connectionSetupTimeoutMaxMs(), + time, + false, + new ApiVersions(), + logContext + ); + } +} \ No newline at end of file diff --git a/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java b/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java index d9d7c1d82c41..93d6f4ff3f30 100644 --- a/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java +++ b/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java @@ -18,8 +18,8 @@ package kafka.server.builders; import kafka.log.LogManager; +import kafka.server.AddPartitionsToTxnManager; import kafka.server.AlterPartitionManager; -import kafka.log.remote.RemoteLogManager; import kafka.server.BrokerTopicStats; import kafka.server.DelayedDeleteRecords; import kafka.server.DelayedElectLeader; @@ -30,6 +30,7 @@ import kafka.server.MetadataCache; import kafka.server.QuotaFactory.QuotaManagers; import kafka.server.ReplicaManager; +import kafka.log.remote.RemoteLogManager; import kafka.zk.KafkaZkClient; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.utils.Time; @@ -62,6 +63,7 @@ public class ReplicaManagerBuilder { private Optional> delayedElectLeaderPurgatory = Optional.empty(); private Optional threadNamePrefix = Optional.empty(); private Long brokerEpoch = -1L; + private Optional addPartitionsToTxnManager = Optional.empty(); public ReplicaManagerBuilder setConfig(KafkaConfig config) { this.config = config; @@ -158,6 +160,11 @@ public ReplicaManagerBuilder setBrokerEpoch(long brokerEpoch) { return this; } + public ReplicaManagerBuilder setAddPartitionsToTransactionManager(AddPartitionsToTxnManager addPartitionsToTxnManager) { + this.addPartitionsToTxnManager = Optional.of(addPartitionsToTxnManager); + return this; + } + public ReplicaManager build() { if (config == null) config = new KafkaConfig(Collections.emptyMap()); if (metrics == null) metrics = new Metrics(); @@ -183,6 +190,7 @@ public ReplicaManager build() { OptionConverters.toScala(delayedDeleteRecordsPurgatory), OptionConverters.toScala(delayedElectLeaderPurgatory), OptionConverters.toScala(threadNamePrefix), - () -> brokerEpoch); + () -> brokerEpoch, + OptionConverters.toScala(addPartitionsToTxnManager)); } } diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala index 7eba74e18db6..0f37d2923856 100755 --- a/core/src/main/scala/kafka/cluster/Partition.scala +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -575,6 +575,10 @@ class Partition(val topicPartition: TopicPartition, } } + def hasOngoingTransaction(producerId: Long): Boolean = { + leaderLogIfLocal.exists(leaderLog => leaderLog.hasOngoingTransaction(producerId)) + } + // Return true if the future replica exists and it has caught up with the current replica for this partition // Only ReplicaAlterDirThread will call this method and ReplicaAlterDirThread should remove the partition // from its partitionStates if this method returns true diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala index 5db020d0adf0..f4529516e350 100644 --- a/core/src/main/scala/kafka/log/UnifiedLog.scala +++ b/core/src/main/scala/kafka/log/UnifiedLog.scala @@ -579,6 +579,11 @@ class UnifiedLog(@volatile var logStartOffset: Long, result } + def hasOngoingTransaction(producerId: Long): Boolean = lock synchronized { + val entry = producerStateManager.activeProducers.get(producerId) + entry != null && entry.currentTxnFirstOffset.isPresent + } + /** * The number of segments in the log. * Take care! this is an O(n) operation. diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala index 0c5cc373f5d2..34a860a20986 100644 --- a/core/src/main/scala/kafka/network/RequestChannel.scala +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -55,6 +55,7 @@ object RequestChannel extends Logging { sealed trait BaseRequest case object ShutdownRequest extends BaseRequest + case object WakeupRequest extends BaseRequest case class Session(principal: KafkaPrincipal, clientAddress: InetAddress) { val sanitizedUser: String = Sanitizer.sanitize(principal.getName) @@ -79,6 +80,9 @@ object RequestChannel extends Logging { } } + case class CallbackRequest(fun: () => Unit, + originalRequest: Request) extends BaseRequest + class Request(val processor: Int, val context: RequestContext, val startTimeNanos: Long, @@ -96,6 +100,8 @@ object RequestChannel extends Logging { @volatile var apiThrottleTimeMs = 0L @volatile var temporaryMemoryBytes = 0L @volatile var recordNetworkThreadTimeCallback: Option[Long => Unit] = None + @volatile var callbackRequestDequeueTimeNanos: Option[Long] = None + @volatile var callbackRequestCompleteTimeNanos: Option[Long] = None val session = Session(context.principal, context.clientAddress) @@ -227,8 +233,9 @@ object RequestChannel extends Logging { } val requestQueueTimeMs = nanosToMs(requestDequeueTimeNanos - startTimeNanos) - val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos) - val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos) + val callbackRequestTimeNanos = callbackRequestCompleteTimeNanos.getOrElse(0L) - callbackRequestDequeueTimeNanos.getOrElse(0L) + val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos + callbackRequestTimeNanos) + val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos - callbackRequestTimeNanos) val responseQueueTimeMs = nanosToMs(responseDequeueTimeNanos - responseCompleteTimeNanos) val responseSendTimeMs = nanosToMs(endTimeNanos - responseDequeueTimeNanos) val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos) @@ -354,6 +361,7 @@ class RequestChannel(val queueSize: Int, private val processors = new ConcurrentHashMap[Int, Processor]() val requestQueueSizeMetricName = metricNamePrefix.concat(RequestQueueSizeMetric) val responseQueueSizeMetricName = metricNamePrefix.concat(ResponseQueueSizeMetric) + private val callbackQueue = new ArrayBlockingQueue[BaseRequest](queueSize) metricsGroup.newGauge(requestQueueSizeMetricName, () => requestQueue.size) @@ -444,6 +452,9 @@ class RequestChannel(val queueSize: Int, request.responseCompleteTimeNanos = timeNanos if (request.apiLocalCompleteTimeNanos == -1L) request.apiLocalCompleteTimeNanos = timeNanos + // If this callback was executed after KafkaApis returned we will need to adjust the callback completion time here. + if (request.callbackRequestDequeueTimeNanos.isDefined && request.callbackRequestCompleteTimeNanos.isEmpty) + request.callbackRequestCompleteTimeNanos = Some(time.nanoseconds()) // For a given request, these may happen in addition to one in the previous section, skip updating the metrics case _: StartThrottlingResponse | _: EndThrottlingResponse => () } @@ -456,9 +467,21 @@ class RequestChannel(val queueSize: Int, } } - /** Get the next request or block until specified time has elapsed */ - def receiveRequest(timeout: Long): RequestChannel.BaseRequest = - requestQueue.poll(timeout, TimeUnit.MILLISECONDS) + /** Get the next request or block until specified time has elapsed + * Check the callback queue and execute first if present since these + * requests have already waited in line. */ + def receiveRequest(timeout: Long): RequestChannel.BaseRequest = { + val callbackRequest = callbackQueue.poll() + if (callbackRequest != null) + callbackRequest + else { + val request = requestQueue.poll(timeout, TimeUnit.MILLISECONDS) + request match { + case WakeupRequest => callbackQueue.poll() + case _ => request + } + } + } /** Get the next request or block until there is one */ def receiveRequest(): RequestChannel.BaseRequest = @@ -472,6 +495,7 @@ class RequestChannel(val queueSize: Int, def clear(): Unit = { requestQueue.clear() + callbackQueue.clear() } def shutdown(): Unit = { @@ -481,6 +505,12 @@ class RequestChannel(val queueSize: Int, def sendShutdownRequest(): Unit = requestQueue.put(ShutdownRequest) + def sendCallbackRequest(request: CallbackRequest): Unit = { + callbackQueue.put(request) + if (!requestQueue.offer(RequestChannel.WakeupRequest)) + trace("Wakeup request could not be added to queue. This means queue is full, so we will still process callback.") + } + } object RequestMetrics { diff --git a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala new file mode 100644 index 000000000000..c1c82eb3c886 --- /dev/null +++ b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala @@ -0,0 +1,180 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} +import org.apache.kafka.clients.{ClientResponse, NetworkClient, RequestCompletionHandler} +import org.apache.kafka.common.{Node, TopicPartition} +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTransaction, AddPartitionsToTxnTransactionCollection} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, AddPartitionsToTxnResponse} +import org.apache.kafka.common.utils.Time + +import scala.collection.mutable + +object AddPartitionsToTxnManager { + type AppendCallback = Map[TopicPartition, Errors] => Unit +} + + +class TransactionDataAndCallbacks(val transactionData: AddPartitionsToTxnTransactionCollection, + val callbacks: mutable.Map[String, AddPartitionsToTxnManager.AppendCallback]) + + +class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time: Time) + extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + config.brokerId, client, config.requestTimeoutMs, time) { + + private val inflightNodes = mutable.HashSet[Node]() + private val nodesToTransactions = mutable.Map[Node, TransactionDataAndCallbacks]() + + def addTxnData(node: Node, transactionData: AddPartitionsToTxnTransaction, callback: AddPartitionsToTxnManager.AppendCallback): Unit = { + nodesToTransactions.synchronized { + // Check if we have already have either node or individual transaction. Add the Node if it isn't there. + val currentNodeAndTransactionData = nodesToTransactions.getOrElseUpdate(node, + new TransactionDataAndCallbacks( + new AddPartitionsToTxnTransactionCollection(1), + mutable.Map[String, AddPartitionsToTxnManager.AppendCallback]())) + + val currentTransactionData = currentNodeAndTransactionData.transactionData.find(transactionData.transactionalId) + + // Check if we already have txn ID -- if the epoch is bumped, return invalid producer epoch, otherwise, the client likely disconnected and + // reconnected so return the retriable network exception. + if (currentTransactionData != null) { + val error = if (currentTransactionData.producerEpoch() < transactionData.producerEpoch()) + Errors.INVALID_PRODUCER_EPOCH + else + Errors.NETWORK_EXCEPTION + val topicPartitionsToError = mutable.Map[TopicPartition, Errors]() + currentTransactionData.topics().forEach { topic => + topic.partitions().forEach { partition => + topicPartitionsToError.put(new TopicPartition(topic.name(), partition), error) + } + } + val oldCallback = currentNodeAndTransactionData.callbacks(transactionData.transactionalId()) + currentNodeAndTransactionData.transactionData.remove(transactionData) + oldCallback(topicPartitionsToError.toMap) + } + currentNodeAndTransactionData.transactionData.add(transactionData) + currentNodeAndTransactionData.callbacks.put(transactionData.transactionalId(), callback) + wakeup() + } + } + + private class AddPartitionsToTxnHandler(node: Node, transactionDataAndCallbacks: TransactionDataAndCallbacks) extends RequestCompletionHandler { + override def onComplete(response: ClientResponse): Unit = { + // Note: Synchronization is not needed on inflightNodes since it is always accessed from this thread. + inflightNodes.remove(node) + if (response.authenticationException() != null) { + error(s"AddPartitionsToTxnRequest failed for node ${response.destination()} with an " + + "authentication exception.", response.authenticationException) + transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) => + callback(buildErrorMap(txnId, Errors.forException(response.authenticationException()).code())) + } + } else if (response.versionMismatch != null) { + // We may see unsupported version exception if we try to send a verify only request to a broker that can't handle it. + // In this case, skip verification. + warn(s"AddPartitionsToTxnRequest failed for node ${response.destination()} with invalid version exception. This suggests verification is not supported." + + s"Continuing handling the produce request.") + transactionDataAndCallbacks.callbacks.values.foreach(_(Map.empty)) + } else if (response.wasDisconnected() || response.wasTimedOut()) { + warn(s"AddPartitionsToTxnRequest failed for node ${response.destination()} with a network exception.") + transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) => + callback(buildErrorMap(txnId, Errors.NETWORK_EXCEPTION.code())) + } + } else { + val addPartitionsToTxnResponseData = response.responseBody.asInstanceOf[AddPartitionsToTxnResponse].data + if (addPartitionsToTxnResponseData.errorCode != 0) { + error(s"AddPartitionsToTxnRequest for node ${response.destination()} returned with error ${Errors.forCode(addPartitionsToTxnResponseData.errorCode)}.") + // The client should not be exposed to CLUSTER_AUTHORIZATION_FAILED so modify the error to signify the verification did not complete. + // Older clients return with INVALID_RECORD and newer ones can return with INVALID_TXN_STATE. + val finalError = if (addPartitionsToTxnResponseData.errorCode() == Errors.CLUSTER_AUTHORIZATION_FAILED.code) + Errors.INVALID_RECORD.code + else + addPartitionsToTxnResponseData.errorCode() + + transactionDataAndCallbacks.callbacks.foreach { case (txnId, callback) => + callback(buildErrorMap(txnId, finalError)) + } + } else { + addPartitionsToTxnResponseData.resultsByTransaction().forEach { transactionResult => + val unverified = mutable.Map[TopicPartition, Errors]() + transactionResult.topicResults().forEach { topicResult => + topicResult.resultsByPartition().forEach { partitionResult => + val tp = new TopicPartition(topicResult.name(), partitionResult.partitionIndex()) + if (partitionResult.partitionErrorCode() != Errors.NONE.code()) { + // Producers expect to handle INVALID_PRODUCER_EPOCH in this scenario. + val code = + if (partitionResult.partitionErrorCode() == Errors.PRODUCER_FENCED.code) + Errors.INVALID_PRODUCER_EPOCH.code + // Older clients return INVALID_RECORD + else if (partitionResult.partitionErrorCode() == Errors.INVALID_TXN_STATE.code) + Errors.INVALID_RECORD.code + else + partitionResult.partitionErrorCode() + unverified.put(tp, Errors.forCode(code)) + } + } + } + val callback = transactionDataAndCallbacks.callbacks(transactionResult.transactionalId()) + callback(unverified.toMap) + } + } + } + wakeup() + } + + private def buildErrorMap(transactionalId: String, errorCode: Short): Map[TopicPartition, Errors] = { + val errors = new mutable.HashMap[TopicPartition, Errors]() + val transactionData = transactionDataAndCallbacks.transactionData.find(transactionalId) + transactionData.topics.forEach { topic => + topic.partitions().forEach { partition => + errors.put(new TopicPartition(topic.name(), partition), Errors.forCode(errorCode)) + } + } + errors.toMap + } + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + + // build and add requests to queue + val buffer = mutable.Buffer[RequestAndCompletionHandler]() + val currentTimeMs = time.milliseconds() + val removedNodes = mutable.Set[Node]() + nodesToTransactions.synchronized { + nodesToTransactions.foreach { case (node, transactionDataAndCallbacks) => + if (!inflightNodes.contains(node)) { + buffer += RequestAndCompletionHandler( + currentTimeMs, + node, + AddPartitionsToTxnRequest.Builder.forBroker(transactionDataAndCallbacks.transactionData), + new AddPartitionsToTxnHandler(node, transactionDataAndCallbacks) + ) + + removedNodes.add(node) + } + } + removedNodes.foreach { node => + inflightNodes.add(node) + nodesToTransactions.remove(node) + } + } + buffer + } + +} diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala index 605455b82736..9452f0afd938 100644 --- a/core/src/main/scala/kafka/server/BrokerServer.scala +++ b/core/src/main/scala/kafka/server/BrokerServer.scala @@ -27,6 +27,7 @@ import kafka.raft.KafkaRaftManager import kafka.security.CredentialProvider import kafka.server.metadata.{BrokerMetadataPublisher, ClientQuotaMetadataManager, DynamicClientQuotaPublisher, DynamicConfigPublisher, KRaftMetadataCache, ScramPublisher} import kafka.utils.CoreUtils +import org.apache.kafka.clients.NetworkClient import org.apache.kafka.common.feature.SupportedVersionRange import org.apache.kafka.common.message.ApiMessageType.ListenerType import org.apache.kafka.common.message.BrokerRegistrationRequestData.{Listener, ListenerCollection} @@ -250,6 +251,10 @@ class BrokerServer( ) alterPartitionManager.start() + val addPartitionsLogContext = new LogContext(s"[AddPartitionsToTxnManager broker=${config.brokerId}]") + val addPartitionsToTxnNetworkClient: NetworkClient = NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, addPartitionsLogContext) + val addPartitionsToTxnManager: AddPartitionsToTxnManager = new AddPartitionsToTxnManager(config, addPartitionsToTxnNetworkClient, time) + this._replicaManager = new ReplicaManager( config = config, metrics = metrics, @@ -265,7 +270,8 @@ class BrokerServer( isShuttingDown = isShuttingDown, zkClient = None, threadNamePrefix = None, // The ReplicaManager only runs on the broker, and already includes the ID in thread names. - brokerEpochSupplier = () => lifecycleManager.brokerEpoch + brokerEpochSupplier = () => lifecycleManager.brokerEpoch, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager) ) /* start token manager */ diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index bc55a7f9ccf9..6a4971c44d35 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -670,6 +670,12 @@ class KafkaApis(val requestChannel: RequestChannel, else { val internalTopicsAllowed = request.header.clientId == AdminUtils.AdminClientId + val transactionStatePartition = + if (produceRequest.transactionalId() == null) + None + else + Some(txnCoordinator.partitionFor(produceRequest.transactionalId())) + // call the replica manager to append messages to the replicas replicaManager.appendRecords( timeout = produceRequest.timeout.toLong, @@ -679,7 +685,9 @@ class KafkaApis(val requestChannel: RequestChannel, entriesPerPartition = authorizedRequestInfo, requestLocal = requestLocal, responseCallback = sendResponseCallback, - recordConversionStatsCallback = processingStatsCallback) + recordConversionStatsCallback = processingStatsCallback, + transactionalId = produceRequest.transactionalId(), + transactionStatePartition = transactionStatePartition) // if the request is put into the purgatory, it will have a held reference and hence cannot be garbage collected; // hence we clear its data here in order to let GC reclaim its memory since it is already appended to log @@ -2432,6 +2440,10 @@ class KafkaApis(val requestChannel: RequestChannel, txns.forEach { transaction => val transactionalId = transaction.transactionalId + + if (transactionalId == null) + throw new InvalidRequestException("Transactional ID can not be null in request.") + val partitionsToAdd = partitionsByTransaction.get(transactionalId).asScala // Versions < 4 come from clients and must be authorized to write for the given transaction and for the given topics. @@ -2478,7 +2490,6 @@ class KafkaApis(val requestChannel: RequestChannel, addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, finalError)) } - if (!transaction.verifyOnly) { txnCoordinator.handleAddPartitionsToTransaction(transactionalId, transaction.producerId, diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 44858e4b7642..c43d329236c8 100755 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -200,6 +200,8 @@ object Defaults { val TransactionsAbortTimedOutTransactionsCleanupIntervalMS = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs val TransactionsRemoveExpiredTransactionsCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs + val TransactionPartitionVerificationEnable = true + val ProducerIdExpirationMs = 86400000 val ProducerIdExpirationCheckIntervalMs = 600000 @@ -541,6 +543,8 @@ object KafkaConfig { val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms" val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms" + val TransactionPartitionVerificationEnableProp = "transaction.partition.verification.enable" + val ProducerIdExpirationMsProp = ProducerStateManagerConfig.PRODUCER_ID_EXPIRATION_MS val ProducerIdExpirationCheckIntervalMsProp = "producer.id.expiration.check.interval.ms" @@ -1009,6 +1013,8 @@ object KafkaConfig { val TransactionsTopicSegmentBytesDoc = "The transaction topic segment bytes should be kept relatively small in order to facilitate faster log compaction and cache loads" val TransactionsAbortTimedOutTransactionsIntervalMsDoc = "The interval at which to rollback transactions that have timed out" val TransactionsRemoveExpiredTransactionsIntervalMsDoc = "The interval at which to remove transactions that have expired due to transactional.id.expiration.ms passing" + + val TransactionPartitionVerificationEnableDoc = "Enable verification that checks that the partition has been added to the transaction before writing transactional records to the partition" val ProducerIdExpirationMsDoc = "The time in ms that a topic partition leader will wait before expiring producer IDs. Producer IDs will not expire while a transaction associated to them is still ongoing. " + "Note that producer IDs may expire sooner if the last write from the producer ID is deleted due to the topic's retention settings. Setting this value the same or higher than " + @@ -1357,6 +1363,8 @@ object KafkaConfig { .define(TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsAbortTimedOutTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsAbortTimedOutTransactionsIntervalMsDoc) .define(TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, INT, Defaults.TransactionsRemoveExpiredTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsRemoveExpiredTransactionsIntervalMsDoc) + .define(TransactionPartitionVerificationEnableProp, BOOLEAN, Defaults.TransactionPartitionVerificationEnable, LOW, TransactionPartitionVerificationEnableDoc) + .define(ProducerIdExpirationMsProp, INT, Defaults.ProducerIdExpirationMs, atLeast(1), LOW, ProducerIdExpirationMsDoc) // Configuration for testing only as default value should be sufficient for typical usage .defineInternal(ProducerIdExpirationCheckIntervalMsProp, INT, Defaults.ProducerIdExpirationCheckIntervalMs, atLeast(1), LOW, ProducerIdExpirationCheckIntervalMsDoc) @@ -1945,6 +1953,8 @@ class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynami val transactionTopicSegmentBytes = getInt(KafkaConfig.TransactionsTopicSegmentBytesProp) val transactionAbortTimedOutTransactionCleanupIntervalMs = getInt(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp) val transactionRemoveExpiredTransactionalIdCleanupIntervalMs = getInt(KafkaConfig.TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp) + + val transactionPartitionVerificationEnable = getBoolean(KafkaConfig.TransactionPartitionVerificationEnableProp) val producerIdExpirationMs = getInt(KafkaConfig.ProducerIdExpirationMsProp) val producerIdExpirationCheckIntervalMs = getInt(KafkaConfig.ProducerIdExpirationCheckIntervalMsProp) diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala index bd5ea797fe22..325c288c58e2 100755 --- a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala +++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala @@ -19,6 +19,7 @@ package kafka.server import kafka.network._ import kafka.utils._ +import kafka.server.KafkaRequestHandler.{threadCurrentRequest, threadRequestChannel} import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -35,6 +36,44 @@ trait ApiRequestHandler { def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit } +object KafkaRequestHandler { + // Support for scheduling callbacks on a request thread. + private val threadRequestChannel = new ThreadLocal[RequestChannel] + private val threadCurrentRequest = new ThreadLocal[RequestChannel.Request] + + // For testing + @volatile private var bypassThreadCheck = false + def setBypassThreadCheck(bypassCheck: Boolean): Unit = { + bypassThreadCheck = bypassCheck + } + + def currentRequestOnThread(): RequestChannel.Request = { + threadCurrentRequest.get() + } + + /** + * Wrap callback to schedule it on a request thread. + * NOTE: this function must be called on a request thread. + * @param fun Callback function to execute + * @return Wrapped callback that would execute `fun` on a request thread + */ + def wrap[T](fun: T => Unit): T => Unit = { + val requestChannel = threadRequestChannel.get() + val currentRequest = threadCurrentRequest.get() + if (requestChannel == null || currentRequest == null) { + if (!bypassThreadCheck) + throw new IllegalStateException("Attempted to reschedule to request handler thread from non-request handler thread.") + T => fun(T) + } else { + T => { + // The requestChannel and request are captured in this lambda, so when it's executed on the callback thread + // we can re-schedule the original callback on a request thread and update the metrics accordingly. + requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() => fun(T), currentRequest)) + } + } + } +} + /** * A thread that answers kafka requests. */ @@ -51,6 +90,7 @@ class KafkaRequestHandler(id: Int, @volatile private var stopped = false def run(): Unit = { + threadRequestChannel.set(requestChannel) while (!stopped) { // We use a single meter for aggregate idle percentage for the thread pool. // Since meter is calculated as total_recorded_value / time_window and @@ -69,10 +109,39 @@ class KafkaRequestHandler(id: Int, completeShutdown() return + case callback: RequestChannel.CallbackRequest => + try { + val originalRequest = callback.originalRequest + + // If we've already executed a callback for this request, reset the times and subtract the callback time from the + // new dequeue time. This will allow calculation of multiple callback times. + // Otherwise, set dequeue time to now. + if (originalRequest.callbackRequestDequeueTimeNanos.isDefined) { + val prevCallbacksTimeNanos = originalRequest.callbackRequestCompleteTimeNanos.getOrElse(0L) - originalRequest.callbackRequestDequeueTimeNanos.getOrElse(0L) + originalRequest.callbackRequestCompleteTimeNanos = None + originalRequest.callbackRequestDequeueTimeNanos = Some(time.nanoseconds() - prevCallbacksTimeNanos) + } else { + originalRequest.callbackRequestDequeueTimeNanos = Some(time.nanoseconds()) + } + + threadCurrentRequest.set(originalRequest) + callback.fun() + if (originalRequest.callbackRequestCompleteTimeNanos.isEmpty) + originalRequest.callbackRequestCompleteTimeNanos = Some(time.nanoseconds()) + } catch { + case e: FatalExitError => + completeShutdown() + Exit.exit(e.statusCode) + case e: Throwable => error("Exception when handling request", e) + } finally { + threadCurrentRequest.remove() + } + case request: RequestChannel.Request => try { request.requestDequeueTimeNanos = endTime trace(s"Kafka request handler $id on broker $brokerId handling request $request") + threadCurrentRequest.set(request) apis.handle(request, requestLocal) } catch { case e: FatalExitError => @@ -80,9 +149,14 @@ class KafkaRequestHandler(id: Int, Exit.exit(e.statusCode) case e: Throwable => error("Exception when handling request", e) } finally { + threadCurrentRequest.remove() request.releaseBuffer() } + case RequestChannel.WakeupRequest => + // We should handle this in receiveRequest by polling callbackQueue. + warn("Received a wakeup request outside of typical usage.") + case null => // continue } } diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index d2fbb6073242..7a3d2f225b6c 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -611,6 +611,10 @@ class KafkaServer( } protected def createReplicaManager(isShuttingDown: AtomicBoolean): ReplicaManager = { + val addPartitionsLogContext = new LogContext(s"[AddPartitionsToTxnManager broker=${config.brokerId}]") + val addPartitionsToTxnNetworkClient: NetworkClient = NetworkUtils.buildNetworkClient("AddPartitionsManager", config, metrics, time, addPartitionsLogContext) + val addPartitionsToTxnManager: AddPartitionsToTxnManager = new AddPartitionsToTxnManager(config, addPartitionsToTxnNetworkClient, time) + new ReplicaManager( metrics = metrics, config = config, @@ -626,7 +630,8 @@ class KafkaServer( isShuttingDown = isShuttingDown, zkClient = Some(zkClient), threadNamePrefix = threadNamePrefix, - brokerEpochSupplier = brokerEpochSupplier) + brokerEpochSupplier = brokerEpochSupplier, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) } private def initZkClient(time: Time): Unit = { diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 0039611c1cbc..e8bb496436f0 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -31,6 +31,7 @@ import kafka.utils._ import kafka.zk.KafkaZkClient import org.apache.kafka.common.errors._ import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic, AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction} import org.apache.kafka.common.message.DeleteRecordsResponseData.DeleteRecordsPartitionResult import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState import org.apache.kafka.common.message.LeaderAndIsrResponseData.{LeaderAndIsrPartitionError, LeaderAndIsrTopicError} @@ -194,7 +195,8 @@ class ReplicaManager(val config: KafkaConfig, delayedDeleteRecordsPurgatoryParam: Option[DelayedOperationPurgatory[DelayedDeleteRecords]] = None, delayedElectLeaderPurgatoryParam: Option[DelayedOperationPurgatory[DelayedElectLeader]] = None, threadNamePrefix: Option[String] = None, - val brokerEpochSupplier: () => Long = () => -1 + val brokerEpochSupplier: () => Long = () => -1, + addPartitionsToTxnManager: Option[AddPartitionsToTxnManager] = None ) extends Logging { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -312,6 +314,7 @@ class ReplicaManager(val config: KafkaConfig, val haltBrokerOnFailure = metadataCache.metadataVersion().isLessThan(IBP_1_0_IV0) logDirFailureHandler = new LogDirFailureHandler("LogDirFailureHandler", haltBrokerOnFailure) logDirFailureHandler.start() + addPartitionsToTxnManager.foreach(_.start()) } private def maybeRemoveTopicMetrics(topic: String): Unit = { @@ -607,6 +610,18 @@ class ReplicaManager(val config: KafkaConfig, * Noted that all pending delayed check operations are stored in a queue. All callers to ReplicaManager.appendRecords() * are expected to call ActionQueue.tryCompleteActions for all affected partitions, without holding any conflicting * locks. + * + * @param timeout maximum time we will wait to append before returning + * @param requiredAcks number of replicas who must acknowledge the append before sending the response + * @param internalTopicsAllowed boolean indicating whether internal topics can be appended to + * @param origin source of the append request (ie, client, replication, coordinator) + * @param entriesPerPartition the records per partition to be appended + * @param responseCallback callback for sending the response + * @param delayedProduceLock lock for the delayed actions + * @param recordConversionStatsCallback callback for updating stats on record conversions + * @param requestLocal container for the stateful instances scoped to this request + * @param transactionalId transactional ID if the request is from a producer and the producer is transactional + * @param transactionStatePartition partition that holds the transactional state if transactionalId is present */ def appendRecords(timeout: Long, requiredAcks: Short, @@ -616,66 +631,128 @@ class ReplicaManager(val config: KafkaConfig, responseCallback: Map[TopicPartition, PartitionResponse] => Unit, delayedProduceLock: Option[Lock] = None, recordConversionStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit = _ => (), - requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + requestLocal: RequestLocal = RequestLocal.NoCaching, + transactionalId: String = null, + transactionStatePartition: Option[Int] = None): Unit = { if (isValidRequiredAcks(requiredAcks)) { val sTime = time.milliseconds - val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed, - origin, entriesPerPartition, requiredAcks, requestLocal) - debug("Produce to local log in %d ms".format(time.milliseconds - sTime)) - - val produceStatus = localProduceResults.map { case (topicPartition, result) => - topicPartition -> ProducePartitionStatus( - result.info.lastOffset + 1, // required offset - new PartitionResponse( - result.error, - result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L), - result.info.logAppendTime, - result.info.logStartOffset, - result.info.recordErrors, - result.info.errorMessage + + val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition) = + if (transactionStatePartition.isEmpty || !config.transactionPartitionVerificationEnable) + (entriesPerPartition, Map.empty) + else + entriesPerPartition.partition { case (topicPartition, records) => + getPartitionOrException(topicPartition).hasOngoingTransaction(records.firstBatch().producerId()) + } + + def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = { + val verifiedEntries = + if (unverifiedEntries.isEmpty) + allEntries + else + allEntries.filter { case (tp, _) => + !unverifiedEntries.contains(tp) + } + + val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed, + origin, verifiedEntries, requiredAcks, requestLocal) + debug("Produce to local log in %d ms".format(time.milliseconds - sTime)) + + val unverifiedResults = unverifiedEntries.map { case (topicPartition, error) => + // NOTE: Older clients return INVALID_RECORD, but newer clients will return INVALID_TXN_STATE + val message = if (error.equals(Errors.INVALID_RECORD)) "Partition was not added to the transaction" else error.message() + topicPartition -> LogAppendResult( + LogAppendInfo.UNKNOWN_LOG_APPEND_INFO, + Some(error.exception(message)) ) - ) // response status - } + } + + val allResults = localProduceResults ++ unverifiedResults + + val produceStatus = allResults.map { case (topicPartition, result) => + topicPartition -> ProducePartitionStatus( + result.info.lastOffset + 1, // required offset + new PartitionResponse( + result.error, + result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L), + result.info.logAppendTime, + result.info.logStartOffset, + result.info.recordErrors, + result.info.errorMessage + ) + ) // response status + } - actionQueue.add { - () => - localProduceResults.foreach { - case (topicPartition, result) => - val requestKey = TopicPartitionOperationKey(topicPartition) - result.info.leaderHwChange match { - case LeaderHwChange.INCREASED => - // some delayed operations may be unblocked after HW changed - delayedProducePurgatory.checkAndComplete(requestKey) - delayedFetchPurgatory.checkAndComplete(requestKey) - delayedDeleteRecordsPurgatory.checkAndComplete(requestKey) - case LeaderHwChange.SAME => - // probably unblock some follower fetch requests since log end offset has been updated - delayedFetchPurgatory.checkAndComplete(requestKey) - case LeaderHwChange.NONE => + actionQueue.add { + () => + allResults.foreach { + case (topicPartition, result) => + val requestKey = TopicPartitionOperationKey(topicPartition) + result.info.leaderHwChange match { + case LeaderHwChange.INCREASED => + // some delayed operations may be unblocked after HW changed + delayedProducePurgatory.checkAndComplete(requestKey) + delayedFetchPurgatory.checkAndComplete(requestKey) + delayedDeleteRecordsPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.SAME => + // probably unblock some follower fetch requests since log end offset has been updated + delayedFetchPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.NONE => // nothing - } - } - } + } + } + } + + recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats }) - recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats }) + if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) { + // create delayed produce operation + val produceMetadata = ProduceMetadata(requiredAcks, produceStatus) + val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock) - if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults)) { - // create delayed produce operation - val produceMetadata = ProduceMetadata(requiredAcks, produceStatus) - val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock) + // create a list of (topic, partition) pairs to use as keys for this delayed produce operation + val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq - // create a list of (topic, partition) pairs to use as keys for this delayed produce operation - val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq + // try to complete the request immediately, otherwise put it into the purgatory + // this is because while the delayed produce operation is being created, new + // requests may arrive and hence make this operation completable. + delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys) - // try to complete the request immediately, otherwise put it into the purgatory - // this is because while the delayed produce operation is being created, new - // requests may arrive and hence make this operation completable. - delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys) + } else { + // we can respond immediately + val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus } + responseCallback(produceResponseStatus) + } + } + if (notYetVerifiedEntriesPerPartition.isEmpty || addPartitionsToTxnManager.isEmpty) { + appendEntries(verifiedEntriesPerPartition)(Map.empty) } else { - // we can respond immediately - val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus } - responseCallback(produceResponseStatus) + // For unverified entries, send a request to verify. When verified, the append process will proceed via the callback. + val (error, node) = getTransactionCoordinator(transactionStatePartition.get) + + if (error != Errors.NONE) { + throw error.exception() // Can throw coordinator not available -- which is retriable + } + + val topicGrouping = notYetVerifiedEntriesPerPartition.keySet.groupBy(tp => tp.topic()) + val topicCollection = new AddPartitionsToTxnTopicCollection() + topicGrouping.foreach { case (topic, tps) => + topicCollection.add(new AddPartitionsToTxnTopic() + .setName(topic) + .setPartitions(tps.map(tp => Integer.valueOf(tp.partition())).toList.asJava)) + } + + // map not yet verified partitions to a request object + val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() + val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(batchInfo.producerId()) + .setProducerEpoch(batchInfo.producerEpoch()) + .setVerifyOnly(true) + .setTopics(topicCollection) + + addPartitionsToTxnManager.foreach(_.addTxnData(node, notYetVerifiedTransaction, KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_)))) } } else { // If required.acks is outside accepted range, something is wrong with the client @@ -1957,6 +2034,7 @@ class ReplicaManager(val config: KafkaConfig, checkpointHighWatermarks() replicaSelectorOpt.foreach(_.close) removeAllTopicMetrics() + addPartitionsToTxnManager.foreach(_.shutdown()) info("Shut down completely") } @@ -2295,4 +2373,32 @@ class ReplicaManager(val config: KafkaConfig, } } } + + private[server] def getTransactionCoordinator(partition: Int): (Errors, Node) = { + val listenerName = config.interBrokerListenerName + + val topicMetadata = metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), listenerName) + + if (topicMetadata.headOption.isEmpty) { + // If topic is not created, then the transaction is definitely not started. + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } else { + if (topicMetadata.head.errorCode != Errors.NONE.code) { + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } else { + val coordinatorEndpoint = topicMetadata.head.partitions.asScala + .find(_.partitionIndex == partition) + .filter(_.leaderId != MetadataResponse.NO_LEADER_ID) + .flatMap(metadata => metadataCache. + getAliveBrokerNode(metadata.leaderId, listenerName)) + + coordinatorEndpoint match { + case Some(endpoint) => + (Errors.NONE, endpoint) + case _ => + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } + } + } + } } diff --git a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala new file mode 100644 index 000000000000..d29267e2a904 --- /dev/null +++ b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import com.yammer.metrics.core.Meter +import kafka.network.RequestChannel +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{RequestContext, RequestHeader} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.mockito.ArgumentMatchers +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} + +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +class KafkaRequestHandlerTest { + + @Test + def testCallbackTiming(): Unit = { + val time = new MockTime() + val startTime = time.nanoseconds() + val metrics = new RequestChannel.Metrics(None) + val requestChannel = new RequestChannel(10, "", time, metrics) + val apiHandler = mock(classOf[ApiRequestHandler]) + + // Make unsupported API versions request to avoid having to parse a real request + val requestHeader = mock(classOf[RequestHeader]) + when(requestHeader.apiKey()).thenReturn(ApiKeys.API_VERSIONS) + when(requestHeader.apiVersion()).thenReturn(0.toShort) + + val context = new RequestContext(requestHeader, "0", mock(classOf[InetAddress]), new KafkaPrincipal("", ""), + new ListenerName(""), SecurityProtocol.PLAINTEXT, mock(classOf[ClientInformation]), false) + val request = new RequestChannel.Request(0, context, time.nanoseconds(), + mock(classOf[MemoryPool]), mock(classOf[ByteBuffer]), metrics) + + val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time) + + requestChannel.sendRequest(request) + + def callback(ms: Int): Unit = { + time.sleep(ms) + handler.stop() + } + + when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ => + time.sleep(2) + KafkaRequestHandler.wrap(callback(_: Int))(1) + request.apiLocalCompleteTimeNanos = time.nanoseconds + } + + handler.run() + + assertEquals(startTime, request.requestDequeueTimeNanos) + assertEquals(startTime + 2000000, request.apiLocalCompleteTimeNanos) + assertEquals(Some(startTime + 2000000), request.callbackRequestDequeueTimeNanos) + assertEquals(Some(startTime + 3000000), request.callbackRequestCompleteTimeNanos) + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala index 47b4a99274b8..b7ff476988fb 100644 --- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala +++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala @@ -1184,11 +1184,24 @@ class PartitionTest extends AbstractPartitionTest { builder.build() } + def createIdempotentRecords(records: Iterable[SimpleRecord], + baseOffset: Long, + baseSequence: Int = 0, + producerId: Long = 1L): MemoryRecords = { + val producerEpoch = 0.toShort + val isTransactional = false + val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) + val builder = MemoryRecords.builder(buf, CompressionType.NONE, baseOffset, producerId, + producerEpoch, baseSequence, isTransactional) + records.foreach(builder.append) + builder.build() + } + def createTransactionalRecords(records: Iterable[SimpleRecord], - baseOffset: Long): MemoryRecords = { - val producerId = 1L + baseOffset: Long, + baseSequence: Int = 0, + producerId: Long = 1L): MemoryRecords = { val producerEpoch = 0.toShort - val baseSequence = 0 val isTransactional = true val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) val builder = MemoryRecords.builder(buf, CompressionType.NONE, baseOffset, producerId, @@ -3232,6 +3245,47 @@ class PartitionTest extends AbstractPartitionTest { listener.verify(expectedHighWatermark = partition.localLogOrException.logEndOffset) } + @Test + def testHasOngoingTransaction(): Unit = { + val controllerEpoch = 0 + val leaderEpoch = 5 + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val isr = replicas + val producerId = 22L + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + assertTrue(partition.makeLeader(new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setPartitionEpoch(1) + .setReplicas(replicas) + .setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed") + assertEquals(leaderEpoch, partition.getLeaderEpoch) + assertFalse(partition.hasOngoingTransaction(producerId)) + + val idempotentRecords = createIdempotentRecords(List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes), + new SimpleRecord("k3".getBytes, "v3".getBytes)), + baseOffset = 0L, + producerId = producerId) + partition.appendRecordsToLeader(idempotentRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching) + assertFalse(partition.hasOngoingTransaction(producerId)) + + val transactionRecords = createTransactionalRecords(List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes), + new SimpleRecord("k3".getBytes, "v3".getBytes)), + baseOffset = 0L, + baseSequence = 3, + producerId = producerId) + partition.appendRecordsToLeader(transactionRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching) + assertTrue(partition.hasOngoingTransaction(producerId)) + } + private def makeLeader( topicId: Option[Uuid], controllerEpoch: Int, diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala index 6d159f18a5ad..492e38439068 100644 --- a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala @@ -178,7 +178,9 @@ object AbstractCoordinatorConcurrencyTest { responseCallback: Map[TopicPartition, PartitionResponse] => Unit, delayedProduceLock: Option[Lock] = None, processingStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit = _ => (), - requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + requestLocal: RequestLocal = RequestLocal.NoCaching, + transactionalId: String = null, + transactionStatePartition: Option[Int]): Unit = { if (entriesPerPartition.isEmpty) return diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala index cbb1e5a7d60e..403a0c0885ae 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala @@ -3863,6 +3863,8 @@ class GroupCoordinatorTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any() )).thenAnswer(_ => { capturedArgument.getValue.apply( @@ -3897,6 +3899,8 @@ class GroupCoordinatorTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any())).thenAnswer(_ => { capturedArgument.getValue.apply( Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) -> @@ -4041,6 +4045,8 @@ class GroupCoordinatorTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) ).thenAnswer(_ => { capturedArgument.getValue.apply( @@ -4074,6 +4080,8 @@ class GroupCoordinatorTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) ).thenAnswer(_ => { capturedArgument.getValue.apply( diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala index 333aeac5af71..b74e76ae8791 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -1179,6 +1179,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) } @@ -1215,6 +1217,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) } @@ -1289,6 +1293,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) // Will update sensor after commit assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count")) @@ -1329,6 +1335,8 @@ class GroupMetadataManagerTest { capturedResponseCallback.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> @@ -1387,6 +1395,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) } @@ -1435,6 +1445,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) } @@ -1585,6 +1597,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager).getMagic(any()) assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count")) @@ -1690,6 +1704,8 @@ class GroupMetadataManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) verify(replicaManager, times(2)).getMagic(any()) } @@ -2584,6 +2600,8 @@ class GroupMetadataManagerTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) capturedArgument } @@ -2599,6 +2617,8 @@ class GroupMetadataManagerTest { capturedCallback.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) ).thenAnswer(_ => { capturedCallback.getValue.apply( diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 9dec3b055a16..78bf5743931f 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -654,6 +654,8 @@ class TransactionStateManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any() ) @@ -697,6 +699,8 @@ class TransactionStateManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any() ) @@ -737,6 +741,8 @@ class TransactionStateManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) assertEquals(Set.empty, listExpirableTransactionalIds()) @@ -793,6 +799,8 @@ class TransactionStateManagerTest { any(), any[Option[ReentrantLock]], any(), + any(), + any(), any() ) @@ -941,6 +949,8 @@ class TransactionStateManagerTest { callbackCapture.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any() )).thenAnswer(_ => callbackCapture.getValue.apply( recordsCapture.getValue.map { case (topicPartition, records) => @@ -1091,6 +1101,8 @@ class TransactionStateManagerTest { capturedArgument.capture(), any[Option[ReentrantLock]], any(), + any(), + any(), any()) ).thenAnswer(_ => capturedArgument.getValue.apply( Map(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) -> diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index cc345ee4bdc2..28b797f7c167 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -137,6 +137,8 @@ class SocketServerTest { private def receiveRequest(channel: RequestChannel, timeout: Long = 2000L): RequestChannel.Request = { channel.receiveRequest(timeout) match { case request: RequestChannel.Request => request + case RequestChannel.WakeupRequest => throw new AssertionError("Unexpected wakeup received") + case request: RequestChannel.CallbackRequest => throw new AssertionError("Unexpected callback received") case RequestChannel.ShutdownRequest => throw new AssertionError("Unexpected shutdown received") case null => throw new AssertionError("receiveRequest timed out") } diff --git a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala new file mode 100644 index 000000000000..8917a332d0bf --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala @@ -0,0 +1,245 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unit.kafka.server + +import kafka.common.RequestAndCompletionHandler +import kafka.server.{AddPartitionsToTxnManager, KafkaConfig} +import kafka.utils.TestUtils +import org.apache.kafka.clients.{ClientResponse, NetworkClient} +import org.apache.kafka.common.errors.{AuthenticationException, SaslAuthenticationException, UnsupportedVersionException} +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic, AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction, AddPartitionsToTxnTransactionCollection} +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection +import org.apache.kafka.common.{Node, TopicPartition} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AbstractResponse, AddPartitionsToTxnRequest, AddPartitionsToTxnResponse} +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.mockito.Mockito.mock + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +class AddPartitionsToTxnManagerTest { + private val networkClient: NetworkClient = mock(classOf[NetworkClient]) + + private val time = new MockTime + + private var addPartitionsToTxnManager: AddPartitionsToTxnManager = _ + + val topic = "foo" + val topicPartitions = List(new TopicPartition(topic, 1), new TopicPartition(topic, 2), new TopicPartition(topic, 3)) + + private val node0 = new Node(0, "host1", 0) + private val node1 = new Node(1, "host2", 1) + private val node2 = new Node(2, "host2", 2) + + private val transactionalId1 = "txn1" + private val transactionalId2 = "txn2" + private val transactionalId3 = "txn3" + + private val producerId1 = 0L + private val producerId2 = 1L + private val producerId3 = 2L + + private val authenticationErrorResponse = clientResponse(null, authException = new SaslAuthenticationException("")) + private val versionMismatchResponse = clientResponse(null, mismatchException = new UnsupportedVersionException("")) + private val disconnectedResponse = clientResponse(null, disconnected = true) + + @BeforeEach + def setup(): Unit = { + addPartitionsToTxnManager = new AddPartitionsToTxnManager( + KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")), + networkClient, + time) + } + + @AfterEach + def teardown(): Unit = { + addPartitionsToTxnManager.shutdown() + } + + def setErrors(errors: mutable.Map[TopicPartition, Errors])(callbackErrors: Map[TopicPartition, Errors]): Unit = { + callbackErrors.foreach { + case (tp, error) => errors.put(tp, error) + } + } + + @Test + def testAddTxnData(): Unit = { + val transaction1Errors = mutable.Map[TopicPartition, Errors]() + val transaction2Errors = mutable.Map[TopicPartition, Errors]() + val transaction3Errors = mutable.Map[TopicPartition, Errors]() + + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transaction1Errors)) + addPartitionsToTxnManager.addTxnData(node1, transactionData(transactionalId2, producerId2), setErrors(transaction2Errors)) + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId3, producerId3), setErrors(transaction3Errors)) + + val transaction1AgainErrorsOldEpoch = mutable.Map[TopicPartition, Errors]() + val transaction1AgainErrorsNewEpoch = mutable.Map[TopicPartition, Errors]() + // Trying to add more transactional data for the same transactional ID, producer ID, and epoch should simply replace the old data and send a retriable response. + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transaction1AgainErrorsOldEpoch)) + val expectedNetworkErrors = topicPartitions.map(_ -> Errors.NETWORK_EXCEPTION).toMap + assertEquals(expectedNetworkErrors, transaction1Errors) + + // Trying to add more transactional data for the same transactional ID and producer ID, but new epoch should replace the old data and send an error response for it. + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1, producerEpoch = 1), setErrors(transaction1AgainErrorsNewEpoch)) + + val expectedEpochErrors = topicPartitions.map(_ -> Errors.INVALID_PRODUCER_EPOCH).toMap + assertEquals(expectedEpochErrors, transaction1AgainErrorsOldEpoch) + + val requestsAndHandlers = addPartitionsToTxnManager.generateRequests() + requestsAndHandlers.foreach { requestAndHandler => + if (requestAndHandler.destination == node0) { + assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs) + assertEquals(AddPartitionsToTxnRequest.Builder.forBroker( + new AddPartitionsToTxnTransactionCollection(Seq(transactionData(transactionalId3, producerId3), transactionData(transactionalId1, producerId1, producerEpoch = 1)).iterator.asJava)).data, + requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data) // insertion order + } else { + verifyRequest(node1, transactionalId2, producerId2, requestAndHandler) + } + } + } + + @Test + def testGenerateRequests(): Unit = { + val transactionErrors = mutable.Map[TopicPartition, Errors]() + + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transactionErrors)) + addPartitionsToTxnManager.addTxnData(node1, transactionData(transactionalId2, producerId2), setErrors(transactionErrors)) + + val requestsAndHandlers = addPartitionsToTxnManager.generateRequests() + assertEquals(2, requestsAndHandlers.size) + // Note: handlers are tested in testAddPartitionsToTxnHandlerErrorHandling + requestsAndHandlers.foreach{ requestAndHandler => + if (requestAndHandler.destination == node0) { + verifyRequest(node0, transactionalId1, producerId1, requestAndHandler) + } else { + verifyRequest(node1, transactionalId2, producerId2, requestAndHandler) + } + } + + addPartitionsToTxnManager.addTxnData(node1, transactionData(transactionalId2, producerId2), setErrors(transactionErrors)) + addPartitionsToTxnManager.addTxnData(node2, transactionData(transactionalId3, producerId3), setErrors(transactionErrors)) + + // Test creationTimeMs increases too. + time.sleep(1000) + + val requestsAndHandlers2 = addPartitionsToTxnManager.generateRequests() + // The request for node1 should not be added because one request is already inflight. + assertEquals(1, requestsAndHandlers2.size) + requestsAndHandlers2.foreach { requestAndHandler => + verifyRequest(node2, transactionalId3, producerId3, requestAndHandler) + } + + // Complete the request for node1 so the new one can go through. + requestsAndHandlers.filter(_.destination == node1).head.handler.onComplete(authenticationErrorResponse) + val requestsAndHandlers3 = addPartitionsToTxnManager.generateRequests() + assertEquals(1, requestsAndHandlers3.size) + requestsAndHandlers3.foreach { requestAndHandler => + verifyRequest(node1, transactionalId2, producerId2, requestAndHandler) + } + } + + @Test + def testAddPartitionsToTxnHandlerErrorHandling(): Unit = { + val transaction1Errors = mutable.Map[TopicPartition, Errors]() + val transaction2Errors = mutable.Map[TopicPartition, Errors]() + + def addTransactionsToVerify(): Unit = { + transaction1Errors.clear() + transaction2Errors.clear() + + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transaction1Errors)) + addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId2, producerId2), setErrors(transaction2Errors)) + } + + val expectedAuthErrors = topicPartitions.map(_ -> Errors.SASL_AUTHENTICATION_FAILED).toMap + addTransactionsToVerify() + receiveResponse(authenticationErrorResponse) + assertEquals(expectedAuthErrors, transaction1Errors) + assertEquals(expectedAuthErrors, transaction2Errors) + + // On version mismatch we ignore errors and keep handling. + val expectedVersionMismatchErrors = mutable.HashMap[TopicPartition, Errors]() + addTransactionsToVerify() + receiveResponse(versionMismatchResponse) + assertEquals(expectedVersionMismatchErrors, transaction1Errors) + assertEquals(expectedVersionMismatchErrors, transaction2Errors) + + val expectedDisconnectedErrors = topicPartitions.map(_ -> Errors.NETWORK_EXCEPTION).toMap + addTransactionsToVerify() + receiveResponse(disconnectedResponse) + assertEquals(expectedDisconnectedErrors, transaction1Errors) + assertEquals(expectedDisconnectedErrors, transaction2Errors) + + val expectedTopLevelErrors = topicPartitions.map(_ -> Errors.INVALID_RECORD).toMap + val topLevelErrorAddPartitionsResponse = new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())) + val topLevelErrorResponse = clientResponse(topLevelErrorAddPartitionsResponse) + addTransactionsToVerify() + receiveResponse(topLevelErrorResponse) + assertEquals(expectedTopLevelErrors, transaction1Errors) + assertEquals(expectedTopLevelErrors, transaction2Errors) + + val preConvertedTransaction1Errors = topicPartitions.map(_ -> Errors.PRODUCER_FENCED).toMap + val expectedTransaction1Errors = topicPartitions.map(_ -> Errors.INVALID_PRODUCER_EPOCH).toMap + val preConvertedTransaction2Errors = Map(new TopicPartition("foo", 1) -> Errors.NONE, + new TopicPartition("foo", 2) -> Errors.INVALID_RECORD, + new TopicPartition("foo", 3) -> Errors.NONE) + val expectedTransaction2Errors = Map(new TopicPartition("foo", 2) -> Errors.INVALID_RECORD) + + val transaction1ErrorResponse = AddPartitionsToTxnResponse.resultForTransaction(transactionalId1, preConvertedTransaction1Errors.asJava) + val transaction2ErrorResponse = AddPartitionsToTxnResponse.resultForTransaction(transactionalId2, preConvertedTransaction2Errors.asJava) + val mixedErrorsAddPartitionsResponse = new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData() + .setResultsByTransaction(new AddPartitionsToTxnResultCollection(Seq(transaction1ErrorResponse, transaction2ErrorResponse).iterator.asJava))) + val mixedErrorsResponse = clientResponse(mixedErrorsAddPartitionsResponse) + + addTransactionsToVerify() + receiveResponse(mixedErrorsResponse) + assertEquals(expectedTransaction1Errors, transaction1Errors) + assertEquals(expectedTransaction2Errors, transaction2Errors) + } + + private def clientResponse(response: AbstractResponse, authException: AuthenticationException = null, mismatchException: UnsupportedVersionException = null, disconnected: Boolean = false): ClientResponse = { + new ClientResponse(null, null, null, 0, 0, disconnected, mismatchException, authException, response) + } + + private def transactionData(transactionalId: String, producerId: Long, producerEpoch: Short = 0): AddPartitionsToTxnTransaction = { + new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setTopics(new AddPartitionsToTxnTopicCollection( + Seq(new AddPartitionsToTxnTopic() + .setName(topic) + .setPartitions(Seq[Integer](1, 2, 3).asJava)).iterator.asJava)) + } + + private def receiveResponse(response: ClientResponse): Unit = { + addPartitionsToTxnManager.generateRequests().head.handler.onComplete(response) + } + + private def verifyRequest(expectedDestination: Node, transactionalId: String, producerId: Long, requestAndHandler: RequestAndCompletionHandler): Unit = { + assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs) + assertEquals(expectedDestination, requestAndHandler.destination) + assertEquals(AddPartitionsToTxnRequest.Builder.forBroker( + new AddPartitionsToTxnTransactionCollection(Seq(transactionData(transactionalId, producerId)).iterator.asJava)).data, + requestAndHandler.request.asInstanceOf[AddPartitionsToTxnRequest.Builder].data) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala index 5673315cf31e..e59ed821c219 100644 --- a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala +++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala @@ -43,7 +43,6 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val numPartitions = 1 override def brokerPropertyOverrides(properties: Properties): Unit = { - properties.put(KafkaConfig.UnstableApiVersionsEnableProp, "true") properties.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) } diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 48303b7a52db..2def5df4603b 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -2198,6 +2198,8 @@ class KafkaApisTest { responseCallback.capture(), any(), any(), + any(), + any(), any()) ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.INVALID_PRODUCER_EPOCH)))) @@ -2218,6 +2220,58 @@ class KafkaApisTest { } } + @Test + def testTransactionalParametersSetCorrectly(): Unit = { + val topic = "topic" + val transactionalId = "txn1" + val transactionCoordinatorPartition = 35 + + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- 3 to ApiKeys.PRODUCE.latestVersion) { + + reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit]) + + val tp = new TopicPartition("topic", 0) + + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withTransactionalRecords(CompressionType.NONE, 0, 0, 0, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTransactionalId(transactionalId) + .setTimeoutMs(5000)) + .build(version.toShort) + val request = buildRequest(produceRequest) + + val kafkaApis = createKafkaApis() + + when(txnCoordinator.partitionFor( + ArgumentMatchers.eq(transactionalId)) + ).thenReturn(transactionCoordinatorPartition) + + kafkaApis.handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + verify(replicaManager).appendRecords(anyLong, + anyShort, + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(AppendOrigin.CLIENT), + any(), + responseCallback.capture(), + any(), + any(), + any(), + ArgumentMatchers.eq(transactionalId), + ArgumentMatchers.eq(Some(transactionCoordinatorPartition))) + } + } + @Test def testAddPartitionsToTxnWithInvalidPartition(): Unit = { val topic = "topic" @@ -2339,7 +2393,9 @@ class KafkaApisTest { responseCallback.capture(), any(), any(), - ArgumentMatchers.eq(requestLocal)) + ArgumentMatchers.eq(requestLocal), + any(), + any()) ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE)))) createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal) @@ -2469,7 +2525,9 @@ class KafkaApisTest { responseCallback.capture(), any(), any(), - ArgumentMatchers.eq(requestLocal)) + ArgumentMatchers.eq(requestLocal), + any(), + any()) ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE)))) createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal) @@ -2501,7 +2559,9 @@ class KafkaApisTest { any(), any(), any(), - ArgumentMatchers.eq(requestLocal)) + ArgumentMatchers.eq(requestLocal), + any(), + any()) } @Test diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 0ef3127fe1a2..1b7e99356779 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -67,9 +67,12 @@ import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource import com.yammer.metrics.core.Gauge +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic, AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction} +import org.apache.kafka.common.message.MetadataResponseData.{MetadataResponsePartition, MetadataResponseTopic} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.mockito.ArgumentMatchers +import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, reset, times, verify, when} @@ -2053,6 +2056,190 @@ class ReplicaManagerTest { assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, produceResult.get.error) } + @Test + def testVerificationForTransactionalPartitions(): Unit = { + val tp = new TopicPartition(topic, 0) + val transactionalId = "txn1" + val producerId = 24L + val producerEpoch = 0.toShort + val sequence = 0 + + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val metadataCache = mock(classOf[MetadataCache]) + val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager]) + + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterPartitionManager = alterPartitionManager, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) + + try { + val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(1, List(0, 1))) + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + // We must set up the metadata cache to handle the append and verification. + val metadataResponseTopic = Seq(new MetadataResponseTopic() + .setName(Topic.TRANSACTION_STATE_TOPIC_NAME) + .setPartitions(Seq( + new MetadataResponsePartition() + .setPartitionIndex(0) + .setLeaderId(0)).asJava)) + val node = new Node(0, "host1", 0) + + when(metadataCache.contains(tp)).thenReturn(true) + when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic) + when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node)) + when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None) + + // We will attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error. + KafkaRequestHandler.setBypassThreadCheck(true) + + // Append some transactional records. + val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes)) + val result = appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)) + + val transactionToAdd = new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(true) + .setTopics(new AddPartitionsToTxnTopicCollection( + Seq(new AddPartitionsToTxnTopic().setName(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator.asJava + )) + + val appendCallback = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback]) + // We should add these partitions to the manager to verify. + verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), appendCallback.capture()) + + // Confirm we did not write to the log and instead returned error. + val callback: AddPartitionsToTxnManager.AppendCallback = appendCallback.getValue() + callback(Map(tp -> Errors.INVALID_RECORD).toMap) + assertEquals(Errors.INVALID_RECORD, result.assertFired.error) + + // If we don't supply a transaction coordinator partition, we do not verify, so counter stays the same. + val transactionalRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1, + new SimpleRecord(s"message $sequence".getBytes)) + appendRecords(replicaManager, tp, transactionalRecords2) + verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]()) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDisabledVerification(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put("transaction.partition.verification.enable", "false") + val config = KafkaConfig.fromProps(props) + + val tp = new TopicPartition(topic, 0) + val transactionalId = "txn1" + val producerId = 24L + val producerEpoch = 0.toShort + val sequence = 0 + + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val metadataCache = mock(classOf[MetadataCache]) + val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager]) + + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterPartitionManager = alterPartitionManager, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) + + try { + val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(0, List(0, 1))) + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + when(metadataCache.contains(tp)).thenReturn(true) + + val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes)) + appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)) + + // We should not add these partitions to the manager to verify. + verify(metadataCache, times(0)).getTopicMetadata(any(), any(), any(), any()) + verify(metadataCache, times(0)).getAliveBrokerNode(any(), any()) + verify(metadataCache, times(0)).getAliveBrokerNode(any(), any()) + verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any()) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testGetTransactionCoordinator(): Unit = { + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + + val metadataCache = mock(classOf[MetadataCache]) + + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterPartitionManager = alterPartitionManager) + + try { + val txnCoordinatorPartition0 = 0 + val txnCoordinatorPartition1 = 1 + + // Before we set up the metadata cache, return nothing for the topic. + when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(Seq()) + assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), replicaManager.getTransactionCoordinator(txnCoordinatorPartition0)) + + // Return an error response. + when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)). + thenReturn(Seq(new MetadataResponseTopic().setErrorCode(Errors.UNSUPPORTED_VERSION.code))) + assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), replicaManager.getTransactionCoordinator(txnCoordinatorPartition0)) + + val metadataResponseTopic = Seq(new MetadataResponseTopic() + .setName(Topic.TRANSACTION_STATE_TOPIC_NAME) + .setPartitions(Seq( + new MetadataResponsePartition() + .setPartitionIndex(0) + .setLeaderId(0), + new MetadataResponsePartition() + .setPartitionIndex(1) + .setLeaderId(1)).asJava)) + val node0 = new Node(0, "host1", 0) + + when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic) + when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node0)) + when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None) + + assertEquals((Errors.NONE, node0), replicaManager.getTransactionCoordinator(txnCoordinatorPartition0)) + assertEquals((Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode), replicaManager.getTransactionCoordinator(txnCoordinatorPartition1)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + private def sendProducerAppend( replicaManager: ReplicaManager, topicPartition: TopicPartition, @@ -2296,7 +2483,9 @@ class ReplicaManagerTest { partition: TopicPartition, records: MemoryRecords, origin: AppendOrigin = AppendOrigin.CLIENT, - requiredAcks: Short = -1): CallbackResult[PartitionResponse] = { + requiredAcks: Short = -1, + transactionalId: String = null, + transactionStatePartition: Option[Int] = None): CallbackResult[PartitionResponse] = { val result = new CallbackResult[PartitionResponse]() def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = { val response = responses.get(partition) @@ -2310,7 +2499,9 @@ class ReplicaManagerTest { internalTopicsAllowed = false, origin = origin, entriesPerPartition = Map(partition -> records), - responseCallback = appendCallback) + responseCallback = appendCallback, + transactionalId = transactionalId, + transactionStatePartition = transactionStatePartition) result }