Skip to content

Commit

Permalink
KAFKA-14561: Improve transactions experience for older clients by ens…
Browse files Browse the repository at this point in the history
…uring ongoing transaction (#13391)

Added check for ongoing transaction
Thread to send and receive verify only add partition to txn requests
Code to send on request thread courtesy of @artemlivshits

Reviewers: Artem Livshits <alivshits@confluent.io>, Jun Rao <junrao@gmail.com>
  • Loading branch information
jolshan committed Apr 13, 2023
1 parent 88e2d6b commit 56dcb83
Show file tree
Hide file tree
Showing 24 changed files with 1,275 additions and 76 deletions.
Expand Up @@ -26,10 +26,7 @@
//
// Version 4 adds VerifyOnly field to check if partitions are already in transaction and adds support to batch multiple transactions.
// Versions 3 and below will be exclusively used by clients and versions 4 and above will be used by brokers.
// The AddPartitionsToTxnRequest version 4 API is added as part of KIP-890 and is still
// under developement. Hence, the API is not exposed by default by brokers
// unless explicitely enabled.
"latestVersionUnstable": true,
"latestVersionUnstable": false,
"validVersions": "0-4",
"flexibleVersions": "3+",
"fields": [
Expand Down
90 changes: 90 additions & 0 deletions core/src/main/java/kafka/server/NetworkUtils.java
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.server;

import org.apache.kafka.clients.ApiVersions;
import org.apache.kafka.clients.ManualMetadataUpdater;
import org.apache.kafka.clients.NetworkClient;
import org.apache.kafka.common.Reconfigurable;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.Selectable;
import org.apache.kafka.common.network.Selector;
import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time;

import java.util.Collections;

public class NetworkUtils {

public static NetworkClient buildNetworkClient(String prefix,
KafkaConfig config,
Metrics metrics,
Time time,
LogContext logContext) {
ChannelBuilder channelBuilder = ChannelBuilders.clientChannelBuilder(
config.interBrokerSecurityProtocol(),
JaasContext.Type.SERVER,
config,
config.interBrokerListenerName(),
config.saslMechanismInterBrokerProtocol(),
time,
config.saslInterBrokerHandshakeRequestEnable(),
logContext
);

if (channelBuilder instanceof Reconfigurable) {
config.addReconfigurable((Reconfigurable) channelBuilder);
}

String metricGroupPrefix = prefix + "-channel";

Selector selector = new Selector(
NetworkReceive.UNLIMITED,
config.connectionsMaxIdleMs(),
metrics,
time,
metricGroupPrefix,
Collections.emptyMap(),
false,
channelBuilder,
logContext
);

String clientId = prefix + "-client-" + config.nodeId();
return new NetworkClient(
selector,
new ManualMetadataUpdater(),
clientId,
1,
50,
50,
Selectable.USE_DEFAULT_BUFFER_SIZE,
config.socketReceiveBufferBytes(),
config.requestTimeoutMs(),
config.connectionSetupTimeoutMs(),
config.connectionSetupTimeoutMaxMs(),
time,
false,
new ApiVersions(),
logContext
);
}
}
Expand Up @@ -18,8 +18,8 @@
package kafka.server.builders;

import kafka.log.LogManager;
import kafka.server.AddPartitionsToTxnManager;
import kafka.server.AlterPartitionManager;
import kafka.log.remote.RemoteLogManager;
import kafka.server.BrokerTopicStats;
import kafka.server.DelayedDeleteRecords;
import kafka.server.DelayedElectLeader;
Expand All @@ -30,6 +30,7 @@
import kafka.server.MetadataCache;
import kafka.server.QuotaFactory.QuotaManagers;
import kafka.server.ReplicaManager;
import kafka.log.remote.RemoteLogManager;
import kafka.zk.KafkaZkClient;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.Time;
Expand Down Expand Up @@ -62,6 +63,7 @@ public class ReplicaManagerBuilder {
private Optional<DelayedOperationPurgatory<DelayedElectLeader>> delayedElectLeaderPurgatory = Optional.empty();
private Optional<String> threadNamePrefix = Optional.empty();
private Long brokerEpoch = -1L;
private Optional<AddPartitionsToTxnManager> addPartitionsToTxnManager = Optional.empty();

public ReplicaManagerBuilder setConfig(KafkaConfig config) {
this.config = config;
Expand Down Expand Up @@ -158,6 +160,11 @@ public ReplicaManagerBuilder setBrokerEpoch(long brokerEpoch) {
return this;
}

public ReplicaManagerBuilder setAddPartitionsToTransactionManager(AddPartitionsToTxnManager addPartitionsToTxnManager) {
this.addPartitionsToTxnManager = Optional.of(addPartitionsToTxnManager);
return this;
}

public ReplicaManager build() {
if (config == null) config = new KafkaConfig(Collections.emptyMap());
if (metrics == null) metrics = new Metrics();
Expand All @@ -183,6 +190,7 @@ public ReplicaManager build() {
OptionConverters.toScala(delayedDeleteRecordsPurgatory),
OptionConverters.toScala(delayedElectLeaderPurgatory),
OptionConverters.toScala(threadNamePrefix),
() -> brokerEpoch);
() -> brokerEpoch,
OptionConverters.toScala(addPartitionsToTxnManager));
}
}
4 changes: 4 additions & 0 deletions core/src/main/scala/kafka/cluster/Partition.scala
Expand Up @@ -575,6 +575,10 @@ class Partition(val topicPartition: TopicPartition,
}
}

def hasOngoingTransaction(producerId: Long): Boolean = {
leaderLogIfLocal.exists(leaderLog => leaderLog.hasOngoingTransaction(producerId))
}

// Return true if the future replica exists and it has caught up with the current replica for this partition
// Only ReplicaAlterDirThread will call this method and ReplicaAlterDirThread should remove the partition
// from its partitionStates if this method returns true
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/kafka/log/UnifiedLog.scala
Expand Up @@ -579,6 +579,11 @@ class UnifiedLog(@volatile var logStartOffset: Long,
result
}

def hasOngoingTransaction(producerId: Long): Boolean = lock synchronized {
val entry = producerStateManager.activeProducers.get(producerId)
entry != null && entry.currentTxnFirstOffset.isPresent
}

/**
* The number of segments in the log.
* Take care! this is an O(n) operation.
Expand Down
40 changes: 35 additions & 5 deletions core/src/main/scala/kafka/network/RequestChannel.scala
Expand Up @@ -55,6 +55,7 @@ object RequestChannel extends Logging {

sealed trait BaseRequest
case object ShutdownRequest extends BaseRequest
case object WakeupRequest extends BaseRequest

case class Session(principal: KafkaPrincipal, clientAddress: InetAddress) {
val sanitizedUser: String = Sanitizer.sanitize(principal.getName)
Expand All @@ -79,6 +80,9 @@ object RequestChannel extends Logging {
}
}

case class CallbackRequest(fun: () => Unit,
originalRequest: Request) extends BaseRequest

class Request(val processor: Int,
val context: RequestContext,
val startTimeNanos: Long,
Expand All @@ -96,6 +100,8 @@ object RequestChannel extends Logging {
@volatile var apiThrottleTimeMs = 0L
@volatile var temporaryMemoryBytes = 0L
@volatile var recordNetworkThreadTimeCallback: Option[Long => Unit] = None
@volatile var callbackRequestDequeueTimeNanos: Option[Long] = None
@volatile var callbackRequestCompleteTimeNanos: Option[Long] = None

val session = Session(context.principal, context.clientAddress)

Expand Down Expand Up @@ -227,8 +233,9 @@ object RequestChannel extends Logging {
}

val requestQueueTimeMs = nanosToMs(requestDequeueTimeNanos - startTimeNanos)
val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos)
val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos)
val callbackRequestTimeNanos = callbackRequestCompleteTimeNanos.getOrElse(0L) - callbackRequestDequeueTimeNanos.getOrElse(0L)
val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos + callbackRequestTimeNanos)
val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos - callbackRequestTimeNanos)
val responseQueueTimeMs = nanosToMs(responseDequeueTimeNanos - responseCompleteTimeNanos)
val responseSendTimeMs = nanosToMs(endTimeNanos - responseDequeueTimeNanos)
val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos)
Expand Down Expand Up @@ -354,6 +361,7 @@ class RequestChannel(val queueSize: Int,
private val processors = new ConcurrentHashMap[Int, Processor]()
val requestQueueSizeMetricName = metricNamePrefix.concat(RequestQueueSizeMetric)
val responseQueueSizeMetricName = metricNamePrefix.concat(ResponseQueueSizeMetric)
private val callbackQueue = new ArrayBlockingQueue[BaseRequest](queueSize)

metricsGroup.newGauge(requestQueueSizeMetricName, () => requestQueue.size)

Expand Down Expand Up @@ -444,6 +452,9 @@ class RequestChannel(val queueSize: Int,
request.responseCompleteTimeNanos = timeNanos
if (request.apiLocalCompleteTimeNanos == -1L)
request.apiLocalCompleteTimeNanos = timeNanos
// If this callback was executed after KafkaApis returned we will need to adjust the callback completion time here.
if (request.callbackRequestDequeueTimeNanos.isDefined && request.callbackRequestCompleteTimeNanos.isEmpty)
request.callbackRequestCompleteTimeNanos = Some(time.nanoseconds())
// For a given request, these may happen in addition to one in the previous section, skip updating the metrics
case _: StartThrottlingResponse | _: EndThrottlingResponse => ()
}
Expand All @@ -456,9 +467,21 @@ class RequestChannel(val queueSize: Int,
}
}

/** Get the next request or block until specified time has elapsed */
def receiveRequest(timeout: Long): RequestChannel.BaseRequest =
requestQueue.poll(timeout, TimeUnit.MILLISECONDS)
/** Get the next request or block until specified time has elapsed
* Check the callback queue and execute first if present since these
* requests have already waited in line. */
def receiveRequest(timeout: Long): RequestChannel.BaseRequest = {
val callbackRequest = callbackQueue.poll()
if (callbackRequest != null)
callbackRequest
else {
val request = requestQueue.poll(timeout, TimeUnit.MILLISECONDS)
request match {
case WakeupRequest => callbackQueue.poll()
case _ => request
}
}
}

/** Get the next request or block until there is one */
def receiveRequest(): RequestChannel.BaseRequest =
Expand All @@ -472,6 +495,7 @@ class RequestChannel(val queueSize: Int,

def clear(): Unit = {
requestQueue.clear()
callbackQueue.clear()
}

def shutdown(): Unit = {
Expand All @@ -481,6 +505,12 @@ class RequestChannel(val queueSize: Int,

def sendShutdownRequest(): Unit = requestQueue.put(ShutdownRequest)

def sendCallbackRequest(request: CallbackRequest): Unit = {
callbackQueue.put(request)
if (!requestQueue.offer(RequestChannel.WakeupRequest))
trace("Wakeup request could not be added to queue. This means queue is full, so we will still process callback.")
}

}

object RequestMetrics {
Expand Down

0 comments on commit 56dcb83

Please sign in to comment.