Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/main/kotlin/net/ccbluex/netty/http/HttpServerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.LastHttpContent
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory
import net.ccbluex.netty.http.HttpServer.Companion.logger
import net.ccbluex.netty.http.middleware.Middleware
import net.ccbluex.netty.http.model.RequestContext
import net.ccbluex.netty.http.websocket.WebSocketHandler
import java.net.URLDecoder
Expand Down Expand Up @@ -66,6 +67,11 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun

if (connection.equals("Upgrade", ignoreCase = true) &&
upgrade.equals("WebSocket", ignoreCase = true)) {

if (server.middlewares.any {
it is Middleware.OnWebSocketUpgrade && !it.invoke(ctx, msg)
}) return

// Takes out Http Request Handler from the pipeline and replaces it with WebSocketHandler
ctx.pipeline().replace(this, "websocketHandler", WebSocketHandler(server))

Expand All @@ -90,6 +96,10 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
msg.headers().associate { it.key to it.value },
)

if (server.middlewares.any {
it is Middleware.OnRequestStart && !it.invoke(ctx, msg, requestContext)
}) return

localRequestContext.set(requestContext)
}
}
Expand All @@ -109,9 +119,13 @@ internal class HttpServerHandler(private val server: HttpServer) : ChannelInboun
if (msg is LastHttpContent) {
localRequestContext.remove()

val response = server.processRequestContext(requestContext)
val httpResponse = server.middlewares.fold(response) { acc, f -> f(requestContext, acc) }
ctx.writeAndFlush(httpResponse)
var response = server.processRequestContext(requestContext)
server.middlewares.forEach {
if (it is Middleware.OnFullHttpResponse) {
response = it.invoke(requestContext, response)
}
}
ctx.writeAndFlush(response)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CorsMiddleware(
listOf("GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"),
private val allowedHeaders: List<String> =
listOf("Content-Type", "Content-Length", "Authorization", "Accept", "X-Requested-With")
): Middleware {
): Middleware.OnFullHttpResponse {

/**
* Middleware to handle CORS requests.
Expand Down
43 changes: 40 additions & 3 deletions src/main/kotlin/net/ccbluex/netty/http/middleware/Middleware.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
/*
* This file is part of Netty-Rest (https://github.com/CCBlueX/netty-rest)
*
* Copyright (c) 2024 CCBlueX
*
* LiquidBounce is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Netty-Rest is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Netty-Rest. If not, see <https://www.gnu.org/licenses/>.
*
*/
package net.ccbluex.netty.http.middleware

import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.http.FullHttpResponse
import io.netty.handler.codec.http.HttpRequest
import net.ccbluex.netty.http.model.RequestContext

fun interface Middleware {
operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse
}
sealed interface Middleware {
fun interface OnWebSocketUpgrade : Middleware {
/**
* @return if it's accepted
*/
operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest): Boolean
}

fun interface OnRequestStart : Middleware {
/**
* @return if it's accepted
*/
operator fun invoke(ctx: ChannelHandlerContext, request: HttpRequest, requestContext: RequestContext): Boolean
}

fun interface OnFullHttpResponse : Middleware {
operator fun invoke(context: RequestContext, response: FullHttpResponse): FullHttpResponse
}
}
48 changes: 46 additions & 2 deletions src/test/kotlin/HttpMiddlewareServerTest.kt
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import com.google.gson.JsonObject
import io.netty.handler.codec.http.FullHttpResponse
import io.netty.handler.codec.http.HttpResponseStatus
import net.ccbluex.netty.http.HttpServer
import net.ccbluex.netty.http.middleware.Middleware
import net.ccbluex.netty.http.model.RequestObject
import net.ccbluex.netty.http.util.httpBadRequest
import net.ccbluex.netty.http.util.httpOk
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.Buffer
import org.junit.jupiter.api.*
import java.net.ProtocolException
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -55,7 +65,7 @@ class HttpMiddlewareServerTest {
get("/", ::static)
}

server.middleware { requestContext, fullHttpResponse ->
server.middleware(Middleware.OnFullHttpResponse { requestContext, fullHttpResponse ->
// Add custom headers to the response
fullHttpResponse.headers().add("X-Custom-Header", "Custom Value")

Expand All @@ -66,7 +76,13 @@ class HttpMiddlewareServerTest {
}

fullHttpResponse
}
}).middleware(Middleware.OnWebSocketUpgrade { context, _ ->
context.writeAndFlush(
httpBadRequest("WebSocket unsupported")
)

false
})

server.start(8080) // Start the server on port 8080
return server
Expand Down Expand Up @@ -125,4 +141,32 @@ class HttpMiddlewareServerTest {
"Query parameter should be present in the response")
}

@Test
fun testWebSocketShouldBeBadRequest() {
val future = CompletableFuture<Boolean>()

client.newWebSocket(
Request.Builder()
.url("http://localhost:8080")
.build(),
object : WebSocketListener() {
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
assertIs<ProtocolException>(t)
assertNotNull(response)
assertEquals(HttpResponseStatus.BAD_REQUEST.code(), response.code())
val exceptedResponseBody = httpBadRequest("WebSocket unsupported")
val buffer = Buffer()
buffer.write(exceptedResponseBody.content().nioBuffer())
assertEquals(
buffer.readUtf8(),
response.body()!!.string()
)
future.complete(true)
}
}
)

assertTrue(future.get(10, TimeUnit.SECONDS))
}

}