Skip to content

Commit

Permalink
[Prototype] do not use tree-based decoding for fast-path polymorphism
Browse files Browse the repository at this point in the history
Fixes #1839
  • Loading branch information
qwwdfsad committed May 24, 2022
1 parent bf3269b commit 3e93bfd
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 20 deletions.
5 changes: 2 additions & 3 deletions benchmark/build.gradle
Expand Up @@ -6,13 +6,12 @@ apply plugin: 'java'
apply plugin: 'kotlin'
apply plugin: 'kotlinx-serialization'
apply plugin: 'idea'
apply plugin: 'net.ltgt.apt'
apply plugin: 'com.github.johnrengelman.shadow'
apply plugin: 'me.champeau.gradle.jmh'
apply plugin: 'me.champeau.jmh'

sourceCompatibility = 1.8
targetCompatibility = 1.8
jmh.jmhVersion = 1.22
jmh.jmhVersion = "1.22"

jmhJar {
baseName 'benchmarks'
Expand Down
@@ -0,0 +1,53 @@
package kotlinx.benchmarks.json

import kotlinx.serialization.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*

@Warmup(iterations = 7, time = 1)
@Measurement(iterations = 5, time = 1)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Fork(1)
open class PolymorphismOverheadBenchmark {

@Serializable
@JsonClassDiscriminator("poly")
data class PolymorphicWrapper(val i: @Polymorphic Poly, val i2: Impl) // amortize the cost a bit

@Serializable
data class BaseWrapper(val i: Impl, val i2: Impl)

@JsonClassDiscriminator("poly")
interface Poly

@Serializable
@JsonClassDiscriminator("poly")
class Impl(val a: Int, val b: String) : Poly

private val impl = Impl(239, "average_size_string")
private val module = SerializersModule {
polymorphic(Poly::class) {
subclass(Impl.serializer())
}
}

private val json = Json { serializersModule = module }
private val implString = json.encodeToString(impl)
private val polyString = json.encodeToString<Poly>(impl)
private val serializer = serializer<Poly>()

// 5000
@Benchmark
fun base() = json.decodeFromString(Impl.serializer(), implString)

// Baseline -- 1500
// v1, no skip -- 2000
// v2, with skip -- 3000
@Benchmark
fun poly() = json.decodeFromString(serializer, polyString)

}
3 changes: 1 addition & 2 deletions build.gradle
Expand Up @@ -74,8 +74,7 @@ buildscript {

// Various benchmarking stuff
classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2"
classpath "me.champeau.gradle:jmh-gradle-plugin:0.5.3"
classpath "net.ltgt.gradle:gradle-apt-plugin:0.21"
classpath "me.champeau.jmh:jmh-gradle-plugin:0.6.6"
}
}

Expand Down
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.internal.*
import kotlinx.serialization.json.*
import kotlin.jvm.*

@Suppress("UNCHECKED_CAST")
internal inline fun <T> JsonEncoder.encodePolymorphically(
Expand Down Expand Up @@ -55,12 +56,13 @@ internal fun checkKind(kind: SerialKind) {
}

internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: DeserializationStrategy<T>): T {
// NB: changes in this method should be reflected in StreamingJsonDecoder#decodeSerializableValue
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
return deserializer.deserialize(this)
}
val discriminator = deserializer.descriptor.classDiscriminator(json)

val jsonTree = cast<JsonObject>(decodeJsonElement(), deserializer.descriptor)
val discriminator = deserializer.descriptor.classDiscriminator(json)
val type = jsonTree[discriminator]?.jsonPrimitive?.content
val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
?: throwSerializerNotFound(type, jsonTree)
Expand All @@ -69,7 +71,8 @@ internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De
return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer as DeserializationStrategy<T>)
}

private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
@JvmName("throwSerializerNotFound")
internal fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
val suffix =
if (type == null) "missing class discriminator ('null')"
else "class discriminator '$type'"
Expand Down
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE
import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME
import kotlinx.serialization.internal.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlin.jvm.*
Expand All @@ -35,7 +36,38 @@ internal open class StreamingJsonDecoder(
@Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
try {
return decodeSerializableValuePolymorphic(deserializer)
/*
* This is an optimized path over decodeSerializableValuePolymorphic(deserializer):
* dSVP reads the very next JSON tree into a memory as JsonElement and then runs TreeJsonDecoder over it
* in order to deal with an arbitrary order of keys, but with the price of additional memory pressure
* and CPU consumption.
* We would like to provide best possible performance for data produced by kotlinx.serialization
* itself, for that we do the following optimistic optimization:
*
* 0) Remember current position in the string
* 1) Read the very next key of JSON structure
* 2) If it matches* the descriminator key, read the value, remember current position
* 3) Return the value, recover an initial position
* 4) Right after starting the structure beginning, immediately skip over discriminator
* using position from '2'.
* In such scenario we neither process the same input twice nor create aux data structures.
*
* (*) -- if it doesn't match, fallback to dSVP method.
*/
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
return deserializer.deserialize(this)
}

val discriminator = deserializer.descriptor.classDiscriminator(json)
val type = lexer.consumeLeadingMatchingValue(discriminator, configuration.isLenient)
val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
// TODO ask, seems to be inference bug?
if (actualSerializer == null) {
return decodeSerializableValuePolymorphic<T>(deserializer as DeserializationStrategy<T>)
}
@Suppress("UNCHECKED_CAST")
return actualSerializer.deserialize(this) as T

} catch (e: MissingFieldException) {
throw MissingFieldException(e.message + " at path: " + lexer.path.getPath(), e)
}
Expand All @@ -54,10 +86,13 @@ internal open class StreamingJsonDecoder(
lexer,
descriptor
)
else -> if (mode == newMode && json.configuration.explicitNulls) {
this
} else {
StreamingJsonDecoder(json, newMode, lexer, descriptor)
else -> {
lexer.skipReadPrefix()
if (mode == newMode && json.configuration.explicitNulls) {
this
} else {
StreamingJsonDecoder(json, newMode, lexer, descriptor)
}
}
}
}
Expand Down
Expand Up @@ -140,6 +140,19 @@ internal abstract class AbstractJsonLexer {
@JvmField
val path = JsonPath()

/*
* Position on the string right after polymorphic descriptor
* that was successfully looked up by 'consumeLeadingMatchingValue'
*/
private var snapshotPosition: Int = -1

fun skipReadPrefix() {
if (snapshotPosition != -1) {
currentPosition = snapshotPosition
snapshotPosition = -1
}
}

open fun ensureHaveChars() {}

fun isNotEof(): Boolean = peekNextToken() != TC_EOF
Expand Down Expand Up @@ -283,6 +296,25 @@ internal abstract class AbstractJsonLexer {
return current
}

fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? {
val positionSnapshot = currentPosition
try {
// Malformed JSON, bailout
if (consumeNextToken() != TC_BEGIN_OBJ) return null
val firstKey = if (isLenient) consumeKeyString() else consumeStringLenientNotNull()
if (firstKey == keyToMatch) {
if (consumeNextToken() != TC_COLON) return null
val result = if (isLenient) consumeString() else consumeStringLenientNotNull()
snapshotPosition = currentPosition
return result
}
return null
} finally {
// Restore the position
currentPosition = positionSnapshot
}
}

fun peekString(isLenient: Boolean): String? {
val token = peekNextToken()
val string = if (isLenient) {
Expand Down
Expand Up @@ -78,10 +78,10 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer(

override fun consumeKeyString(): String {
/*
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
consumeNextToken(STRING)
val current = currentPosition
val closingQuote = source.indexOf('"', current)
Expand Down
Expand Up @@ -133,10 +133,10 @@ internal class ReaderJsonLexer(

override fun consumeKeyString(): String {
/*
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
consumeNextToken(STRING)
var current = currentPosition
val closingQuote = indexOf('"', current)
Expand Down

0 comments on commit 3e93bfd

Please sign in to comment.