Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFKA-15481: Fix concurrency bug in RemoteIndexCache #14483

Merged
merged 16 commits into from Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 103 additions & 29 deletions core/src/test/scala/unit/kafka/log/remote/RemoteIndexCacheTest.scala
Expand Up @@ -23,7 +23,7 @@ import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
import org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentId, RemoteLogSegmentMetadata, RemoteResourceNotFoundException, RemoteStorageManager}
import org.apache.kafka.server.util.MockTime
import org.apache.kafka.storage.internals.log.RemoteIndexCache.{Entry, REMOTE_LOG_INDEX_CACHE_CLEANER_THREAD, remoteDeletedSuffixIndexFileName, remoteOffsetIndexFile, remoteOffsetIndexFileName, remoteTimeIndexFile, remoteTimeIndexFileName, remoteTransactionIndexFile, remoteTransactionIndexFileName}
import org.apache.kafka.storage.internals.log.RemoteIndexCache.{DIR_NAME, Entry, REMOTE_LOG_INDEX_CACHE_CLEANER_THREAD, remoteDeletedSuffixIndexFileName, remoteOffsetIndexFile, remoteOffsetIndexFileName, remoteTimeIndexFile, remoteTimeIndexFileName, remoteTransactionIndexFile, remoteTransactionIndexFileName}
import org.apache.kafka.storage.internals.log.{AbortedTxn, CorruptIndexException, LogFileUtils, OffsetIndex, OffsetPosition, RemoteIndexCache, TimeIndex, TransactionIndex}
import org.apache.kafka.test.{TestUtils => JTestUtils}
import org.junit.jupiter.api.Assertions._
Expand All @@ -32,6 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
import org.mockito.invocation.InvocationOnMock
import org.mockito.Mockito._
import org.slf4j.{Logger, LoggerFactory}

Expand Down Expand Up @@ -138,8 +139,8 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand Down Expand Up @@ -249,7 +250,7 @@ class RemoteIndexCacheTest {
}

@Test
def testCacheEntryIsDeletedOnInvalidation(): Unit = {
def testCacheEntryIsDeletedOnRemoval(): Unit = {
def getIndexFileFromDisk(suffix: String) = {
Files.walk(tpDir.toPath)
.filter(Files.isRegularFile(_))
Expand All @@ -271,8 +272,8 @@ class RemoteIndexCacheTest {
// no expired entries yet
assertEquals(0, cache.expiredIndexes.size, "expiredIndex queue should be zero at start of test")

// invalidate the cache. it should async mark the entry for removal
cache.internalCache.invalidate(internalIndexKey)
// call remove function to mark the entry for removal
cache.remove(internalIndexKey)

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => cacheEntry.isMarkedForCleanup,
Expand Down Expand Up @@ -672,16 +673,89 @@ class RemoteIndexCacheTest {
verifyFetchIndexInvocation(count = 1)
}

@Test
def testConcurrentRemoveReadForCache(): Unit = {
// Create a spy Cache Entry
val rlsMetadata = new RemoteLogSegmentMetadata(RemoteLogSegmentId.generateNew(idPartition), baseOffset, lastOffset,
time.milliseconds(), brokerId, time.milliseconds(), segmentSize, Collections.singletonMap(0, 0L))

val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata, new File(tpDir, DIR_NAME)))

val spyEntry = spy(new RemoteIndexCache.Entry(offsetIndex, timeIndex, txIndex))
cache.internalCache.put(rlsMetadata.remoteLogSegmentId().id(), spyEntry)

assertCacheSize(1)

var entry: RemoteIndexCache.Entry = null

val latchForCacheRead = new CountDownLatch(1)
val latchForCacheRemove = new CountDownLatch(1)
val latchForTestWait = new CountDownLatch(1)

var markForCleanupCallCount = 0

doAnswer((invocation: InvocationOnMock) => {
markForCleanupCallCount += 1

if (markForCleanupCallCount == 1) {
// Signal the CacheRead to unblock itself
latchForCacheRead.countDown()
// Wait for signal to start renaming the files
latchForCacheRemove.await()
// Calling the markForCleanup() actual method to start renaming the files
invocation.callRealMethod()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeel2420 Why we are invoking this again invocation.callRealMethod()
It is already called in
712 line when(spyEntry).markForCleanup()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invocation.callRealMethod() is called to call the markForCleanup() after read is called and before we start asserting to make sure indexes get renamed before we assert the results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeel2420 markForCleanUp should be called only one time. We need to test the behaviour when there are concurrent read/remove happens on the cache for the same entry.
In the test we just need to assert the way @showuon suggested
// So, maybe we verify with this:
if (Files.exists(entry.offsetIndex().file().toPath)) {
assertCacheSize(1)
} else {
assertCacheSize(0)
}
Calling 'markForCleanUp' twice will always result in cacheSize 0 eventually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iit2009060 Read markForCleanUp() is not getting called twice. Please see the mock, inside that during the first execution of markForCleanUp() I am calling the actual markForCleanup() function (i.e Index are getting renamed) but for subsequent calls, mock is doing nothing so we actual markForCleanup() function to rename the indexes is getting called once only and it is as expected.

I have verified this behaviour as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeel2420 During read MarkforCleanUp should not be called not even once as per the functionality.

Then why we need to call it explicitly again here invocation.callRealMethod().
I am seeing two invocation of markForCleanUp

  1. }).when(spyEntry).markForCleanup() 712 line no
  2. invocation.callRealMethod() 708 line no

Copy link
Contributor

@showuon showuon Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two times when markForCleanUp is called.

  1. remove function which we are calling in removeCache Runnable.
  2. One at invocation.callRealMethod() 708 line no

You're right, but they are "different" markForCleanUp.
For (1), the markForCleanUp is an injected method for controlling the invoking order. So there are latches wait/countdown.
For (2), it's the real markForCleanUp method to rename the cache files.

The goal is to simulate the race condition happened in KAFKA-15481.

Even i tried running your test case locally it always assert with cacheSize 0 , as it is eventually getting deleted.

Yes, I think so. But in some cases, there could be 1 if getEntry goes after. The thread management are all decided by OS, we can't assure which one will go first, right?

I think the goal of this test is to make sure the issue in KAFKA-15481 will not happen again. That's why I added this comment.

IMO we should read and remove concurrently in the separate thread and validate the cacheSize based on the order of execution.

I'm not following you here. What we're doing in this test is to read and remove concurrently in the separate thread. About validate the cacheSize based on the order of execution, since we can't assure which thread will be executed first, we can't do this, right? If we can decide the execution order, then it means they are not running concurrently, is that right?

We should not need to call explicitly for the scenario.

Maybe you can show us if it were you, what test you'll create. Some pseudo code are enough. Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@showuon
For (1), the markForCleanUp is an injected method for controlling the invoking order. So there are latches wait/countdown.
Do you mean this is a mock method and no rename would happen in this case ?
Effectively the functionality/logic of markCleanUp is called one time only ?

I was thinking something like this

`val latchForTestWait = new CountDownLatch(2)
 val removeCache = (() => {
      cache.remove(rlsMetadata.remoteLogSegmentId().id())
      latchForTestWait.countdown()
    }): Runnable

    val readCache = (() => {
      entry = cache.getIndexEntry(rlsMetadata)
      // Signal the CacheRemove to start renaming the files
      latchForTestWait.countDown()
    }): Runnable

    val executor = Executors.newFixedThreadPool(2)
    try {
      executor.submit(removeCache: Runnable)
      executor.submit(readCache: Runnable)

      // Wait for signal to complete the test
      latchForTestWait.await()

      // validate cache size based on the file existence`
       if offset file exists validate this 
     // validate rsm call should happen  if( execution order is remove,read)
     if cache size == 0
     // validate no rsm call should happen if ( execution order is read,remove)
     

In the test case mentioned in the jira KAFKA-15481
the execution order is remove,read and the overall result is cache size 0 which is wrong because of timegap between removal and renaming the files. Here we are validating the same with rsm call count. If they are atomic rsm execution should happen and files should be restored.

Copy link
Contributor

@showuon showuon Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean this is a mock method and no rename would happen in this case ?
Effectively the functionality/logic of markCleanUp is called one time only ?

Correct.

// validate cache size based on the file existence`
if offset file exists validate this
// validate rsm call should happen if( execution order is remove,read)
if cache size == 0
// validate no rsm call should happen if ( execution order is read,remove)

Yes, they are basically similar with what we have now. By injecting mock implementation for markForCleanUp is just to make the 2 thread execution more close. In the end, what we have now is to invoke "realMethod", which is what you did above. I'm fine if you think we should validate the rsm call count. But again, they are basically testing the same thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@showuon Yes correct , it is testing the same thing. I am also fine. But from readability perspective the one I propose is simpler to understand and does not require any future change if markForCleanUp function changed. I left @jeel2420 to make a decision here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iit2009060 As current test case is able to reproduce the case mentioned in the jira KAFKA-15481 I think we should be fine.

The only reason to have markForCleanUp mock is to have control over the 2 thread execution.

// Signal TestWait to unblock itself so that test can be completed
latchForTestWait.countDown()
}
}).when(spyEntry).markForCleanup()

val removeCache = (() => {
cache.remove(rlsMetadata.remoteLogSegmentId().id())
}): Runnable

val readCache = (() => {
// Wait for signal to start CacheRead
latchForCacheRead.await()
entry = cache.getIndexEntry(rlsMetadata)
// Signal the CacheRemove to start renaming the files
latchForCacheRemove.countDown()
}): Runnable

val executor = Executors.newFixedThreadPool(2)
try {
executor.submit(removeCache: Runnable)
executor.submit(readCache: Runnable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeel2420 , sorry, I had another look and found we should also verify these 2 threads has no exception thrown. In the issue description, without this fix, there will be IOException thrown. So, we should verify there's no exception using the returned future from executor.submit. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@showuon Nice catch. Calling .get() on both task future objects so if there will be any error test will fail with that exception.


// Wait for signal to complete the test
latchForTestWait.await()
// We can't determine read thread or remove thread will go first so if,
// 1. Read thread go first, cache file should not exist and cache size should be zero.
// 2. Remove thread go first, cache file should present and cache size should be one.
// so basically here we are making sure that if cache existed, the cache file should exist,
// and if cache is non-existed, the cache file should not exist.
if (getIndexFileFromRemoteCacheDir(cache, LogFileUtils.INDEX_FILE_SUFFIX).isPresent) {
assertCacheSize(1)
} else {
assertCacheSize(0)
}
showuon marked this conversation as resolved.
Show resolved Hide resolved
} finally {
executor.shutdownNow()
}

}

@Test
def testMultipleIndexEntriesExecutionInCorruptException(): Unit = {
reset(rsm)
when(rsm.fetchIndex(any(classOf[RemoteLogSegmentMetadata]), any(classOf[IndexType])))
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
// Create corrupted index file
createCorruptTimeIndexOffsetFile(tpDir)
Expand Down Expand Up @@ -717,9 +791,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand Down Expand Up @@ -764,7 +838,7 @@ class RemoteIndexCacheTest {
Files.copy(entry.txnIndex().file().toPath(), Paths.get(Utils.replaceSuffix(entry.txnIndex().file().getPath(), "", tempSuffix)))
Files.copy(entry.timeIndex().file().toPath(), Paths.get(Utils.replaceSuffix(entry.timeIndex().file().getPath(), "", tempSuffix)))

cache.internalCache().invalidate(rlsMetadata.remoteLogSegmentId().id())
cache.remove(rlsMetadata.remoteLogSegmentId().id())

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => entry.isMarkedForCleanup,
Expand Down Expand Up @@ -792,9 +866,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
// Create corrupt index file return from RSM
createCorruptedIndexFile(testIndexType, tpDir)
Expand Down Expand Up @@ -839,7 +913,7 @@ class RemoteIndexCacheTest {
// verify deleted file exists on disk
assertTrue(getRemoteCacheIndexFileFromDisk(LogFileUtils.DELETED_FILE_SUFFIX).isPresent, s"Deleted Offset index file should be present on disk at ${remoteIndexCacheDir.toPath}")

cache.internalCache().invalidate(rlsMetadata.remoteLogSegmentId().id())
cache.remove(rlsMetadata.remoteLogSegmentId().id())

// wait until entry is marked for deletion
TestUtils.waitUntilTrue(() => entry.isMarkedForCleanup,
Expand All @@ -862,9 +936,9 @@ class RemoteIndexCacheTest {
= RemoteLogSegmentId.generateNew(idPartition)): RemoteIndexCache.Entry = {
val rlsMetadata = new RemoteLogSegmentMetadata(remoteLogSegmentId, baseOffset, lastOffset,
time.milliseconds(), brokerId, time.milliseconds(), segmentSize, Collections.singletonMap(0, 0L))
val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata))
val timeIndex = spy(createTimeIndexForSegmentMetadata(rlsMetadata, tpDir))
val txIndex = spy(createTxIndexForSegmentMetadata(rlsMetadata, tpDir))
val offsetIndex = spy(createOffsetIndexForSegmentMetadata(rlsMetadata, tpDir))
spy(new RemoteIndexCache.Entry(offsetIndex, timeIndex, txIndex))
}

Expand Down Expand Up @@ -892,8 +966,8 @@ class RemoteIndexCacheTest {
}
}

private def createTxIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata): TransactionIndex = {
val txnIdxFile = remoteTransactionIndexFile(tpDir, metadata)
private def createTxIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File): TransactionIndex = {
val txnIdxFile = remoteTransactionIndexFile(dir, metadata)
txnIdxFile.createNewFile()
new TransactionIndex(metadata.startOffset(), txnIdxFile)
}
Expand All @@ -914,14 +988,14 @@ class RemoteIndexCacheTest {
return new TransactionIndex(100L, txnIdxFile)
}

private def createTimeIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata): TimeIndex = {
private def createTimeIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File): TimeIndex = {
val maxEntries = (metadata.endOffset() - metadata.startOffset()).asInstanceOf[Int]
new TimeIndex(remoteTimeIndexFile(tpDir, metadata), metadata.startOffset(), maxEntries * 12)
new TimeIndex(remoteTimeIndexFile(dir, metadata), metadata.startOffset(), maxEntries * 12)
}

private def createOffsetIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata) = {
private def createOffsetIndexForSegmentMetadata(metadata: RemoteLogSegmentMetadata, dir: File) = {
val maxEntries = (metadata.endOffset() - metadata.startOffset()).asInstanceOf[Int]
new OffsetIndex(remoteOffsetIndexFile(tpDir, metadata), metadata.startOffset(), maxEntries * 8)
new OffsetIndex(remoteOffsetIndexFile(dir, metadata), metadata.startOffset(), maxEntries * 8)
}

private def generateRemoteLogSegmentMetadata(size: Int,
Expand Down Expand Up @@ -969,9 +1043,9 @@ class RemoteIndexCacheTest {
.thenAnswer(ans => {
val metadata = ans.getArgument[RemoteLogSegmentMetadata](0)
val indexType = ans.getArgument[IndexType](1)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata)
val timeIdx = createTimeIndexForSegmentMetadata(metadata)
val txnIdx = createTxIndexForSegmentMetadata(metadata)
val offsetIdx = createOffsetIndexForSegmentMetadata(metadata, tpDir)
val timeIdx = createTimeIndexForSegmentMetadata(metadata, tpDir)
val txnIdx = createTxIndexForSegmentMetadata(metadata, tpDir)
maybeAppendIndexEntries(offsetIdx, timeIdx)
indexType match {
case IndexType.OFFSET => new FileInputStream(offsetIdx.file)
Expand Down
Expand Up @@ -151,9 +151,8 @@ private Cache<Uuid, Entry> initEmptyCache(long maxSize) {
.weigher((Uuid key, Entry entry) -> {
return (int) entry.entrySizeBytes;
})
// removeListener is invoked when either the entry is invalidated (means manual removal by the caller) or
// evicted (means removal due to the policy)
.removalListener((Uuid key, Entry entry, RemovalCause cause) -> {
// evictionListener is invoked when RemovalCause.wasEvicted() is true
divijvaidya marked this conversation as resolved.
Show resolved Hide resolved
.evictionListener((Uuid key, Entry entry, RemovalCause cause) -> {
// Mark the entries for cleanup and add them to the queue to be garbage collected later by the background thread.
if (entry != null) {
try {
Expand Down Expand Up @@ -187,7 +186,18 @@ public File cacheDir() {
public void remove(Uuid key) {
lock.readLock().lock();
try {
internalCache.invalidate(key);
internalCache.asMap().computeIfPresent(key, (k, v) -> {
try {
v.markForCleanup();
if (!expiredIndexes.offer(v)) {
log.error("Error while inserting entry {} for key {} into the cleaner queue because queue is full.", v, k);
}
} catch (IOException e) {
throw new KafkaException(e);
}
// Returning null to remove the key from the cache
return null;
});
divijvaidya marked this conversation as resolved.
Show resolved Hide resolved
} finally {
lock.readLock().unlock();
}
Expand All @@ -196,7 +206,18 @@ public void remove(Uuid key) {
public void removeAll(Collection<Uuid> keys) {
lock.readLock().lock();
try {
internalCache.invalidateAll(keys);
keys.forEach(key -> internalCache.asMap().computeIfPresent(key, (k, v) -> {
try {
v.markForCleanup();
jeel2420 marked this conversation as resolved.
Show resolved Hide resolved
if (!expiredIndexes.offer(v)) {
log.error("Error while inserting entry {} for key {} into the cleaner queue because queue is full.", v, k);
}
} catch (IOException e) {
throw new KafkaException(e);
}
// Returning null to remove the key from the cache
return null;
}));
} finally {
lock.readLock().unlock();
}
Expand Down