Skip to content

Commit

Permalink
Don't automatically add key fields to union selections
Browse files Browse the repository at this point in the history
  • Loading branch information
BoD committed Jan 24, 2024
1 parent 6efc382 commit 65ab702
Show file tree
Hide file tree
Showing 56 changed files with 2,221 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.apollographql.apollo3.ast.GQLFragmentSpread
import com.apollographql.apollo3.ast.GQLInlineFragment
import com.apollographql.apollo3.ast.GQLOperationDefinition
import com.apollographql.apollo3.ast.GQLSelection
import com.apollographql.apollo3.ast.GQLUnionTypeDefinition
import com.apollographql.apollo3.ast.Schema
import com.apollographql.apollo3.ast.definitionFromScope
import com.apollographql.apollo3.ast.isAbstract
Expand Down Expand Up @@ -45,6 +46,7 @@ private fun List<GQLSelection>.isPolymorphic(schema: Schema, fragments: Map<Stri
val tc = it.typeCondition?.name ?: rootType
!schema.isTypeASuperTypeOf(tc, rootType) || it.selections.isPolymorphic(schema, fragments, rootType)
}

is GQLFragmentSpread -> {
val fragmentDefinition = fragments[it.name] ?: error("cannot find fragment ${it.name}")
/**
Expand Down Expand Up @@ -75,16 +77,22 @@ private fun List<GQLSelection>.addRequiredFields(

val selectionSet = this

val requiresTypename = when(addTypename) {
val requiresTypename = when (addTypename) {
"ifPolymorphic" -> isRoot && isPolymorphic(schema, fragments, parentType)
"ifFragments" -> {
selectionSet.any { it is GQLFragmentSpread || it is GQLInlineFragment }
}

"ifAbstract" -> isRoot && schema.typeDefinition(parentType).isAbstract()
"always" -> isRoot
else -> error("Unknown addTypename option: $addTypename")
}
val requiredFieldNames = schema.keyFields(parentType).toMutableSet()
val requiredFieldNames = if (schema.typeDefinition(parentType) is GQLUnionTypeDefinition) {
// Can't select any fields on unions other than __typename
mutableSetOf()
} else {
schema.keyFields(parentType).toMutableSet()
}

if (requiredFieldNames.isNotEmpty() || requiresTypename) {
requiredFieldNames.add("__typename")
Expand All @@ -106,6 +114,7 @@ private fun List<GQLSelection>.addRequiredFields(
)
)
}

is GQLFragmentSpread -> it
is GQLField -> it.addRequiredFields(schema, addTypename, fragments, parentType)
}
Expand Down Expand Up @@ -142,7 +151,12 @@ private fun List<GQLSelection>.addRequiredFields(
return newSelections
}

private fun GQLField.addRequiredFields(schema: Schema, addTypename: String, fragments: Map<String, GQLFragmentDefinition>, parentType: String): GQLField {
private fun GQLField.addRequiredFields(
schema: Schema,
addTypename: String,
fragments: Map<String, GQLFragmentDefinition>,
parentType: String,
): GQLField {
val typeDefinition = definitionFromScope(schema, parentType)!!
val newSelectionSet = selections.addRequiredFields(
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.apollographql.apollo3.ast.GQLObjectTypeDefinition
import com.apollographql.apollo3.ast.GQLOperationDefinition
import com.apollographql.apollo3.ast.GQLSelection
import com.apollographql.apollo3.ast.GQLStringValue
import com.apollographql.apollo3.ast.GQLUnionTypeDefinition
import com.apollographql.apollo3.ast.Schema
import com.apollographql.apollo3.ast.definitionFromScope
import com.apollographql.apollo3.ast.rawType
Expand All @@ -24,8 +25,15 @@ private class CheckKeyFieldsScope(
}

private val keyFieldsCache = mutableMapOf<String, Set<String>>()
fun keyFields(name: String) = keyFieldsCache.getOrPut(name) {
schema.keyFields(name)
fun keyFields(name: String, parentName: String): Set<String> {
return keyFieldsCache.getOrPut("$parentName/$name") {
schema.keyFields(name) + if (schema.typeDefinition(parentName) is GQLUnionTypeDefinition) {
// If parent is an union, need to check the presence of its key fields too, as these can't be added automatically
schema.keyFields(parentName)
} else {
emptySet()
}
}
}
}

Expand Down Expand Up @@ -64,7 +72,7 @@ private fun CheckKeyFieldsScope.checkFieldSet(path: String, selections: List<GQL
val fieldNames = mergedFields.map { it.first().field }
.filter { it.alias == null }
.map { it.name }.toSet()
val keyFieldNames = keyFields(possibleType)
val keyFieldNames = keyFields(possibleType, parentType)

val missingFieldNames = keyFieldNames.subtract(fieldNames)
check(missingFieldNames.isEmpty()) {
Expand Down Expand Up @@ -101,13 +109,15 @@ private fun CheckKeyFieldsScope.collectFields(

listOf(FieldWithParent(it, parentType))
}

is GQLInlineFragment -> {
if (it.directives.hasCondition()) {
return@flatMap emptyList()
}

collectFields(it.selections, it.typeCondition?.name ?: parentType, implementedTypes)
}

is GQLFragmentSpread -> {
if (it.directives.hasCondition()) {
return@flatMap emptyList()
Expand Down

0 comments on commit 65ab702

Please sign in to comment.