Skip to content

Commit

Permalink
Do not use tree-based decoding for fast-path polymorphism (#1919)
Browse files Browse the repository at this point in the history
Do not use tree-based decoding for fast-path polymorphism and try to optimistically read it as very first key and then silently skip

Fixes #1839
  • Loading branch information
qwwdfsad committed Jun 24, 2022
1 parent bb18d62 commit 93a06df
Show file tree
Hide file tree
Showing 15 changed files with 233 additions and 28 deletions.
5 changes: 2 additions & 3 deletions benchmark/build.gradle
Expand Up @@ -6,13 +6,12 @@ apply plugin: 'java'
apply plugin: 'kotlin'
apply plugin: 'kotlinx-serialization'
apply plugin: 'idea'
apply plugin: 'net.ltgt.apt'
apply plugin: 'com.github.johnrengelman.shadow'
apply plugin: 'me.champeau.gradle.jmh'
apply plugin: 'me.champeau.jmh'

sourceCompatibility = 1.8
targetCompatibility = 1.8
jmh.jmhVersion = 1.22
jmh.jmhVersion = "1.22"

jmhJar {
baseName 'benchmarks'
Expand Down
@@ -0,0 +1,54 @@
package kotlinx.benchmarks.json

import kotlinx.serialization.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*

@Warmup(iterations = 7, time = 1)
@Measurement(iterations = 5, time = 1)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Fork(1)
open class PolymorphismOverheadBenchmark {

@Serializable
@JsonClassDiscriminator("poly")
data class PolymorphicWrapper(val i: @Polymorphic Poly, val i2: Impl) // amortize the cost a bit

@Serializable
data class BaseWrapper(val i: Impl, val i2: Impl)

@JsonClassDiscriminator("poly")
interface Poly

@Serializable
@JsonClassDiscriminator("poly")
class Impl(val a: Int, val b: String) : Poly

private val impl = Impl(239, "average_size_string")
private val module = SerializersModule {
polymorphic(Poly::class) {
subclass(Impl.serializer())
}
}

private val json = Json { serializersModule = module }
private val implString = json.encodeToString(impl)
private val polyString = json.encodeToString<Poly>(impl)
private val serializer = serializer<Poly>()

// 5000
@Benchmark
fun base() = json.decodeFromString(Impl.serializer(), implString)

// As of 1.3.x
// Baseline -- 1500
// v1, no skip -- 2000
// v2, with skip -- 3000 [withdrawn]
@Benchmark
fun poly() = json.decodeFromString(serializer, polyString)

}
3 changes: 1 addition & 2 deletions build.gradle
Expand Up @@ -74,8 +74,7 @@ buildscript {

// Various benchmarking stuff
classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2"
classpath "me.champeau.gradle:jmh-gradle-plugin:0.5.3"
classpath "net.ltgt.gradle:gradle-apt-plugin:0.21"
classpath "me.champeau.jmh:jmh-gradle-plugin:0.6.6"
}
}

Expand Down
Expand Up @@ -96,7 +96,7 @@ public sealed class Json(
*/
public final override fun <T> decodeFromString(deserializer: DeserializationStrategy<T>, string: String): T {
val lexer = StringJsonLexer(string)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null)
val result = input.decodeSerializableValue(deserializer)
lexer.expectEof()
return result
Expand Down
Expand Up @@ -24,7 +24,7 @@ internal class JsonPath {

// Tombstone indicates that we are within a map, but the map key is currently being decoded.
// It is also used to overwrite a previous map key to avoid memory leaks and misattribution.
object Tombstone
private object Tombstone

/*
* Serial descriptor, map key or the tombstone for map key
Expand Down
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.internal.*
import kotlinx.serialization.json.*
import kotlin.jvm.*

@Suppress("UNCHECKED_CAST")
internal inline fun <T> JsonEncoder.encodePolymorphically(
Expand Down Expand Up @@ -55,12 +56,13 @@ internal fun checkKind(kind: SerialKind) {
}

internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: DeserializationStrategy<T>): T {
// NB: changes in this method should be reflected in StreamingJsonDecoder#decodeSerializableValue
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
return deserializer.deserialize(this)
}
val discriminator = deserializer.descriptor.classDiscriminator(json)

val jsonTree = cast<JsonObject>(decodeJsonElement(), deserializer.descriptor)
val discriminator = deserializer.descriptor.classDiscriminator(json)
val type = jsonTree[discriminator]?.jsonPrimitive?.content
val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
?: throwSerializerNotFound(type, jsonTree)
Expand All @@ -69,7 +71,8 @@ internal fun <T> JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De
return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer as DeserializationStrategy<T>)
}

private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
@JvmName("throwSerializerNotFound")
internal fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing {
val suffix =
if (type == null) "missing class discriminator ('null')"
else "class discriminator '$type'"
Expand Down
Expand Up @@ -9,6 +9,7 @@ import kotlinx.serialization.descriptors.*
import kotlinx.serialization.encoding.*
import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE
import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME
import kotlinx.serialization.internal.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlin.jvm.*
Expand All @@ -21,11 +22,27 @@ internal open class StreamingJsonDecoder(
final override val json: Json,
private val mode: WriteMode,
@JvmField internal val lexer: AbstractJsonLexer,
descriptor: SerialDescriptor
descriptor: SerialDescriptor,
discriminatorHolder: DiscriminatorHolder?
) : JsonDecoder, AbstractDecoder() {

// A mutable reference to the discriminator that have to be skipped when in optimistic phase
// of polymorphic serialization, see `decodeSerializableValue`
internal class DiscriminatorHolder(@JvmField var discriminatorToSkip: String?)

private fun DiscriminatorHolder?.trySkip(unknownKey: String): Boolean {
if (this == null) return false
if (discriminatorToSkip == unknownKey) {
discriminatorToSkip = null
return true
}
return false
}


override val serializersModule: SerializersModule = json.serializersModule
private var currentIndex = -1
private var discriminatorHolder: DiscriminatorHolder? = discriminatorHolder
private val configuration = json.configuration

private val elementMarker: JsonElementMarker? = if (configuration.explicitNulls) null else JsonElementMarker(descriptor)
Expand All @@ -35,7 +52,40 @@ internal open class StreamingJsonDecoder(
@Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
try {
return decodeSerializableValuePolymorphic(deserializer)
/*
* This is an optimized path over decodeSerializableValuePolymorphic(deserializer):
* dSVP reads the very next JSON tree into a memory as JsonElement and then runs TreeJsonDecoder over it
* in order to deal with an arbitrary order of keys, but with the price of additional memory pressure
* and CPU consumption.
* We would like to provide best possible performance for data produced by kotlinx.serialization
* itself, for that we do the following optimistic optimization:
*
* 0) Remember current position in the string
* 1) Read the very next key of JSON structure
* 2) If it matches* the descriminator key, read the value, remember current position
* 3) Return the value, recover an initial position
* (*) -- if it doesn't match, fallback to dSVP method.
*/
if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) {
return deserializer.deserialize(this)
}

val discriminator = deserializer.descriptor.classDiscriminator(json)
val type = lexer.consumeLeadingMatchingValue(discriminator, configuration.isLenient)
var actualSerializer: DeserializationStrategy<out Any>? = null
if (type != null) {
actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type)
}
if (actualSerializer == null) {
// Fallback if we haven't found discriminator or serializer
return decodeSerializableValuePolymorphic<T>(deserializer as DeserializationStrategy<T>)
}

discriminatorHolder = DiscriminatorHolder(discriminator)
@Suppress("UNCHECKED_CAST")
val result = actualSerializer.deserialize(this) as T
return result

} catch (e: MissingFieldException) {
throw MissingFieldException(e.message + " at path: " + lexer.path.getPath(), e)
}
Expand All @@ -52,12 +102,13 @@ internal open class StreamingJsonDecoder(
json,
newMode,
lexer,
descriptor
descriptor,
discriminatorHolder
)
else -> if (mode == newMode && json.configuration.explicitNulls) {
this
} else {
StreamingJsonDecoder(json, newMode, lexer, descriptor)
StreamingJsonDecoder(json, newMode, lexer, descriptor, discriminatorHolder)
}
}
}
Expand Down Expand Up @@ -193,7 +244,7 @@ internal open class StreamingJsonDecoder(
}

private fun handleUnknown(key: String): Boolean {
if (configuration.ignoreUnknownKeys) {
if (configuration.ignoreUnknownKeys || discriminatorHolder.trySkip(key)) {
lexer.skipElement(configuration.isLenient)
} else {
// Here we cannot properly update json path indicies
Expand Down
Expand Up @@ -283,6 +283,8 @@ internal abstract class AbstractJsonLexer {
return current
}

abstract fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String?

fun peekString(isLenient: Boolean): String? {
val token = peekNextToken()
val string = if (isLenient) {
Expand Down
Expand Up @@ -78,10 +78,10 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer(

override fun consumeKeyString(): String {
/*
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
consumeNextToken(STRING)
val current = currentPosition
val closingQuote = source.indexOf('"', current)
Expand All @@ -96,4 +96,22 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer(
this.currentPosition = closingQuote + 1
return source.substring(current, closingQuote)
}

override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? {
val positionSnapshot = currentPosition
try {
// Malformed JSON, bailout
if (consumeNextToken() != TC_BEGIN_OBJ) return null
val firstKey = if (isLenient) consumeKeyString() else consumeStringLenientNotNull()
if (firstKey == keyToMatch) {
if (consumeNextToken() != TC_COLON) return null
val result = if (isLenient) consumeString() else consumeStringLenientNotNull()
return result
}
return null
} finally {
// Restore the position
currentPosition = positionSnapshot
}
}
}
@@ -0,0 +1,35 @@
/*
* Copyright 2017-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.serialization.features

import kotlinx.serialization.*
import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlin.test.*

class DefaultPolymorphicSerializerTest : JsonTestBase() {

@Serializable
abstract class Project {
abstract val name: String
}

@Serializable
data class DefaultProject(override val name: String, val type: String): Project()

val module = SerializersModule {
polymorphic(Project::class) {
defaultDeserializer { DefaultProject.serializer() }
}
}

private val json = Json { serializersModule = module }

@Test
fun test() = parametrizedTest {
assertEquals(DefaultProject("example", "unknown"),
json.decodeFromString<Project>(""" {"type":"unknown","name":"example"}""", it))
}

}
Expand Up @@ -67,7 +67,7 @@ abstract class JsonTestBase {
}
JsonTestingMode.TREE -> {
val lexer = StringJsonLexer(source)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null)
val tree = input.decodeJsonElement()
lexer.expectEof()
readJson(tree, deserializer)
Expand Down
Expand Up @@ -61,7 +61,7 @@ public fun <T> Json.decodeFromStream(
stream: InputStream
): T {
val lexer = ReaderJsonLexer(stream)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor)
val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null)
val result = input.decodeSerializableValue(deserializer)
lexer.expectEof()
return result
Expand Down
Expand Up @@ -56,7 +56,7 @@ private class JsonIteratorWsSeparated<T>(
private val deserializer: DeserializationStrategy<T>
) : Iterator<T> {
override fun next(): T =
StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor)
StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null)
.decodeSerializableValue(deserializer)

override fun hasNext(): Boolean = lexer.isNotEof()
Expand All @@ -75,7 +75,7 @@ private class JsonIteratorArrayWrapped<T>(
} else {
lexer.consumeNextToken(COMMA)
}
val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor)
val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null)
return input.decodeSerializableValue(deserializer)
}

Expand Down
Expand Up @@ -133,10 +133,10 @@ internal class ReaderJsonLexer(

override fun consumeKeyString(): String {
/*
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
* For strings we assume that escaped symbols are rather an exception, so firstly
* we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf',
* than do our pessimistic check for backslash and fallback to slow-path if necessary.
*/
consumeNextToken(STRING)
var current = currentPosition
val closingQuote = indexOf('"', current)
Expand Down Expand Up @@ -174,4 +174,7 @@ internal class ReaderJsonLexer(
override fun appendRange(fromIndex: Int, toIndex: Int) {
escapedString.append(_source, fromIndex, toIndex - fromIndex)
}

// Can be carefully implemented but postponed for now
override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? = null
}

0 comments on commit 93a06df

Please sign in to comment.