diff --git a/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt b/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt index 712bc66..03ccf5a 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt @@ -27,6 +27,7 @@ import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.LastHttpContent import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import net.ccbluex.netty.http.HttpServer.Companion.logger +import net.ccbluex.netty.http.middleware.Middleware import net.ccbluex.netty.http.model.RequestContext import net.ccbluex.netty.http.websocket.WebSocketHandler import java.net.URLDecoder @@ -66,6 +67,11 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun if (connection.equals("Upgrade", ignoreCase = true) && upgrade.equals("WebSocket", ignoreCase = true)) { + + if (server.middlewares.any { + it is Middleware.OnWebSocketUpgrade && !it.invoke(ctx, msg) + }) return + // Takes out Http Request Handler from the pipeline and replaces it with WebSocketHandler ctx.pipeline().replace(this, "websocketHandler", WebSocketHandler(server)) @@ -90,6 +96,10 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun msg.headers().associate { it.key to it.value }, ) + if (server.middlewares.any { + it is Middleware.OnRequestStart && !it.invoke(ctx, msg, requestContext) + }) return + localRequestContext.set(requestContext) } } @@ -109,9 +119,13 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun if (msg is LastHttpContent) { localRequestContext.remove() - val response = server.processRequestContext(requestContext) - val httpResponse = server.middlewares.fold(response) { acc, f -> f(requestContext, acc) } - ctx.writeAndFlush(httpResponse) + var response = server.processRequestContext(requestContext) + server.middlewares.forEach { + if (it is Middleware.OnFullHttpResponse) { + response = it.invoke(requestContext, response) + } + } + ctx.writeAndFlush(response) } } diff --git a/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt b/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt index 1b0c036..42e52ca 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/middleware/CorsMiddleware.kt @@ -24,7 +24,7 @@ class CorsMiddleware( listOf("GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"), private val allowedHeaders: List = listOf("Content-Type", "Content-Length", "Authorization", "Accept", "X-Requested-With") -): Middleware { +): Middleware.OnFullHttpResponse { /** * Middleware to handle CORS requests. diff --git a/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt b/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt index 5ca240f..cc2f5cc 100644 --- a/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt +++ b/src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt @@ -1,8 +1,45 @@ +/* + * This file is part of Netty-Rest (https://github.com/CCBlueX/netty-rest) + * + * Copyright (c) 2024 CCBlueX + * + * LiquidBounce is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Netty-Rest is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Netty-Rest. If not, see . + * + */ package net.ccbluex.netty.http.middleware +import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.HttpRequest import net.ccbluex.netty.http.model.RequestContext -fun interface Middleware { - operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse -} \ No newline at end of file +sealed interface Middleware { + fun interface OnWebSocketUpgrade : Middleware { + /** + * @return if it's accepted + */ + operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest): Boolean + } + + fun interface OnRequestStart : Middleware { + /** + * @return if it's accepted + */ + operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest, requestContext: RequestContext): Boolean + } + + fun interface OnFullHttpResponse : Middleware { + operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse + } +} diff --git a/src/test/kotlin/HttpMiddlewareServerTest.kt b/src/test/kotlin/HttpMiddlewareServerTest.kt index 8100a01..64c2d7e 100644 --- a/src/test/kotlin/HttpMiddlewareServerTest.kt +++ b/src/test/kotlin/HttpMiddlewareServerTest.kt @@ -1,13 +1,23 @@ import com.google.gson.JsonObject import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.HttpResponseStatus import net.ccbluex.netty.http.HttpServer +import net.ccbluex.netty.http.middleware.Middleware import net.ccbluex.netty.http.model.RequestObject +import net.ccbluex.netty.http.util.httpBadRequest import net.ccbluex.netty.http.util.httpOk import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okio.Buffer import org.junit.jupiter.api.* +import java.net.ProtocolException +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit import kotlin.test.assertEquals +import kotlin.test.assertIs import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -55,7 +65,7 @@ class HttpMiddlewareServerTest { get("/", ::static) } - server.middleware { requestContext, fullHttpResponse -> + server.middleware(Middleware.OnFullHttpResponse { requestContext, fullHttpResponse -> // Add custom headers to the response fullHttpResponse.headers().add("X-Custom-Header", "Custom Value") @@ -66,7 +76,13 @@ class HttpMiddlewareServerTest { } fullHttpResponse - } + }).middleware(Middleware.OnWebSocketUpgrade { context, _ -> + context.writeAndFlush( + httpBadRequest("WebSocket unsupported") + ) + + false + }) server.start(8080) // Start the server on port 8080 return server @@ -125,4 +141,32 @@ class HttpMiddlewareServerTest { "Query parameter should be present in the response") } + @Test + fun testWebSocketShouldBeBadRequest() { + val future = CompletableFuture() + + client.newWebSocket( + Request.Builder() + .url("http://localhost:8080") + .build(), + object : WebSocketListener() { + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + assertIs(t) + assertNotNull(response) + assertEquals(HttpResponseStatus.BAD_REQUEST.code(), response.code()) + val exceptedResponseBody = httpBadRequest("WebSocket unsupported") + val buffer = Buffer() + buffer.write(exceptedResponseBody.content().nioBuffer()) + assertEquals( + buffer.readUtf8(), + response.body()!!.string() + ) + future.complete(true) + } + } + ) + + assertTrue(future.get(10, TimeUnit.SECONDS)) + } + }