Skip to content

Commit

Permalink
Fix Scala case class with fields in body serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Elmacioro committed Apr 4, 2024
1 parent 19eb836 commit d649645
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 6 deletions.
1 change: 1 addition & 0 deletions build.sbt
Expand Up @@ -1155,6 +1155,7 @@ lazy val flinkScalaUtils = (project in flink("scala-utils"))
) ++ flinkLibScalaDeps(scalaVersion.value, Some("provided"))
}
)
.dependsOn(testUtils % Test)

lazy val flinkTestUtils = (project in flink("test-utils"))
.settings(commonSettings)
Expand Down
1 change: 1 addition & 0 deletions docs/Changelog.md
Expand Up @@ -7,6 +7,7 @@
* openapi-circe-yaml: 0.6.0 -> 0.7.4
* [#5438](https://github.com/TouK/nussknacker/pull/5438) [#5495](https://github.com/TouK/nussknacker/pull/5495) Improvement in DeploymentManager API:
* Alignment in the api of primary (deploy/cancel) actions and the experimental api of custom actions.
* [#5780](https://github.com/TouK/nussknacker/pull/5780) Fixed Scala case classes serialization when a class has additional fields in its body

1.14.0 (21 Mar 2024)
-------------------------
Expand Down
Expand Up @@ -8,6 +8,7 @@ import org.apache.flink.api.java.typeutils.runtime.NullableSerializer

import java.lang.reflect.Type
import scala.reflect._
import scala.reflect.runtime.universe._

// Generic class factory for creating CaseClassTypeInfo
abstract class CaseClassTypeInfoFactory[T <: Product: ClassTag] extends TypeInfoFactory[T] with Serializable {
Expand All @@ -16,18 +17,33 @@ abstract class CaseClassTypeInfoFactory[T <: Product: ClassTag] extends TypeInfo
t: Type,
genericParameters: java.util.Map[String, TypeInformation[_]]
): TypeInformation[T] = {
val tClass = classTag[T].runtimeClass.asInstanceOf[Class[T]]
val fieldNames = tClass.getDeclaredFields.map(_.getName).toList
val fieldTypes = tClass.getDeclaredFields.map(_.getType).map(TypeExtractor.getForClass(_))

new CaseClassTypeInfo[T](tClass, Array.empty, fieldTypes.toIndexedSeq, fieldNames) {
val runtimeClassType = classTag[T].runtimeClass
val (fieldNames, fieldTypes) = getClassFieldsInfo(runtimeClassType)
val classType = runtimeClassType.asInstanceOf[Class[T]]
new CaseClassTypeInfo[T](classType, Array.empty, fieldTypes.toIndexedSeq, fieldNames) {
override def createSerializer(config: ExecutionConfig): TypeSerializer[T] = {
new ScalaCaseClassSerializer[T](
tClass,
classType,
fieldTypes.map(typeInfo => NullableSerializer.wrap(typeInfo.createSerializer(config), true)).toArray
)
}
}
}

private def getClassFieldsInfo(runtimeClassType: Class[_]): (List[String], List[TypeInformation[_]]) = {
val mirror = runtimeMirror(getClass.getClassLoader)
val fields = mirror
.classSymbol(runtimeClassType)
.primaryConstructor
.asMethod
.paramLists
.head
val fieldNames = fields.map(_.name.decodedName.toString)
val fieldTypes = fields.map { field =>
val fieldClass = mirror.runtimeClass(field.typeSignature)
TypeExtractor.getForClass(fieldClass)
}
(fieldNames, fieldTypes)
}

}
@@ -0,0 +1,96 @@
package pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass

import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.typeinfo.TypeInfo
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.core.memory.{DataInputViewStreamWrapper, DataOutputViewStreamWrapper}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.must.Matchers
import pl.touk.nussknacker.test.ProcessUtils.convertToAnyShouldWrapper

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.reflect.{ClassTag, classTag}

class CaseClassSerializationTest extends AnyFunSuite with Matchers {

private val executionConfig = new ExecutionConfig()

private val bufferSize = 1024

test("Simple case class should be the same after serialization and deserialization") {
val input = SimpleCaseClass("value")

val serializer = getSerializer[SimpleCaseClass]
val deserialized = serializeAndDeserialize(serializer, input)

serializer shouldBe a[ScalaCaseClassSerializer[_]]
deserialized shouldEqual input
}

test("Case class with field in body should be the same after serialization and deserialization") {
val input = SimpleCaseClassWithAdditionalField("value")

val serializer = getSerializer[SimpleCaseClassWithAdditionalField]
val deserialized = serializeAndDeserialize(serializer, input)

serializer shouldBe a[ScalaCaseClassSerializer[_]]
deserialized shouldEqual input
}

test("Case class with secondary constructor should be the same after serialization and deserialization") {
val input = new SimpleCaseClassWithMultipleConstructors(2, "value")

val serializer = getSerializer[SimpleCaseClassWithMultipleConstructors]
val deserialized = serializeAndDeserialize(serializer, input)

serializer shouldBe a[ScalaCaseClassSerializer[_]]
deserialized shouldEqual input
}

private def getSerializer[T: ClassTag] =
TypeExtractor.getForClass(classTag[T].runtimeClass.asInstanceOf[Class[T]]).createSerializer(executionConfig)

private def serializeAndDeserialize[T](serializer: TypeSerializer[T], in: T): T = {
val outStream = new ByteArrayOutputStream(bufferSize)
val outWrapper = new DataOutputViewStreamWrapper(outStream)

serializer.serialize(in, outWrapper)
serializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(outStream.toByteArray)))
}

}

@TypeInfo(classOf[SimpleCaseClass.TypeInfoFactory])
final case class SimpleCaseClass(constructorField: String)

object SimpleCaseClass {
class TypeInfoFactory extends CaseClassTypeInfoFactory[SimpleCaseClass]
}

@TypeInfo(classOf[SimpleCaseClassWithAdditionalField.TypeInfoFactory])
final case class SimpleCaseClassWithAdditionalField(constructorField: String) {
val bodyField: String = "body " + constructorField
}

object SimpleCaseClassWithAdditionalField {
class TypeInfoFactory extends CaseClassTypeInfoFactory[SimpleCaseClassWithAdditionalField]
}

@TypeInfo(classOf[SimpleCaseClassWithMultipleConstructors.TypeInfoFactory])
final case class SimpleCaseClassWithMultipleConstructors(firstField: String, secondField: Double) {
val bodyField: String = "body " + firstField

def this(someField: Int, someSecondField: String) = {
this(someSecondField, someField)
}

def this(someField: Int, someSecondField: String, toIgnore: String) = {
this(someSecondField, someField)
}

}

object SimpleCaseClassWithMultipleConstructors {
class TypeInfoFactory extends CaseClassTypeInfoFactory[SimpleCaseClassWithMultipleConstructors]
}

0 comments on commit d649645

Please sign in to comment.