Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FieldWalker, don't access JDK classes #1799

Merged
merged 2 commits into from
Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 112 additions & 73 deletions kotlinx-coroutines-core/jvm/test/FieldWalker.kt
Original file line number Diff line number Diff line change
@@ -1,115 +1,154 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import java.lang.reflect.*
import java.util.*
import java.util.Collections.*
import java.util.concurrent.atomic.*
import kotlin.collections.ArrayList
import kotlin.test.*

object FieldWalker {
sealed class Ref {
object RootRef : Ref()
class FieldRef(val parent: Any, val name: String) : Ref()
class ArrayRef(val parent: Any, val index: Int) : Ref()
}

private val fieldsCache = HashMap<Class<*>, List<Field>>()

init {
// excluded/terminal classes (don't walk them)
fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class)
.map { it.java }
.associateWith { emptyList<Field>() }
}

/*
* Reflectively starts to walk through object graph and returns identity set of all reachable objects.
* Use [walkRefs] if you need a path from root for debugging.
*/
public fun walk(root: Any?): Set<Any> = walkRefs(root).keys

public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) {
val visited = walkRefs(root)
val actual = visited.keys.filter(predicate)
if (actual.size != expected) {
val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) }
assertEquals(
expected, actual.size,
"Unexpected number objects. Expected $expected, found ${actual.size}$textDump"
)
}
}

/*
* Reflectively starts to walk through object graph and map to all the reached object to their path
* in from root. Use [showPath] do display a path if needed.
*/
public fun walk(root: Any): Set<Any> {
val result = newSetFromMap<Any>(IdentityHashMap())
result.add(root)
private fun walkRefs(root: Any?): Map<Any, Ref> {
val visited = IdentityHashMap<Any, Ref>()
if (root == null) return visited
visited[root] = Ref.RootRef
val stack = ArrayDeque<Any>()
stack.addLast(root)
while (stack.isNotEmpty()) {
val element = stack.removeLast()
val type = element.javaClass
type.visit(element, result, stack)
try {
visit(element, visited, stack)
} catch (e: Exception) {
error("Failed to visit element ${showPath(element, visited)}: $e")
}
}
return result
return visited
}

private fun Class<*>.visit(
element: Any,
result: MutableSet<Any>,
stack: ArrayDeque<Any>
) {
val fields = fields()
fields.forEach {
it.isAccessible = true
val value = it.get(element) ?: return@forEach
if (result.add(value)) {
stack.addLast(value)
private fun showPath(element: Any, visited: Map<Any, Ref>): String {
val path = ArrayList<String>()
var cur = element
while (true) {
val ref = visited.getValue(cur)
if (ref is Ref.RootRef) break
when (ref) {
is Ref.FieldRef -> {
cur = ref.parent
path += ".${ref.name}"
}
is Ref.ArrayRef -> {
cur = ref.parent
path += "[${ref.index}]"
}
}
}
path.reverse()
return path.joinToString("")
}

if (isArray && !componentType.isPrimitive) {
val array = element as Array<Any?>
array.filterNotNull().forEach {
if (result.add(it)) {
stack.addLast(it)
private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) {
val type = element.javaClass
when {
// Special code for arrays
type.isArray && !type.componentType.isPrimitive -> {
@Suppress("UNCHECKED_CAST")
val array = element as Array<Any?>
array.forEachIndexed { index, value ->
push(value, visited, stack) { Ref.ArrayRef(element, index) }
}
}
// Special code for platform types that cannot be reflectively accessed on modern JDKs
type.name.startsWith("java.") && element is Collection<*> -> {
element.forEachIndexed { index, value ->
push(value, visited, stack) { Ref.ArrayRef(element, index) }
}
}
type.name.startsWith("java.") && element is Map<*, *> -> {
qwwdfsad marked this conversation as resolved.
Show resolved Hide resolved
push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") }
push(element.values, visited, stack) { Ref.FieldRef(element, "values") }
}
element is AtomicReference<*> -> {
push(element.get(), visited, stack) { Ref.FieldRef(element, "value") }
}
// All the other classes are reflectively scanned
else -> fields(type).forEach { field ->
push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) }
// special case to scan Throwable cause (cannot get it reflectively)
if (element is Throwable) {
push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") }
}
}
}
}

private fun Class<*>.fields(): List<Field> {
private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) {
if (value != null && !visited.containsKey(value)) {
visited[value] = ref()
stack.addLast(value)
}
}

private fun fields(type0: Class<*>): List<Field> {
fieldsCache[type0]?.let { return it }
val result = ArrayList<Field>()
var type = this
while (type != Any::class.java) {
var type = type0
while (true) {
val fields = type.declaredFields.filter {
!it.type.isPrimitive
&& !Modifier.isStatic(it.modifiers)
&& !(it.type.isArray && it.type.componentType.isPrimitive)
}
fields.forEach { it.isAccessible = true } // make them all accessible
result.addAll(fields)
type = type.superclass
}

return result
}

// Debugging-only
@Suppress("UNUSED")
fun printPath(from: Any, to: Any) {
val pathNodes = ArrayList<String>()
val visited = newSetFromMap<Any>(IdentityHashMap())
visited.add(from)
if (findPath(from, to, visited, pathNodes)) {
pathNodes.reverse()
println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName))
} else {
println("Path from $from to $to not found")
}
}

private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean {
if (from === to) {
return true
}

val type = from.javaClass
if (type.isArray) {
if (type.componentType.isPrimitive) return false
val array = from as Array<Any?>
array.filterNotNull().forEach {
if (findPath(it, to, visited, pathNodes)) {
return true
}
val superFields = fieldsCache[type] // will stop at Any anyway
if (superFields != null) {
result.addAll(superFields)
break
}
return false
}

val fields = type.fields()
fields.forEach {
it.isAccessible = true
val value = it.get(from) ?: return@forEach
if (!visited.add(value)) return@forEach
val found = findPath(value, to, visited, pathNodes)
if (found) {
pathNodes += from.javaClass.simpleName + ":" + it.name
return true
}
}

return false
fieldsCache[type0] = result
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(4)
ensureActive()
// Verify child was bound
assertNotNull(FieldWalker.walk(coroutineContext[Job]!!).single { it === continuation })
FieldWalker.assertReachableCount(1, coroutineContext[Job]) { it === continuation }
suspendAtomicCancellableCoroutineReusable<Unit> {
expect(5)
coroutineContext[Job]!!.cancel()
Expand All @@ -97,7 +97,7 @@ class ReusableCancellableContinuationTest : TestBase() {
cont = it
}
ensureActive()
assertTrue { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
assertTrue { FieldWalker.walk(coroutineContext[Job]).contains(cont!!) }
finish(2)
}

Expand All @@ -112,7 +112,7 @@ class ReusableCancellableContinuationTest : TestBase() {
cont = it
}
ensureActive()
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
finish(2)
}

Expand All @@ -127,7 +127,7 @@ class ReusableCancellableContinuationTest : TestBase() {
}
expectUnreached()
} catch (e: CancellationException) {
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
finish(2)
}
}
Expand All @@ -148,19 +148,19 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(4)
ensureActive()
// Verify child was bound
assertEquals(1, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(1, currentJob) { it is CancellableContinuation<*> }
currentJob.cancel()
assertFalse(isActive)
// Child detached
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }

try {
suspendAtomicCancellableCoroutineReusable<Unit> {}
} catch (e: CancellationException) {
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
finish(5)
}
}
Expand All @@ -184,12 +184,12 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(2)
val job = coroutineContext[Job]!!
// 1 for reusable CC, another one for outer joiner
assertEquals(2, FieldWalker.walk(job).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(2, job) { it is CancellableContinuation<*> }
}
expect(1)
receiver.join()
// Reference should be claimed at this point
assertEquals(0, FieldWalker.walk(receiver).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, receiver) { it is CancellableContinuation<*> }
finish(3)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ConsumeAsFlowLeakTest : TestBase() {
if (shouldSuspendOnSend) yield()
channel.send(second)
yield()
assertEquals(0, FieldWalker.walk(channel).count { it === second })
FieldWalker.assertReachableCount(0, channel) { it === second }
finish(6)
job.cancelAndJoin()
}
Expand Down