Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement graphql-transport-ws protocol for websocket subscriptions (webmvc & webflux) #1200

Merged
merged 1 commit into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/codeStyles/Project.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21,586 changes: 21,527 additions & 59 deletions graphql-dgs-example-shared/ui-example/package-lock.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion graphql-dgs-example-shared/ui-example/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"version": "1.0.0",
"private": true,
"dependencies": {
"@apollo/client": "^3.2.7",
"@apollo/client": "^3.5.10",
"@reach/router": "^1.2.1",
"@types/node": "^12.12.14",
"@types/reach__router": "^1.2.6",
"@types/react": "^16.9.15",
"@types/react-dom": "^16.9.4",
"emotion": "^9.2.12",
"graphql": "^14.4.2",
"graphql-ws": "^5.10.0",
"polished": "^3.4.1",
"react": "^16.12.0",
"react-dom": "^16.12.0",
Expand Down
13 changes: 7 additions & 6 deletions graphql-dgs-example-shared/ui-example/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ import {
useSubscription
} from '@apollo/client';

import {WebSocketLink} from "@apollo/client/link/ws";
import { GraphQLWsLink } from "@apollo/client/link/subscriptions";
import { createClient } from 'graphql-ws';

const httpLink = createHttpLink({uri:'http://localhost:8080/graphql' })
const webSocketLink = new GraphQLWsLink(createClient({
url: 'ws://localhost:8080/subscriptions',
}));

const webSocketLink = new WebSocketLink({
uri: 'ws://localhost:8080/subscriptions'
});

const httpLink = createHttpLink({uri:'http://localhost:8080/graphql' })
const client: ApolloClient<NormalizedCacheObject> = new ApolloClient({
link: split((operation) => {
return operation.operationName === "StockWatch"
Expand Down Expand Up @@ -117,4 +118,4 @@ ReactDOM.render(
<App/>
</ApolloProvider>,
document.getElementById('root'),
);
);
1 change: 1 addition & 0 deletions graphql-dgs-spring-webflux-autoconfigure/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dependencies {
api(project(":graphql-dgs"))
api(project(":graphql-dgs-reactive"))
api(project(":graphql-dgs-subscription-types"))

implementation("org.springframework.boot:spring-boot-starter")
implementation("org.springframework:spring-webflux")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ open class DgsWebFluxAutoConfiguration(private val configProps: DgsWebfluxConfig
}

@Bean
open fun websocketSubscriptionHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor): SimpleUrlHandlerMapping {
open fun websocketSubscriptionHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor, webfluxConfigurationProperties: DgsWebfluxConfigurationProperties): SimpleUrlHandlerMapping {
val simpleUrlHandlerMapping =
SimpleUrlHandlerMapping(mapOf("/subscriptions" to DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor)))
SimpleUrlHandlerMapping(mapOf("/subscriptions" to DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor, webfluxConfigurationProperties.websocket.connectionInitTimeout)))
simpleUrlHandlerMapping.order = 1
return simpleUrlHandlerMapping
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,30 @@ import org.springframework.boot.context.properties.ConfigurationProperties
import org.springframework.boot.context.properties.ConstructorBinding
import org.springframework.boot.context.properties.NestedConfigurationProperty
import org.springframework.boot.context.properties.bind.DefaultValue
import java.time.Duration
import javax.annotation.PostConstruct

@ConstructorBinding
@ConfigurationProperties(prefix = "dgs.graphql")
@Suppress("ConfigurationProperties")
class DgsWebfluxConfigurationProperties(
/** Websocket configuration. */
@NestedConfigurationProperty var websocket: DgsWebsocketConfigurationProperties = DgsWebsocketConfigurationProperties(
DEFAULT_CONNECTION_INIT_TIMEOUT_DURATION
),
/** Path to the endpoint that will serve GraphQL requests. */
@DefaultValue("/graphql") var path: String = "/graphql",
@NestedConfigurationProperty var graphiql: DgsGraphiQLConfigurationProperties = DgsGraphiQLConfigurationProperties(),
@NestedConfigurationProperty var schemaJson: DgsSchemaJsonConfigurationProperties = DgsSchemaJsonConfigurationProperties()
) {
/**
* Configuration properties for websockets.
*/
data class DgsWebsocketConfigurationProperties(
/** Connection Initialization timeout for graphql-transport-ws. */
@DefaultValue(DEFAULT_CONNECTION_INIT_TIMEOUT) var connectionInitTimeout: Duration
)

/**
* Configuration properties for the GraphiQL endpoint.
*/
Expand Down Expand Up @@ -60,4 +73,9 @@ class DgsWebfluxConfigurationProperties(
throw IllegalArgumentException("$pathProperty must start with '/' and not end with '/' but was '$path'")
}
}

companion object {
const val DEFAULT_CONNECTION_INIT_TIMEOUT = "10s"
val DEFAULT_CONNECTION_INIT_TIMEOUT_DURATION: Duration = Duration.ofSeconds(10)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,165 +16,27 @@

package com.netflix.graphql.dgs.webflux.handlers

import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.module.kotlin.convertValue
import com.netflix.graphql.dgs.reactive.DgsReactiveQueryExecutor
import graphql.ExecutionResult
import org.reactivestreams.Publisher
import org.reactivestreams.Subscription
import org.slf4j.LoggerFactory
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
import org.springframework.core.io.buffer.DataBufferUtils
import org.springframework.http.codec.json.Jackson2JsonDecoder
import org.springframework.http.codec.json.Jackson2JsonEncoder
import org.springframework.util.MimeTypeUtils
import com.netflix.graphql.types.subscription.GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL
import com.netflix.graphql.types.subscription.GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL
import org.springframework.web.reactive.socket.WebSocketHandler
import org.springframework.web.reactive.socket.WebSocketMessage
import org.springframework.web.reactive.socket.WebSocketSession
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import java.util.concurrent.ConcurrentHashMap
import java.time.Duration

class DgsReactiveWebsocketHandler(private val dgsReactiveQueryExecutor: DgsReactiveQueryExecutor) : WebSocketHandler {
class DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor, connectionInitTimeout: Duration) : WebSocketHandler {

private val resolvableType = ResolvableType.forType(OperationMessage::class.java)
private val subscriptions = ConcurrentHashMap<String, MutableMap<String, Subscription>>()
private val decoder = Jackson2JsonDecoder()
private val encoder = Jackson2JsonEncoder(decoder.objectMapper)

override fun getSubProtocols(): List<String> = listOf("graphql-ws")
private val graphqlWSHandler = WebsocketGraphQLWSProtocolHandler(dgsReactiveQueryExecutor)
private val graphqlTransportWSHandler = WebsocketGraphQLTransportWSProtocolHandler(dgsReactiveQueryExecutor, connectionInitTimeout)
override fun getSubProtocols(): List<String> = listOf(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL)

override fun handle(webSocketSession: WebSocketSession): Mono<Void> {
return webSocketSession.send(
webSocketSession.receive()
.flatMap { message ->
val buffer: DataBuffer = DataBufferUtils.retain(message.payload)

val operationMessage: OperationMessage = decoder.decode(
buffer,
resolvableType,
MimeTypeUtils.APPLICATION_JSON,
null
) as OperationMessage

when (operationMessage.type) {
GQL_CONNECTION_INIT -> Flux.just(
toWebsocketMessage(
OperationMessage(GQL_CONNECTION_ACK), webSocketSession
)
)
GQL_START -> {
val queryPayload = decoder.objectMapper.convertValue<QueryPayload>(
operationMessage.payload ?: error("payload == null")
)
logger.debug("Starting subscription {} for session {}", queryPayload, webSocketSession.id)
dgsReactiveQueryExecutor.execute(queryPayload.query, queryPayload.variables)
.flatMapMany { executionResult ->
val publisher: Publisher<ExecutionResult> = executionResult.getData()
Flux.from(publisher).map { executionResult ->
toWebsocketMessage(
OperationMessage(GQL_DATA, DataPayload(data = executionResult.getData(), errors = executionResult.errors), operationMessage.id),
webSocketSession
)
}.doOnSubscribe {
if (operationMessage.id != null) {
subscriptions[webSocketSession.id] = mutableMapOf(operationMessage.id to it)
}
}.doOnComplete {
webSocketSession.send(
Flux.just(
toWebsocketMessage(
OperationMessage(GQL_COMPLETE, null, operationMessage.id),
webSocketSession
)
)
).subscribe()

subscriptions[webSocketSession.id]?.remove(operationMessage.id)
logger.debug(
"Completing subscription {} for connection {}",
operationMessage.id, webSocketSession.id
)
}.doOnError {
webSocketSession.send(
Flux.just(
toWebsocketMessage(
OperationMessage(GQL_ERROR, DataPayload(null, listOf(it.message!!)), operationMessage.id),
webSocketSession
)
)
).subscribe()

subscriptions[webSocketSession.id]?.remove(operationMessage.id)
logger.debug(
"Subscription publisher error for input {} for subscription {} for connection {}",
queryPayload, operationMessage.id, webSocketSession.id, it
)
}
}
}

GQL_STOP -> {
subscriptions[webSocketSession.id]?.remove(operationMessage.id)?.cancel()
logger.debug(
"Client stopped subscription {} for connection {}",
operationMessage.id, webSocketSession.id
)
Flux.empty()
}

GQL_CONNECTION_TERMINATE -> {
subscriptions[webSocketSession.id]?.values?.forEach { it.cancel() }
subscriptions.remove(webSocketSession.id)
webSocketSession.close()
logger.debug("Connection {} terminated", webSocketSession.id)
Flux.empty()
}
if (webSocketSession.handshakeInfo.subProtocol.equals(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, ignoreCase = true)) {
return graphqlWSHandler.handle(webSocketSession)
} else if (webSocketSession.handshakeInfo.subProtocol.equals(GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL, ignoreCase = true)) {
return graphqlTransportWSHandler.handle(webSocketSession)
}

else -> Flux.empty()
}
}
)
}

private fun toWebsocketMessage(operationMessage: OperationMessage, session: WebSocketSession): WebSocketMessage {
return WebSocketMessage(
WebSocketMessage.Type.TEXT,
encoder.encodeValue(
operationMessage,
session.bufferFactory(),
resolvableType,
MimeTypeUtils.APPLICATION_JSON,
null
)
)
}

companion object {
private val logger = LoggerFactory.getLogger(DgsReactiveQueryExecutor::class.java)

const val GQL_CONNECTION_INIT = "connection_init"
const val GQL_CONNECTION_ACK = "connection_ack"
const val GQL_START = "start"
const val GQL_STOP = "stop"
const val GQL_DATA = "data"
const val GQL_ERROR = "error"
const val GQL_COMPLETE = "complete"
const val GQL_CONNECTION_TERMINATE = "connection_terminate"
return Mono.empty()
}
}

data class DataPayload(val data: Any?, val errors: List<Any>? = emptyList())
data class OperationMessage(
@JsonProperty("type") val type: String,
@JsonProperty("payload") val payload: Any? = null,
@JsonProperty("id", required = false) val id: String? = ""
)

data class QueryPayload(
@JsonProperty("variables") val variables: Map<String, Any> = emptyMap(),
@JsonProperty("extensions") val extensions: Map<String, Any> = emptyMap(),
@JsonProperty("operationName") val operationName: String?,
@JsonProperty("query") val query: String
)