Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lightbug_http/header.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct HeaderKey:
alias CONTENT_LENGTH = "content-length"
alias CONTENT_ENCODING = "content-encoding"
alias DATE = "date"
alias LOCATION = "location"
alias HOST = "host"


@value
Expand Down Expand Up @@ -70,16 +72,12 @@ struct Headers(Formattable, Stringable):
self._inner[key.lower()] = value

fn content_length(self) -> Int:
if HeaderKey.CONTENT_LENGTH not in self:
return 0
try:
return int(self[HeaderKey.CONTENT_LENGTH])
except:
return 0

fn parse_raw(
inout self, inout r: ByteReader
) raises -> (String, String, String):
fn parse_raw(inout self, inout r: ByteReader) raises -> (String, String, String):
var first_byte = r.peek()
if not first_byte:
raise Error("Failed to read first byte from response header")
Expand Down
84 changes: 53 additions & 31 deletions lightbug_http/http.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ fn encode(owned res: HTTPResponse) -> Bytes:
return res._encoded()


struct StatusCode:
alias OK = 200
alias MOVED_PERMANENTLY = 301
alias FOUND = 302
alias TEMPORARY_REDIRECT = 307
alias PERMANENT_REDIRECT = 308
alias NOT_FOUND = 404


@value
struct HTTPRequest(Formattable, Stringable):
var headers: Headers
Expand All @@ -48,9 +57,7 @@ struct HTTPRequest(Formattable, Stringable):
var timeout: Duration

@staticmethod
fn from_bytes(
addr: String, max_body_size: Int, owned b: Bytes
) raises -> HTTPRequest:
fn from_bytes(addr: String, max_body_size: Int, owned b: Bytes) raises -> HTTPRequest:
var reader = ByteReader(b^)
var headers = Headers()
var method: String
Expand All @@ -65,16 +72,10 @@ struct HTTPRequest(Formattable, Stringable):

var content_length = headers.content_length()

if (
content_length > 0
and max_body_size > 0
and content_length > max_body_size
):
if content_length > 0 and max_body_size > 0 and content_length > max_body_size:
raise Error("Request body too large")

var request = HTTPRequest(
uri, headers=headers, method=method, protocol=protocol
)
var request = HTTPRequest(uri, headers=headers, method=method, protocol=protocol)

try:
request.read_body(reader, content_length, max_body_size)
Expand Down Expand Up @@ -103,6 +104,8 @@ struct HTTPRequest(Formattable, Stringable):
self.set_content_length(len(body))
if HeaderKey.CONNECTION not in self.headers:
self.set_connection_close()
if HeaderKey.HOST not in self.headers:
self.headers[HeaderKey.HOST] = uri.host

fn set_connection_close(inout self):
self.headers[HeaderKey.CONNECTION] = "close"
Expand All @@ -114,20 +117,22 @@ struct HTTPRequest(Formattable, Stringable):
return self.headers[HeaderKey.CONNECTION] == "close"

@always_inline
fn read_body(
inout self, inout r: ByteReader, content_length: Int, max_body_size: Int
) raises -> None:
fn read_body(inout self, inout r: ByteReader, content_length: Int, max_body_size: Int) raises -> None:
if content_length > max_body_size:
raise Error("Request body too large")

r.consume(self.body_raw)
r.consume(self.body_raw, content_length)
self.set_content_length(content_length)

fn format_to(self, inout writer: Formatter):
writer.write(self.method, whitespace)
path = self.uri.path if len(self.uri.path) > 1 else strSlash
if len(self.uri.query_string) > 0:
path += "?" + self.uri.query_string

writer.write(path)

writer.write(
self.method,
whitespace,
self.uri.path if len(self.uri.path) > 1 else strSlash,
whitespace,
self.protocol,
lineBreak,
Expand All @@ -147,6 +152,8 @@ struct HTTPRequest(Formattable, Stringable):
writer.write(self.method)
writer.write(whitespace)
var path = self.uri.path if len(self.uri.path) > 1 else strSlash
if len(self.uri.query_string) > 0:
path += "?" + self.uri.query_string
writer.write(path)
writer.write(whitespace)
writer.write(self.protocol)
Expand Down Expand Up @@ -215,8 +222,16 @@ struct HTTPResponse(Formattable, Stringable):
self.status_text = status_text
self.protocol = protocol
self.body_raw = body_bytes
self.set_connection_keep_alive()
self.set_content_length(len(body_bytes))
if HeaderKey.CONNECTION not in self.headers:
self.set_connection_keep_alive()
if HeaderKey.CONTENT_LENGTH not in self.headers:
self.set_content_length(len(body_bytes))
if HeaderKey.DATE not in self.headers:
try:
var current_time = now(utc=True).__str__()
self.headers[HeaderKey.DATE] = current_time
except:
pass

fn get_body_bytes(self) -> Bytes:
return self.body_raw
Expand All @@ -236,9 +251,25 @@ struct HTTPResponse(Formattable, Stringable):
fn set_content_length(inout self, l: Int):
self.headers[HeaderKey.CONTENT_LENGTH] = str(l)

@always_inline
fn content_length(inout self) -> Int:
try:
return int(self.headers[HeaderKey.CONTENT_LENGTH])
except:
return 0

fn is_redirect(self) -> Bool:
return (
self.status_code == StatusCode.MOVED_PERMANENTLY
or self.status_code == StatusCode.FOUND
or self.status_code == StatusCode.TEMPORARY_REDIRECT
or self.status_code == StatusCode.PERMANENT_REDIRECT
)

@always_inline
fn read_body(inout self, inout r: ByteReader) raises -> None:
r.consume(self.body_raw)
r.consume(self.body_raw, self.content_length())
self.set_content_length(len(self.body_raw))

fn format_to(self, inout writer: Formatter):
writer.write(
Expand All @@ -252,13 +283,6 @@ struct HTTPResponse(Formattable, Stringable):
lineBreak,
)

if HeaderKey.DATE not in self.headers:
try:
var current_time = now(utc=True).__str__()
write_header(writer, HeaderKey.DATE, current_time)
except:
pass

self.headers.format_to(writer)

writer.write(lineBreak)
Expand Down Expand Up @@ -326,9 +350,7 @@ fn OK(body: Bytes, content_type: String) -> HTTPResponse:
)


fn OK(
body: Bytes, content_type: String, content_encoding: String
) -> HTTPResponse:
fn OK(body: Bytes, content_type: String, content_encoding: String) -> HTTPResponse:
return HTTPResponse(
headers=Headers(
Header(HeaderKey.CONTENT_TYPE, content_type),
Expand Down
56 changes: 15 additions & 41 deletions lightbug_http/libc.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,7 @@ fn inet_ntop(
](af, src, dst, size)


fn inet_pton(
af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]
) -> c_int:
fn inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) -> c_int:
"""Libc POSIX `inet_pton` function
Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html
Fn signature: int inet_pton(int af, const char *restrict src, void *restrict dst).
Expand Down Expand Up @@ -512,9 +510,7 @@ fn socket(domain: c_int, type: c_int, protocol: c_int) -> c_int:
protocol: The protocol to use.
Returns: A File Descriptor or -1 in case of failure.
"""
return external_call[
"socket", c_int, c_int, c_int, c_int # FnName, RetType # Args
](domain, type, protocol)
return external_call["socket", c_int, c_int, c_int, c_int](domain, type, protocol) # FnName, RetType # Args


fn setsockopt(
Expand Down Expand Up @@ -592,16 +588,12 @@ fn getpeername(
](sockfd, addr, address_len)


fn bind(
socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t
) -> c_int:
fn bind(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int:
"""Libc POSIX `bind` function
Reference: https://man7.org/linux/man-pages/man3/bind.3p.html
Fn signature: int bind(int socket, const struct sockaddr *address, socklen_t address_len).
"""
return external_call[
"bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t
](socket, address, address_len)
return external_call["bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t](socket, address, address_len)


fn listen(socket: c_int, backlog: c_int) -> c_int:
Expand Down Expand Up @@ -639,9 +631,7 @@ fn accept(
](socket, address, address_len)


fn connect(
socket: c_int, address: Reference[sockaddr], address_len: socklen_t
) -> c_int:
fn connect(socket: c_int, address: Reference[sockaddr], address_len: socklen_t) -> c_int:
"""Libc POSIX `connect` function
Reference: https://man7.org/linux/man-pages/man3/connect.3p.html
Fn signature: int connect(int socket, const struct sockaddr *address, socklen_t address_len).
Expand Down Expand Up @@ -674,9 +664,7 @@ fn recv(
](socket, buffer, length, flags)


fn send(
socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int
) -> c_ssize_t:
fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int) -> c_ssize_t:
"""Libc POSIX `send` function
Reference: https://man7.org/linux/man-pages/man3/send.3p.html
Fn signature: ssize_t send(int socket, const void *buffer, size_t length, int flags).
Expand All @@ -699,11 +687,7 @@ fn shutdown(socket: c_int, how: c_int) -> c_int:
how: How to shutdown the socket.
Returns: 0 on success, -1 on error.
"""
return external_call[
"shutdown", c_int, c_int, c_int
]( # FnName, RetType # Args
socket, how
)
return external_call["shutdown", c_int, c_int, c_int](socket, how) # FnName, RetType # Args


fn getaddrinfo(
Expand Down Expand Up @@ -734,9 +718,7 @@ fn gai_strerror(ecode: c_int) -> UnsafePointer[c_char]:
Args: ecode: The error code.
Returns: A UnsafePointer to a string describing the error.
"""
return external_call[
"gai_strerror", UnsafePointer[c_char], c_int # FnName, RetType # Args
](ecode)
return external_call["gai_strerror", UnsafePointer[c_char], c_int](ecode) # FnName, RetType # Args


fn inet_pton(address_family: Int, address: String) -> Int:
Expand All @@ -745,9 +727,7 @@ fn inet_pton(address_family: Int, address: String) -> Int:
ip_buf_size = 16

var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size)
var conv_status = inet_pton(
rebind[c_int](address_family), to_char_ptr(address), ip_buf
)
var conv_status = inet_pton(rebind[c_int](address_family), to_char_ptr(address), ip_buf)
return int(ip_buf.bitcast[c_uint]())


Expand All @@ -772,9 +752,7 @@ fn close(fildes: c_int) -> c_int:
return external_call["close", c_int, c_int](fildes)


fn open[
*T: AnyType
](path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int:
fn open[*T: AnyType](path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int:
"""Libc POSIX `open` function
Reference: https://man7.org/linux/man-pages/man3/open.3p.html
Fn signature: int open(const char *path, int oflag, ...).
Expand Down Expand Up @@ -814,9 +792,7 @@ fn read(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int:
nbyte: The number of bytes to read.
Returns: The number of bytes read or -1 in case of failure.
"""
return external_call[
"read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t
](fildes, buf, nbyte)
return external_call["read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte)


fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int:
Expand All @@ -829,9 +805,7 @@ fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int:
nbyte: The number of bytes to write.
Returns: The number of bytes written or -1 in case of failure.
"""
return external_call[
"write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t
](fildes, buf, nbyte)
return external_call["write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte)


fn __test_getaddrinfo__():
Expand All @@ -853,8 +827,8 @@ fn __test_getaddrinfo__():
UnsafePointer.address_of(servinfo),
)
var msg_ptr = gai_strerror(c_int(status))
_ = external_call[
"printf", c_int, UnsafePointer[c_char], UnsafePointer[c_char]
](to_char_ptr("gai_strerror: %s"), msg_ptr)
_ = external_call["printf", c_int, UnsafePointer[c_char], UnsafePointer[c_char]](
to_char_ptr("gai_strerror: %s"), msg_ptr
)
var msg = c_charptr_to_string(msg_ptr)
print("getaddrinfo satus: " + msg)
14 changes: 3 additions & 11 deletions lightbug_http/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ struct TCPAddr(Addr):

fn string(self) -> String:
if self.zone != "":
return join_host_port(
self.ip + "%" + self.zone, self.port.__str__()
)
return join_host_port(self.ip + "%" + self.zone, self.port.__str__())
return join_host_port(self.ip, self.port.__str__())


Expand All @@ -143,11 +141,7 @@ fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr:
host = host_port.host
port = host_port.port
portnum = atol(port.__str__())
elif (
network == NetworkType.ip.value
or network == NetworkType.ip4.value
or network == NetworkType.ip6.value
):
elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value:
if address != "":
host = address
elif network == NetworkType.unix.value:
Expand Down Expand Up @@ -221,9 +215,7 @@ fn convert_binary_port_to_int(port: UInt16) -> Int:
return int(ntohs(port))


fn convert_binary_ip_to_string(
owned ip_address: UInt32, address_family: Int32, address_length: UInt32
) -> String:
fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32, address_length: UInt32) -> String:
"""Convert a binary IP address to a string by calling inet_ntop.

Args:
Expand Down
4 changes: 1 addition & 3 deletions lightbug_http/server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ trait ServerTrait:
fn get_concurrency(self) -> Int:
...

fn listen_and_serve(
self, address: String, handler: HTTPService
) raises -> None:
fn listen_and_serve(self, address: String, handler: HTTPService) raises -> None:
...

fn serve(self, ln: Listener, handler: HTTPService) raises -> None:
Expand Down
4 changes: 1 addition & 3 deletions lightbug_http/service.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ struct Printer(HTTPService):
var header = req.headers
print("Request protocol: ", req.protocol)
print("Request method: ", req.method)
print(
"Request Content-Type: ", to_string(header[HeaderKey.CONTENT_TYPE])
)
print("Request Content-Type: ", to_string(header[HeaderKey.CONTENT_TYPE]))

var body = req.body_raw
print("Request Body: ", to_string(body))
Expand Down
Loading