Skip to content

Commit

Permalink
move the read.unlock bounds check into the critical section to avoid …
Browse files Browse the repository at this point in the history
…a race condition
  • Loading branch information
frett committed Oct 9, 2020
1 parent 19ceaeb commit a48bc70
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
@@ -1,7 +1,6 @@
package org.ccci.gto.android.common.kotlin.coroutines

import androidx.annotation.VisibleForTesting
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
Expand All @@ -16,28 +15,29 @@ fun ReadWriteMutex(): ReadWriteMutex = ReadWriteMutexImpl()
@VisibleForTesting
internal class ReadWriteMutexImpl : ReadWriteMutex {
private val stateMutex = Mutex()
private val readerOwner = Any()
@VisibleForTesting
internal val readers = AtomicLong(0)
internal var readers = 0L

override val write = Mutex()
override val read = object : Mutex {
override suspend fun lock(owner: Any?) {
stateMutex.withLock {
check(readers.get() < Long.MAX_VALUE) {
check(readers < Long.MAX_VALUE) {
"Attempt to lock the read mutex more than ${Long.MAX_VALUE} times concurrently"
}
// first reader should lock the write mutex
if (readers.get() == 0L) write.lock(readers)
readers.incrementAndGet()
if (readers == 0L) write.lock(readerOwner)
readers++
}
}

override fun unlock(owner: Any?) {
runBlocking {
check(readers.get() > 0L) { "Attempt to unlock the read mutex when it wasn't locked" }
stateMutex.withLock {
check(readers > 0L) { "Attempt to unlock the read mutex when it wasn't locked" }
// release the write mutex lock when this is the last reader
if (readers.decrementAndGet() == 0L) write.unlock(readers)
if (--readers == 0L) write.unlock(readerOwner)
}
}
}
Expand Down
Expand Up @@ -129,7 +129,7 @@ class ReadWriteMutexTest {
@Test(expected = IllegalStateException::class)
fun testReadLockTooManyTimes() {
runBlocking {
(mutex as ReadWriteMutexImpl).readers.set(Long.MAX_VALUE - 1)
(mutex as ReadWriteMutexImpl).readers = Long.MAX_VALUE - 1
while (true) mutex.read.lock()
}
}
Expand Down Expand Up @@ -159,7 +159,7 @@ class ReadWriteMutexTest {
mutex.read.lock()
running.set(false)
tasks.joinAll()
assertEquals(0, (mutex as ReadWriteMutexImpl).readers.get())
assertEquals(0, (mutex as ReadWriteMutexImpl).readers)
}
}
}
Expand Down

0 comments on commit a48bc70

Please sign in to comment.