-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve FieldWalker, don't access JDK classes (#1799)
* Improve FieldWalker, don't access JDK classes * Works on future JDKs that forbid reflective access to JDK classes * Show human-readable path to field is something fails
- Loading branch information
Showing
3 changed files
with
123 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,115 +1,154 @@ | ||
/* | ||
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. | ||
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. | ||
*/ | ||
|
||
package kotlinx.coroutines | ||
|
||
import java.lang.reflect.* | ||
import java.util.* | ||
import java.util.Collections.* | ||
import java.util.concurrent.atomic.* | ||
import kotlin.collections.ArrayList | ||
import kotlin.test.* | ||
|
||
object FieldWalker { | ||
sealed class Ref { | ||
object RootRef : Ref() | ||
class FieldRef(val parent: Any, val name: String) : Ref() | ||
class ArrayRef(val parent: Any, val index: Int) : Ref() | ||
} | ||
|
||
private val fieldsCache = HashMap<Class<*>, List<Field>>() | ||
|
||
init { | ||
// excluded/terminal classes (don't walk them) | ||
fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class) | ||
.map { it.java } | ||
.associateWith { emptyList<Field>() } | ||
} | ||
|
||
/* | ||
* Reflectively starts to walk through object graph and returns identity set of all reachable objects. | ||
* Use [walkRefs] if you need a path from root for debugging. | ||
*/ | ||
public fun walk(root: Any?): Set<Any> = walkRefs(root).keys | ||
|
||
public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) { | ||
val visited = walkRefs(root) | ||
val actual = visited.keys.filter(predicate) | ||
if (actual.size != expected) { | ||
val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) } | ||
assertEquals( | ||
expected, actual.size, | ||
"Unexpected number objects. Expected $expected, found ${actual.size}$textDump" | ||
) | ||
} | ||
} | ||
|
||
/* | ||
* Reflectively starts to walk through object graph and map to all the reached object to their path | ||
* in from root. Use [showPath] do display a path if needed. | ||
*/ | ||
public fun walk(root: Any): Set<Any> { | ||
val result = newSetFromMap<Any>(IdentityHashMap()) | ||
result.add(root) | ||
private fun walkRefs(root: Any?): Map<Any, Ref> { | ||
val visited = IdentityHashMap<Any, Ref>() | ||
if (root == null) return visited | ||
visited[root] = Ref.RootRef | ||
val stack = ArrayDeque<Any>() | ||
stack.addLast(root) | ||
while (stack.isNotEmpty()) { | ||
val element = stack.removeLast() | ||
val type = element.javaClass | ||
type.visit(element, result, stack) | ||
try { | ||
visit(element, visited, stack) | ||
} catch (e: Exception) { | ||
error("Failed to visit element ${showPath(element, visited)}: $e") | ||
} | ||
} | ||
return result | ||
return visited | ||
} | ||
|
||
private fun Class<*>.visit( | ||
element: Any, | ||
result: MutableSet<Any>, | ||
stack: ArrayDeque<Any> | ||
) { | ||
val fields = fields() | ||
fields.forEach { | ||
it.isAccessible = true | ||
val value = it.get(element) ?: return@forEach | ||
if (result.add(value)) { | ||
stack.addLast(value) | ||
private fun showPath(element: Any, visited: Map<Any, Ref>): String { | ||
val path = ArrayList<String>() | ||
var cur = element | ||
while (true) { | ||
val ref = visited.getValue(cur) | ||
if (ref is Ref.RootRef) break | ||
when (ref) { | ||
is Ref.FieldRef -> { | ||
cur = ref.parent | ||
path += ".${ref.name}" | ||
} | ||
is Ref.ArrayRef -> { | ||
cur = ref.parent | ||
path += "[${ref.index}]" | ||
} | ||
} | ||
} | ||
path.reverse() | ||
return path.joinToString("") | ||
} | ||
|
||
if (isArray && !componentType.isPrimitive) { | ||
val array = element as Array<Any?> | ||
array.filterNotNull().forEach { | ||
if (result.add(it)) { | ||
stack.addLast(it) | ||
private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) { | ||
val type = element.javaClass | ||
when { | ||
// Special code for arrays | ||
type.isArray && !type.componentType.isPrimitive -> { | ||
@Suppress("UNCHECKED_CAST") | ||
val array = element as Array<Any?> | ||
array.forEachIndexed { index, value -> | ||
push(value, visited, stack) { Ref.ArrayRef(element, index) } | ||
} | ||
} | ||
// Special code for platform types that cannot be reflectively accessed on modern JDKs | ||
type.name.startsWith("java.") && element is Collection<*> -> { | ||
element.forEachIndexed { index, value -> | ||
push(value, visited, stack) { Ref.ArrayRef(element, index) } | ||
} | ||
} | ||
type.name.startsWith("java.") && element is Map<*, *> -> { | ||
push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") } | ||
push(element.values, visited, stack) { Ref.FieldRef(element, "values") } | ||
} | ||
element is AtomicReference<*> -> { | ||
push(element.get(), visited, stack) { Ref.FieldRef(element, "value") } | ||
} | ||
// All the other classes are reflectively scanned | ||
else -> fields(type).forEach { field -> | ||
push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) } | ||
// special case to scan Throwable cause (cannot get it reflectively) | ||
if (element is Throwable) { | ||
push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") } | ||
} | ||
} | ||
} | ||
} | ||
|
||
private fun Class<*>.fields(): List<Field> { | ||
private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) { | ||
if (value != null && !visited.containsKey(value)) { | ||
visited[value] = ref() | ||
stack.addLast(value) | ||
} | ||
} | ||
|
||
private fun fields(type0: Class<*>): List<Field> { | ||
fieldsCache[type0]?.let { return it } | ||
val result = ArrayList<Field>() | ||
var type = this | ||
while (type != Any::class.java) { | ||
var type = type0 | ||
while (true) { | ||
val fields = type.declaredFields.filter { | ||
!it.type.isPrimitive | ||
&& !Modifier.isStatic(it.modifiers) | ||
&& !(it.type.isArray && it.type.componentType.isPrimitive) | ||
} | ||
fields.forEach { it.isAccessible = true } // make them all accessible | ||
result.addAll(fields) | ||
type = type.superclass | ||
} | ||
|
||
return result | ||
} | ||
|
||
// Debugging-only | ||
@Suppress("UNUSED") | ||
fun printPath(from: Any, to: Any) { | ||
val pathNodes = ArrayList<String>() | ||
val visited = newSetFromMap<Any>(IdentityHashMap()) | ||
visited.add(from) | ||
if (findPath(from, to, visited, pathNodes)) { | ||
pathNodes.reverse() | ||
println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName)) | ||
} else { | ||
println("Path from $from to $to not found") | ||
} | ||
} | ||
|
||
private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean { | ||
if (from === to) { | ||
return true | ||
} | ||
|
||
val type = from.javaClass | ||
if (type.isArray) { | ||
if (type.componentType.isPrimitive) return false | ||
val array = from as Array<Any?> | ||
array.filterNotNull().forEach { | ||
if (findPath(it, to, visited, pathNodes)) { | ||
return true | ||
} | ||
val superFields = fieldsCache[type] // will stop at Any anyway | ||
if (superFields != null) { | ||
result.addAll(superFields) | ||
break | ||
} | ||
return false | ||
} | ||
|
||
val fields = type.fields() | ||
fields.forEach { | ||
it.isAccessible = true | ||
val value = it.get(from) ?: return@forEach | ||
if (!visited.add(value)) return@forEach | ||
val found = findPath(value, to, visited, pathNodes) | ||
if (found) { | ||
pathNodes += from.javaClass.simpleName + ":" + it.name | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
fieldsCache[type0] = result | ||
return result | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters