diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBuffer.java index 956d72f5ffa0b..cf7eb3912d119 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBuffer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBuffer.java @@ -231,7 +231,7 @@ public void evictWhile(final Supplier predicate, final Consumer predicate, final Consumer(key, value, bufferValue.context())); diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBufferTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBufferTest.java index e288c04517e27..1dbdd5fbdbb63 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBufferTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedKeyValueBufferTest.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.metrics.Sensor; @@ -40,11 +42,15 @@ import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -205,6 +211,49 @@ public void shouldHandleCollidingKeys() { assertNumSizeAndTimestamp(buffer, 1, 7, 42); } + @Test + public void shouldPropagateHeadersThroughEviction() { + createBuffer(Duration.ZERO, Serdes.String()); + final RecordHeaders headers = new RecordHeaders(new Header[]{ + new RecordHeader("h1", "v1".getBytes(StandardCharsets.UTF_8)) + }); + context.setRecordContext(new ProcessorRecordContext(0L, offset++, 0, "testing", headers)); + buffer.put(0L, new Record<>("k", "v", 0L, headers), context.recordContext()); + + final List> evicted = new ArrayList<>(); + buffer.evictWhile(() -> buffer.numRecords() > 0, evicted::add); + + assertThat(evicted.size(), is(1)); + assertThat(evicted.get(0).recordContext().headers(), is(headers)); + } + + @Test + public void shouldNotBeAffectedByProcessorContextHeaderMutationBetweenPutAndEvict() { + createBuffer(Duration.ofMillis(1), Serdes.String()); + final RecordHeaders putHeaders = new RecordHeaders(new Header[]{ + new RecordHeader("at-put", "first".getBytes(StandardCharsets.UTF_8)) + }); + context.setRecordContext(new ProcessorRecordContext(0L, offset++, 0, "testing", putHeaders)); + buffer.put(0L, new Record<>("k", "v", 0L, putHeaders), context.recordContext()); + + // Simulate the processor moving on to handle a different record with different headers + // before the grace period expires and eviction runs. + final RecordHeaders laterHeaders = new RecordHeaders(new Header[]{ + new RecordHeader("at-evict", "second".getBytes(StandardCharsets.UTF_8)) + }); + context.setRecordContext(new ProcessorRecordContext(10L, offset++, 0, "testing", laterHeaders)); + // Advance stream time past the grace period for the original record. + buffer.put(10L, new Record<>("trigger", "v", 10L, laterHeaders), context.recordContext()); + + final List> evicted = new ArrayList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Only the original "k" record at t=0 falls outside the grace window of t=10. + assertThat(evicted.size(), is(1)); + assertThat(evicted.get(0).key(), is("k")); + assertThat(evicted.get(0).recordContext().headers(), is(putHeaders)); + } + private void assertNumSizeAndTimestamp(final TimeOrderedKeyValueBuffer buffer, final int num, final long time,