Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions src/main/kotlin/net/ccbluex/netty/http/HttpServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
package net.ccbluex.netty.http

import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.Channel
import io.netty.channel.ChannelOption
import io.netty.channel.EventLoopGroup
import io.netty.channel.epoll.Epoll
import io.netty.channel.epoll.EpollEventLoopGroup
import io.netty.channel.epoll.EpollServerSocketChannel
Expand All @@ -32,6 +34,9 @@ import net.ccbluex.netty.http.middleware.Middleware
import net.ccbluex.netty.http.rest.RouteController
import net.ccbluex.netty.http.websocket.WebSocketController
import org.apache.logging.log4j.LogManager
import java.net.InetSocketAddress
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock


/**
Expand All @@ -44,22 +49,32 @@ class HttpServer {
val routeController = RouteController()
val webSocketController = WebSocketController()

val middlewares = mutableListOf<Middleware>()
private val lock = ReentrantLock()

internal val middlewares = mutableListOf<Middleware>()

private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null

companion object {
internal val logger = LogManager.getLogger("HttpServer")
}

fun middleware(middleware: Middleware) {
fun middleware(middleware: Middleware) = apply {
middlewares += middleware
}

/**
* Starts the Netty server on the specified port.
*
* @param port The port of HTTP server. `0` means to auto select one.
*
* @return actual port of server.
*/
fun start(port: Int) {
val bossGroup = if (Epoll.isAvailable()) EpollEventLoopGroup() else NioEventLoopGroup()
val workerGroup = if (Epoll.isAvailable()) EpollEventLoopGroup() else NioEventLoopGroup()
fun start(port: Int): Int = lock.withLock {
bossGroup = if (Epoll.isAvailable()) EpollEventLoopGroup(1) else NioEventLoopGroup(1)
workerGroup = if (Epoll.isAvailable()) EpollEventLoopGroup() else NioEventLoopGroup()

try {
logger.info("Starting Netty server...")
Expand All @@ -70,21 +85,35 @@ class HttpServer {
.handler(LoggingHandler(LogLevel.INFO))
.childHandler(HttpChannelInitializer(this))
val ch = b.bind(port).sync().channel()
serverChannel = ch

logger.info("Netty server started on port $port.")
ch.closeFuture().sync()
} catch (e: InterruptedException) {
logger.error("Netty server interrupted", e)

return@withLock (ch.localAddress() as InetSocketAddress).port
} catch (t: Throwable) {
logger.error("Netty server failed - $port", t)

stop()
// Forward the exception because we ran into an unexpected error
throw t
} finally {
bossGroup.shutdownGracefully()
workerGroup.shutdownGracefully()
}
}

/**
* Stops the Netty server gracefully.
*/
fun stop() = lock.withLock {
logger.info("Shutting down Netty server...")
try {
serverChannel?.close()?.sync()
bossGroup?.shutdownGracefully()?.sync()
workerGroup?.shutdownGracefully()?.sync()
} catch (e: Exception) {
logger.warn("Error during shutdown", e)
} finally {
serverChannel = null
bossGroup = null
workerGroup = null
}
logger.info("Netty server stopped.")
}

Expand Down
14 changes: 5 additions & 9 deletions src/test/kotlin/HttpMiddlewareServerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import org.junit.jupiter.api.*
import java.io.File
import java.nio.file.Files
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
Expand All @@ -21,7 +18,7 @@ import kotlin.test.assertTrue
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class HttpMiddlewareServerTest {

private lateinit var serverThread: Thread
private lateinit var server: HttpServer
private val client = OkHttpClient()

/**
Expand All @@ -32,9 +29,7 @@ class HttpMiddlewareServerTest {
@BeforeAll
fun initialize() {
// Start the HTTP server in a separate thread
serverThread = thread {
startHttpServer()
}
server = startHttpServer()

// Allow the server some time to start
Thread.sleep(1000)
Expand All @@ -46,14 +41,14 @@ class HttpMiddlewareServerTest {
*/
@AfterAll
fun cleanup() {
serverThread.interrupt()
server.stop()
}

/**
* This function starts the HTTP server with routing configured for
* different difficulty levels.
*/
private fun startHttpServer() {
private fun startHttpServer(): HttpServer {
val server = HttpServer()

server.routeController.apply {
Expand All @@ -74,6 +69,7 @@ class HttpMiddlewareServerTest {
}

server.start(8080) // Start the server on port 8080
return server
}

@Suppress("UNUSED_PARAMETER")
Expand Down
12 changes: 5 additions & 7 deletions src/test/kotlin/HttpServerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import okhttp3.Response
import org.junit.jupiter.api.*
import java.io.File
import java.nio.file.Files
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
Expand All @@ -22,7 +21,7 @@ import kotlin.test.assertTrue
class HttpServerTest {

private lateinit var folder: File
private lateinit var serverThread: Thread
private lateinit var server: HttpServer
private val client = OkHttpClient()

/**
Expand All @@ -44,9 +43,7 @@ class HttpServerTest {
File(subFolder, "index.html").writeText("Hello, World!")

// Start the HTTP server in a separate thread
serverThread = thread {
startHttpServer(folder)
}
server = startHttpServer(folder)

// Allow the server some time to start
Thread.sleep(1000)
Expand All @@ -58,15 +55,15 @@ class HttpServerTest {
*/
@AfterAll
fun cleanup() {
serverThread.interrupt()
server.stop()
folder.deleteRecursively() // Clean up the temporary folder
}

/**
* This function starts the HTTP server with routing configured for
* different difficulty levels.
*/
private fun startHttpServer(folder: File) {
private fun startHttpServer(folder: File): HttpServer {
val server = HttpServer()

server.routeController.apply {
Expand Down Expand Up @@ -97,6 +94,7 @@ class HttpServerTest {
}

server.start(8080) // Start the server on port 8080
return server
}

@Suppress("UNUSED_PARAMETER")
Expand Down