diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java index ceeeb6c191607..cb9a38d0ddc1e 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java @@ -527,6 +527,7 @@ private void onSuccessResponse(final StreamsGroupHeartbeatResponse response, fin heartbeatRequestState.updateHeartbeatIntervalMs(data.heartbeatIntervalMs()); heartbeatRequestState.onSuccessfulAttempt(currentTimeMs); heartbeatState.setEndpointInformationEpoch(data.endpointInformationEpoch()); + streamsRebalanceData.setHeartbeatIntervalMs(data.heartbeatIntervalMs()); if (data.partitionsByUserEndpoint() != null) { streamsRebalanceData.setPartitionsByHost(convertHostInfoMap(data)); diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java index 2fe7ae8ad35d2..c6fe1fd9215ee 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java @@ -30,6 +30,7 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; /** @@ -329,6 +330,8 @@ public String toString() { private final AtomicReference> statuses = new AtomicReference<>(List.of()); + private final AtomicInteger heartbeatIntervalMs = new AtomicInteger(-1); + public StreamsRebalanceData(final UUID processId, final Optional endpoint, final Map subtopologies, @@ -395,4 +398,14 @@ public List statuses() { return statuses.get(); } + /** Updated whenever a heartbeat response is received from the broker. */ + public void setHeartbeatIntervalMs(final int heartbeatIntervalMs) { + this.heartbeatIntervalMs.set(heartbeatIntervalMs); + } + + /** Returns the heartbeat interval in milliseconds, or -1 if not yet set. */ + public int heartbeatIntervalMs() { + return heartbeatIntervalMs.get(); + } + } diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java index f4a2726b9e570..9e4b843714447 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManagerTest.java @@ -1507,6 +1507,35 @@ public void testResetPollTimerWhenExpired() { } } + @Test + public void testStreamsRebalanceDataHeartbeatIntervalMsUpdatedOnSuccess() { + try ( + final MockedConstruction ignored = mockConstruction( + HeartbeatRequestState.class, + (mock, context) -> when(mock.canSendRequest(time.milliseconds())).thenReturn(true)) + ) { + final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager(); + when(coordinatorRequestManager.coordinator()).thenReturn(Optional.of(coordinatorNode)); + when(membershipManager.groupId()).thenReturn(GROUP_ID); + when(membershipManager.memberId()).thenReturn(MEMBER_ID); + when(membershipManager.memberEpoch()).thenReturn(MEMBER_EPOCH); + when(membershipManager.groupInstanceId()).thenReturn(Optional.of(INSTANCE_ID)); + + // Initially, heartbeatIntervalMs should be -1 + assertEquals(-1, streamsRebalanceData.heartbeatIntervalMs()); + + final NetworkClientDelegate.PollResult result = heartbeatRequestManager.poll(time.milliseconds()); + assertEquals(1, result.unsentRequests.size()); + + final NetworkClientDelegate.UnsentRequest networkRequest = result.unsentRequests.get(0); + final ClientResponse response = buildClientResponse(); + networkRequest.handler().onComplete(response); + + // After successful response, heartbeatIntervalMs should be updated + assertEquals(RECEIVED_HEARTBEAT_INTERVAL_MS, streamsRebalanceData.heartbeatIntervalMs()); + } + } + private static ConsumerConfig config() { Properties prop = new Properties(); prop.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java index 606ba0b735027..f2376640c0102 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceDataTest.java @@ -437,4 +437,41 @@ public void streamsRebalanceDataShouldBeConstructedWithEmptyStatuses() { assertTrue(streamsRebalanceData.statuses().isEmpty()); } + @Test + public void streamsRebalanceDataShouldBeConstructedWithHeartbeatIntervalMsSetToMinusOne() { + final UUID processId = UUID.randomUUID(); + final Optional endpoint = Optional.of(new + StreamsRebalanceData.HostInfo("localhost", 9090)); + final Map subtopologies = Map.of(); + final Map clientTags = Map.of("clientTag1", + "clientTagValue1"); + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData( + processId, + endpoint, + subtopologies, + clientTags + ); + + assertEquals(-1, streamsRebalanceData.heartbeatIntervalMs()); + } + + @Test + public void streamsRebalanceDataShouldBeAbleToUpdateHeartbeatIntervalMs() { + final UUID processId = UUID.randomUUID(); + final Optional endpoint = Optional.of(new + StreamsRebalanceData.HostInfo("localhost", 9090)); + final Map subtopologies = Map.of(); + final Map clientTags = Map.of("clientTag1", + "clientTagValue1"); + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData( + processId, + endpoint, + subtopologies, + clientTags + ); + + streamsRebalanceData.setHeartbeatIntervalMs(1000); + assertEquals(1000, streamsRebalanceData.heartbeatIntervalMs()); + } + } diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/HandlingSourceTopicDeletionIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/HandlingSourceTopicDeletionIntegrationTest.java index d8f9061dfdb11..f31e79c53992f 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/HandlingSourceTopicDeletionIntegrationTest.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/HandlingSourceTopicDeletionIntegrationTest.java @@ -33,8 +33,9 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; import java.util.Properties; @@ -75,8 +76,9 @@ public void after() throws InterruptedException { CLUSTER.deleteTopics(INPUT_TOPIC, OUTPUT_TOPIC); } - @Test - public void shouldThrowErrorAfterSourceTopicDeleted(final TestInfo testName) throws InterruptedException { + @ParameterizedTest + @ValueSource(strings = {"classic", "streams"}) + public void shouldThrowErrorAfterSourceTopicDeleted(final String groupProtocol, final TestInfo testName) throws InterruptedException { final StreamsBuilder builder = new StreamsBuilder(); builder.stream(INPUT_TOPIC, Consumed.with(Serdes.Integer(), Serdes.String())) .to(OUTPUT_TOPIC, Produced.with(Serdes.Integer(), Serdes.String())); @@ -91,6 +93,7 @@ public void shouldThrowErrorAfterSourceTopicDeleted(final TestInfo testName) thr streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS); streamsConfiguration.put(StreamsConfig.METADATA_MAX_AGE_CONFIG, 2000); + streamsConfiguration.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, groupProtocol); final Topology topology = builder.build(); final AtomicBoolean calledUncaughtExceptionHandler1 = new AtomicBoolean(false); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 3521b31d8a3f0..f208567c32db7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -1587,11 +1587,15 @@ public void handleStreamsRebalanceData() { } private void handleMissingSourceTopicsWithTimeout(final String missingTopicsDetail) { + // Use 2 * heartbeatIntervalMs as the timeout ensures at least one heartbeat is sent before raising the exception + final int heartbeatIntervalMs = streamsRebalanceData.get().heartbeatIntervalMs(); + final long timeoutMs = 2L * heartbeatIntervalMs; + // Start timeout tracking on first encounter with missing topics if (topicsReadyTimer == null) { - topicsReadyTimer = time.timer(maxPollTimeMs); + topicsReadyTimer = time.timer(timeoutMs); log.info("Missing source topics detected: {}. Will wait up to {}ms before failing.", - missingTopicsDetail, maxPollTimeMs); + missingTopicsDetail, timeoutMs); } else { topicsReadyTimer.update(); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index d7f61883adb1b..66b90ffcc032d 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -3966,11 +3966,13 @@ public void testStreamsProtocolRunOnceWithoutProcessingThreadsMissingSourceTopic .setStatusDetail("Missing source topics") )); + streamsRebalanceData.setHeartbeatIntervalMs(5000); + // First call should not throw exception (within timeout) thread.runOnceWithoutProcessingThreads(); // Advance time beyond max.poll.interval.ms (default is 300000ms) to trigger timeout - mockTime.sleep(300001); + mockTime.sleep(10001); final MissingSourceTopicException exception = assertThrows(MissingSourceTopicException.class, () -> thread.runOnceWithoutProcessingThreads()); assertTrue(exception.getMessage().contains("Missing source topics")); @@ -4153,11 +4155,13 @@ public void testStreamsProtocolRunOnceWithProcessingThreadsMissingSourceTopic() .setStatusDetail("Missing source topics") )); + streamsRebalanceData.setHeartbeatIntervalMs(5000); + // First call should not throw exception (within timeout) thread.runOnceWithProcessingThreads(); - // Advance time beyond max.poll.interval.ms (default is 300000ms) to trigger timeout - mockTime.sleep(300001); + // Advance time beyond 2 * heartbeatIntervalMs (default is 5000ms) to trigger timeout + mockTime.sleep(10001); final MissingSourceTopicException exception = assertThrows(MissingSourceTopicException.class, () -> thread.runOnceWithProcessingThreads()); assertTrue(exception.getMessage().contains("Missing source topics")); @@ -4221,11 +4225,13 @@ public void testStreamsProtocolMissingSourceTopicRecovery() { .setStatusDetail("Missing source topics") )); + streamsRebalanceData.setHeartbeatIntervalMs(5000); + // First call should not throw exception (within timeout) thread.runOnceWithoutProcessingThreads(); // Advance time but not beyond timeout - mockTime.sleep(150000); // Half of max.poll.interval.ms + mockTime.sleep(5000); // Half of max.poll.interval.ms // Should still not throw exception thread.runOnceWithoutProcessingThreads(); @@ -4243,13 +4249,13 @@ public void testStreamsProtocolMissingSourceTopicRecovery() { .setStatusDetail("Different missing topics") )); - // Advance time by 250 seconds to test if timer was reset - // Total time from beginning: 150000 + 250000 = 400000ms (400s) - // If timer was NOT reset: elapsed time = 400s > 300s → should throw - // If timer WAS reset: elapsed time = 250s < 300s → should NOT throw - mockTime.sleep(250000); // Advance by 250 seconds + // Advance time by 6 seconds to test if timer was reset + // Total time from beginning: 5000 + 6000 = 11000ms (11s) + // If timer was NOT reset: elapsed time = 11s > 10s → should throw + // If timer WAS reset: elapsed time = 6s < 10s → should NOT throw + mockTime.sleep(6000); // Advance by 6 seconds - // Should not throw because timer was reset - only 250s elapsed from reset point + // Should not throw because timer was reset - only 6s elapsed from reset point assertDoesNotThrow(() -> thread.runOnceWithoutProcessingThreads()); }