Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature to generate GraphQL enum types as sealed classes #2279

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -37,7 +37,8 @@ class GraphQLCompiler {
useSemanticNaming = args.useSemanticNaming,
packageNameProvider = args.packageNameProvider,
generateAsInternal = args.generateAsInternal,
kotlinMultiPlatformProject = args.kotlinMultiPlatformProject
kotlinMultiPlatformProject = args.kotlinMultiPlatformProject,
enumAsSealedClassPatternFilters = args.enumAsSealedClassPatternFilters.map { it.toRegex() }
).write(args.outputDir)
} else {
ir.writeJavaFiles(
Expand Down Expand Up @@ -121,7 +122,8 @@ class GraphQLCompiler {
val generateKotlinModels: Boolean = false,
val operationOutputFile: File? = null,
val generateAsInternal: Boolean = false,

// only if generateKotlinModels = true
val enumAsSealedClassPatternFilters: List<String>,
// only if generateKotlinModels = false
val nullableValueType: NullableValueType,
// only if generateKotlinModels = false
Expand Down
@@ -1,7 +1,11 @@
package com.apollographql.apollo.compiler.ast.builder

import com.apollographql.apollo.compiler.OperationIdGenerator
import com.apollographql.apollo.compiler.ast.*
import com.apollographql.apollo.compiler.ast.CustomTypes
import com.apollographql.apollo.compiler.ast.EnumType
import com.apollographql.apollo.compiler.ast.FieldType
import com.apollographql.apollo.compiler.ast.Schema
import com.apollographql.apollo.compiler.ast.TypeRef
import com.apollographql.apollo.compiler.escapeKotlinReservedWord
import com.apollographql.apollo.compiler.ir.CodeGenerationIR
import com.apollographql.apollo.compiler.ir.ScalarType
Expand Down
Expand Up @@ -4,19 +4,32 @@ import com.apollographql.apollo.compiler.applyIf
import com.apollographql.apollo.compiler.ast.EnumType
import com.squareup.kotlinpoet.*

internal fun EnumType.typeSpec(generateAsInternal: Boolean = false) =
TypeSpec
.enumBuilder(name)
.applyIf(description.isNotBlank()) { addKdoc("%L\n", description) }
.applyIf(generateAsInternal) { addModifiers(KModifier.INTERNAL) }
.primaryConstructor(primaryConstructorSpec)
.addProperty(rawValuePropertySpec)
.apply {
values.forEach { value -> addEnumConstant(value.constName, value.enumConstTypeSpec) }
addEnumConstant("UNKNOWN__", unknownEnumConstTypeSpec)
}
.addType(companionObjectSpec)
.build()
internal fun EnumType.typeSpec(
generateAsInternal: Boolean = false,
enumAsSealedClassPatternFilters: List<Regex>
): TypeSpec {
val asSealedClass = enumAsSealedClassPatternFilters.isNotEmpty() && enumAsSealedClassPatternFilters.any { pattern ->
martinbonnin marked this conversation as resolved.
Show resolved Hide resolved
name.matches(pattern)
}

return if (asSealedClass) toSealedClassTypeSpec(generateAsInternal)
else toEnumTypeSpec(generateAsInternal)
}

private fun EnumType.toEnumTypeSpec(generateAsInternal: Boolean): TypeSpec {
return TypeSpec
.enumBuilder(name)
.applyIf(description.isNotBlank()) { addKdoc("%L\n", description) }
.applyIf(generateAsInternal) { addModifiers(KModifier.INTERNAL) }
.primaryConstructor(primaryConstructorSpec)
.addProperty(rawValuePropertySpec)
.apply {
values.forEach { value -> addEnumConstant(value.constName, value.enumConstTypeSpec) }
addEnumConstant("UNKNOWN__", unknownEnumConstTypeSpec)
}
.addType(enumCompanionObjectSpec)
.build()
}

private val primaryConstructorSpec =
FunSpec
Expand All @@ -40,22 +53,24 @@ private val EnumType.Value.enumConstTypeSpec: TypeSpec
.build()
}

private val unknownEnumConstTypeSpec: TypeSpec =
TypeSpec
private val unknownEnumConstTypeSpec: TypeSpec
get() {
return TypeSpec
.anonymousClassBuilder()
.addKdoc("%L", "Auto generated constant for unknown enum values\n")
.addSuperclassConstructorParameter("%S", "UNKNOWN__")
.build()
}

private val EnumType.companionObjectSpec: TypeSpec
private val EnumType.enumCompanionObjectSpec: TypeSpec
get() {
return TypeSpec
.companionObjectBuilder()
.addFunction(safeValueOfFunSpec)
.addFunction(enumSafeValueOfFunSpec)
.build()
}

private val EnumType.safeValueOfFunSpec: FunSpec
private val EnumType.enumSafeValueOfFunSpec: FunSpec
get() {
return FunSpec
.builder("safeValueOf")
Expand All @@ -64,3 +79,62 @@ private val EnumType.safeValueOfFunSpec: FunSpec
.addStatement("return values().find·{·it.rawValue·==·rawValue·} ?: UNKNOWN__")
.build()
}

private fun EnumType.toSealedClassTypeSpec(generateAsInternal: Boolean): TypeSpec {
return TypeSpec
.classBuilder(name)
.applyIf(description.isNotBlank()) { addKdoc("%L\n", description) }
.applyIf(generateAsInternal) { addModifiers(KModifier.INTERNAL) }
.addModifiers(KModifier.SEALED)
.primaryConstructor(primaryConstructorSpec)
.addProperty(rawValuePropertySpec)
.addTypes(values.map { value -> value.toObjectTypeSpec(ClassName("", name)) })
.addType(unknownValueTypeSpec)
.addType(sealedClassCompanionObjectSpec)
.build()
}

private fun EnumType.Value.toObjectTypeSpec(superClass: TypeName): TypeSpec {
return TypeSpec.objectBuilder(constName)
.applyIf(description.isNotBlank()) { addKdoc("%L\n", description) }
.applyIf(isDeprecated) { addAnnotation(KotlinCodeGen.deprecatedAnnotation(deprecationReason)) }
.superclass(superClass)
.addSuperclassConstructorParameter("rawValue = %S", value)
.build()
}

private val EnumType.unknownValueTypeSpec: TypeSpec
get() {
return TypeSpec.classBuilder("UNKNOWN__")
.addKdoc("%L", "Auto generated constant for unknown enum values\n")
.primaryConstructor(primaryConstructorSpec)
.superclass(ClassName("", name))
.addSuperclassConstructorParameter("rawValue = rawValue")
.build()
}

private val EnumType.sealedClassCompanionObjectSpec: TypeSpec
get() {
return TypeSpec
.companionObjectBuilder()
.addFunction(sealedClassSafeValueOfFunSpec)
.build()
}

private val EnumType.sealedClassSafeValueOfFunSpec: FunSpec
get() {
val returnClassName = ClassName("", name)
return FunSpec
.builder("safeValueOf")
.addParameter("rawValue", String::class)
.returns(returnClassName)
.beginControlFlow("return when(rawValue)")
.addCode(
values
.map { CodeBlock.of("%S -> %L", it.value, it.constName) }
.joinToCode(separator = "\n", suffix = "\n")
)
.addCode("else -> UNKNOWN__(rawValue)\n")
.endControlFlow()
.build()
}
Expand Up @@ -17,7 +17,8 @@ class GraphQLKompiler(
private val useSemanticNaming: Boolean,
private val generateAsInternal: Boolean = false,
private val operationIdGenerator: OperationIdGenerator,
private val kotlinMultiPlatformProject: Boolean
private val kotlinMultiPlatformProject: Boolean,
private val enumAsSealedClassPatternFilters: List<Regex>
) {
fun write(outputDir: File) {
val customTypeMap = customTypeMap.supportedCustomTypes(ir.typesUsed)
Expand All @@ -28,11 +29,11 @@ class GraphQLKompiler(
useSemanticNaming = useSemanticNaming,
operationIdGenerator = operationIdGenerator
)

val schemaCodegen = SchemaCodegen(
packageNameProvider = packageNameProvider,
generateAsInternal = generateAsInternal,
kotlinMultiPlatformProject = kotlinMultiPlatformProject
kotlinMultiPlatformProject = kotlinMultiPlatformProject,
enumAsSealedClassPatternFilters = enumAsSealedClassPatternFilters
)
schemaCodegen.apply(schema::accept).writeTo(outputDir)
}
Expand Down
@@ -1,7 +1,12 @@
package com.apollographql.apollo.compiler.codegen.kotlin

import com.apollographql.apollo.compiler.PackageNameProvider
import com.apollographql.apollo.compiler.ast.*
import com.apollographql.apollo.compiler.ast.CustomTypes
import com.apollographql.apollo.compiler.ast.EnumType
import com.apollographql.apollo.compiler.ast.InputType
import com.apollographql.apollo.compiler.ast.ObjectType
import com.apollographql.apollo.compiler.ast.OperationType
import com.apollographql.apollo.compiler.ast.SchemaVisitor
import com.apollographql.apollo.compiler.codegen.kotlin.KotlinCodeGen.patchKotlinNativeOptionalArrayProperties
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.TypeSpec
Expand All @@ -10,7 +15,8 @@ import java.io.File
internal class SchemaCodegen(
private val packageNameProvider: PackageNameProvider,
private val generateAsInternal: Boolean = false,
private val kotlinMultiPlatformProject: Boolean
private val kotlinMultiPlatformProject: Boolean,
private val enumAsSealedClassPatternFilters: List<Regex>
) : SchemaVisitor {
private var fileSpecs: List<FileSpec> = emptyList()

Expand All @@ -19,7 +25,10 @@ internal class SchemaCodegen(
}

override fun visit(enumType: EnumType) {
fileSpecs = fileSpecs + enumType.typeSpec(generateAsInternal).fileSpec(packageNameProvider.typesPackageName)
fileSpecs = fileSpecs + enumType.typeSpec(
generateAsInternal = generateAsInternal,
enumAsSealedClassPatternFilters = enumAsSealedClassPatternFilters
).fileSpec(packageNameProvider.typesPackageName)
}

override fun visit(inputType: InputType) {
Expand Down
Expand Up @@ -11,43 +11,51 @@ import kotlin.String
/**
* The episodes in the Star Wars trilogy
*/
enum class Episode(
sealed class Episode(
val rawValue: String
) {
/**
* Star Wars Episode IV: A New Hope, released in 1977.
*/
NEWHOPE("NEWHOPE"),
object NEWHOPE : Episode(rawValue = "NEWHOPE")

/**
* Star Wars Episode V: The Empire Strikes Back, released in 1980.
*/
EMPIRE("EMPIRE"),
object EMPIRE : Episode(rawValue = "EMPIRE")

/**
* Star Wars Episode VI: Return of the Jedi, released in 1983.
*/
JEDI("JEDI"),
object JEDI : Episode(rawValue = "JEDI")

/**
* Test deprecated enum value
*/
@Deprecated(message = "For test purpose only")
DEPRECATED("DEPRECATED"),
object DEPRECATED : Episode(rawValue = "DEPRECATED")

/**
* Test java reserved word
*/
@Deprecated(message = "For test purpose only")
NEW("new"),
object NEW : Episode(rawValue = "new")

/**
* Auto generated constant for unknown enum values
*/
UNKNOWN__("UNKNOWN__");
class UNKNOWN__(
rawValue: String
) : Episode(rawValue = rawValue)

companion object {
fun safeValueOf(rawValue: String): Episode = values().find { it.rawValue == rawValue } ?:
UNKNOWN__
fun safeValueOf(rawValue: String): Episode = when(rawValue) {
"NEWHOPE" -> NEWHOPE
"EMPIRE" -> EMPIRE
"JEDI" -> JEDI
"DEPRECATED" -> DEPRECATED
"new" -> NEW
else -> UNKNOWN__(rawValue)
}
}
}
Expand Up @@ -11,43 +11,51 @@ import kotlin.String
/**
* The episodes in the Star Wars trilogy
*/
enum class Episode(
sealed class Episode(
val rawValue: String
) {
/**
* Star Wars Episode IV: A New Hope, released in 1977.
*/
NEWHOPE("NEWHOPE"),
object NEWHOPE : Episode(rawValue = "NEWHOPE")

/**
* Star Wars Episode V: The Empire Strikes Back, released in 1980.
*/
EMPIRE("EMPIRE"),
object EMPIRE : Episode(rawValue = "EMPIRE")

/**
* Star Wars Episode VI: Return of the Jedi, released in 1983.
*/
JEDI("JEDI"),
object JEDI : Episode(rawValue = "JEDI")

/**
* Test deprecated enum value
*/
@Deprecated(message = "For test purpose only")
DEPRECATED("DEPRECATED"),
object DEPRECATED : Episode(rawValue = "DEPRECATED")

/**
* Test java reserved word
*/
@Deprecated(message = "For test purpose only")
NEW("new"),
object NEW : Episode(rawValue = "new")

/**
* Auto generated constant for unknown enum values
*/
UNKNOWN__("UNKNOWN__");
class UNKNOWN__(
rawValue: String
) : Episode(rawValue = rawValue)

companion object {
fun safeValueOf(rawValue: String): Episode = values().find { it.rawValue == rawValue } ?:
UNKNOWN__
fun safeValueOf(rawValue: String): Episode = when(rawValue) {
"NEWHOPE" -> NEWHOPE
"EMPIRE" -> EMPIRE
"JEDI" -> JEDI
"DEPRECATED" -> DEPRECATED
"new" -> NEW
else -> UNKNOWN__(rawValue)
}
}
}