From da3893c76be66489e69f60cf03344204bfbae1d6 Mon Sep 17 00:00:00 2001 From: Mateusz Kubuszok Date: Tue, 4 Oct 2022 22:24:37 +0200 Subject: [PATCH] Initial implementation of addFallbackValue config --- .../PlatformProductCaseGeneration.scala | 68 +++++++++++++------ .../PlatformProductCaseGeneration.scala | 62 +++++++++++++---- .../scala/pipez/internal/Definitions.scala | 4 ++ .../internal/ProductCaseGeneration.scala | 4 +- 4 files changed, 102 insertions(+), 36 deletions(-) diff --git a/pipez/src/main/scala-2/pipez/internal/PlatformProductCaseGeneration.scala b/pipez/src/main/scala-2/pipez/internal/PlatformProductCaseGeneration.scala index 979b137..a92e54d 100644 --- a/pipez/src/main/scala-2/pipez/internal/PlatformProductCaseGeneration.scala +++ b/pipez/src/main/scala-2/pipez/internal/PlatformProductCaseGeneration.scala @@ -1,8 +1,10 @@ package pipez.internal import pipez.internal.Definitions.{ Context, Result } +import pipez.internal.ProductCaseGeneration.inputNameMatchesOutputName import scala.annotation.nowarn +import scala.collection.AnyStepper import scala.collection.immutable.ListMap import scala.util.chaining.* import scala.language.existentials @@ -55,28 +57,45 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] private def isVar(setter: Symbol): Boolean = setter.isTerm && setter.asTerm.name.toString.endsWith("_$eq") && setter.isPublic - final def extractProductInData(settings: Settings): DerivationResult[ProductInData] = - In.decls + final private case class Getter[Extracted, ExtractedField]( + tpe: Type[ExtractedField], + get: Expr[Extracted] => Expr[ExtractedField] + ) + private def extractGetters[Extracted: Type]: ListMap[String, Getter[Extracted, Any]] = + typeOf[Extracted].decls .to(List) .filterNot(isGarbage) .filter(m => isCaseClassField(m) || isJavaGetter(m)) .map { getter => val name = getter.name.toString val termName = getter.asMethod.name.toTermName - name -> ProductInData.Getter[Any]( - name = name, - tpe = returnTypeOf(In, getter).asInstanceOf[Type[Any]], + name -> Getter[Extracted, Any]( + tpe = returnTypeOf(typeOf[Extracted], getter).asInstanceOf[Type[Any]], get = - if (getter.asMethod.paramLists.isEmpty) (in: Expr[In]) => c.Expr[Any](q"$in.$termName") - else (in: Expr[In]) => c.Expr[Any](q"$in.$termName()"), - path = Path.Field(Path.Root, name) + if (getter.asMethod.paramLists.isEmpty) (in: Expr[Extracted]) => c.Expr[Any](q"$in.$termName") + else (in: Expr[Extracted]) => c.Expr[Any](q"$in.$termName()") ) } .to(ListMap) + + final def extractProductInData(settings: Settings): DerivationResult[ProductInData] = + extractGetters[In] + .map { case (name, Getter(tpe, get)) => + name -> ProductInData.Getter(name, tpe, get, Path.Field(Path.Root, name)) + } .pipe(ProductInData(_)) .pipe(DerivationResult.pure(_)) .logSuccess(data => s"Resolved input: $data") + private def fallbackValueGetters(settings: Settings): Map[String, FieldFallback[?]] = + settings.fallbackValues + .collect { case ConfigEntry.AddFallbackValue(fallbackType, fallbackValue) => + extractGetters(fallbackType).map { case (name, Getter(fallbackFieldType, extractFallbackValue)) => + name -> FieldFallback.Value(fallbackFieldType, extractFallbackValue(fallbackValue)) + } + } + .fold(Map.empty[String, FieldFallback[?]])((left, right) => right ++ left) // preserve first rather than last + final def extractProductOutData(settings: Settings): DerivationResult[ProductOutData] = if (isJavaBean[Out]) { // Java Bean case @@ -85,6 +104,8 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] Out.decls.filterNot(isGarbage).find(isDefaultConstructor).map(_ => c.Expr[Out](q"new $Out()")) )(DerivationError.MissingPublicConstructor) + val fallbackValues = fallbackValueGetters(settings) + val setters = Out.decls .to(List) .filterNot(isGarbage) @@ -100,7 +121,13 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] name = name, tpe = paramListsOf(Out, setter).flatten.head.typeSignature.asInstanceOf[Type[Any]], set = (out: Expr[Out], value: Expr[Any]) => c.Expr[Unit](q"$out.$termName($value)"), - fallback = FieldFallback.Unavailable // TODO: .addFallbackValue + fallback = fallbackValues + .collectFirst { + case (fallbackName, fallback) + if inputNameMatchesOutputName(name, fallbackName, settings.isFieldCaseInsensitive) => + fallback + } + .getOrElse(FieldFallback.Unavailable) ) } .to(ListMap) @@ -127,15 +154,12 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] Out.decls.to(List).filterNot(isGarbage).find(m => m.isPublic && m.isConstructor) )(DerivationError.MissingPublicConstructor) // default value for case class field n (1 indexed) is obtained from Companion.apply$default$n - fallbacks = primaryConstructor.typeSignature.paramLists.headOption.toList.flatten.zipWithIndex - .collect { - case (param, idx) if param.asTerm.isParamWithDefault => - param.name.toString -> FieldFallback.Default( - c.Expr[Any](q"${Out.typeSymbol.companion}.${TermName("apply$default$" + (idx + 1))}") - ) - } - .toMap - .withDefaultValue(FieldFallback.Unavailable) + fallbackValues = primaryConstructor.typeSignature.paramLists.headOption.toList.flatten.zipWithIndex.collect { + case (param, idx) if param.asTerm.isParamWithDefault => + param.name.toString -> FieldFallback.Default( + c.Expr[Any](q"${Out.typeSymbol.companion}.${TermName("apply$default$" + (idx + 1))}") + ) + }.toMap ++ fallbackValueGetters(settings) // we want defaults to be overridden by provided values } yield ProductOutData.CaseClass( caller = params => c.Expr(q"new $Out(...$params)"), params = paramListsOf(Out, primaryConstructor).map { params => @@ -145,7 +169,13 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] name -> ProductOutData.ConstructorParam( name = name, tpe = param.typeSignature.asInstanceOf[Type[Any]], - fallback = fallbacks(name) // TODO: .addFallbackValue + fallback = fallbackValues + .collectFirst { + case (fallbackName, fallback) + if inputNameMatchesOutputName(name, fallbackName, settings.isFieldCaseInsensitive) => + fallback + } + .getOrElse(FieldFallback.Unavailable) ) } .to(ListMap) diff --git a/pipez/src/main/scala-3/pipez/internal/PlatformProductCaseGeneration.scala b/pipez/src/main/scala-3/pipez/internal/PlatformProductCaseGeneration.scala index f1050b8..8e9bf57 100644 --- a/pipez/src/main/scala-3/pipez/internal/PlatformProductCaseGeneration.scala +++ b/pipez/src/main/scala-3/pipez/internal/PlatformProductCaseGeneration.scala @@ -1,6 +1,7 @@ package pipez.internal import pipez.internal.Definitions.{ Context, Result } +import pipez.internal.ProductCaseGeneration.inputNameMatchesOutputName import scala.collection.immutable.ListMap import scala.util.chaining.* @@ -47,27 +48,44 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] private def isVar(setter: Symbol): Boolean = setter.isValDef && setter.flags.is(Flags.Mutable) - final def extractProductInData(settings: Settings): DerivationResult[ProductInData] = { - val sym = TypeRepr.of[In].typeSymbol + final private case class Getter[Extracted, ExtractedField]( + tpe: Type[ExtractedField], + get: Expr[Extracted] => Expr[ExtractedField] + ) + private def extractGetters[Extracted: Type]: ListMap[String, Getter[Extracted, Any]] = { + val sym = TypeRepr.of[Extracted].typeSymbol // apparently each case field is duplicated: "a" and "a ", "_1" and "_1" o_0 - the first is method, the other val // the exceptions are cases in Scala 3 enum: they only have vals (sym.caseFields.filter(if (sym.flags.is(Flags.Enum)) _.isValDef else _.isDefDef) ++ sym.declaredMethods.filter( isJavaGetter )).map { method => val name = method.name - name -> ProductInData.Getter[Any]( - name = name, + name -> Getter[Extracted, Any]( tpe = returnType[Any](TypeRepr.of[In].memberType(method)), get = - if (method.paramSymss.isEmpty) (in: Expr[In]) => in.asTerm.select(method).appliedToArgss(Nil).asExpr - else (in: Expr[In]) => in.asTerm.select(method).appliedToNone.asExpr, - path = Path.Field(Path.Root, name) + if (method.paramSymss.isEmpty) (in: Expr[Extracted]) => in.asTerm.select(method).appliedToArgss(Nil).asExpr + else (in: Expr[Extracted]) => in.asTerm.select(method).appliedToNone.asExpr ) }.to(ListMap) + } + + final def extractProductInData(settings: Settings): DerivationResult[ProductInData] = + extractGetters[In] + .map { case (name, Getter(tpe, get)) => + name -> ProductInData.Getter(name, tpe, get, Path.Field(Path.Root, name)) + } .pipe(ProductInData(_)) .pipe(DerivationResult.pure) .logSuccess(data => s"Resolved input: $data") - } + + private def fallbackValueGetters(settings: Settings): Map[String, FieldFallback[?]] = + settings.fallbackValues + .collect { case ConfigEntry.AddFallbackValue(fallbackType, fallbackValue) => + extractGetters(fallbackType).map { case (name, Getter(fallbackFieldType, extractFallbackValue)) => + name -> FieldFallback.Value(fallbackFieldType, extractFallbackValue(fallbackValue)) + } + } + .fold(Map.empty[String, FieldFallback[?]])((left, right) => right ++ left) // preserve first rather than last final def extractProductOutData(settings: Settings): DerivationResult[ProductOutData] = if (isJavaBean[Out]) { @@ -106,6 +124,8 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] } )(DerivationError.MissingPublicConstructor) + val fallbackValues = fallbackValueGetters(settings) + val setters = sym.declaredMethods .filter(s => isJavaSetter(s) || isVar(s)) .map { setter => @@ -118,7 +138,13 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] }, set = (out: Expr[Out], value: Expr[Any]) => out.asTerm.select(setter).appliedTo(value.asTerm).asExpr.asExprOf[Unit], - fallback = FieldFallback.Unavailable // TODO: .addFallbackValue + fallback = fallbackValues + .collectFirst { + case (fallbackName, fallback) + if inputNameMatchesOutputName(name, fallbackName, settings.isFieldCaseInsensitive) => + fallback + } + .getOrElse(FieldFallback.Unavailable) ) } .to(ListMap) @@ -155,7 +181,7 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] pair <- resolveTypeArgsForMethodArguments(TypeRepr.of[Out], primaryConstructor) (typeByName, typeParams) = pair // default value for case class field n (1 indexed) is obtained from Companion.apply$default$n - fallbacks = primaryConstructor.paramSymss + fallbackValues = primaryConstructor.paramSymss .pipe(if (typeParams.nonEmpty) ps => ps.tail else ps => ps) .headOption .toList @@ -167,8 +193,7 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] val sym = mod.declaredMethod("apply$default$" + (idx + 1)).head param.name -> FieldFallback.Default(Ref(mod).select(sym).asExpr.asInstanceOf[Expr[Any]]) } - .toMap - .withDefaultValue(FieldFallback.Unavailable) + .toMap ++ fallbackValueGetters(settings) // we want defaults to be overridden by provided values } yield ProductOutData.CaseClass( params => New(TypeTree.of[Out]) @@ -180,10 +205,17 @@ private[internal] trait PlatformProductCaseGeneration[Pipe[_, _], In, Out] primaryConstructor.paramSymss.pipe(if (typeParams.nonEmpty) ps => ps.tail else ps => ps).map { params => params .map { param => + val name = param.name param.name -> ProductOutData.ConstructorParam( - name = param.name, - tpe = typeByName(param.name).asType.asInstanceOf[Type[Any]], - fallback = fallbacks(param.name) // TODO: .addFallbackValue + name = name, + tpe = typeByName(name).asType.asInstanceOf[Type[Any]], + fallback = fallbackValues + .collectFirst { + case (fallbackName, fallback) + if inputNameMatchesOutputName(name, fallbackName, settings.isFieldCaseInsensitive) => + fallback + } + .getOrElse(FieldFallback.Unavailable) ) } .to(ListMap) diff --git a/pipez/src/main/scala/pipez/internal/Definitions.scala b/pipez/src/main/scala/pipez/internal/Definitions.scala index 2d7fbd3..31bbfb8 100644 --- a/pipez/src/main/scala/pipez/internal/Definitions.scala +++ b/pipez/src/main/scala/pipez/internal/Definitions.scala @@ -147,6 +147,10 @@ private[internal] trait Definitions[Pipe[_, _], In, Out] { self => lazy val isEnumCaseInsensitive: Boolean = entries.contains(EnumCaseInsensitive) + lazy val fallbackValues: List[ConfigEntry.AddFallbackValue[?]] = entries.collect { + case config @ ConfigEntry.AddFallbackValue(_, _) => config + } + lazy val isFallbackToDefaultEnabled: Boolean = entries.contains(EnableFallbackToDefaults) lazy val isRecursiveDerivationEnabled: Boolean = entries.contains(EnableRecursiveDerivation) diff --git a/pipez/src/main/scala/pipez/internal/ProductCaseGeneration.scala b/pipez/src/main/scala/pipez/internal/ProductCaseGeneration.scala index 1922909..1f78bdb 100644 --- a/pipez/src/main/scala/pipez/internal/ProductCaseGeneration.scala +++ b/pipez/src/main/scala/pipez/internal/ProductCaseGeneration.scala @@ -39,7 +39,7 @@ private[internal] trait ProductCaseGeneration[Pipe[_, _], In, Out] { sealed trait FieldFallback[+OutField] extends Product with Serializable object FieldFallback { - final case class Value[OutField, Value](expr: Expr[Value], tpe: Type[Value]) extends FieldFallback[OutField] + final case class Value[OutField, Value](tpe: Type[Value], expr: Expr[Value]) extends FieldFallback[OutField] final case class Default[OutField](expr: Expr[OutField]) extends FieldFallback[OutField] case object Unavailable extends FieldFallback[Nothing] } @@ -525,7 +525,7 @@ object ProductCaseGeneration { } val isSetterName: String => Boolean = name => setAccessor.matches(name) - private def inputNameMatchesOutputName( + def inputNameMatchesOutputName( inFieldName: String, outFieldName: String, caseInsensitive: Boolean