Skip to content

Commit

Permalink
#47 Support configurable discriminator for sum ADTs (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakipatryk committed Apr 10, 2024
1 parent c8ce0cb commit c4cc3e4
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 6 deletions.
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This library aims to avoid pollution of the model by custom annotations and depe
- support for basic Scala collections (`Map`, `Seq`, `Set`, `Array`) as types of `case class` parameters
- only top-level case classes need to be registered, child case classes are then recursively registered
- support for Scala `Enumeration` where simple `Value` constructor is used (without `name`)
- support for sum ADTs (`sealed trait` and `sealed abstract class`)
- support for sum ADTs (`sealed trait` and `sealed abstract class`) with optional discriminator

## Usage

Expand Down Expand Up @@ -215,6 +215,37 @@ Then, in `handleFn`, the handler creates a `Schema` object for `CustomClass`,
adds it to `Components` so that it can be referenced by name `CustomClass`,
and returns reference to that object.

### Registration configuration
It is possible to further customize registration by providing custom `RegistrationConfig` to `OpenAPIModelRegistration`.

#### Example
```scala
val components = ...
val registration = OpenAPIModelRegistration(
components,
config = RegistrationConfig(
OpenAPIModelRegistration.RegistrationConfig(
sumADTsShape =
// default values apply for discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren
OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator()
)
)
)
```

#### sumADTsShape
This config property sets how sum ADTs are registered. It has two possible values:
- `RegistrationConfig.SumADTsShape.WithoutDiscriminator` - default option, doesn't add discriminators
- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren)` -
adds discriminator to sealed types schema,
and also adds discriminator to sum ADTs elements properties; discriminator property name is customizable by `discriminatorPropertyNameFn`,
by default it takes sealed type name, converts its first letter to lower case, and adds `"Type"` suffix,
for example if sealed type name is `Expression`, the property name is `expressionType`;
if `addDiscriminatorPropertyOnlyToDirectChildren` is `false`, discriminator property is added to all children,
so for example in `ADT = A | B | C; B = D | E` discriminator of `ADT` would be added to `A`, `C`, `D`, `E`
(`D` and `E` would have discriminator of `B` in addition to that)
while with `addDiscriminatorPropertyOnlyToDirectChildren` set to `true` (default)
it would be added only to `A` and `C`

## Examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
package za.co.absa.springdocopenapiscala

import io.swagger.v3.oas.models.Components
import io.swagger.v3.oas.models.media.Schema
import io.swagger.v3.oas.models.media.{Discriminator, Schema}

import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime}
import java.util.UUID
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe._

import OpenAPIModelRegistration._

class OpenAPIModelRegistration(
components: Components,
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling,
config: RegistrationConfig = RegistrationConfig()
) {

/**
Expand Down Expand Up @@ -144,13 +144,60 @@ class OpenAPIModelRegistration(
s.isTerm && s.asTerm.isVal && s.typeSignature <:< typeOf[Enumeration#Value]

private def handleSealedType(tpe: Type): Schema[_] = {

def addDiscriminatorPropertyToChildren(
currentSchema: Schema[_],
discriminatorPropertyName: String,
addOnlyToDirectChildren: Boolean,
discriminatorValue: Option[String] = None
): Unit = {
val children = currentSchema.getOneOf.asScala
children.foreach { s =>
val ref = s.get$ref
val name = extractSchemaNameFromRef(ref)
val actualSchema = components.getSchemas.get(name)
if (actualSchema.getType == "object") {
val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name))
actualSchema.addProperty(discriminatorPropertyName, constEnumSchema)
actualSchema.addRequiredItem(discriminatorPropertyName)
} else if (
!addOnlyToDirectChildren &&
Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root
) {
addDiscriminatorPropertyToChildren(
actualSchema,
discriminatorPropertyName,
addOnlyToDirectChildren,
Some(name)
)
}
}
}

val classSymbol = tpe.typeSymbol.asClass
val name = tpe.typeSymbol.name.toString.trim
val children = classSymbol.knownDirectSubclasses
val childrenSchemas = children.map(_.asType.toType).map(handleType)
val schema = new Schema
schema.setOneOf(childrenSchemas.toList.asJava)

config.sumADTsShape match {
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
schema.setDiscriminator {
val discriminator = new Discriminator
discriminator.setPropertyName(discriminatorPropertyName)
discriminator
}
addDiscriminatorPropertyToChildren(
schema,
discriminatorPropertyName,
addOnlyToDirectChildren
)

case _ => ()
}

registerAsReference(name, schema)
}

Expand Down Expand Up @@ -187,10 +234,54 @@ class OpenAPIModelRegistration(
schemaReference
}

private def createConstEnumSchema(const: String): Schema[_] = {
val constEnumSchema = new Schema[String]
constEnumSchema.setType("string")
constEnumSchema.setEnum(Seq(const).asJava)
constEnumSchema
}

private def extractSchemaNameFromRef(ref: String): String = {
ref.substring(ref.lastIndexOf("/") + 1)
}

}

object OpenAPIModelRegistration {

/**
* Configuration of the registration class.
*
* @param sumADTsShape how sum ADTs should be registered (with or without discriminator)
*/
case class RegistrationConfig(
sumADTsShape: RegistrationConfig.SumADTsShape = RegistrationConfig.SumADTsShape.WithoutDiscriminator
)

object RegistrationConfig {

sealed abstract class SumADTsShape

object SumADTsShape {
case object WithoutDiscriminator extends SumADTsShape
case class WithDiscriminator(
discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn =
WithDiscriminator.defaultDiscriminatorPropertyNameFn,
addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true
) extends SumADTsShape

object WithDiscriminator {

/** Function from sealed type name to discriminator property name. */
type DiscriminatorPropertyNameFn = String => String

val defaultDiscriminatorPropertyNameFn: DiscriminatorPropertyNameFn = sealedTypeName =>
sealedTypeName.head.toLower + sealedTypeName.tail + "Type"
}
}

}

/**
* Context of model registration.
* Currently contains only `Components` that can be mutated if needed
Expand Down

0 comments on commit c4cc3e4

Please sign in to comment.