Skip to content

Commit

Permalink
fixes #365 CoroutineCollection#save doesn't use specified Serializer …
Browse files Browse the repository at this point in the history
…(kotlinx.serialization)
  • Loading branch information
zigzago committed Sep 10, 2022
1 parent a36e506 commit d48c933
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 31 deletions.
Expand Up @@ -29,9 +29,7 @@ import org.bson.codecs.pojo.annotations.BsonProperty
import org.bson.json.JsonMode
import org.bson.json.JsonWriter
import org.bson.json.JsonWriterSettings
import org.litote.kmongo.service.ClassMappingType
import org.litote.kmongo.service.ClassMappingTypeService
import org.litote.kmongo.util.ObjectMappingConfiguration
import java.io.StringWriter
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
Expand Down Expand Up @@ -121,11 +119,11 @@ internal class PojoClassMappingTypeService : ClassMappingTypeService {
}

override fun coreCodecRegistry(baseCodecRegistry: CodecRegistry): CodecRegistry {
internalCodecRegistry = ClassMappingType.codecRegistry(
internalCodecRegistry = codecRegistryWithCustomCodecs(
baseCodecRegistry,
codecRegistry
)
internalNullCodecRegistry = ClassMappingType.codecRegistry(
internalNullCodecRegistry = codecRegistryWithCustomCodecs(
baseCodecRegistry,
codecRegistryWithNullSerialization
)
Expand All @@ -134,11 +132,11 @@ internal class PojoClassMappingTypeService : ClassMappingTypeService {

override fun <T> calculatePath(property: KProperty<T>): String {
val owner = property.javaField?.declaringClass
?: try {
property.javaGetter?.declaringClass
} catch (e: Exception) {
null
}
?: try {
property.javaGetter?.declaringClass
} catch (e: Exception) {
null
}

return if (owner?.kotlin?.let { findIdProperty(it) }?.name == property.name)
"_id"
Expand Down
Expand Up @@ -127,6 +127,9 @@ internal object KMongoSerializationRepository {
UUID::class to UUIDSerializer
)

private fun getCustomSerializer(kClass: KClass<*>): KSerializer<*>? =
customSerializersMap[kClass] ?: serializersMap[kClass]

@ExperimentalSerializationApi
@InternalSerializationApi
private fun <T : Any> getBaseSerializer(
Expand All @@ -142,12 +145,14 @@ internal object KMongoSerializationRepository {
getSerializer(obj.second),
getSerializer(obj.third)
)

is Array<*> -> ArraySerializer(
kClass as KClass<Any>,
obj.filterNotNull().let {
if (it.isEmpty()) String.serializer() else getSerializer(it.first())
} as KSerializer<Any>
)

else -> module.getContextual(kClass)
?: findPolymorphic(kClass, obj)?.let {
PolymorphicSerializer(it)
Expand Down Expand Up @@ -179,7 +184,7 @@ internal object KMongoSerializationRepository {
if (obj == null) {
error("no serializer for null")
} else {
(serializersMap[kClass]
(getCustomSerializer(kClass)
?: getBaseSerializer(obj, kClass)
?: kClass.serializer()) as? KSerializer<T>
?: error("no serializer for $obj of class $kClass")
Expand All @@ -192,7 +197,7 @@ internal object KMongoSerializationRepository {
if (obj == null) {
error("no serializer for null")
} else {
(serializersMap[obj.javaClass.kotlin]
(getCustomSerializer(obj.javaClass.kotlin)
?: getBaseSerializer(obj)
?: obj.javaClass.kotlin.serializer()) as? KSerializer<T>
?: error("no serializer for $obj of class ${obj.javaClass.kotlin}")
Expand All @@ -202,7 +207,7 @@ internal object KMongoSerializationRepository {
@InternalSerializationApi
@Suppress("UNCHECKED_CAST")
fun <T : Any> getSerializer(kClass: KClass<T>): KSerializer<T> =
(serializersMap[kClass]
(getCustomSerializer(kClass)
?: module.getContextual(kClass)
?: try {
kClass.serializer()
Expand Down
Expand Up @@ -35,6 +35,14 @@ import kotlin.reflect.KProperty1
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.hasAnnotation

private class BaseRegistryWithoutCustomSerializers(private val codecRegistry: CodecRegistry) : CodecRegistry {
override fun <T : Any> get(clazz: Class<T>): Codec<T>? =
if (customSerializersMap.containsKey(clazz.kotlin)) null else codecRegistry.get(clazz)

override fun <T : Any> get(clazz: Class<T>, registry: CodecRegistry): Codec<T>? =
if (customSerializersMap.containsKey(clazz.kotlin)) null else codecRegistry.get(clazz, registry)
}

/**
* kotlinx serialization ClassMapping.
*/
Expand Down Expand Up @@ -91,30 +99,32 @@ class SerializationClassMappingTypeService : ClassMappingTypeService {
idController.getIdValue(idProperty, instance)

override fun coreCodecRegistry(baseCodecRegistry: CodecRegistry): CodecRegistry {

val withNonEncodeNull = SerializationCodecRegistry(configuration.copy(nonEncodeNull = true))
codecRegistryWithNonEncodeNull =
codecRegistry(
baseCodecRegistry,
SerializationCodecRegistry(configuration.copy(nonEncodeNull = true))
codecRegistryWithCustomCodecs(
filterBaseCodecRegistry(baseCodecRegistry),
withNonEncodeNull
)
codecRegistryWithEncodeNull = codecRegistry(
baseCodecRegistry,
SerializationCodecRegistry(configuration.copy(nonEncodeNull = false))
val withEncodeNull = SerializationCodecRegistry(configuration.copy(nonEncodeNull = false))
codecRegistryWithEncodeNull = codecRegistryWithCustomCodecs(
filterBaseCodecRegistry(baseCodecRegistry),
withEncodeNull
)
return object : CodecRegistry {
override fun <T : Any?> get(clazz: Class<T>): Codec<T>? =
if (ObjectMappingConfiguration.serializeNull)
codecRegistryWithEncodeNull.get(clazz)
else codecRegistryWithNonEncodeNull.get(clazz)
override fun <T : Any> get(clazz: Class<T>): Codec<T> =
if (ObjectMappingConfiguration.serializeNull) withEncodeNull.get(clazz)
else withNonEncodeNull.get(clazz)


override fun <T : Any?> get(clazz: Class<T>, registry: CodecRegistry): Codec<T>? =
if (ObjectMappingConfiguration.serializeNull)
codecRegistryWithEncodeNull.get(clazz, registry)
else codecRegistryWithNonEncodeNull.get(clazz, registry)
override fun <T : Any> get(clazz: Class<T>, registry: CodecRegistry): Codec<T> =
if (ObjectMappingConfiguration.serializeNull) withEncodeNull.get(clazz, registry)
else withNonEncodeNull.get(clazz, registry)
}
}

override fun filterBaseCodecRegistry(baseCodecRegistry: CodecRegistry): CodecRegistry =
BaseRegistryWithoutCustomSerializers(baseCodecRegistry)

override fun <T> calculatePath(property: KProperty<T>): String =
property.findAnnotation<SerialName>()?.value
?: (if (property.hasAnnotation<MongoId>()) "_id" else property.findAnnotation<MongoProperty>()?.value)
Expand Down
92 changes: 92 additions & 0 deletions kmongo-serialization/src/test/kotlin/Issue365UUIDSerializer.kt
@@ -0,0 +1,92 @@
/*
* Copyright (C) 2016/2021 Litote
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.litote.kmongo.issues

import kotlinx.serialization.Contextual
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import org.junit.Test
import org.litote.kmongo.AllCategoriesKMongoBaseTest
import org.litote.kmongo.findOne
import org.litote.kmongo.save
import org.litote.kmongo.serialization.registerSerializer
import java.util.UUID
import kotlin.test.assertEquals
import kotlin.test.assertTrue

@Serializable
data class DBPlayer(
@SerialName("_id") @Contextual val uuid: UUID,
)

object UUIDSerializer : KSerializer<UUID> {
override val descriptor: SerialDescriptor =
PrimitiveSerialDescriptor("UUID", PrimitiveKind.STRING)

override fun deserialize(decoder: Decoder): UUID {
deserialized = true
return UUID.fromString(decoder.decodeString())
}

override fun serialize(encoder: Encoder, value: UUID) {
serialized = true
encoder.encodeString(value.toString())
}
}

private var deserialized = false
private var serialized = false

/**
*
*/
class Issue365UUIDSerializer : AllCategoriesKMongoBaseTest<DBPlayer>() {

init {
registerSerializer(UUIDSerializer)
}

@Test
fun testInsertSerialization() {
val test = DBPlayer(UUID.randomUUID())

col.insertOne(test)
val test2 = col.findOne()

assertEquals(test, test2)
assertTrue(deserialized)
assertTrue(serialized)
}

@Test
fun testSaveSerialization() {
val test = DBPlayer(UUID.randomUUID())

col.save(test)
val test2 = col.findOne()

assertEquals(test, test2)
assertTrue(deserialized)
assertTrue(serialized)
}
}
Expand Up @@ -72,17 +72,23 @@ interface ClassMappingTypeService {
fun <T, R> getIdValue(idProperty: KProperty1<T, R>, instance: T): R?

/**
* Returns a codec registry built with [baseCodecRegistry], and [coreCodeRegistry].
* Returns a codec registry built with [baseCodecRegistry].
*/
fun codecRegistry(
baseCodecRegistry: CodecRegistry
): CodecRegistry = codecRegistryWithCustomCodecs(baseCodecRegistry, coreCodecRegistry(baseCodecRegistry))

fun codecRegistryWithCustomCodecs(
baseCodecRegistry: CodecRegistry,
coreCodeRegistry: CodecRegistry = coreCodecRegistry(baseCodecRegistry)
): CodecRegistry = CodecRegistries.fromProviders(
baseCodecRegistry,
coreCodeRegistry: CodecRegistry
): CodecRegistry = CodecRegistries.fromProviders(
filterBaseCodecRegistry(baseCodecRegistry),
CustomCodecProvider,
coreCodeRegistry
)

fun filterBaseCodecRegistry(baseCodecRegistry: CodecRegistry) : CodecRegistry = baseCodecRegistry

fun coreCodecRegistry(baseCodecRegistry: CodecRegistry = KMongoUtil.defaultCodecRegistry): CodecRegistry

fun <T> getPath(property: KProperty<T>): String {
Expand Down
69 changes: 69 additions & 0 deletions kmongo/src/test/kotlin/org/litote/kmongo/issues/TypedTest.kt
@@ -0,0 +1,69 @@
/*
* Copyright (C) 2016/2021 Litote
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.litote.kmongo.issues

import org.bson.codecs.pojo.annotations.BsonId
import org.bson.types.ObjectId
import org.junit.Test
import org.litote.kmongo.AllCategoriesKMongoBaseTest
import org.litote.kmongo.MongoOperator
import org.litote.kmongo.aggregate
import org.litote.kmongo.div
import org.litote.kmongo.eq
import org.litote.kmongo.from
import org.litote.kmongo.match
import org.litote.kmongo.project
import org.litote.kmongo.projection
import org.litote.kmongo.save
import java.time.Instant
import java.time.ZoneOffset
import kotlin.test.assertEquals

data class MyClass(
val items: List<Item>,
val timestamp: Long = Instant.now().atOffset(ZoneOffset.UTC).toEpochSecond(),
@BsonId
val id: String = ObjectId().toString()
)

data class Item(
val name: String,
val price: Double,
val qty: Double
)


data class Result(val sumPrice: Double)

/**
*
*/
class TypedTest :
AllCategoriesKMongoBaseTest<MyClass>() {

@Test
fun `serialization and deserialization is ok`() {
val d = MyClass(listOf(Item("a", 12.9, 1.0)), id = "id")
col.save(d)

val r: Double = col.aggregate<Result>(
match(MyClass::id eq "id"),
project(Result::sumPrice to MongoOperator.sum.from((MyClass::items / Item::price).projection))
).first()?.sumPrice ?: 0.0
assertEquals(12.9, r)
}
}

0 comments on commit d48c933

Please sign in to comment.