diff --git a/.github/workflows/package.yml b/.github/workflows/main.yml similarity index 67% rename from .github/workflows/package.yml rename to .github/workflows/main.yml index 1a7aa402..08d8ef3a 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/main.yml @@ -1,34 +1,42 @@ -name: Package and Release +name: Main pipeline on: push: branches: - main + permissions: contents: write + jobs: - run-tests: - name: Release package + setup: + name: Setup environment and install dependencies runs-on: ubuntu-latest - steps: - name: Checkout code uses: actions/checkout@v2 - - name: Install modular run: | curl -s https://get.modular.com | sh - modular auth examples - - name: Install Mojo run: modular install mojo - - name: Add to PATH run: echo "/home/runner/.modular/pkg/packages.modular.com_mojo/bin" >> $GITHUB_PATH - - - name: Create package + test: + name: Run tests + runs-on: ubuntu-latest + needs: setup + steps: + - name: Run the test suite + run: mojo run_tests.mojo + package: + name: Create package + runs-on: ubuntu-latest + needs: setup + steps: + - name: Run the package command run: mojo package lightbug_http -o lightbug_http.mojopkg - - name: Upload package to release uses: svenstaro/upload-release-action@v2 with: diff --git a/.gitignore b/.gitignore index 86957465..e9cf84f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.📦 .DS_Store -.mojoenv \ No newline at end of file +.mojoenv +install_id \ No newline at end of file diff --git a/bench.mojo b/bench.mojo index c432bf09..4ae5dfad 100644 --- a/bench.mojo +++ b/bench.mojo @@ -2,19 +2,18 @@ import benchmark from lightbug_http.sys.server import SysServer from lightbug_http.python.server import PythonServer from lightbug_http.service import TechEmpowerRouter -from lightbug_http.tests.utils import ( +from tests.utils import ( TestStruct, FakeResponder, new_fake_listener, FakeServer, getRequest, ) -from external.libc import __test_socket_client__ fn main(): try: - var server = SysServer() + var server = SysServer(tcp_keep_alive=True) var handler = TechEmpowerRouter() server.listen_and_serve("0.0.0.0:8080", handler) except e: @@ -22,12 +21,6 @@ fn main(): return -fn lightbug_benchmark_get_1req_per_conn(): - var req_report = benchmark.run[__test_socket_client__](1, 10000, 0, 3, 100) - print("Request: ") - req_report.print(benchmark.Unit.ns) - - fn lightbug_benchmark_server(): var server_report = benchmark.run[run_fake_server](max_iters=1) print("Server: ") @@ -54,9 +47,9 @@ fn run_fake_server(): fn init_test_and_set_a_copy() -> None: var test = TestStruct("a", "b") - var newtest = test.set_a_copy("c") + _ = test.set_a_copy("c") fn init_test_and_set_a_direct() -> None: var test = TestStruct("a", "b") - var newtest = test.set_a_direct("c") + _ = test.set_a_direct("c") diff --git a/client.mojo b/client.mojo index e84e7bd8..2208c4ed 100644 --- a/client.mojo +++ b/client.mojo @@ -5,6 +5,12 @@ from lightbug_http.sys.client import MojoClient fn test_request(inout client: MojoClient) raises -> None: var uri = URI("http://httpbin.org/status/404") + try: + uri.parse() + except e: + print("error parsing uri: " + e.__str__()) + + var request = HTTPRequest(uri) var response = client.do(request) @@ -17,11 +23,12 @@ fn test_request(inout client: MojoClient) raises -> None: # print parsed headers (only some are parsed for now) print("Content-Type:", String(response.header.content_type())) print("Content-Length", response.header.content_length()) - print("Connection:", response.header.connection_close()) print("Server:", String(response.header.server())) + print("Is connection set to connection-close? ", response.header.connection_close()) + # print body - print(String(response.get_body())) + print(String(response.get_body_bytes())) fn main() raises -> None: diff --git a/external/gojo/__init__.mojo b/external/gojo/__init__.mojo index 3e354b74..e69de29b 100644 --- a/external/gojo/__init__.mojo +++ b/external/gojo/__init__.mojo @@ -1 +0,0 @@ -# gojo, created by thastoasty, https://github.com/thatstoasty/gojo/ diff --git a/external/gojo/bufio/bufio.mojo b/external/gojo/bufio/bufio.mojo index 95386fad..b1d0dee5 100644 --- a/external/gojo/bufio/bufio.mojo +++ b/external/gojo/bufio/bufio.mojo @@ -1,13 +1,11 @@ -from math import max -from collections.optional import Optional -from ..io import traits as io -from ..builtins import copy, panic, WrappedError, Result -from ..builtins.bytes import Byte, index_byte +import ..io +from ..builtins import copy, panic +from ..builtins.bytes import UInt8, index_byte from ..strings import StringBuilder alias MIN_READ_BUFFER_SIZE = 16 alias MAX_CONSECUTIVE_EMPTY_READS = 100 -alias DEFAULT_BUF_SIZE = 4096 +alias DEFAULT_BUF_SIZE = 8200 alias ERR_INVALID_UNREAD_BYTE = "bufio: invalid use of unread_byte" alias ERR_INVALID_UNREAD_RUNE = "bufio: invalid use of unread_rune" @@ -18,44 +16,42 @@ alias ERR_NEGATIVE_WRITE = "bufio: writer returned negative count from write" # buffered input -struct Reader[R: io.Reader]( - Sized, io.Reader, io.ByteReader, io.ByteScanner, io.WriterTo -): +struct Reader[R: io.Reader](Sized, io.Reader, io.ByteReader, io.ByteScanner): """Implements buffering for an io.Reader object.""" - var buf: List[Byte] + var buf: List[UInt8] var reader: R # reader provided by the client var read_pos: Int var write_pos: Int # buf read and write positions var last_byte: Int # last byte read for unread_byte; -1 means invalid var last_rune_size: Int # size of last rune read for unread_rune; -1 means invalid - var err: Optional[WrappedError] + var err: Error fn __init__( inout self, owned reader: R, - buf: List[Byte] = List[Byte](capacity=DEFAULT_BUF_SIZE), + buf: List[UInt8] = List[UInt8](capacity=DEFAULT_BUF_SIZE), read_pos: Int = 0, write_pos: Int = 0, last_byte: Int = -1, last_rune_size: Int = -1, ): self.buf = buf - self.reader = reader ^ + self.reader = reader^ self.read_pos = read_pos self.write_pos = write_pos self.last_byte = last_byte self.last_rune_size = last_rune_size - self.err = None + self.err = Error() fn __moveinit__(inout self, owned existing: Self): - self.buf = existing.buf ^ - self.reader = existing.reader ^ + self.buf = existing.buf^ + self.reader = existing.reader^ self.read_pos = existing.read_pos self.write_pos = existing.write_pos self.last_byte = existing.last_byte self.last_rune_size = existing.last_rune_size - self.err = existing.err ^ + self.err = existing.err^ # size returns the size of the underlying buffer in bytes. fn __len__(self) -> Int: @@ -74,14 +70,14 @@ struct Reader[R: io.Reader]( # return # # if self.buf == nil: - # # self.buf = make(List[Byte], DEFAULT_BUF_SIZE) + # # self.buf = make(List[UInt8], DEFAULT_BUF_SIZE) # self.reset(self.buf, r) - fn reset[R: io.Reader](inout self, buf: List[Byte], owned reader: R): + fn reset(inout self, buf: List[UInt8], owned reader: R): self = Reader[R]( buf=buf, - reader=reader ^, + reader=reader^, last_byte=-1, last_rune_size=-1, ) @@ -96,8 +92,8 @@ struct Reader[R: io.Reader]( self.write_pos -= self.read_pos self.read_pos = 0 - # Compares to the length of the entire List[Byte] object, including 0 initialized positions. - # IE. var b = List[Byte](capacity=4096), then trying to write at b[4096] and onwards will fail. + # Compares to the length of the entire List[UInt8] object, including 0 initialized positions. + # IE. var b = List[UInt8](capacity=8200), then trying to write at b[8200] and onwards will fail. if self.write_pos >= self.buf.capacity: panic("bufio.Reader: tried to fill full buffer") @@ -105,16 +101,18 @@ struct Reader[R: io.Reader]( var i: Int = MAX_CONSECUTIVE_EMPTY_READS while i > 0: # TODO: Using temp until slicing can return a Reference - var temp = List[Byte](capacity=DEFAULT_BUF_SIZE) - var result = self.reader.read(temp) - var bytes_read = copy(self.buf, temp, self.write_pos) + var temp = List[UInt8](capacity=DEFAULT_BUF_SIZE) + var bytes_read: Int + var err: Error + bytes_read, err = self.reader.read(temp) if bytes_read < 0: panic(ERR_NEGATIVE_READ) + bytes_read = copy(self.buf, temp, self.write_pos) self.write_pos += bytes_read - if result.has_error(): - self.err = result.get_error() + if err: + self.err = err return if bytes_read > 0: @@ -122,18 +120,17 @@ struct Reader[R: io.Reader]( i -= 1 - self.err = WrappedError(io.ERR_NO_PROGRESS) + self.err = Error(io.ERR_NO_PROGRESS) - fn read_error(inout self) -> Optional[WrappedError]: + fn read_error(inout self) -> Error: if not self.err: - return None + return Error() - var err = self.err.value() - self.err = None + var err = self.err + self.err = Error() return err - # Peek - fn peek(inout self, number_of_bytes: Int) -> Result[List[Byte]]: + fn peek(inout self, number_of_bytes: Int) -> (List[UInt8], Error): """Returns the next n bytes without advancing the reader. The bytes stop being valid at the next read call. If Peek returns fewer than n bytes, it also returns an error explaining why the read is short. The error is @@ -146,34 +143,29 @@ struct Reader[R: io.Reader]( number_of_bytes: The number of bytes to peek. """ if number_of_bytes < 0: - return Result(List[Byte](), WrappedError(ERR_NEGATIVE_COUNT)) + return List[UInt8](), Error(ERR_NEGATIVE_COUNT) self.last_byte = -1 self.last_rune_size = -1 - while ( - self.write_pos - self.read_pos < number_of_bytes - and self.write_pos - self.read_pos < self.buf.capacity - ): + while self.write_pos - self.read_pos < number_of_bytes and self.write_pos - self.read_pos < self.buf.capacity: self.fill() # self.write_pos-self.read_pos < self.buf.capacity => buffer is not full if number_of_bytes > self.buf.capacity: - return Result( - self.buf[self.read_pos : self.write_pos], WrappedError(ERR_BUFFER_FULL) - ) + return self.buf[self.read_pos : self.write_pos], Error(ERR_BUFFER_FULL) # 0 <= n <= self.buf.capacity - var err: Optional[WrappedError] = None + var err = Error() var available_space = self.write_pos - self.read_pos if available_space < number_of_bytes: # not enough data in buffer err = self.read_error() if not err: - err = WrappedError(ERR_BUFFER_FULL) + err = Error(ERR_BUFFER_FULL) - return Result(self.buf[self.read_pos : self.read_pos + number_of_bytes], err) + return self.buf[self.read_pos : self.read_pos + number_of_bytes], err - fn discard(inout self, number_of_bytes: Int) -> Result[Int]: + fn discard(inout self, number_of_bytes: Int) -> (Int, Error): """Discard skips the next n bytes, returning the number of bytes discarded. If Discard skips fewer than n bytes, it also returns an error. @@ -181,10 +173,10 @@ struct Reader[R: io.Reader]( reading from the underlying io.Reader. """ if number_of_bytes < 0: - return Result(0, WrappedError(ERR_NEGATIVE_COUNT)) + return 0, Error(ERR_NEGATIVE_COUNT) if number_of_bytes == 0: - return Result(0, None) + return 0, Error() self.last_byte = -1 self.last_rune_size = -1 @@ -202,31 +194,32 @@ struct Reader[R: io.Reader]( self.read_pos += skip remain -= skip if remain == 0: - return number_of_bytes - - # Read reads data into dest. - # It returns the number of bytes read into dest. - # The bytes are taken from at most one Read on the underlying [Reader], - # hence n may be less than len(src). - # To read exactly len(src) bytes, use io.ReadFull(b, src). - # If the underlying [Reader] can return a non-zero count with io.EOF, - # then this Read method can do so as well; see the [io.Reader] docs. - fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + return number_of_bytes, Error() + + fn read(inout self, inout dest: List[UInt8]) -> (Int, Error): + """Reads data into dest. + It returns the number of bytes read into dest. + The bytes are taken from at most one Read on the underlying [Reader], + hence n may be less than len(src). + To read exactly len(src) bytes, use io.ReadFull(b, src). + If the underlying [Reader] can return a non-zero count with io.EOF, + then this Read method can do so as well; see the [io.Reader] docs.""" var space_available = dest.capacity - len(dest) if space_available == 0: if self.buffered() > 0: - return Result(0, None) - return Result(0, self.read_error()) + return 0, Error() + return 0, self.read_error() var bytes_read: Int = 0 if self.read_pos == self.write_pos: if space_available >= len(self.buf): # Large read, empty buffer. # Read directly into dest to avoid copy. - var result = self.reader.read(dest) + var bytes_read: Int + var err: Error + bytes_read, err = self.reader.read(dest) - self.err = result.get_error() - bytes_read = result.value + self.err = err if bytes_read < 0: panic(ERR_NEGATIVE_READ) @@ -234,20 +227,21 @@ struct Reader[R: io.Reader]( self.last_byte = int(dest[bytes_read - 1]) self.last_rune_size = -1 - return Result(bytes_read, self.read_error()) + return bytes_read, self.read_error() # One read. # Do not use self.fill, which will loop. self.read_pos = 0 self.write_pos = 0 - var result = self.reader.read(self.buf) + var bytes_read: Int + var err: Error + bytes_read, err = self.reader.read(self.buf) - bytes_read = result.value if bytes_read < 0: panic(ERR_NEGATIVE_READ) if bytes_read == 0: - return Result(0, self.read_error()) + return 0, self.read_error() self.write_pos += bytes_read @@ -258,23 +252,22 @@ struct Reader[R: io.Reader]( self.read_pos += bytes_read self.last_byte = int(self.buf[self.read_pos - 1]) self.last_rune_size = -1 - return Result(bytes_read, None) + return bytes_read, Error() - fn read_byte(inout self) -> Result[Byte]: - """Reads and returns a single byte from the internal buffer. If no byte is available, returns an error. - """ + fn read_byte(inout self) -> (UInt8, Error): + """Reads and returns a single byte from the internal buffer. If no byte is available, returns an error.""" self.last_rune_size = -1 while self.read_pos == self.write_pos: if self.err: - return Result(Int8(0), self.read_error()) + return UInt8(0), self.read_error() self.fill() # buffer is empty var c = self.buf[self.read_pos] self.read_pos += 1 self.last_byte = int(c) - return c + return c, Error() - fn unread_byte(inout self) -> Optional[WrappedError]: + fn unread_byte(inout self) -> Error: """Unreads the last byte. Only the most recently read byte can be unread. unread_byte returns an error if the most recent method called on the @@ -282,7 +275,7 @@ struct Reader[R: io.Reader]( considered read operations. """ if self.last_byte < 0 or self.read_pos == 0 and self.write_pos > 0: - return WrappedError(ERR_INVALID_UNREAD_BYTE) + return Error(ERR_INVALID_UNREAD_BYTE) # self.read_pos > 0 or self.write_pos == 0 if self.read_pos > 0: @@ -294,7 +287,7 @@ struct Reader[R: io.Reader]( self.buf[self.read_pos] = self.last_byte self.last_byte = -1 self.last_rune_size = -1 - return None + return Error() # # read_rune reads a single UTF-8 encoded Unicode character and returns the # # rune and its size in bytes. If the encoded rune is invalid, it consumes one byte @@ -337,7 +330,7 @@ struct Reader[R: io.Reader]( """ return self.write_pos - self.read_pos - fn read_slice(inout self, delim: Int8) -> Result[List[Byte]]: + fn read_slice(inout self, delim: UInt8) -> (List[UInt8], Error): """Reads until the first occurrence of delim in the input, returning a slice pointing at the bytes in the buffer. It includes the first occurrence of the delimiter. The bytes stop being valid at the next read. @@ -353,11 +346,11 @@ struct Reader[R: io.Reader]( delim: The delimiter to search for. Returns: - The List[Byte] from the internal buffer. + The List[UInt8] from the internal buffer. """ - var err: Optional[WrappedError] = None + var err = Error() var s = 0 # search start index - var line: List[Byte] = List[Byte](capacity=DEFAULT_BUF_SIZE) + var line: List[UInt8] = List[UInt8](capacity=DEFAULT_BUF_SIZE) while True: # Search buffer. var i = index_byte(self.buf[self.read_pos + s : self.write_pos], delim) @@ -378,7 +371,7 @@ struct Reader[R: io.Reader]( if self.buffered() >= self.buf.capacity: self.read_pos = self.write_pos line = self.buf - err = WrappedError(ERR_BUFFER_FULL) + err = Error(ERR_BUFFER_FULL) break s = self.write_pos - self.read_pos # do not rescan area we scanned before @@ -390,9 +383,9 @@ struct Reader[R: io.Reader]( self.last_byte = int(line[i]) self.last_rune_size = -1 - return Result(line, err) + return line, err - fn read_line(inout self) raises -> (List[Byte], Bool): + fn read_line(inout self) raises -> (List[UInt8], Bool): """Low-level line-reading primitive. Most callers should use [Reader.read_bytes]('\n') or [Reader.read_string]('\n') instead or use a [Scanner]. @@ -410,11 +403,11 @@ struct Reader[R: io.Reader]( (possibly a character belonging to the line end) even if that byte is not part of the line returned by read_line. """ - var result = self.read_slice(ord("\n")) - var line = result.value - var err = result.get_error() + var line: List[UInt8] + var err: Error + line, err = self.read_slice(ord("\n")) - if err and str(err.value()) == ERR_BUFFER_FULL: + if err and str(err) == ERR_BUFFER_FULL: # Handle the case where "\r\n" straddles the buffer. if len(line) > 0 and line[len(line) - 1] == ord("\r"): # Put the '\r' back on buf and drop it from line. @@ -439,45 +432,38 @@ struct Reader[R: io.Reader]( return line, False - fn collect_fragments( - inout self, - delim: Int8, - inout frag: List[Byte], - inout full_buffers: List[List[Byte]], - inout total_len: Int, - ) -> Optional[WrappedError]: + fn collect_fragments(inout self, delim: UInt8) -> (List[List[UInt8]], List[UInt8], Int, Error): """Reads until the first occurrence of delim in the input. It returns (slice of full buffers, remaining bytes before delim, total number of bytes in the combined first two elements, error). Args: delim: The delimiter to search for. - frag: The fragment to collect. - full_buffers: The full buffers to collect. - total_len: The total length of the combined first two elements. """ # Use read_slice to look for delim, accumulating full buffers. - var err: Optional[WrappedError] = None + var err = Error() + var full_buffers = List[List[UInt8]]() + var total_len = 0 + var frag = List[UInt8](capacity=8200) while True: - var result = self.read_slice(delim) - frag = result.value - if not result.has_error(): + frag, err = self.read_slice(delim) + if not err: break - var read_slice_error = result.get_error() - if str(read_slice_error.value()) != ERR_BUFFER_FULL: + var read_slice_error = err + if str(read_slice_error) != ERR_BUFFER_FULL: err = read_slice_error break # Make a copy of the buffer. - var buf = List[Byte](frag) + var buf = List[UInt8](frag) full_buffers.append(buf) total_len += len(buf) total_len += len(frag) - return err + return full_buffers, frag, total_len, err - fn read_bytes(inout self, delim: Int8) -> Result[List[Byte]]: + fn read_bytes(inout self, delim: UInt8) -> (List[UInt8], Error): """Reads until the first occurrence of delim in the input, returning a slice containing the data up to and including the delimiter. If read_bytes encounters an error before finding a delimiter, @@ -490,15 +476,16 @@ struct Reader[R: io.Reader]( delim: The delimiter to search for. Returns: - The List[Byte] from the internal buffer. + The List[UInt8] from the internal buffer. """ - var full = List[List[Byte]]() - var frag = List[Byte](capacity=4096) - var n: Int = 0 - var err = self.collect_fragments(delim, frag, full, n) + var full: List[List[UInt8]] + var frag: List[UInt8] + var n: Int + var err: Error + full, frag, n, err = self.collect_fragments(delim) # Allocate new buffer to hold the full pieces and the fragment. - var buf = List[Byte](capacity=n) + var buf = List[UInt8](capacity=n) n = 0 # copy full pieces and fragment in. @@ -508,9 +495,9 @@ struct Reader[R: io.Reader]( _ = copy(buf, frag, n) - return Result(buf, err) + return buf, err - fn read_string(inout self, delim: Int8) -> Result[String]: + fn read_string(inout self, delim: UInt8) -> (String, Error): """Reads until the first occurrence of delim in the input, returning a string containing the data up to and including the delimiter. If read_string encounters an error before finding a delimiter, @@ -525,91 +512,85 @@ struct Reader[R: io.Reader]( Returns: The String from the internal buffer. """ - var full = List[List[Byte]]() - var frag = List[Byte]() - var n: Int = 0 - var err = self.collect_fragments(delim, frag, full, n) + var full: List[List[UInt8]] + var frag: List[UInt8] + var n: Int + var err: Error + full, frag, n, err = self.collect_fragments(delim) # Allocate new buffer to hold the full pieces and the fragment. - var buf = StringBuilder(n) + var buf = StringBuilder(capacity=n) # copy full pieces and fragment in. for i in range(len(full)): var buffer = full[i] - _ = buf.write(buffer) - - _ = buf.write(frag) - return Result(str(buf), err) + _ = buf.write(Span(buffer)) - fn write_to[W: io.Writer](inout self, inout writer: W) -> Result[Int64]: - """Writes the internal buffer to the writer. This may make multiple calls to the [Reader.Read] method of the underlying [Reader]. - If the underlying reader supports the [Reader.WriteTo] method, - this calls the underlying [Reader.WriteTo] without buffering. - write_to implements io.WriterTo. + _ = buf.write(Span(frag)) + return str(buf), err - Args: - writer: The writer to write to. + # fn write_to[W: io.Writer](inout self, inout writer: W) -> (Int64, Error): + # """Writes the internal buffer to the writer. This may make multiple calls to the [Reader.Read] method of the underlying [Reader]. + # If the underlying reader supports the [Reader.WriteTo] method, + # this calls the underlying [Reader.WriteTo] without buffering. + # write_to implements io.WriterTo. - Returns: - The number of bytes written. - """ - self.last_byte = -1 - self.last_rune_size = -1 + # Args: + # writer: The writer to write to. - var result = self.write_buf(writer) - var bytes_written = result.value - var error = result.get_error() - if error: - return Result(bytes_written, error) + # Returns: + # The number of bytes written. + # """ + # self.last_byte = -1 + # self.last_rune_size = -1 - # if r, ok := self.reader.(io.WriterTo); ok: - # m, err := r.WriteTo(w) - # n += m - # return n, err + # var bytes_written: Int64 + # var err: Error + # bytes_written, err = self.write_buf(writer) + # if err: + # return bytes_written, err - # if w, ok := w.(io.ReaderFrom); ok: - # m, err := w.read_from(self.reader) - # n += m - # return n, err + # # internal buffer not full, fill before writing to writer + # if (self.write_pos - self.read_pos) < self.buf.capacity: + # self.fill() - # internal buffer not full, fill before writing to writer - if (self.write_pos - self.read_pos) < self.buf.capacity: - self.fill() + # while self.read_pos < self.write_pos: + # # self.read_pos < self.write_pos => buffer is not empty + # var bw: Int64 + # var err: Error + # bw, err = self.write_buf(writer) + # bytes_written += bw - while self.read_pos < self.write_pos: - # self.read_pos < self.write_pos => buffer is not empty - var res = self.write_buf(writer) - var bw = res.value - bytes_written += bw + # self.fill() # buffer is empty - self.fill() # buffer is empty + # return bytes_written, Error() - return bytes_written + # fn write_buf[W: io.Writer](inout self, inout writer: W) -> (Int64, Error): + # """Writes the [Reader]'s buffer to the writer. - fn write_buf[W: io.Writer](inout self, inout writer: W) -> Result[Int64]: - """Writes the [Reader]'s buffer to the writer. + # Args: + # writer: The writer to write to. - Args: - writer: The writer to write to. + # Returns: + # The number of bytes written. + # """ + # # Nothing to write + # if self.read_pos == self.write_pos: + # return Int64(0), Error() - Returns: - The number of bytes written. - """ - # Nothing to write - if self.read_pos == self.write_pos: - return Result(Int64(0), None) + # # Write the buffer to the writer, if we hit EOF it's fine. That's not a failure condition. + # var bytes_written: Int + # var err: Error + # var buf_to_write = self.buf[self.read_pos : self.write_pos] + # bytes_written, err = writer.write(Span(buf_to_write)) + # if err: + # return Int64(bytes_written), err - # Write the buffer to the writer, if we hit EOF it's fine. That's not a failure condition. - var result = writer.write(self.buf[self.read_pos : self.write_pos]) - if result.error: - return Result(Int64(result.value), result.error) - - var bytes_written = result.value - if bytes_written < 0: - panic(ERR_NEGATIVE_WRITE) + # if bytes_written < 0: + # panic(ERR_NEGATIVE_WRITE) - self.read_pos += bytes_written - return Int64(bytes_written) + # self.read_pos += bytes_written + # return Int64(bytes_written), Error() # fn new_reader_size[R: io.Reader](owned reader: R, size: Int) -> Reader[R]: @@ -630,7 +611,7 @@ struct Reader[R: io.Reader]( # # return b # var r = Reader(reader ^) -# r.reset(List[Byte](capacity=max(size, MIN_READ_BUFFER_SIZE)), reader ^) +# r.reset(List[UInt8](capacity=max(size, MIN_READ_BUFFER_SIZE)), reader ^) # return r @@ -648,9 +629,7 @@ struct Reader[R: io.Reader]( # buffered output # TODO: Reader and Writer maybe should not take ownership of the underlying reader/writer? Seems okay for now. -struct Writer[W: io.Writer]( - Sized, io.Writer, io.ByteWriter, io.StringWriter, io.ReaderFrom -): +struct Writer[W: io.Writer](Sized, io.Writer, io.ByteWriter, io.StringWriter): """Implements buffering for an [io.Writer] object. # If an error occurs writing to a [Writer], no more data will be # accepted and all subsequent writes, and [Writer.flush], will return the error. @@ -658,27 +637,27 @@ struct Writer[W: io.Writer]( # [Writer.flush] method to guarantee all data has been forwarded to # the underlying [io.Writer].""" - var buf: List[Byte] + var buf: List[UInt8] var bytes_written: Int var writer: W - var err: Optional[WrappedError] + var err: Error fn __init__( inout self, owned writer: W, - buf: List[Byte] = List[Byte](capacity=DEFAULT_BUF_SIZE), + buf: List[UInt8] = List[UInt8](capacity=DEFAULT_BUF_SIZE), bytes_written: Int = 0, ): self.buf = buf self.bytes_written = bytes_written - self.writer = writer ^ - self.err = None + self.writer = writer^ + self.err = Error() fn __moveinit__(inout self, owned existing: Self): - self.buf = existing.buf ^ + self.buf = existing.buf^ self.bytes_written = existing.bytes_written - self.writer = existing.writer ^ - self.err = existing.err ^ + self.writer = existing.writer^ + self.err = existing.err^ fn __len__(self) -> Int: """Returns the size of the underlying buffer in bytes.""" @@ -701,46 +680,46 @@ struct Writer[W: io.Writer]( # return # if self.buf == nil: - # self.buf = make(List[Byte], DEFAULT_BUF_SIZE) + # self.buf = make(List[UInt8], DEFAULT_BUF_SIZE) - self.err = None + self.err = Error() self.bytes_written = 0 - self.writer = writer ^ + self.writer = writer^ - fn flush(inout self) -> Optional[WrappedError]: + fn flush(inout self) -> Error: """Writes any buffered data to the underlying [io.Writer].""" # Prior to attempting to flush, check if there's a pre-existing error or if there's nothing to flush. + var err = Error() if self.err: return self.err if self.bytes_written == 0: - return None + return err - var result = self.writer.write(self.buf[0 : self.bytes_written]) - var bytes_written = result.value - var error = result.get_error() + var bytes_written: Int = 0 + bytes_written, err = self.writer.write(Span(self.buf[0 : self.bytes_written])) # If the write was short, set a short write error and try to shift up the remaining bytes. - if bytes_written < self.bytes_written and not error: - error = WrappedError(io.ERR_SHORT_WRITE) + if bytes_written < self.bytes_written and not err: + err = Error(io.ERR_SHORT_WRITE) - if error: + if err: if bytes_written > 0 and bytes_written < self.bytes_written: _ = copy(self.buf, self.buf[bytes_written : self.bytes_written]) self.bytes_written -= bytes_written - self.err = error - return error + self.err = err + return err # Reset the buffer - self.buf = List[Byte](capacity=self.buf.capacity) + self.buf = List[UInt8](capacity=self.buf.capacity) self.bytes_written = 0 - return None + return err fn available(self) -> Int: """Returns how many bytes are unused in the buffer.""" return self.buf.capacity - len(self.buf) - fn available_buffer(self) raises -> List[Byte]: + fn available_buffer(self) raises -> List[UInt8]: """Returns an empty buffer with self.available() capacity. This buffer is intended to be appended to and passed to an immediately succeeding [Writer.write] call. @@ -759,7 +738,7 @@ struct Writer[W: io.Writer]( """ return self.bytes_written - fn write(inout self, src: List[Byte]) -> Result[Int]: + fn write(inout self, src: Span[UInt8]) -> (Int, Error): """Writes the contents of src into the buffer. It returns the number of bytes written. If nn < len(src), it also returns an error explaining @@ -773,14 +752,14 @@ struct Writer[W: io.Writer]( """ var total_bytes_written: Int = 0 var src_copy = src + var err = Error() while len(src_copy) > self.available() and not self.err: - var bytes_written: Int + var bytes_written: Int = 0 if self.buffered() == 0: # Large write, empty buffer. # write directly from p to avoid copy. - var result = self.writer.write(src_copy) - bytes_written = result.value - self.err = result.get_error() + bytes_written, err = self.writer.write(src_copy) + self.err = err else: bytes_written = copy(self.buf, src_copy, self.bytes_written) self.bytes_written += bytes_written @@ -790,30 +769,30 @@ struct Writer[W: io.Writer]( src_copy = src_copy[bytes_written : len(src_copy)] if self.err: - return Result(total_bytes_written, self.err) + return total_bytes_written, self.err var n = copy(self.buf, src_copy, self.bytes_written) self.bytes_written += n total_bytes_written += n - return total_bytes_written + return total_bytes_written, err - fn write_byte(inout self, src: Int8) -> Result[Int]: + fn write_byte(inout self, src: UInt8) -> (Int, Error): """Writes a single byte to the internal buffer. Args: src: The byte to write. """ if self.err: - return Result(0, self.err) + return 0, self.err # If buffer is full, flush to the underlying writer. var err = self.flush() if self.available() <= 0 and err: - return Result(0, self.err) + return 0, self.err self.buf.append(src) self.bytes_written += 1 - return 1 + return 1, Error() # # WriteRune writes a single Unicode code point, returning # # the number of bytes written and any error. @@ -843,7 +822,7 @@ struct Writer[W: io.Writer]( # self.bytes_written += size # return size, nil - fn write_string(inout self, src: String) -> Result[Int]: + fn write_string(inout self, src: String) -> (Int, Error): """Writes a string to the internal buffer. It returns the number of bytes written. If the count is less than len(s), it also returns an error explaining @@ -855,9 +834,9 @@ struct Writer[W: io.Writer]( Returns: The number of bytes written. """ - return self.write(src.as_bytes()) + return self.write(src.as_bytes_slice()) - fn read_from[R: io.Reader](inout self, inout reader: R) -> Result[Int64]: + fn read_from[R: io.Reader](inout self, inout reader: R) -> (Int64, Error): """Implements [io.ReaderFrom]. If the underlying writer supports the read_from method, this calls the underlying read_from. If there is buffered data and an underlying read_from, this fills @@ -870,48 +849,47 @@ struct Writer[W: io.Writer]( The number of bytes read. """ if self.err: - return Result(Int64(0), self.err) + return Int64(0), self.err var bytes_read: Int = 0 var total_bytes_written: Int64 = 0 - var err: Optional[WrappedError] = None + var err = Error() while True: if self.available() == 0: var err = self.flush() if err: - return Result(total_bytes_written, err) + return total_bytes_written, err var nr = 0 while nr < MAX_CONSECUTIVE_EMPTY_READS: # TODO: should really be using a slice that returns refs and not a copy. # Read into remaining unused space in the buffer. We need to reserve capacity for the slice otherwise read will never hit EOF. - var sl = self.buf[self.bytes_written:len(self.buf)] + var sl = self.buf[self.bytes_written : len(self.buf)] sl.reserve(self.buf.capacity) - var result = reader.read(sl) - bytes_read = result.value - err = result.get_error() - _ = copy(self.buf, sl, self.bytes_written) - + bytes_read, err = reader.read(sl) + if bytes_read > 0: + bytes_read = copy(self.buf, sl, self.bytes_written) + if bytes_read != 0 or err: break nr += 1 if nr == MAX_CONSECUTIVE_EMPTY_READS: - return Result(Int64(bytes_read), WrappedError(io.ERR_NO_PROGRESS)) + return Int64(bytes_read), Error(io.ERR_NO_PROGRESS) self.bytes_written += bytes_read total_bytes_written += Int64(bytes_read) if err: break - if err and str(err.value()) == io.EOF: + if err and str(err) == io.EOF: # If we filled the buffer exactly, flush preemptively. if self.available() == 0: err = self.flush() else: - err = None - - return Result(total_bytes_written, None) + err = Error() + + return total_bytes_written, Error() fn new_writer_size[W: io.Writer](owned writer: W, size: Int) -> Writer[W]: @@ -928,8 +906,8 @@ fn new_writer_size[W: io.Writer](owned writer: W, size: Int) -> Writer[W]: buf_size = DEFAULT_BUF_SIZE return Writer[W]( - buf=List[Byte](capacity=size), - writer=writer ^, + buf=List[UInt8](capacity=size), + writer=writer^, bytes_written=0, ) @@ -938,7 +916,7 @@ fn new_writer[W: io.Writer](owned writer: W) -> Writer[W]: """Returns a new [Writer] whose buffer has the default size. # If the argument io.Writer is already a [Writer] with large enough buffer size, # it returns the underlying [Writer].""" - return new_writer_size[W](writer ^, DEFAULT_BUF_SIZE) + return new_writer_size[W](writer^, DEFAULT_BUF_SIZE) # buffered input and output @@ -950,13 +928,11 @@ struct ReadWriter[R: io.Reader, W: io.Writer](): var writer: W fn __init__(inout self, owned reader: R, owned writer: W): - self.reader = reader ^ - self.writer = writer ^ + self.reader = reader^ + self.writer = writer^ # new_read_writer -fn new_read_writer[ - R: io.Reader, W: io.Writer -](owned reader: Reader, owned writer: Writer) -> ReadWriter[R, W]: +fn new_read_writer[R: io.Reader, W: io.Writer](owned reader: R, owned writer: W) -> ReadWriter[R, W]: """Allocates a new [ReadWriter] that dispatches to r and w.""" - return ReadWriter[R, W](reader ^, writer ^) + return ReadWriter[R, W](reader^, writer^) diff --git a/external/gojo/bufio/scan.mojo b/external/gojo/bufio/scan.mojo index 3c31f0bc..046cc87b 100644 --- a/external/gojo/bufio/scan.mojo +++ b/external/gojo/bufio/scan.mojo @@ -1,13 +1,12 @@ import math from collections import Optional import ..io -from ..builtins import copy, panic, WrappedError, Result +from ..builtins import copy, panic, Error from ..builtins.bytes import Byte, index_byte from .bufio import MAX_CONSECUTIVE_EMPTY_READS alias MAX_INT: Int = 2147483647 -alias Err = Optional[WrappedError] struct Scanner[R: io.Reader](): @@ -37,7 +36,7 @@ struct Scanner[R: io.Reader](): var empties: Int # Count of successive empty tokens. var scan_called: Bool # Scan has been called; buffer is in use. var done: Bool # Scan has finished. - var err: Err + var err: Error fn __init__( inout self, @@ -52,7 +51,7 @@ struct Scanner[R: io.Reader](): scan_called: Bool = False, done: Bool = False, ): - self.reader = reader ^ + self.reader = reader^ self.split = split self.max_token_size = max_token_size self.token = token @@ -62,7 +61,7 @@ struct Scanner[R: io.Reader](): self.empties = empties self.scan_called = scan_called self.done = done - self.err = Err() + self.err = Error() fn current_token_as_bytes(self) -> List[Byte]: """Returns the most recent token generated by a call to [Scanner.Scan]. @@ -98,15 +97,13 @@ struct Scanner[R: io.Reader](): if (self.end > self.start) or self.err: var advance: Int var token = List[Byte](capacity=io.BUFFER_SIZE) - var err: Optional[WrappedError] = None + var err = Error() var at_eof = False if self.err: at_eof = True - advance = self.split( - self.buf[self.start : self.end], at_eof, token, err - ) + advance, token, err = self.split(self.buf[self.start : self.end], at_eof) if err: - if str(err.value()) == ERR_FINAL_TOKEN: + if str(err) == str(ERR_FINAL_TOKEN): self.token = token self.done = True # When token is not nil, it means the scanning stops @@ -114,7 +111,7 @@ struct Scanner[R: io.Reader](): # should be True to indicate the existence of the token. return len(token) != 0 - self.set_err(err.value()) + self.set_err(err) return False if not self.advance(advance): @@ -128,9 +125,7 @@ struct Scanner[R: io.Reader](): # Returning tokens not advancing input at EOF. self.empties += 1 if self.empties > MAX_CONSECUTIVE_EMPTY_READS: - panic( - "bufio.Scan: too many empty tokens without progressing" - ) + panic("bufio.Scan: too many empty tokens without progressing") return True @@ -145,9 +140,7 @@ struct Scanner[R: io.Reader](): # Must read more data. # First, shift data to beginning of buffer if there's lots of empty space # or space is needed. - if self.start > 0 and ( - self.end == len(self.buf) or self.start > int(len(self.buf) / 2) - ): + if self.start > 0 and (self.end == len(self.buf) or self.start > int(len(self.buf) / 2)): _ = copy(self.buf, self.buf[self.start : self.end]) self.end -= self.start self.start = 0 @@ -155,10 +148,8 @@ struct Scanner[R: io.Reader](): # Is the buffer full? If so, resize. if self.end == len(self.buf): # Guarantee no overflow in the multiplication below. - if len(self.buf) >= self.max_token_size or len(self.buf) > int( - MAX_INT / 2 - ): - self.set_err(WrappedError(ERR_TOO_LONG)) + if len(self.buf) >= self.max_token_size or len(self.buf) > int(MAX_INT / 2): + self.set_err(Error(str(ERR_TOO_LONG))) return False var new_size = len(self.buf) * 2 @@ -166,7 +157,7 @@ struct Scanner[R: io.Reader](): new_size = START_BUF_SIZE # Make a new List[Byte] buffer and copy the elements in - new_size = math.min(new_size, self.max_token_size) + new_size = min(new_size, self.max_token_size) var new_buf = List[Byte](capacity=new_size) _ = copy(new_buf, self.buf[self.start : self.end]) self.buf = new_buf @@ -178,22 +169,20 @@ struct Scanner[R: io.Reader](): # be extra careful: Scanner is for safe, simple jobs. var loop = 0 while True: - var bytes_read: Int = 0 + var bytes_read: Int var sl = self.buf[self.end : len(self.buf)] - var error: Optional[WrappedError] = None + var err: Error # Catch any reader errors and set the internal error field to that err instead of bubbling it up. - var result = self.reader.read(sl) - bytes_read = result.value - error = result.get_error() + bytes_read, err = self.reader.read(sl) _ = copy(self.buf, sl, self.end) if bytes_read < 0 or len(self.buf) - self.end < bytes_read: - self.set_err(WrappedError(ERR_BAD_READ_COUNT)) + self.set_err(Error(str(ERR_BAD_READ_COUNT))) break self.end += bytes_read - if error: - self.set_err(error) + if err: + self.set_err(err) break if bytes_read > 0: @@ -202,17 +191,17 @@ struct Scanner[R: io.Reader](): loop += 1 if loop > MAX_CONSECUTIVE_EMPTY_READS: - self.set_err(WrappedError(io.ERR_NO_PROGRESS)) + self.set_err(Error(io.ERR_NO_PROGRESS)) break - fn set_err(inout self, err: Err): + fn set_err(inout self, err: Error): """Set the internal error field to the provided error. Args: err: The error to set. """ if self.err: - var value = String(self.err.value()) + var value = str(self.err) if value == "" or value == io.EOF: self.err = err else: @@ -228,11 +217,11 @@ struct Scanner[R: io.Reader](): True if the advance was legal, False otherwise. """ if n < 0: - self.set_err(WrappedError(ERR_NEGATIVE_ADVANCE)) + self.set_err(Error(str(ERR_NEGATIVE_ADVANCE))) return False if n > self.end - self.start: - self.set_err(WrappedError(ERR_ADVANCE_TOO_FAR)) + self.set_err(Error(str(ERR_ADVANCE_TOO_FAR))) return False self.start += n @@ -297,19 +286,12 @@ struct Scanner[R: io.Reader](): # The function is never called with an empty data slice unless at_eof # is True. If at_eof is True, however, data may be non-empty and, # as always, holds unprocessed text. -# TODO: For now, passing in token and err to be modified by the SplitFunction. This is because the Tuple return unpacking cannot handle a memory only type (List[Byte] and Err) -alias SplitFunction = fn ( - data: List[Byte], at_eof: Bool, inout token: List[Byte], inout err: Err -) raises -> Int +alias SplitFunction = fn (data: List[Byte], at_eof: Bool) -> (Int, List[Byte], Error) # # Errors returned by Scanner. alias ERR_TOO_LONG = Error("bufio.Scanner: token too long") -alias ERR_NEGATIVE_ADVANCE = Error( - "bufio.Scanner: SplitFunction returns negative advance count" -) -alias ERR_ADVANCE_TOO_FAR = Error( - "bufio.Scanner: SplitFunction returns advance count beyond input" -) +alias ERR_NEGATIVE_ADVANCE = Error("bufio.Scanner: SplitFunction returns negative advance count") +alias ERR_ADVANCE_TOO_FAR = Error("bufio.Scanner: SplitFunction returns advance count beyond input") alias ERR_BAD_READ_COUNT = Error("bufio.Scanner: Read returned impossible count") # ERR_FINAL_TOKEN is a special sentinel error value. It is Intended to be # returned by a split function to indicate that the scanning should stop @@ -329,25 +311,22 @@ alias ERR_FINAL_TOKEN = Error("final token") # The actual maximum token size may be smaller as the buffer # may need to include, for instance, a newline. alias MAX_SCAN_TOKEN_SIZE = 64 * 1024 -alias START_BUF_SIZE = 4096 # Size of initial allocation for buffer. +alias START_BUF_SIZE = 8200 # Size of initial allocation for buffer. fn new_scanner[R: io.Reader](owned reader: R) -> Scanner[R]: """Returns a new [Scanner] to read from r. The split function defaults to [scan_lines].""" - return Scanner(reader ^) + return Scanner(reader^) ###### split functions ###### -fn scan_bytes( - data: List[Byte], at_eof: Bool, inout token: List[Byte], inout err: Err -) raises -> Int: +fn scan_bytes(data: List[Byte], at_eof: Bool) -> (Int, List[Byte], Error): """Split function for a [Scanner] that returns each byte as a token.""" if at_eof and data.capacity == 0: - return 0 + return 0, List[Byte](), Error() - token = data[0:1] - return 1 + return 1, data[0:1], Error() # var errorRune = List[Byte](string(utf8.RuneError)) @@ -390,7 +369,7 @@ fn scan_bytes( # return 1, errorRune, nil -fn drop_carriage_return(data: List[Byte]) raises -> List[Byte]: +fn drop_carriage_return(data: List[Byte]) -> List[Byte]: """Drops a terminal \r from the data. Args: @@ -407,9 +386,7 @@ fn drop_carriage_return(data: List[Byte]) raises -> List[Byte]: # TODO: Doing modification of token and err in these split functions, so we don't have to return any memory only types as part of the return tuple. -fn scan_lines( - data: List[Byte], at_eof: Bool, inout token: List[Byte], inout err: Err -) raises -> Int: +fn scan_lines(data: List[Byte], at_eof: Bool) -> (Int, List[Byte], Error): """Split function for a [Scanner] that returns each line of text, stripped of any trailing end-of-line marker. The returned line may be empty. The end-of-line marker is one optional carriage return followed @@ -419,30 +396,26 @@ fn scan_lines( Args: data: The data to split. at_eof: Whether the data is at the end of the file. - token: The token to return. - err: The error to return. Returns: The number of bytes to advance the input. """ if at_eof and data.capacity == 0: - return 0 + return 0, List[Byte](), Error() var i = index_byte(data, ord("\n")) - token = drop_carriage_return(data[0:i]) if i >= 0: # We have a full newline-terminated line. - return i + 1 + return i + 1, drop_carriage_return(data[0:i]), Error() # If we're at EOF, we have a final, non-terminated line. Return it. - token = drop_carriage_return(data) # if at_eof: - return data.capacity + return data.capacity, drop_carriage_return(data), Error() # Request more data. # return 0 -fn is_space(r: Int8) -> Bool: +fn is_space(r: UInt8) -> Bool: alias ALL_WHITESPACES: String = " \t\n\r\x0b\f" if chr(int(r)) in ALL_WHITESPACES: return True @@ -450,9 +423,7 @@ fn is_space(r: Int8) -> Bool: # TODO: Handle runes and utf8 decoding. For now, just assuming single byte length. -fn scan_words( - data: List[Byte], at_eof: Bool, inout token: List[Byte], inout err: Err -) raises -> Int: +fn scan_words(data: List[Byte], at_eof: Bool) -> (Int, List[Byte], Error): """Split function for a [Scanner] that returns each space-separated word of text, with surrounding spaces deleted. It will never return an empty string. The definition of space is set by @@ -475,15 +446,13 @@ fn scan_words( while i < data.capacity: width = len(data[i]) if is_space(data[i]): - token = data[start:i] - return i + width + return i + width, data[start:i], Error() i += width # If we're at EOF, we have a final, non-empty, non-terminated word. Return it. if at_eof and data.capacity > start: - token = data[start:] - return data.capacity + return data.capacity, data[start:], Error() # Request more data. - return start + return start, List[Byte](), Error() diff --git a/external/gojo/builtins/__init__.mojo b/external/gojo/builtins/__init__.mojo index 3168c9dc..3d1e11aa 100644 --- a/external/gojo/builtins/__init__.mojo +++ b/external/gojo/builtins/__init__.mojo @@ -1,5 +1,6 @@ from .bytes import Byte, index_byte, has_suffix, has_prefix, to_string from .list import equals -from .result import Result, WrappedError from .attributes import cap, copy from .errors import exit, panic + +alias Rune = Int32 diff --git a/external/gojo/builtins/attributes.mojo b/external/gojo/builtins/attributes.mojo index 86417b45..2bf21747 100644 --- a/external/gojo/builtins/attributes.mojo +++ b/external/gojo/builtins/attributes.mojo @@ -1,6 +1,4 @@ -fn copy[ - T: CollectionElement -](inout target: List[T], source: List[T], start: Int = 0) -> Int: +fn copy[T: CollectionElement](inout target: List[T], source: List[T], start: Int = 0) -> Int: """Copies the contents of source into target at the same index. Returns the number of bytes copied. Added a start parameter to specify the index to start copying into. @@ -24,6 +22,27 @@ fn copy[ return count +fn copy[T: CollectionElement](inout target: Span[T, True], source: Span[T], start: Int = 0) -> Int: + """Copies the contents of source into target at the same index. Returns the number of bytes copied. + Added a start parameter to specify the index to start copying into. + + Args: + target: The buffer to copy into. + source: The buffer to copy from. + start: The index to start copying into. + + Returns: + The number of bytes copied. + """ + var count = 0 + + for i in range(len(source)): + target[i + start] = source[i] + count += 1 + + return count + + fn cap[T: CollectionElement](iterable: List[T]) -> Int: """Returns the capacity of the List. diff --git a/external/gojo/builtins/bytes.mojo b/external/gojo/builtins/bytes.mojo index 2d72ee49..d8ba4066 100644 --- a/external/gojo/builtins/bytes.mojo +++ b/external/gojo/builtins/bytes.mojo @@ -1,7 +1,4 @@ -from .list import equals - - -alias Byte = Int8 +alias Byte = UInt8 fn has_prefix(bytes: List[Byte], prefix: List[Byte]) -> Bool: @@ -35,21 +32,20 @@ fn has_suffix(bytes: List[Byte], suffix: List[Byte]) -> Bool: fn index_byte(bytes: List[Byte], delim: Byte) -> Int: - """Return the index of the first occurrence of the byte delim. + """Return the index of the first occurrence of the byte delim. - Args: - bytes: The List[Byte] struct to search. - delim: The byte to search for. + Args: + bytes: The List[Byte] struct to search. + delim: The byte to search for. - Returns: - The index of the first occurrence of the byte delim. - """ - var i = 0 - for i in range(len(bytes)): - if bytes[i] == delim: - return i + Returns: + The index of the first occurrence of the byte delim. + """ + for i in range(len(bytes)): + if bytes[i] == delim: + return i - return -1 + return -1 fn to_string(bytes: List[Byte]) -> String: diff --git a/external/gojo/builtins/errors.mojo b/external/gojo/builtins/errors.mojo index 6488453e..19a0bd10 100644 --- a/external/gojo/builtins/errors.mojo +++ b/external/gojo/builtins/errors.mojo @@ -1,9 +1,4 @@ -fn exit(code: Int): - """Exits the program with the given exit code via libc. - TODO: Using this in the meantime until Mojo has a built in way to panic/exit. - """ - var status = external_call["exit", Int, Int](code) - +from sys import exit fn panic[T: Stringable](message: T, code: Int = 1): """Panics the program with the given message and exit code. diff --git a/external/gojo/builtins/list.mojo b/external/gojo/builtins/list.mojo deleted file mode 100644 index ddbfdbb9..00000000 --- a/external/gojo/builtins/list.mojo +++ /dev/null @@ -1,133 +0,0 @@ -fn equals(left: List[Int8], right: List[Int8]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[UInt8], right: List[UInt8]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Int16], right: List[Int16]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[UInt16], right: List[UInt16]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Int32], right: List[Int32]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[UInt32], right: List[UInt32]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Int64], right: List[Int64]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[UInt64], right: List[UInt64]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Int], right: List[Int]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Float16], right: List[Float16]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Float32], right: List[Float32]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Float64], right: List[Float64]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[String], right: List[String]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[StringLiteral], right: List[StringLiteral]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True - - -fn equals(left: List[Bool], right: List[Bool]) -> Bool: - if len(left) != len(right): - return False - for i in range(len(left)): - if left[i] != right[i]: - return False - return True \ No newline at end of file diff --git a/external/gojo/builtins/result.mojo b/external/gojo/builtins/result.mojo deleted file mode 100644 index 38db539f..00000000 --- a/external/gojo/builtins/result.mojo +++ /dev/null @@ -1,51 +0,0 @@ -from collections.optional import Optional - - -@value -struct WrappedError(CollectionElement, Stringable): - """Wrapped Error struct is just to enable the use of optional Errors.""" - - var error: Error - - fn __init__(inout self, error: Error): - self.error = error - - fn __init__[T: Stringable](inout self, message: T): - self.error = Error(message) - - fn __str__(self) -> String: - return str(self.error) - - -alias ValuePredicateFn = fn[T: CollectionElement] (value: T) -> Bool -alias ErrorPredicateFn = fn (error: Error) -> Bool - - -@value -struct Result[T: CollectionElement](): - var value: T - var error: Optional[WrappedError] - - fn __init__( - inout self, - value: T, - error: Optional[WrappedError] = None, - ): - self.value = value - self.error = error - - fn has_error(self) -> Bool: - if self.error: - return True - return False - - fn has_error_and(self, f: ErrorPredicateFn) -> Bool: - if self.error: - return f(self.error.value().error) - return False - - fn get_error(self) -> Optional[WrappedError]: - return self.error - - fn unwrap_error(self) -> WrappedError: - return self.error.value().error diff --git a/external/gojo/bytes/buffer.mojo b/external/gojo/bytes/buffer.mojo index 8d2eafd4..33c66182 100644 --- a/external/gojo/bytes/buffer.mojo +++ b/external/gojo/bytes/buffer.mojo @@ -1,16 +1,5 @@ -from collections.optional import Optional -from ..io import ( - Reader, - Writer, - ReadWriter, - ByteReader, - ByteWriter, - WriterTo, - StringWriter, - ReaderFrom, - BUFFER_SIZE, -) -from ..builtins import cap, copy, Byte, Result, WrappedError, panic, index_byte +import ..io +from ..builtins import cap, copy, Byte, panic, index_byte alias Rune = Int32 @@ -46,17 +35,19 @@ alias ERR_NEGATIVE_READ = "buffer.Buffer: reader returned negative count from re alias ERR_SHORT_WRITE = "short write" +# TODO: Removed read_from and write_to for now. Until the span arg trait issue is resolved. +# https://github.com/modularml/mojo/issues/2917 @value struct Buffer( Copyable, Stringable, Sized, - ReadWriter, - StringWriter, - ByteReader, - ByteWriter, - WriterTo, - ReaderFrom, + io.ReadWriter, + io.StringWriter, + io.ByteReader, + io.ByteWriter, + # WriterTo, + # ReaderFrom, ): """A Buffer is a variable-sized buffer of bytes with [Buffer.read] and [Buffer.write] methods. The zero value for Buffer is an empty buffer ready to use. @@ -217,7 +208,7 @@ struct Buffer( var m = self.grow(n) self.buf = self.buf[:m] - fn write(inout self, src: List[Byte]) -> Result[Int]: + fn write(inout self, src: Span[Byte]) -> (Int, Error): """Appends the contents of p to the buffer, growing the buffer as needed. The return value n is the length of p; err is always nil. If the buffer becomes too large, write will panic with [ERR_TOO_LARGE]. @@ -236,9 +227,9 @@ struct Buffer( write_at = self.grow(len(src)) var bytes_written = copy(self.buf, src, write_at) - return Result(bytes_written, None) + return bytes_written, Error() - fn write_string(inout self, src: String) -> Result[Int]: + fn write_string(inout self, src: String) -> (Int, Error): """Appends the contents of s to the buffer, growing the buffer as needed. The return value n is the length of s; err is always nil. If the buffer becomes too large, write_string will panic with [ERR_TOO_LARGE]. @@ -256,38 +247,40 @@ struct Buffer( # if not ok: # m = self.grow(len(src)) # var b = self.buf[m:] - return self.write(src.as_bytes()) + return self.write(src.as_bytes_slice()) - fn read_from[R: Reader](inout self, inout reader: R) -> Result[Int64]: - """Reads data from r until EOF and appends it to the buffer, growing - the buffer as needed. The return value n is the number of bytes read. Any - error except io.EOF encountered during the read is also returned. If the - buffer becomes too large, read_from will panic with [ERR_TOO_LARGE]. + # fn read_from[R: Reader](inout self, inout reader: R) -> (Int64, Error): + # """Reads data from r until EOF and appends it to the buffer, growing + # the buffer as needed. The return value n is the number of bytes read. Any + # error except io.EOF encountered during the read is also returned. If the + # buffer becomes too large, read_from will panic with [ERR_TOO_LARGE]. - Args: - reader: The reader to read from. + # Args: + # reader: The reader to read from. - Returns: - The number of bytes read from the reader. - """ - self.last_read = OP_INVALID - var total_bytes_read: Int64 = 0 - while True: - _ = self.grow(MIN_READ) + # Returns: + # The number of bytes read from the reader. + # """ + # self.last_read = OP_INVALID + # var total_bytes_read: Int64 = 0 + # while True: + # _ = self.grow(MIN_READ) - var result = reader.read(self.buf) - var bytes_read = result.value - if bytes_read < 0: - panic(ERR_NEGATIVE_READ) + # var span = Span(self.buf) + # var bytes_read: Int + # var err: Error + # bytes_read, err = reader.read(span) + # if bytes_read < 0: + # panic(ERR_NEGATIVE_READ) - total_bytes_read += bytes_read + # total_bytes_read += bytes_read - if result.has_error(): - var error = result.get_error() - if String(error.value()) == io.EOF: - return Result(total_bytes_read, None) + # var err_message = str(err) + # if err_message != "": + # if err_message == io.EOF: + # return total_bytes_read, Error() - return Result(total_bytes_read, error) + # return total_bytes_read, err fn grow_slice(self, inout b: List[Byte], n: Int) -> List[Byte]: """Grows b by n, preserving the original content of self. @@ -318,46 +311,47 @@ struct Buffer( # b._vector.reserve(c) return resized_buffer[: b.capacity] - fn write_to[W: Writer](inout self, inout writer: W) -> Result[Int64]: - """Writes data to w until the buffer is drained or an error occurs. - The return value n is the number of bytes written; it always fits into an - Int, but it is int64 to match the io.WriterTo trait. Any error - encountered during the write is also returned. - - Args: - writer: The writer to write to. - - Returns: - The number of bytes written to the writer. - """ - self.last_read = OP_INVALID - var bytes_to_write = len(self.buf) - var total_bytes_written: Int64 = 0 - - if bytes_to_write > 0: - # TODO: Replace usage of this intermeidate slice when normal slicing, once slice references work. - var sl = self.buf[self.off : bytes_to_write] - var result = writer.write(sl) - var bytes_written = result.value - if bytes_written > bytes_to_write: - panic("bytes.Buffer.write_to: invalid write count") - - self.off += bytes_written - total_bytes_written = Int64(bytes_written) + # fn write_to[W: Writer](inout self, inout writer: W) -> (Int64, Error): + # """Writes data to w until the buffer is drained or an error occurs. + # The return value n is the number of bytes written; it always fits into an + # Int, but it is int64 to match the io.WriterTo trait. Any error + # encountered during the write is also returned. - if result.has_error(): - var error = result.get_error() - return Result(total_bytes_written, error) + # Args: + # writer: The writer to write to. - # all bytes should have been written, by definition of write method in io.Writer - if bytes_written != bytes_to_write: - return Result(total_bytes_written, WrappedError(ERR_SHORT_WRITE)) - - # Buffer is now empty; reset. - self.reset() - return Result(total_bytes_written, None) - - fn write_byte(inout self, byte: Byte) -> Result[Int]: + # Returns: + # The number of bytes written to the writer. + # """ + # self.last_read = OP_INVALID + # var bytes_to_write = len(self.buf) + # var total_bytes_written: Int64 = 0 + + # if bytes_to_write > 0: + # # TODO: Replace usage of this intermeidate slice when normal slicing, once slice references work. + # var sl = Span(self.buf[self.off : bytes_to_write]) + # var bytes_written: Int + # var err: Error + # bytes_written, err = writer.write(sl) + # if bytes_written > bytes_to_write: + # panic("bytes.Buffer.write_to: invalid write count") + + # self.off += bytes_written + # total_bytes_written = Int64(bytes_written) + + # var err_message = str(err) + # if err_message != "": + # return total_bytes_written, err + + # # all bytes should have been written, by definition of write method in io.Writer + # if bytes_written != bytes_to_write: + # return total_bytes_written, Error(ERR_SHORT_WRITE) + + # # Buffer is now empty; reset. + # self.reset() + # return total_bytes_written, Error() + + fn write_byte(inout self, byte: Byte) -> (Int, Error): """Appends the byte c to the buffer, growing the buffer as needed. The returned error is always nil, but is included to match [bufio.Writer]'s write_byte. If the buffer becomes too large, write_byte will panic with @@ -377,7 +371,7 @@ struct Buffer( write_at = self.grow(1) _ = copy(self.buf, List[Byte](byte), write_at) - return Result(write_at, None) + return write_at, Error() # fn write_rune(inout self, r: Rune) -> Int: # """Appends the UTF-8 encoding of Unicode code point r to the @@ -400,7 +394,7 @@ struct Buffer( # self.buf = utf8.AppendRune(self.buf[:write_at], r) # return len(self.buf) - write_at - fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): """Reads the next len(dest) bytes from the buffer or until the buffer is drained. The return value n is the number of bytes read. If the buffer has no data to return, err is io.EOF (unless len(dest) is zero); @@ -417,15 +411,15 @@ struct Buffer( # Buffer is empty, reset to recover space. self.reset() if dest.capacity == 0: - return Result(0, None) - return Result(0, WrappedError(io.EOF)) + return 0, Error() + return 0, Error(io.EOF) var bytes_read = copy(dest, self.buf[self.off : len(self.buf)]) self.off += bytes_read if bytes_read > 0: self.last_read = OP_READ - return Result(bytes_read, None) + return bytes_read, Error() fn next(inout self, number_of_bytes: Int) raises -> List[Byte]: """Returns a slice containing the next n bytes from the buffer, @@ -452,20 +446,20 @@ struct Buffer( return data - fn read_byte(inout self) -> Result[Byte]: + fn read_byte(inout self) -> (Byte, Error): """Reads and returns the next byte from the buffer. If no byte is available, it returns error io.EOF. """ if self.empty(): # Buffer is empty, reset to recover space. self.reset() - return Result(Byte(0), WrappedError(io.EOF)) + return Byte(0), Error(io.EOF) var byte = self.buf[self.off] self.off += 1 self.last_read = OP_READ - return byte + return byte, Error() # read_rune reads and returns the next UTF-8-encoded # Unicode code point from the buffer. @@ -507,25 +501,22 @@ struct Buffer( # var err_unread_byte = errors.New("buffer.Buffer: unread_byte: previous operation was not a successful read") - fn unread_byte(inout self) -> Optional[WrappedError]: + fn unread_byte(inout self) -> Error: """Unreads the last byte returned by the most recent successful read operation that read at least one byte. If a write has happened since the last read, if the last read returned an error, or if the read read zero bytes, unread_byte returns an error. """ if self.last_read == OP_INVALID: - return WrappedError( - "buffer.Buffer: unread_byte: previous operation was not a successful" - " read" - ) + return Error("buffer.Buffer: unread_byte: previous operation was not a successful read") self.last_read = OP_INVALID if self.off > 0: self.off -= 1 - return None + return Error() - fn read_bytes(inout self, delim: Byte) -> Result[List[Byte]]: + fn read_bytes(inout self, delim: Byte) -> (List[Byte], Error): """Reads until the first occurrence of delim in the input, returning a slice containing the data up to and including the delimiter. If read_bytes encounters an error before finding a delimiter, @@ -539,20 +530,19 @@ struct Buffer( Returns: A List[Byte] struct containing the data up to and including the delimiter. """ - var result = self.read_slice(delim) - var slice = result.value + var slice: List[Byte] + var err: Error + slice, err = self.read_slice(delim) # return a copy of slice. The buffer's backing array may # be overwritten by later calls. - var line = List[Byte](capacity=BUFFER_SIZE) + var line = List[Byte](capacity=io.BUFFER_SIZE) for i in range(len(slice)): line.append(slice[i]) - return line + return line, Error() - fn read_slice(inout self, delim: Byte) -> Result[List[Byte]]: + fn read_slice(inout self, delim: Byte) -> (List[Byte], Error): """Like read_bytes but returns a reference to internal buffer data. - TODO: not returning a reference yet. Also, this returns List[Byte] and Error in Go, - but we arent't returning Errors as values until Mojo tuple returns supports Memory Only types. Args: delim: The delimiter to read until. @@ -561,7 +551,7 @@ struct Buffer( A List[Byte] struct containing the data up to and including the delimiter. """ var at_eof = False - var i = index_byte(self.buf[self.off : len(self.buf)], (delim)) + var i = index_byte(self.buf[self.off : len(self.buf)], delim) var end = self.off + i + 1 if i < 0: @@ -573,11 +563,11 @@ struct Buffer( self.last_read = OP_READ if at_eof: - return Result(line, WrappedError(io.EOF)) + return line, Error(io.EOF) - return Result(line, None) + return line, Error() - fn read_string(inout self, delim: Byte) -> Result[String]: + fn read_string(inout self, delim: Byte) -> (String, Error): """Reads until the first occurrence of delim in the input, returning a string containing the data up to and including the delimiter. If read_string encounters an error before finding a delimiter, @@ -591,8 +581,11 @@ struct Buffer( Returns: A string containing the data up to and including the delimiter. """ - var result = self.read_slice(delim) - return Result(String(result.value), result.get_error()) + var slice: List[Byte] + var err: Error + slice, err = self.read_slice(delim) + slice.append(0) + return String(slice), err fn new_buffer() -> Buffer: @@ -606,8 +599,8 @@ fn new_buffer() -> Buffer: In most cases, new([Buffer]) (or just declaring a [Buffer] variable) is sufficient to initialize a [Buffer]. """ - var b = List[Byte](capacity=BUFFER_SIZE) - return Buffer(b ^) + var b = List[Byte](capacity=io.BUFFER_SIZE) + return Buffer(b^) fn new_buffer(owned buf: List[Byte]) -> Buffer: @@ -627,7 +620,7 @@ fn new_buffer(owned buf: List[Byte]) -> Buffer: Returns: A new [Buffer] initialized with the provided bytes. """ - return Buffer(buf ^) + return Buffer(buf^) fn new_buffer(owned s: String) -> Buffer: @@ -645,4 +638,4 @@ fn new_buffer(owned s: String) -> Buffer: A new [Buffer] initialized with the provided string. """ var bytes_buffer = List[Byte](s.as_bytes()) - return Buffer(bytes_buffer ^) + return Buffer(bytes_buffer^) diff --git a/external/gojo/bytes/reader.mojo b/external/gojo/bytes/reader.mojo index 7fe19454..0b91dcdc 100644 --- a/external/gojo/bytes/reader.mojo +++ b/external/gojo/bytes/reader.mojo @@ -1,5 +1,5 @@ from collections.optional import Optional -from ..builtins import cap, copy, Byte, Result, WrappedError, panic +from ..builtins import cap, copy, Byte, panic import ..io @@ -9,7 +9,7 @@ struct Reader( Sized, io.Reader, io.ReaderAt, - io.WriterTo, + # io.WriterTo, io.Seeker, io.ByteReader, io.ByteScanner, @@ -39,7 +39,7 @@ struct Reader( The result is unaffected by any method calls except [Reader.Reset].""" return len(self.buffer) - fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): """Reads from the internal buffer into the dest List[Byte] struct. Implements the [io.Reader] Interface. @@ -49,16 +49,16 @@ struct Reader( Returns: Int: The number of bytes read into dest.""" if self.index >= len(self.buffer): - return Result(0, WrappedError(io.EOF)) + return 0, Error(io.EOF) self.prev_rune = -1 var unread_bytes = self.buffer[int(self.index) : len(self.buffer)] var bytes_read = copy(dest, unread_bytes) self.index += bytes_read - return Result(bytes_read) + return bytes_read, Error() - fn read_at(self, inout dest: List[Byte], off: Int64) -> Result[Int]: + fn read_at(self, inout dest: List[Byte], off: Int64) -> (Int, Error): """Reads len(dest) bytes into dest beginning at byte offset off. Implements the [io.ReaderAt] Interface. @@ -71,40 +71,39 @@ struct Reader( """ # cannot modify state - see io.ReaderAt if off < 0: - return Result(0, WrappedError("bytes.Reader.read_at: negative offset")) + return 0, Error("bytes.Reader.read_at: negative offset") if off >= Int64(len(self.buffer)): - return Result(0, WrappedError(io.EOF)) + return 0, Error(io.EOF) var unread_bytes = self.buffer[int(off) : len(self.buffer)] var bytes_written = copy(dest, unread_bytes) if bytes_written < len(dest): - return Result(0, WrappedError(io.EOF)) + return 0, Error(io.EOF) - return bytes_written + return bytes_written, Error() - fn read_byte(inout self) -> Result[Byte]: - """Reads and returns a single byte from the internal buffer. Implements the [io.ByteReader] Interface. - """ + fn read_byte(inout self) -> (Byte, Error): + """Reads and returns a single byte from the internal buffer. Implements the [io.ByteReader] Interface.""" self.prev_rune = -1 if self.index >= len(self.buffer): - return Result(Int8(0), WrappedError(io.EOF)) + return UInt8(0), Error(io.EOF) var byte = self.buffer[int(self.index)] self.index += 1 - return byte + return byte, Error() - fn unread_byte(inout self) -> Optional[WrappedError]: + fn unread_byte(inout self) -> Error: """Unreads the last byte read by moving the read position back by one. Complements [Reader.read_byte] in implementing the [io.ByteScanner] Interface. """ if self.index <= 0: - return WrappedError("bytes.Reader.unread_byte: at beginning of slice") + return Error("bytes.Reader.unread_byte: at beginning of slice") self.prev_rune = -1 self.index -= 1 - return None + return Error() # # read_rune implements the [io.RuneReader] Interface. # fn read_rune(self) (ch rune, size Int, err error): @@ -133,7 +132,7 @@ struct Reader( # self.prev_rune = -1 # return nil - fn seek(inout self, offset: Int64, whence: Int) -> Result[Int64]: + fn seek(inout self, offset: Int64, whence: Int) -> (Int64, Error): """Moves the read position to the specified offset from the specified whence. Implements the [io.Seeker] Interface. @@ -154,38 +153,37 @@ struct Reader( elif whence == io.SEEK_END: position = len(self.buffer) + offset else: - return Result(Int64(0), WrappedError("bytes.Reader.seek: invalid whence")) + return Int64(0), Error("bytes.Reader.seek: invalid whence") if position < 0: - return Result( - Int64(0), WrappedError("bytes.Reader.seek: negative position") - ) + return Int64(0), Error("bytes.Reader.seek: negative position") self.index = position - return Result(position, None) + return position, Error() - fn write_to[W: io.Writer](inout self, inout writer: W) -> Result[Int64]: - """Writes data to w until the buffer is drained or an error occurs. - implements the [io.WriterTo] Interface. + # fn write_to[W: io.Writer](inout self, inout writer: W) -> (Int64, Error): + # """Writes data to w until the buffer is drained or an error occurs. + # implements the [io.WriterTo] Interface. - Args: - writer: The writer to write to. - """ - self.prev_rune = -1 - if self.index >= len(self.buffer): - return Result(Int64(0), None) + # Args: + # writer: The writer to write to. + # """ + # self.prev_rune = -1 + # if self.index >= len(self.buffer): + # return Int64(0), Error() - var bytes = self.buffer[int(self.index) : len(self.buffer)] - var result = writer.write(bytes) - var write_count = result.value - if write_count > len(bytes): - panic("bytes.Reader.write_to: invalid Write count") + # var bytes = Span(self.buffer[int(self.index) : len(self.buffer)]) + # var write_count: Int + # var err: Error + # write_count, err = writer.write(bytes) + # if write_count > len(bytes): + # panic("bytes.Reader.write_to: invalid Write count") - self.index += write_count - if write_count != len(bytes): - return Result(Int64(write_count), WrappedError(io.ERR_SHORT_WRITE)) + # self.index += write_count + # if write_count != len(bytes): + # return Int64(write_count), Error(io.ERR_SHORT_WRITE) - return Int64(write_count) + # return Int64(write_count), Error() fn reset(inout self, buffer: List[Byte]): """Resets the [Reader.Reader] to be reading from b. @@ -216,4 +214,3 @@ fn new_reader(buffer: String) -> Reader: """ return Reader(buffer.as_bytes(), 0, -1) - diff --git a/external/gojo/fmt/__init__.mojo b/external/gojo/fmt/__init__.mojo index fe5652b0..a4b04e30 100644 --- a/external/gojo/fmt/__init__.mojo +++ b/external/gojo/fmt/__init__.mojo @@ -1 +1 @@ -from .fmt import sprintf, printf +from .fmt import sprintf, printf, sprintf_str diff --git a/external/gojo/fmt/fmt.mojo b/external/gojo/fmt/fmt.mojo index 02c9617e..8997e50b 100644 --- a/external/gojo/fmt/fmt.mojo +++ b/external/gojo/fmt/fmt.mojo @@ -8,35 +8,41 @@ Boolean Integer %d base 10 +%q a single-quoted character literal. +%x base 16, with lower-case letters for a-f +%X base 16, with upper-case letters for A-F Floating-point and complex constituents: %f decimal point but no exponent, e.g. 123.456 String and slice of bytes (treated equivalently with these verbs): %s the uninterpreted bytes of the string or slice +%q a double-quoted string TODO: - Add support for more formatting options - Switch to buffered writing to avoid multiple string concatenations - Add support for width and precision formatting options +- Handle escaping for String's %q """ from utils.variant import Variant +from math import floor +from ..builtins import Byte - -alias Args = Variant[String, Int, Float64, Bool] +alias Args = Variant[String, Int, Float64, Bool, List[Byte]] fn replace_first(s: String, old: String, new: String) -> String: """Replace the first occurrence of a substring in a string. - Parameters: - s (str): The original string - old (str): The substring to be replaced - new (str): The new substring + Args: + s: The original string + old: The substring to be replaced + new: The new substring Returns: - String: The string with the first occurrence of the old substring replaced by the new one. + The string with the first occurrence of the old substring replaced by the new one. """ # Find the first occurrence of the old substring var index = s.find(old) @@ -49,57 +55,147 @@ fn replace_first(s: String, old: String, new: String) -> String: return s -fn format_string(s: String, arg: String) -> String: - return replace_first(s, String("%s"), arg) +fn find_first_verb(s: String, verbs: List[String]) -> String: + """Find the first occurrence of a verb in a string. + + Args: + s: The original string + verbs: The list of verbs to search for. + + Returns: + The verb to replace. + """ + var index = -1 + var verb: String = "" + + for v in verbs: + var i = s.find(v[]) + if i != -1 and (index == -1 or i < index): + index = i + verb = v[] + + return verb + + +alias BASE10_TO_BASE16 = List[String]("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f") + + +fn convert_base10_to_base16(value: Int) -> String: + """Converts a base 10 number to base 16. + + Args: + value: Base 10 number. + + Returns: + Base 16 number as a String. + """ + + var val: Float64 = 0.0 + var result: Float64 = value + var base16: String = "" + while result > 1: + var temp = result / 16 + var floor_result = floor(temp) + var remainder = temp - floor_result + result = floor_result + val = 16 * remainder + + base16 = BASE10_TO_BASE16[int(val)] + base16 + + return base16 + + +fn format_string(format: String, arg: String) -> String: + var verb = find_first_verb(format, List[String]("%s", "%q")) + var arg_to_place = arg + if verb == "%q": + arg_to_place = '"' + arg + '"' + + return replace_first(format, String("%s"), arg) + + +fn format_bytes(format: String, arg: List[Byte]) -> String: + var argument = arg + if argument[-1] != 0: + argument.append(0) + + return format_string(format, argument) -fn format_integer(s: String, arg: Int) -> String: - return replace_first(s, String("%d"), arg) +fn format_integer(format: String, arg: Int) -> String: + var verb = find_first_verb(format, List[String]("%x", "%X", "%d", "%q")) + var arg_to_place = str(arg) + if verb == "%x": + arg_to_place = str(convert_base10_to_base16(arg)).lower() + elif verb == "%X": + arg_to_place = str(convert_base10_to_base16(arg)).upper() + elif verb == "%q": + arg_to_place = "'" + str(arg) + "'" + return replace_first(format, verb, arg_to_place) -fn format_float(s: String, arg: Float64) -> String: - return replace_first(s, String("%f"), arg) +fn format_float(format: String, arg: Float64) -> String: + return replace_first(format, str("%f"), str(arg)) -fn format_boolean(s: String, arg: Bool) -> String: - var value: String = "" + +fn format_boolean(format: String, arg: Bool) -> String: + var value: String = "False" if arg: value = "True" - else: - value = "False" - return replace_first(s, String("%t"), value) + return replace_first(format, String("%t"), value) + + +# If the number of arguments does not match the number of format specifiers +alias BadArgCount = "(BAD ARG COUNT)" -fn sprintf(formatting: String, *args: Args) raises -> String: +fn sprintf(formatting: String, *args: Args) -> String: var text = formatting - var formatter_count = formatting.count("%") + var raw_percent_count = formatting.count("%%") * 2 + var formatter_count = formatting.count("%") - raw_percent_count - if formatter_count > len(args): - raise Error("Not enough arguments for format string") - elif formatter_count < len(args): - raise Error("Too many arguments for format string") + if formatter_count != len(args): + return BadArgCount for i in range(len(args)): var argument = args[i] if argument.isa[String](): - text = format_string(text, argument.get[String]()[]) + text = format_string(text, argument[String]) + elif argument.isa[List[Byte]](): + text = format_bytes(text, argument[List[Byte]]) elif argument.isa[Int](): - text = format_integer(text, argument.get[Int]()[]) + text = format_integer(text, argument[Int]) elif argument.isa[Float64](): - text = format_float(text, argument.get[Float64]()[]) + text = format_float(text, argument[Float64]) elif argument.isa[Bool](): - text = format_boolean(text, argument.get[Bool]()[]) - else: - raise Error("Unknown for argument #" + String(i)) + text = format_boolean(text, argument[Bool]) return text -fn printf(formatting: String, *args: Args) raises: +# TODO: temporary until we have arg packing. +fn sprintf_str(formatting: String, args: List[String]) raises -> String: var text = formatting var formatter_count = formatting.count("%") + if formatter_count > len(args): + raise Error("Not enough arguments for format string") + elif formatter_count < len(args): + raise Error("Too many arguments for format string") + + for i in range(len(args)): + text = format_string(text, args[i]) + + return text + + +fn printf(formatting: String, *args: Args) raises: + var text = formatting + var raw_percent_count = formatting.count("%%") * 2 + var formatter_count = formatting.count("%") - raw_percent_count + if formatter_count > len(args): raise Error("Not enough arguments for format string") elif formatter_count < len(args): @@ -108,14 +204,16 @@ fn printf(formatting: String, *args: Args) raises: for i in range(len(args)): var argument = args[i] if argument.isa[String](): - text = format_string(text, argument.get[String]()[]) + text = format_string(text, argument[String]) + elif argument.isa[List[Byte]](): + text = format_bytes(text, argument[List[Byte]]) elif argument.isa[Int](): - text = format_integer(text, argument.get[Int]()[]) + text = format_integer(text, argument[Int]) elif argument.isa[Float64](): - text = format_float(text, argument.get[Float64]()[]) + text = format_float(text, argument[Float64]) elif argument.isa[Bool](): - text = format_boolean(text, argument.get[Bool]()[]) + text = format_boolean(text, argument[Bool]) else: - raise Error("Unknown for argument #" + String(i)) + raise Error("Unknown for argument #" + str(i)) print(text) diff --git a/external/gojo/io/__init__.mojo b/external/gojo/io/__init__.mojo index 979bee71..74b8a521 100644 --- a/external/gojo/io/__init__.mojo +++ b/external/gojo/io/__init__.mojo @@ -32,3 +32,8 @@ from .traits import ( EOF, ) from .io import write_string, read_at_least, read_full, read_all, BUFFER_SIZE + + +alias i1 = __mlir_type.i1 +alias i1_1 = __mlir_attr.`1: i1` +alias i1_0 = __mlir_attr.`0: i1` diff --git a/external/gojo/io/io.mojo b/external/gojo/io/io.mojo index 185677b4..61477052 100644 --- a/external/gojo/io/io.mojo +++ b/external/gojo/io/io.mojo @@ -1,12 +1,11 @@ from collections.optional import Optional -from ..builtins import cap, copy, Byte, Result, WrappedError, panic +from ..builtins import cap, copy, Byte, panic from .traits import ERR_UNEXPECTED_EOF +alias BUFFER_SIZE = 8200 -alias BUFFER_SIZE = 4096 - -fn write_string[W: Writer](inout writer: W, string: String) -> Result[Int]: +fn write_string[W: Writer](inout writer: W, string: String) -> (Int, Error): """Writes the contents of the string s to w, which accepts a slice of bytes. If w implements [StringWriter], [StringWriter.write_string] is invoked directly. Otherwise, [Writer.write] is called exactly once. @@ -18,10 +17,10 @@ fn write_string[W: Writer](inout writer: W, string: String) -> Result[Int]: Returns: The number of bytes written and an error, if any. """ - return writer.write(string.as_bytes()) + return writer.write(string.as_bytes_slice()) -fn write_string[W: StringWriter](inout writer: W, string: String) -> Result[Int]: +fn write_string[W: StringWriter](inout writer: W, string: String) -> (Int, Error): """Writes the contents of the string s to w, which accepts a slice of bytes. If w implements [StringWriter], [StringWriter.write_string] is invoked directly. Otherwise, [Writer.write] is called exactly once. @@ -35,9 +34,7 @@ fn write_string[W: StringWriter](inout writer: W, string: String) -> Result[Int] return writer.write_string(string) -fn read_at_least[ - R: Reader -](inout reader: R, inout dest: List[Byte], min: Int) -> Result[Int]: +fn read_at_least[R: Reader](inout reader: R, inout dest: List[Byte], min: Int) -> (Int, Error): """Reads from r into buf until it has read at least min bytes. It returns the number of bytes copied and an error if fewer bytes were read. The error is EOF only if no bytes were read. @@ -54,26 +51,26 @@ fn read_at_least[ Returns: The number of bytes read.""" - var error: Optional[WrappedError] = None + var error = Error() if len(dest) < min: - return Result(0, WrappedError(io.ERR_SHORT_BUFFER)) + return 0, Error(io.ERR_SHORT_BUFFER) var total_bytes_read: Int = 0 while total_bytes_read < min and not error: - var result = reader.read(dest) - var bytes_read = result.value - var error = result.get_error() + var bytes_read: Int + bytes_read, error = reader.read(dest) total_bytes_read += bytes_read if total_bytes_read >= min: - error = None - elif total_bytes_read > 0 and str(error.value()): - error = WrappedError(ERR_UNEXPECTED_EOF) + error = Error() + + elif total_bytes_read > 0 and str(error): + error = Error(ERR_UNEXPECTED_EOF) - return Result(total_bytes_read, None) + return total_bytes_read, error -fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> Result[Int]: +fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> (Int, Error): """Reads exactly len(buf) bytes from r into buf. It returns the number of bytes copied and an error if fewer bytes were read. The error is EOF only if no bytes were read. @@ -135,7 +132,7 @@ fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> Result[Int]: # } -# fn copy_buffer[W: Writer, R: Reader](dst: W, src: R, buf: List[Byte]) raises -> Int64: +# fn copy_buffer[W: Writer, R: Reader](dst: W, src: R, buf: Span[Byte]) raises -> Int64: # """Actual implementation of copy and CopyBuffer. # if buf is nil, one is allocated. # """ @@ -155,11 +152,11 @@ fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> Result[Int]: # return written -# fn copy_buffer[W: Writer, R: ReaderWriteTo](dst: W, src: R, buf: List[Byte]) -> Int64: +# fn copy_buffer[W: Writer, R: ReaderWriteTo](dst: W, src: R, buf: Span[Byte]) -> Int64: # return src.write_to(dst) -# fn copy_buffer[W: WriterReadFrom, R: Reader](dst: W, src: R, buf: List[Byte]) -> Int64: +# fn copy_buffer[W: WriterReadFrom, R: Reader](dst: W, src: R, buf: Span[Byte]) -> Int64: # return dst.read_from(src) # # LimitReader returns a Reader that reads from r @@ -409,7 +406,7 @@ fn read_full[R: Reader](inout reader: R, inout dest: List[Byte]) -> Result[Int]: # } -fn read_all[R: Reader](inout reader: R) -> Result[List[Byte]]: +fn read_all[R: Reader](inout reader: R) -> (List[Byte], Error): """Reads from r until an error or EOF and returns the data it read. A successful call returns err == nil, not err == EOF. Because ReadAll is defined to read from src until EOF, it does not treat an EOF from Read @@ -421,17 +418,17 @@ fn read_all[R: Reader](inout reader: R) -> Result[List[Byte]]: Returns: The data read.""" var dest = List[Byte](capacity=BUFFER_SIZE) - var index: Int = 0 var at_eof: Bool = False while True: var temp = List[Byte](capacity=BUFFER_SIZE) - var result = reader.read(temp) - var bytes_read = result.value - var err = result.get_error() - if err: - if str(err.value()) != EOF: - return Result(dest, err) + var bytes_read: Int + var err: Error + bytes_read, err = reader.read(temp) + var err_message = str(err) + if err_message != "": + if err_message != EOF: + return dest, err at_eof = True @@ -442,4 +439,4 @@ fn read_all[R: Reader](inout reader: R) -> Result[List[Byte]]: dest.extend(temp) if at_eof: - return Result(dest, err.value()) + return dest, err diff --git a/external/gojo/io/traits.mojo b/external/gojo/io/traits.mojo index 7b9f4d6e..ff1e8e6d 100644 --- a/external/gojo/io/traits.mojo +++ b/external/gojo/io/traits.mojo @@ -1,5 +1,5 @@ from collections.optional import Optional -from ..builtins import Byte, Result, WrappedError +from ..builtins import Byte alias Rune = Int32 @@ -78,7 +78,7 @@ trait Reader(Movable): Implementations must not retain p.""" - fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): ... @@ -94,7 +94,7 @@ trait Writer(Movable): Implementations must not retain p. """ - fn write(inout self, src: List[Byte]) -> Result[Int]: + fn write(inout self, src: Span[Byte]) -> (Int, Error): ... @@ -106,7 +106,7 @@ trait Closer(Movable): Specific implementations may document their own behavior. """ - fn close(inout self) raises: + fn close(inout self) -> Error: ... @@ -129,7 +129,7 @@ trait Seeker(Movable): is implementation-dependent. """ - fn seek(inout self, offset: Int64, whence: Int) -> Result[Int64]: + fn seek(inout self, offset: Int64, whence: Int) -> (Int64, Error): ... @@ -174,7 +174,7 @@ trait ReaderFrom: The [copy] function uses [ReaderFrom] if available.""" - fn read_from[R: Reader](inout self, inout reader: R) -> Result[Int64]: + fn read_from[R: Reader](inout self, inout reader: R) -> (Int64, Error): ... @@ -191,7 +191,7 @@ trait WriterTo: The copy function uses WriterTo if available.""" - fn write_to[W: Writer](inout self, inout writer: W) -> Result[Int64]: + fn write_to[W: Writer](inout self, inout writer: W) -> (Int64, Error): ... @@ -227,7 +227,7 @@ trait ReaderAt: Implementations must not retain p.""" - fn read_at(self, inout dest: List[Byte], off: Int64) -> Result[Int]: + fn read_at(self, inout dest: List[Byte], off: Int64) -> (Int, Error): ... @@ -248,7 +248,7 @@ trait WriterAt: Implementations must not retain p.""" - fn write_at(self, src: List[Byte], off: Int64) -> Result[Int]: + fn write_at(self, src: Span[Byte], off: Int64) -> (Int, Error): ... @@ -263,7 +263,7 @@ trait ByteReader: processing. A [Reader] that does not implement ByteReader can be wrapped using bufio.NewReader to add this method.""" - fn read_byte(inout self) -> Result[Byte]: + fn read_byte(inout self) -> (Byte, Error): ... @@ -277,14 +277,14 @@ trait ByteScanner(ByteReader): last-unread byte), or (in implementations that support the [Seeker] trait) seek to one byte before the current offset.""" - fn unread_byte(inout self) -> Optional[WrappedError]: + fn unread_byte(inout self) -> Error: ... trait ByteWriter: """ByteWriter is the trait that wraps the write_byte method.""" - fn write_byte(inout self, byte: Byte) -> Result[Int]: + fn write_byte(inout self, byte: Byte) -> (Int, Error): ... @@ -316,5 +316,5 @@ trait RuneScanner(RuneReader): trait StringWriter: """StringWriter is the trait that wraps the WriteString method.""" - fn write_string(inout self, src: String) -> Result[Int]: + fn write_string(inout self, src: String) -> (Int, Error): ... diff --git a/external/gojo/net/__init__.mojo b/external/gojo/net/__init__.mojo new file mode 100644 index 00000000..25876739 --- /dev/null +++ b/external/gojo/net/__init__.mojo @@ -0,0 +1,4 @@ +"""Adapted from go's net package + +A good chunk of the leg work here came from the lightbug_http project! https://github.com/saviorand/lightbug_http/tree/main +""" diff --git a/external/gojo/net/address.mojo b/external/gojo/net/address.mojo new file mode 100644 index 00000000..9bf5a50a --- /dev/null +++ b/external/gojo/net/address.mojo @@ -0,0 +1,145 @@ +@value +struct NetworkType: + var value: String + + alias empty = NetworkType("") + alias tcp = NetworkType("tcp") + alias tcp4 = NetworkType("tcp4") + alias tcp6 = NetworkType("tcp6") + alias udp = NetworkType("udp") + alias udp4 = NetworkType("udp4") + alias udp6 = NetworkType("udp6") + alias ip = NetworkType("ip") + alias ip4 = NetworkType("ip4") + alias ip6 = NetworkType("ip6") + alias unix = NetworkType("unix") + + +trait Addr(CollectionElement, Stringable): + fn network(self) -> String: + """Name of the network (for example, "tcp", "udp").""" + ... + + +@value +struct TCPAddr(Addr): + """Addr struct representing a TCP address. + + Args: + ip: IP address. + port: Port number. + zone: IPv6 addressing zone. + """ + + var ip: String + var port: Int + var zone: String # IPv6 addressing zone + + fn __init__(inout self): + self.ip = String("127.0.0.1") + self.port = 8000 + self.zone = "" + + fn __init__(inout self, ip: String, port: Int): + self.ip = ip + self.port = port + self.zone = "" + + fn __str__(self) -> String: + if self.zone != "": + return join_host_port(str(self.ip) + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) + + fn network(self) -> String: + return NetworkType.tcp.value + + +fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: + var host: String = "" + var port: String = "" + var portnum: Int = 0 + if ( + network == NetworkType.tcp.value + or network == NetworkType.tcp4.value + or network == NetworkType.tcp6.value + or network == NetworkType.udp.value + or network == NetworkType.udp4.value + or network == NetworkType.udp6.value + ): + if address != "": + var host_port = split_host_port(address) + host = host_port.host + port = str(host_port.port) + portnum = atol(port.__str__()) + elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: + if address != "": + host = address + elif network == NetworkType.unix.value: + raise Error("Unix addresses not supported yet") + else: + raise Error("unsupported network type: " + network) + return TCPAddr(host, portnum) + + +alias missingPortError = Error("missing port in address") +alias tooManyColonsError = Error("too many colons in address") + + +struct HostPort(Stringable): + var host: String + var port: Int + + fn __init__(inout self, host: String, port: Int): + self.host = host + self.port = port + + fn __str__(self) -> String: + return join_host_port(self.host, str(self.port)) + + +fn join_host_port(host: String, port: String) -> String: + if host.find(":") != -1: # must be IPv6 literal + return "[" + host + "]:" + port + return host + ":" + port + + +fn split_host_port(hostport: String) raises -> HostPort: + var host: String = "" + var port: String = "" + var colon_index = hostport.rfind(":") + var j: Int = 0 + var k: Int = 0 + + if colon_index == -1: + raise missingPortError + if hostport[0] == "[": + var end_bracket_index = hostport.find("]") + if end_bracket_index == -1: + raise Error("missing ']' in address") + if end_bracket_index + 1 == len(hostport): + raise missingPortError + elif end_bracket_index + 1 == colon_index: + host = hostport[1:end_bracket_index] + j = 1 + k = end_bracket_index + 1 + else: + if hostport[end_bracket_index + 1] == ":": + raise tooManyColonsError + else: + raise missingPortError + else: + host = hostport[:colon_index] + if host.find(":") != -1: + raise tooManyColonsError + if hostport[j:].find("[") != -1: + raise Error("unexpected '[' in address") + if hostport[k:].find("]") != -1: + raise Error("unexpected ']' in address") + port = hostport[colon_index + 1 :] + + if port == "": + raise missingPortError + if host == "": + raise Error("missing host") + + return HostPort(host, atol(port)) diff --git a/external/gojo/net/dial.mojo b/external/gojo/net/dial.mojo new file mode 100644 index 00000000..f5719e67 --- /dev/null +++ b/external/gojo/net/dial.mojo @@ -0,0 +1,44 @@ +from .tcp import TCPAddr, TCPConnection, resolve_internet_addr +from .socket import Socket +from .address import split_host_port + + +@value +struct Dialer: + var local_address: TCPAddr + + fn dial(self, network: String, address: String) raises -> TCPConnection: + var tcp_addr = resolve_internet_addr(network, address) + var socket = Socket(local_address=self.local_address) + socket.connect(tcp_addr.ip, tcp_addr.port) + return TCPConnection(socket^) + + +fn dial_tcp(network: String, remote_address: TCPAddr) raises -> TCPConnection: + """Connects to the address on the named network. + + The network must be "tcp", "tcp4", or "tcp6". + Args: + network: The network type. + remote_address: The remote address to connect to. + + Returns: + The TCP connection. + """ + # TODO: Add conversion of domain name to ip address + return Dialer(remote_address).dial(network, remote_address.ip + ":" + str(remote_address.port)) + + +fn dial_tcp(network: String, remote_address: String) raises -> TCPConnection: + """Connects to the address on the named network. + + The network must be "tcp", "tcp4", or "tcp6". + Args: + network: The network type. + remote_address: The remote address to connect to. + + Returns: + The TCP connection. + """ + var address = split_host_port(remote_address) + return Dialer(TCPAddr(address.host, address.port)).dial(network, remote_address) diff --git a/external/gojo/net/fd.mojo b/external/gojo/net/fd.mojo new file mode 100644 index 00000000..6e4fc621 --- /dev/null +++ b/external/gojo/net/fd.mojo @@ -0,0 +1,77 @@ +from collections.optional import Optional +import ..io +from ..builtins import Byte +from ..syscall.file import close +from ..syscall.types import c_char +from ..syscall.net import ( + recv, + send, + strlen, +) + +alias O_RDWR = 0o2 + + +trait FileDescriptorBase(io.Reader, io.Writer, io.Closer): + ... + + +struct FileDescriptor(FileDescriptorBase): + var fd: Int + var is_closed: Bool + + # This takes ownership of a POSIX file descriptor. + fn __moveinit__(inout self, owned existing: Self): + self.fd = existing.fd + self.is_closed = existing.is_closed + + fn __init__(inout self, fd: Int): + self.fd = fd + self.is_closed = False + + fn __del__(owned self): + if not self.is_closed: + var err = self.close() + if err: + print(str(err)) + + fn close(inout self) -> Error: + """Mark the file descriptor as closed.""" + var close_status = close(self.fd) + if close_status == -1: + return Error("FileDescriptor.close: Failed to close socket") + + self.is_closed = True + return Error() + + fn dup(self) -> Self: + """Duplicate the file descriptor.""" + var new_fd = external_call["dup", Int, Int](self.fd) + return Self(new_fd) + + # TODO: Need faster approach to copying data from the file descriptor to the buffer. + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + """Receive data from the file descriptor and write it to the buffer provided.""" + var ptr = Pointer[UInt8]().alloc(dest.capacity) + var bytes_received = recv(self.fd, ptr, dest.capacity, 0) + if bytes_received == -1: + return 0, Error("Failed to receive message from socket.") + + var int8_ptr = ptr.bitcast[Int8]() + for i in range(bytes_received): + dest.append(int8_ptr[i]) + + if bytes_received < dest.capacity: + return bytes_received, Error(io.EOF) + + return bytes_received, Error() + + fn write(inout self, src: List[Byte]) -> (Int, Error): + """Write data from the buffer to the file descriptor.""" + var header_pointer = Pointer[Int8](src.data.address).bitcast[UInt8]() + + var bytes_sent = send(self.fd, header_pointer, strlen(header_pointer), 0) + if bytes_sent == -1: + return 0, Error("Failed to send message") + + return bytes_sent, Error() diff --git a/external/gojo/net/ip.mojo b/external/gojo/net/ip.mojo new file mode 100644 index 00000000..e76d5cc3 --- /dev/null +++ b/external/gojo/net/ip.mojo @@ -0,0 +1,179 @@ +from utils.variant import Variant +from utils.static_tuple import StaticTuple +from sys.info import os_is_linux, os_is_macos +from ..syscall.types import ( + c_int, + c_char, + c_void, + c_uint, +) +from ..syscall.net import ( + addrinfo, + addrinfo_unix, + AF_INET, + SOCK_STREAM, + AI_PASSIVE, + sockaddr, + sockaddr_in, + htons, + ntohs, + inet_pton, + inet_ntop, + getaddrinfo, + getaddrinfo_unix, + gai_strerror, + to_char_ptr, + c_charptr_to_string, +) + +alias AddrInfo = Variant[addrinfo, addrinfo_unix] + + +fn get_addr_info(host: String) raises -> AddrInfo: + var status: Int32 = 0 + if os_is_macos(): + var servinfo = Pointer[addrinfo]().alloc(1) + servinfo.store(addrinfo()) + var hints = addrinfo() + hints.ai_family = AF_INET + hints.ai_socktype = SOCK_STREAM + hints.ai_flags = AI_PASSIVE + + var host_ptr = to_char_ptr(host) + + var status = getaddrinfo( + host_ptr, + Pointer[UInt8](), + Pointer.address_of(hints), + Pointer.address_of(servinfo), + ) + if status != 0: + print("getaddrinfo failed to execute with status:", status) + var msg_ptr = gai_strerror(c_int(status)) + _ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]]( + to_char_ptr("gai_strerror: %s"), msg_ptr + ) + var msg = c_charptr_to_string(msg_ptr) + print("getaddrinfo error message: ", msg) + + if not servinfo: + print("servinfo is null") + raise Error("Failed to get address info. Pointer to addrinfo is null.") + + return servinfo.load() + elif os_is_linux(): + var servinfo = Pointer[addrinfo_unix]().alloc(1) + servinfo.store(addrinfo_unix()) + var hints = addrinfo_unix() + hints.ai_family = AF_INET + hints.ai_socktype = SOCK_STREAM + hints.ai_flags = AI_PASSIVE + + var host_ptr = to_char_ptr(host) + + var status = getaddrinfo_unix( + host_ptr, + Pointer[UInt8](), + Pointer.address_of(hints), + Pointer.address_of(servinfo), + ) + if status != 0: + print("getaddrinfo failed to execute with status:", status) + var msg_ptr = gai_strerror(c_int(status)) + _ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]]( + to_char_ptr("gai_strerror: %s"), msg_ptr + ) + var msg = c_charptr_to_string(msg_ptr) + print("getaddrinfo error message: ", msg) + + if not servinfo: + print("servinfo is null") + raise Error("Failed to get address info. Pointer to addrinfo is null.") + + return servinfo.load() + else: + raise Error("Windows is not supported yet! Sorry!") + + +fn get_ip_address(host: String) raises -> String: + """Get the IP address of a host.""" + # Call getaddrinfo to get the IP address of the host. + var result = get_addr_info(host) + var ai_addr: Pointer[sockaddr] + var address_family: Int32 = 0 + var address_length: UInt32 = 0 + if result.isa[addrinfo](): + var addrinfo = result[addrinfo] + ai_addr = addrinfo.ai_addr + address_family = addrinfo.ai_family + address_length = addrinfo.ai_addrlen + else: + var addrinfo = result[addrinfo_unix] + ai_addr = addrinfo.ai_addr + address_family = addrinfo.ai_family + address_length = addrinfo.ai_addrlen + + if not ai_addr: + print("ai_addr is null") + raise Error("Failed to get IP address. getaddrinfo was called successfully, but ai_addr is null.") + + # Cast sockaddr struct to sockaddr_in struct and convert the binary IP to a string using inet_ntop. + var addr_in = ai_addr.bitcast[sockaddr_in]().load() + + return convert_binary_ip_to_string(addr_in.sin_addr.s_addr, address_family, address_length).strip() + + +fn convert_port_to_binary(port: Int) -> UInt16: + return htons(UInt16(port)) + + +fn convert_binary_port_to_int(port: UInt16) -> Int: + return int(ntohs(port)) + + +fn convert_ip_to_binary(ip_address: String, address_family: Int) -> UInt32: + var ip_buffer = Pointer[c_void].alloc(4) + var status = inet_pton(address_family, to_char_ptr(ip_address), ip_buffer) + if status == -1: + print("Failed to convert IP address to binary") + + return ip_buffer.bitcast[c_uint]().load() + + +fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32, address_length: UInt32) -> String: + """Convert a binary IP address to a string by calling inet_ntop. + + Args: + ip_address: The binary IP address. + address_family: The address family of the IP address. + address_length: The length of the address. + + Returns: + The IP address as a string. + """ + # It seems like the len of the buffer depends on the length of the string IP. + # Allocating 10 works for localhost (127.0.0.1) which I suspect is 9 bytes + 1 null terminator byte. So max should be 16 (15 + 1). + var ip_buffer = Pointer[c_void].alloc(16) + var ip_address_ptr = Pointer.address_of(ip_address).bitcast[c_void]() + _ = inet_ntop(address_family, ip_address_ptr, ip_buffer, 16) + + var string_buf = ip_buffer.bitcast[Int8]() + var index = 0 + while True: + if string_buf[index] == 0: + break + index += 1 + + return StringRef(string_buf, index) + + +fn build_sockaddr_pointer(ip_address: String, port: Int, address_family: Int) -> Pointer[sockaddr]: + """Build a sockaddr pointer from an IP address and port number. + https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 + https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-sockaddr_in. + """ + var bin_port = convert_port_to_binary(port) + var bin_ip = convert_ip_to_binary(ip_address, address_family) + + var ai = sockaddr_in(address_family, bin_port, bin_ip, StaticTuple[c_char, 8]()) + return Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() diff --git a/external/gojo/net/net.mojo b/external/gojo/net/net.mojo new file mode 100644 index 00000000..74387d40 --- /dev/null +++ b/external/gojo/net/net.mojo @@ -0,0 +1,130 @@ +from memory.arc import Arc +import ..io +from ..builtins import Byte +from .socket import Socket +from .address import Addr, TCPAddr + +alias DEFAULT_BUFFER_SIZE = 8200 + + +trait Conn(io.Writer, io.Reader, io.Closer): + fn __init__(inout self, owned socket: Socket): + ... + + """Conn is a generic stream-oriented network connection.""" + + fn local_address(self) -> TCPAddr: + """Returns the local network address, if known.""" + ... + + fn remote_address(self) -> TCPAddr: + """Returns the local network address, if known.""" + ... + + # fn set_deadline(self, t: time.Time) -> Error: + # """Sets the read and write deadlines associated + # with the connection. It is equivalent to calling both + # SetReadDeadline and SetWriteDeadline. + + # A deadline is an absolute time after which I/O operations + # fail instead of blocking. The deadline applies to all future + # and pending I/O, not just the immediately following call to + # read or write. After a deadline has been exceeded, the + # connection can be refreshed by setting a deadline in the future. + + # If the deadline is exceeded a call to read or write or to other + # I/O methods will return an error that wraps os.ErrDeadlineExceeded. + # This can be tested using errors.Is(err, os.ErrDeadlineExceeded). + # The error's Timeout method will return true, but note that there + # are other possible errors for which the Timeout method will + # return true even if the deadline has not been exceeded. + + # An idle timeout can be implemented by repeatedly extending + # the deadline after successful read or write calls. + + # A zero value for t means I/O operations will not time out.""" + # ... + + # fn set_read_deadline(self, t: time.Time) -> Error: + # """Sets the deadline for future read calls + # and any currently-blocked read call. + # A zero value for t means read will not time out.""" + # ... + + # fn set_write_deadline(self, t: time.Time) -> Error: + # """Sets the deadline for future write calls + # and any currently-blocked write call. + # Even if write times out, it may return n > 0, indicating that + # some of the data was successfully written. + # A zero value for t means write will not time out.""" + # ... + + +@value +struct Connection(Conn): + """Connection is a concrete generic stream-oriented network connection. + It is used as the internal connection for structs like TCPConnection. + + Args: + fd: The file descriptor of the connection. + """ + + var fd: Arc[Socket] + + fn __init__(inout self, owned socket: Socket): + self.fd = Arc(socket^) + + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + """Reads data from the underlying file descriptor. + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + """ + var bytes_written: Int = 0 + var err = Error() + bytes_written, err = self.fd[].read(dest) + if err: + if str(err) != io.EOF: + return 0, err + + return bytes_written, err + + fn write(inout self, src: List[Byte]) -> (Int, Error): + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + + Returns: + The number of bytes written, or an error if one occurred. + """ + var bytes_read: Int = 0 + var err = Error() + bytes_read, err = self.fd[].write(src) + if err: + return 0, err + + return bytes_read, err + + fn close(inout self) -> Error: + """Closes the underlying file descriptor. + + Returns: + An error if one occurred, or None if the file descriptor was closed successfully. + """ + return self.fd[].close() + + fn local_address(self) -> TCPAddr: + """Returns the local network address. + The Addr returned is shared by all invocations of local_address, so do not modify it. + """ + return self.fd[].local_address + + fn remote_address(self) -> TCPAddr: + """Returns the remote network address. + The Addr returned is shared by all invocations of remote_address, so do not modify it. + """ + return self.fd[].remote_address diff --git a/external/gojo/net/socket.mojo b/external/gojo/net/socket.mojo new file mode 100644 index 00000000..e019255e --- /dev/null +++ b/external/gojo/net/socket.mojo @@ -0,0 +1,432 @@ +from collections.optional import Optional +from ..builtins import Byte +from ..syscall.file import close +from ..syscall.types import ( + c_void, + c_uint, + c_char, + c_int, +) +from ..syscall.net import ( + sockaddr, + sockaddr_in, + addrinfo, + addrinfo_unix, + socklen_t, + socket, + connect, + recv, + send, + shutdown, + inet_pton, + inet_ntoa, + inet_ntop, + to_char_ptr, + htons, + ntohs, + strlen, + getaddrinfo, + getaddrinfo_unix, + gai_strerror, + c_charptr_to_string, + bind, + listen, + accept, + setsockopt, + getsockopt, + getsockname, + getpeername, + AF_INET, + SOCK_STREAM, + SHUT_RDWR, + AI_PASSIVE, + SOL_SOCKET, + SO_REUSEADDR, + SO_RCVTIMEO, +) +from .fd import FileDescriptor, FileDescriptorBase +from .ip import ( + convert_binary_ip_to_string, + build_sockaddr_pointer, + convert_binary_port_to_int, +) +from .address import Addr, TCPAddr, HostPort + +alias SocketClosedError = Error("Socket: Socket is already closed") + + +struct Socket(FileDescriptorBase): + """Represents a network file descriptor. Wraps around a file descriptor and provides network functions. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + """ + + var sockfd: FileDescriptor + var address_family: Int + var socket_type: UInt8 + var protocol: UInt8 + var local_address: TCPAddr + var remote_address: TCPAddr + var _closed: Bool + var _is_connected: Bool + + fn __init__( + inout self, + local_address: TCPAddr = TCPAddr(), + remote_address: TCPAddr = TCPAddr(), + address_family: Int = AF_INET, + socket_type: UInt8 = SOCK_STREAM, + protocol: UInt8 = 0, + ) raises: + """Create a new socket object. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + """ + self.address_family = address_family + self.socket_type = socket_type + self.protocol = protocol + + var fd = socket(address_family, SOCK_STREAM, 0) + if fd == -1: + raise Error("Socket creation error") + self.sockfd = FileDescriptor(int(fd)) + self.local_address = local_address + self.remote_address = remote_address + self._closed = False + self._is_connected = False + + fn __init__( + inout self, + fd: Int32, + address_family: Int, + socket_type: UInt8, + protocol: UInt8, + local_address: TCPAddr = TCPAddr(), + remote_address: TCPAddr = TCPAddr(), + ): + """ + Create a new socket object when you already have a socket file descriptor. Typically through socket.accept(). + + Args: + fd: The file descriptor of the socket. + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + local_address: Local address of socket. + remote_address: Remote address of port. + """ + self.sockfd = FileDescriptor(int(fd)) + self.address_family = address_family + self.socket_type = socket_type + self.protocol = protocol + self.local_address = local_address + self.remote_address = remote_address + self._closed = False + self._is_connected = True + + fn __moveinit__(inout self, owned existing: Self): + self.sockfd = existing.sockfd^ + self.address_family = existing.address_family + self.socket_type = existing.socket_type + self.protocol = existing.protocol + self.local_address = existing.local_address^ + self.remote_address = existing.remote_address^ + self._closed = existing._closed + self._is_connected = existing._is_connected + + # fn __enter__(self) -> Self: + # return self + + # fn __exit__(inout self) raises: + # if self._is_connected: + # self.shutdown() + # if not self._closed: + # self.close() + + fn __del__(owned self): + if self._is_connected: + self.shutdown() + if not self._closed: + var err = self.close() + _ = self.sockfd.fd + if err: + print("Failed to close socket during deletion:", str(err)) + + @always_inline + fn accept(self) raises -> Self: + """Accept a connection. The socket must be bound to an address and listening for connections. + The return value is a connection where conn is a new socket object usable to send and receive data on the connection, + and address is the address bound to the socket on the other end of the connection. + """ + var their_addr_ptr = Pointer[sockaddr].alloc(1) + var sin_size = socklen_t(sizeof[socklen_t]()) + var new_sockfd = accept(self.sockfd.fd, their_addr_ptr, Pointer[socklen_t].address_of(sin_size)) + if new_sockfd == -1: + raise Error("Failed to accept connection") + + var remote = self.get_peer_name() + return Self( + new_sockfd, + self.address_family, + self.socket_type, + self.protocol, + self.local_address, + TCPAddr(remote.host, remote.port), + ) + + fn listen(self, backlog: Int = 0) raises: + """Enable a server to accept connections. + + Args: + backlog: The maximum number of queued connections. Should be at least 0, and the maximum is system-dependent (usually 5). + """ + var queued = backlog + if backlog < 0: + queued = 0 + if listen(self.sockfd.fd, queued) == -1: + raise Error("Failed to listen for connections") + + @always_inline + fn bind(inout self, address: String, port: Int) raises: + """Bind the socket to address. The socket must not already be bound. (The format of address depends on the address family). + + When a socket is created with Socket(), it exists in a name + space (address family) but has no address assigned to it. bind() + assigns the address specified by addr to the socket referred to + by the file descriptor sockfd. addrlen specifies the size, in + bytes, of the address structure pointed to by addr. + Traditionally, this operation is called 'assigning a name to a + socket'. + + Args: + address: String - The IP address to bind the socket to. + port: The port number to bind the socket to. + """ + var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family) + + if bind(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + _ = shutdown(self.sockfd.fd, SHUT_RDWR) + raise Error("Binding socket failed. Wait a few seconds and try again?") + + var local = self.get_sock_name() + self.local_address = TCPAddr(local.host, local.port) + + @always_inline + fn file_no(self) -> Int32: + """Return the file descriptor of the socket.""" + return self.sockfd.fd + + @always_inline + fn get_sock_name(self) raises -> HostPort: + """Return the address of the socket.""" + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + + var local_address_ptr = Pointer[sockaddr].alloc(1) + var local_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var status = getsockname( + self.sockfd.fd, + local_address_ptr, + Pointer[socklen_t].address_of(local_address_ptr_size), + ) + if status == -1: + raise Error("Socket.get_sock_name: Failed to get address of local socket.") + var addr_in = local_address_ptr.bitcast[sockaddr_in]().load() + + return HostPort( + host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), + port=convert_binary_port_to_int(addr_in.sin_port), + ) + + fn get_peer_name(self) raises -> HostPort: + """Return the address of the peer connected to the socket.""" + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + var remote_address_ptr = Pointer[sockaddr].alloc(1) + var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var status = getpeername( + self.sockfd.fd, + remote_address_ptr, + Pointer[socklen_t].address_of(remote_address_ptr_size), + ) + if status == -1: + raise Error("Socket.get_peer_name: Failed to get address of remote socket.") + + # Cast sockaddr struct to sockaddr_in to convert binary IP to string. + var addr_in = remote_address_ptr.bitcast[sockaddr_in]().load() + + return HostPort( + host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), + port=convert_binary_port_to_int(addr_in.sin_port), + ) + + fn get_socket_option(self, option_name: Int) raises -> Int: + """Return the value of the given socket option. + + Args: + option_name: The socket option to get. + """ + var option_value_pointer = Pointer[c_void].alloc(1) + var option_len = socklen_t(sizeof[socklen_t]()) + var option_len_pointer = Pointer.address_of(option_len) + var status = getsockopt( + self.sockfd.fd, + SOL_SOCKET, + option_name, + option_value_pointer, + option_len_pointer, + ) + if status == -1: + raise Error("Socket.get_sock_opt failed with status: " + str(status)) + + return option_value_pointer.bitcast[Int]().load() + + fn set_socket_option(self, option_name: Int, owned option_value: UInt8 = 1) raises: + """Return the value of the given socket option. + + Args: + option_name: The socket option to set. + option_value: The value to set the socket option to. + """ + var option_value_pointer = Pointer[c_void].address_of(option_value) + var option_len = sizeof[socklen_t]() + var status = setsockopt(self.sockfd.fd, SOL_SOCKET, option_name, option_value_pointer, option_len) + if status == -1: + raise Error("Socket.set_sock_opt failed with status: " + str(status)) + + fn connect(inout self, address: String, port: Int) raises: + """Connect to a remote socket at address. + + Args: + address: String - The IP address to connect to. + port: The port number to connect to. + """ + var sockaddr_pointer = build_sockaddr_pointer(address, port, self.address_family) + + if connect(self.sockfd.fd, sockaddr_pointer, sizeof[sockaddr_in]()) == -1: + self.shutdown() + raise Error("Socket.connect: Failed to connect to the remote socket at: " + address + ":" + str(port)) + + var remote = self.get_peer_name() + self.remote_address = TCPAddr(remote.host, remote.port) + + fn write(inout self: Self, src: List[Byte]) -> (Int, Error): + """Send data to the socket. The socket must be connected to a remote socket. + + Args: + src: The data to send. + + Returns: + The number of bytes sent. + """ + var bytes_written: Int + var err: Error + bytes_written, err = self.sockfd.write(src) + if err: + return 0, err + + return bytes_written, Error() + + fn send_all(self, src: List[Byte], max_attempts: Int = 3) raises: + """Send data to the socket. The socket must be connected to a remote socket. + + Args: + src: The data to send. + max_attempts: The maximum number of attempts to send the data. + """ + var header_pointer = src.unsafe_ptr() + var total_bytes_sent = 0 + var attempts = 0 + + # Try to send all the data in the buffer. If it did not send all the data, keep trying but start from the offset of the last successful send. + while total_bytes_sent < len(src): + if attempts > max_attempts: + raise Error("Failed to send message after " + str(max_attempts) + " attempts.") + + var bytes_sent = send( + self.sockfd.fd, + header_pointer.offset(total_bytes_sent), + strlen(header_pointer.offset(total_bytes_sent)), + 0, + ) + if bytes_sent == -1: + raise Error("Failed to send message, wrote" + String(total_bytes_sent) + "bytes before failing.") + total_bytes_sent += bytes_sent + attempts += 1 + + fn send_to(inout self, src: List[Byte], address: String, port: Int) raises -> Int: + """Send data to the a remote address by connecting to the remote socket before sending. + The socket must be not already be connected to a remote socket. + + Args: + src: The data to send. + address: The IP address to connect to. + port: The port number to connect to. + """ + var header_pointer = Pointer[Int8](src.data.address).bitcast[UInt8]() + self.connect(address, port) + var bytes_written: Int + var err: Error + bytes_written, err = self.write(src) + if err: + raise err + return bytes_written + + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + """Receive data from the socket.""" + # Not ideal since we can't use the pointer from the List[Byte] struct directly. So we use a temporary pointer to receive the data. + # Then we copy all the data over. + var bytes_written: Int + var err: Error + bytes_written, err = self.sockfd.read(dest) + if err: + if str(err) != "EOF": + return 0, err + + return bytes_written, Error() + + fn shutdown(self): + _ = shutdown(self.sockfd.fd, SHUT_RDWR) + + fn close(inout self) -> Error: + """Mark the socket closed. + Once that happens, all future operations on the socket object will fail. + The remote end will receive no more data (after queued data is flushed). + """ + self.shutdown() + var err = self.sockfd.close() + if err: + return err + + self._closed = True + return Error() + + # TODO: Trying to set timeout fails, but some other options don't? + # fn get_timeout(self) raises -> Seconds: + # """Return the timeout value for the socket.""" + # return self.get_socket_option(SO_RCVTIMEO) + + # fn set_timeout(self, owned duration: Seconds) raises: + # """Set the timeout value for the socket. + + # Args: + # duration: Seconds - The timeout duration in seconds. + # """ + # self.set_socket_option(SO_RCVTIMEO, duration) + + fn send_file(self, file: FileHandle, offset: Int = 0) raises: + self.send_all(file.read_bytes()) diff --git a/external/gojo/net/tcp.mojo b/external/gojo/net/tcp.mojo new file mode 100644 index 00000000..433bca95 --- /dev/null +++ b/external/gojo/net/tcp.mojo @@ -0,0 +1,207 @@ +from ..builtins import Byte +from ..syscall.net import SO_REUSEADDR +from .net import Connection, Conn +from .address import TCPAddr, NetworkType, split_host_port +from .socket import Socket + + +# Time in nanoseconds +alias Duration = Int +alias DEFAULT_BUFFER_SIZE = 8200 +alias DEFAULT_TCP_KEEP_ALIVE = Duration(15 * 1000 * 1000 * 1000) # 15 seconds + + +fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: + var host: String = "" + var port: String = "" + var portnum: Int = 0 + if ( + network == NetworkType.tcp.value + or network == NetworkType.tcp4.value + or network == NetworkType.tcp6.value + or network == NetworkType.udp.value + or network == NetworkType.udp4.value + or network == NetworkType.udp6.value + ): + if address != "": + var host_port = split_host_port(address) + host = host_port.host + port = str(host_port.port) + portnum = atol(port.__str__()) + elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: + if address != "": + host = address + elif network == NetworkType.unix.value: + raise Error("Unix addresses not supported yet") + else: + raise Error("unsupported network type: " + network) + return TCPAddr(host, portnum) + + +# TODO: For now listener is paired with TCP until we need to support +# more than one type of Connection or Listener +@value +struct ListenConfig(CollectionElement): + var keep_alive: Duration + + fn listen(self, network: String, address: String) raises -> TCPListener: + var tcp_addr = resolve_internet_addr(network, address) + var socket = Socket(local_address=tcp_addr) + socket.bind(tcp_addr.ip, tcp_addr.port) + socket.set_socket_option(SO_REUSEADDR, 1) + socket.listen() + print(str("Listening on ") + str(socket.local_address)) + return TCPListener(socket^, self, network, address) + + +trait Listener(Movable): + # Raising here because a Result[Optional[Connection], Error] is funky. + fn accept(self) raises -> Connection: + ... + + fn close(inout self) -> Error: + ... + + fn addr(self) raises -> TCPAddr: + ... + + +@value +struct TCPConnection(Conn): + """TCPConn is an implementation of the Conn interface for TCP network connections. + + Args: + connection: The underlying Connection. + """ + + var _connection: Connection + + fn __init__(inout self, connection: Connection): + self._connection = connection + + fn __init__(inout self, owned socket: Socket): + self._connection = Connection(socket^) + + fn __moveinit__(inout self, owned existing: Self): + self._connection = existing._connection^ + + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): + """Reads data from the underlying file descriptor. + + Args: + dest: The buffer to read data into. + + Returns: + The number of bytes read, or an error if one occurred. + """ + var bytes_written: Int + var err: Error + bytes_written, err = self._connection.read(dest) + if err: + if str(err) != io.EOF: + return 0, err + + return bytes_written, Error() + + fn write(inout self, src: List[Byte]) -> (Int, Error): + """Writes data to the underlying file descriptor. + + Args: + src: The buffer to read data into. + + Returns: + The number of bytes written, or an error if one occurred. + """ + var bytes_written: Int + var err: Error + bytes_written, err = self._connection.write(src) + if err: + return 0, err + + return bytes_written, Error() + + fn close(inout self) -> Error: + """Closes the underlying file descriptor. + + Returns: + An error if one occurred, or None if the file descriptor was closed successfully. + """ + return self._connection.close() + + fn local_address(self) -> TCPAddr: + """Returns the local network address. + The Addr returned is shared by all invocations of local_address, so do not modify it. + + Returns: + The local network address. + """ + return self._connection.local_address() + + fn remote_address(self) -> TCPAddr: + """Returns the remote network address. + The Addr returned is shared by all invocations of remote_address, so do not modify it. + + Returns: + The remote network address. + """ + return self._connection.remote_address() + + +fn listen_tcp(network: String, local_address: TCPAddr) raises -> TCPListener: + """Creates a new TCP listener. + + Args: + network: The network type. + local_address: The local address to listen on. + """ + return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, local_address.ip + ":" + str(local_address.port)) + + +fn listen_tcp(network: String, local_address: String) raises -> TCPListener: + """Creates a new TCP listener. + + Args: + network: The network type. + local_address: The address to listen on. The format is "host:port". + """ + return ListenConfig(DEFAULT_TCP_KEEP_ALIVE).listen(network, local_address) + + +struct TCPListener(Listener): + var _file_descriptor: Socket + var listen_config: ListenConfig + var network_type: String + var address: String + + fn __init__( + inout self, + owned file_descriptor: Socket, + listen_config: ListenConfig, + network_type: String, + address: String, + ): + self._file_descriptor = file_descriptor^ + self.listen_config = listen_config + self.network_type = network_type + self.address = address + + fn __moveinit__(inout self, owned existing: Self): + self._file_descriptor = existing._file_descriptor^ + self.listen_config = existing.listen_config^ + self.network_type = existing.network_type + self.address = existing.address + + fn listen(self) raises -> Self: + return self.listen_config.listen(self.network_type, self.address) + + fn accept(self) raises -> Connection: + return Connection(self._file_descriptor.accept()) + + fn accept_tcp(self) raises -> TCPConnection: + return TCPConnection(self._file_descriptor.accept()) + + fn close(inout self) -> Error: + return self._file_descriptor.close() + + fn addr(self) raises -> TCPAddr: + return resolve_internet_addr(self.network_type, self.address) diff --git a/external/gojo/strings/builder.mojo b/external/gojo/strings/builder.mojo index 520fd0c9..e4bdce99 100644 --- a/external/gojo/strings/builder.mojo +++ b/external/gojo/strings/builder.mojo @@ -1,17 +1,15 @@ -# Adapted from https://github.com/maniartech/mojo-strings/blob/master/strings/builder.mojo -# Modified to use List[Int8] instead of List[String] - import ..io -from ..builtins import Byte, Result, WrappedError +from ..builtins import Byte @value -struct StringBuilder(Stringable, Sized, io.Writer, io.ByteWriter, io.StringWriter): +struct StringBuilder[growth_factor: Float32 = 2](Stringable, Sized, io.Writer, io.StringWriter): """ A string builder class that allows for efficient string management and concatenation. This class is useful when you need to build a string by appending multiple strings - together. It is around 10x faster than using the `+` operator to concatenate - strings because it avoids the overhead of creating and destroying many + together. The performance increase is not linear. Compared to string concatenation, + I've observed around 20-30x faster for writing and rending ~4KB and up to 2100x-2300x + for ~4MB. This is because it avoids the overhead of creating and destroying many intermediate strings and performs memcopy operations. The result is a more efficient when building larger string concatenations. It @@ -24,17 +22,39 @@ struct StringBuilder(Stringable, Sized, io.Writer, io.ByteWriter, io.StringWrite from strings.builder import StringBuilder var sb = StringBuilder() - sb.append("mojo") - sb.append("jojo") - print(sb) # mojojojo + sb.write_string("Hello ") + sb.write_string("World!") + print(sb) # Hello World! ``` """ - var _vector: List[Byte] + var data: DTypePointer[DType.uint8] + var size: Int + var capacity: Int + + @always_inline + fn __init__(inout self, *, capacity: Int = 8200): + constrained[growth_factor >= 1.25]() + self.data = DTypePointer[DType.uint8]().alloc(capacity) + self.size = 0 + self.capacity = capacity + + @always_inline + fn __del__(owned self): + if self.data: + self.data.free() - fn __init__(inout self, size: Int = 4096): - self._vector = List[Byte](capacity=size) + @always_inline + fn __len__(self) -> Int: + """ + Returns the length of the string builder. + + Returns: + The length of the string builder. + """ + return self.size + @always_inline fn __str__(self) -> String: """ Converts the string builder to a string. @@ -43,91 +63,62 @@ struct StringBuilder(Stringable, Sized, io.Writer, io.ByteWriter, io.StringWrite The string representation of the string builder. Returns an empty string if the string builder is empty. """ - var copy = List[Byte](self._vector) - if copy[-1] != 0: - copy.append(0) - return String(copy) + var copy = DTypePointer[DType.uint8]().alloc(self.size) + memcpy(copy, self.data, self.size) + return StringRef(copy, self.size) - fn get_bytes(self) -> List[Int8]: + @always_inline + fn render(self) -> StringSlice[is_mutable=False, lifetime=ImmutableStaticLifetime]: """ - Returns a deepcopy of the byte array of the string builder. + Return a StringSlice view of the data owned by the builder. + Slightly faster than __str__, 10-20% faster in limited testing. Returns: - The byte array of the string builder. + The string representation of the string builder. Returns an empty string if the string builder is empty. """ - return List[Byte](self._vector) + return StringSlice[is_mutable=False, lifetime=ImmutableStaticLifetime](unsafe_from_utf8_strref=StringRef(self.data, self.size)) - fn get_null_terminated_bytes(self) -> List[Int8]: + @always_inline + fn _resize(inout self, capacity: Int) -> None: """ - Returns a deepcopy of the byte array of the string builder with a null terminator. + Resizes the string builder buffer. - Returns: - The byte array of the string builder with a null terminator. + Args: + capacity: The new capacity of the string builder buffer. """ - var copy = List[Byte](self._vector) - if copy[-1] != 0: - copy.append(0) + var new_data = DTypePointer[DType.uint8]().alloc(capacity) + memcpy(new_data, self.data, self.size) + self.data.free() + self.data = new_data + self.capacity = capacity - return copy + return None - fn write(inout self, src: List[Byte]) -> Result[Int]: + @always_inline + fn write(inout self, src: Span[Byte]) -> (Int, Error): """ - Appends a byte array to the builder buffer. + Appends a byte Span to the builder buffer. Args: src: The byte array to append. """ - self._vector.extend(src) - return Result(len(src), None) + if len(src) > self.capacity - self.size: + var new_capacity = int(self.capacity * growth_factor) + if new_capacity < self.capacity + len(src): + new_capacity = self.capacity + len(src) + self._resize(new_capacity) - fn write_byte(inout self, byte: Int8) -> Result[Int]: - """ - Appends a byte array to the builder buffer. + memcpy(self.data.offset(self.size), src._data, len(src)) + self.size += len(src) - Args: - byte: The byte array to append. - """ - self._vector.append(byte) - return Result(1, None) + return len(src), Error() - fn write_string(inout self, src: String) -> Result[Int]: + @always_inline + fn write_string(inout self, src: String) -> (Int, Error): """ Appends a string to the builder buffer. Args: src: The string to append. """ - var string_buffer = src.as_bytes() - self._vector.extend(string_buffer) - return Result(len(string_buffer), None) - - fn __len__(self) -> Int: - """ - Returns the length of the string builder. - - Returns: - The length of the string builder. - """ - return len(self._vector) - - fn __getitem__(self, index: Int) -> String: - """ - Returns the string at the given index. - - Args: - index: The index of the string to return. - - Returns: - The string at the given index. - """ - return self._vector[index] - - fn __setitem__(inout self, index: Int, value: Int8): - """ - Sets the string at the given index. - - Args: - index: The index of the string to set. - value: The value to set. - """ - self._vector[index] = value + return self.write(src.as_bytes_slice()) diff --git a/external/gojo/strings/reader.mojo b/external/gojo/strings/reader.mojo index 7ba0789d..8d8af287 100644 --- a/external/gojo/strings/reader.mojo +++ b/external/gojo/strings/reader.mojo @@ -1,12 +1,9 @@ import ..io -from collections.optional import Optional -from ..builtins import Byte, copy, Result, panic, WrappedError +from ..builtins import Byte, copy, panic @value -struct Reader( - Sized, io.Reader, io.ReaderAt, io.ByteReader, io.ByteScanner, io.Seeker, io.WriterTo -): +struct Reader(Sized, io.Reader, io.ReaderAt, io.ByteReader, io.ByteScanner, io.Seeker, io.WriterTo): """A Reader that implements the [io.Reader], [io.ReaderAt], [io.ByteReader], [io.ByteScanner], [io.Seeker], and [io.WriterTo] traits by reading from a string. The zero value for Reader operates like a Reader of an empty string. """ @@ -42,7 +39,7 @@ struct Reader( """ return Int64(len(self.string)) - fn read(inout self, inout dest: List[Byte]) -> Result[Int]: + fn read(inout self, inout dest: List[Byte]) -> (Int, Error): """Reads from the underlying string into the provided List[Byte] object. Implements the [io.Reader] trait. @@ -53,14 +50,14 @@ struct Reader( The number of bytes read into dest. """ if self.read_pos >= Int64(len(self.string)): - return Result(0, WrappedError(io.EOF)) + return 0, Error(io.EOF) self.prev_rune = -1 var bytes_written = copy(dest, self.string[int(self.read_pos) :].as_bytes()) self.read_pos += Int64(bytes_written) - return bytes_written + return bytes_written, Error() - fn read_at(self, inout dest: List[Byte], off: Int64) -> Result[Int]: + fn read_at(self, inout dest: List[Byte], off: Int64) -> (Int, Error): """Reads from the Reader into the dest List[Byte] starting at the offset off. It returns the number of bytes read into dest and an error if any. Implements the [io.ReaderAt] trait. @@ -74,19 +71,19 @@ struct Reader( """ # cannot modify state - see io.ReaderAt if off < 0: - return Result(0, WrappedError("strings.Reader.read_at: negative offset")) + return 0, Error("strings.Reader.read_at: negative offset") if off >= Int64(len(self.string)): - return Result(0, WrappedError(io.EOF)) + return 0, Error(io.EOF) - var error: Optional[WrappedError] = None + var error = Error() var copied_elements_count = copy(dest, self.string[int(off) :].as_bytes()) if copied_elements_count < len(dest): - error = WrappedError(io.EOF) + error = Error(io.EOF) - return copied_elements_count + return copied_elements_count, Error() - fn read_byte(inout self) -> Result[Byte]: + fn read_byte(inout self) -> (Byte, Error): """Reads the next byte from the underlying string. Implements the [io.ByteReader] trait. @@ -95,23 +92,23 @@ struct Reader( """ self.prev_rune = -1 if self.read_pos >= Int64(len(self.string)): - return Result(Byte(0), WrappedError(io.EOF)) + return Byte(0), Error(io.EOF) var b = self.string[int(self.read_pos)] self.read_pos += 1 - return Result(Byte(ord(b)), None) + return Byte(ord(b)), Error() - fn unread_byte(inout self) -> Optional[WrappedError]: + fn unread_byte(inout self) -> Error: """Unreads the last byte read. Only the most recent byte read can be unread. Implements the [io.ByteScanner] trait. """ if self.read_pos <= 0: - return WrappedError("strings.Reader.unread_byte: at beginning of string") + return Error("strings.Reader.unread_byte: at beginning of string") self.prev_rune = -1 self.read_pos -= 1 - return None + return Error() # # read_rune implements the [io.RuneReader] trait. # fn read_rune() (ch rune, size int, err error): @@ -140,7 +137,7 @@ struct Reader( # self.prev_rune = -1 # return nil - fn seek(inout self, offset: Int64, whence: Int) -> Result[Int64]: + fn seek(inout self, offset: Int64, whence: Int) -> (Int64, Error): """Seeks to a new position in the underlying string. The next read will start from that position. Implements the [io.Seeker] trait. @@ -161,17 +158,15 @@ struct Reader( elif whence == io.SEEK_END: position = Int64(len(self.string)) + offset else: - return Result(Int64(0), WrappedError("strings.Reader.seek: invalid whence")) + return Int64(0), Error("strings.Reader.seek: invalid whence") if position < 0: - return Result( - Int64(0), WrappedError("strings.Reader.seek: negative position") - ) + return Int64(0), Error("strings.Reader.seek: negative position") self.read_pos = position - return position + return position, Error() - fn write_to[W: io.Writer](inout self, inout writer: W) -> Result[Int64]: + fn write_to[W: io.Writer](inout self, inout writer: W) -> (Int64, Error): """Writes the remaining portion of the underlying string to the provided writer. Implements the [io.WriterTo] trait. @@ -183,20 +178,20 @@ struct Reader( """ self.prev_rune = -1 if self.read_pos >= Int64(len(self.string)): - return Result(Int64(0), None) + return Int64(0), Error() var chunk_to_write = self.string[int(self.read_pos) :] - var result = io.write_string(writer, chunk_to_write) - var bytes_written = result.value + var bytes_written: Int + var err: Error + bytes_written, err = io.write_string(writer, chunk_to_write) if bytes_written > len(chunk_to_write): panic("strings.Reader.write_to: invalid write_string count") - var error: Optional[WrappedError] = None self.read_pos += Int64(bytes_written) - if bytes_written != len(chunk_to_write) and result.has_error(): - error = WrappedError(io.ERR_SHORT_WRITE) + if bytes_written != len(chunk_to_write) and not err: + err = Error(io.ERR_SHORT_WRITE) - return Result(Int64(bytes_written), error) + return Int64(bytes_written), err # TODO: How can I differentiate between the two write_to methods when the writer implements both traits? # fn write_to[W: io.StringWriter](inout self, inout writer: W) raises -> Int64: diff --git a/external/gojo/syscall/__init__.mojo b/external/gojo/syscall/__init__.mojo new file mode 100644 index 00000000..d59a41f9 --- /dev/null +++ b/external/gojo/syscall/__init__.mojo @@ -0,0 +1,24 @@ +from .net import FD_STDIN, FD_STDOUT, FD_STDERR + +# Adapted from https://github.com/crisadamo/mojo-Libc . Huge thanks to Cristian! +# C types +alias c_void = UInt8 +alias c_char = UInt8 +alias c_schar = Int8 +alias c_uchar = UInt8 +alias c_short = Int16 +alias c_ushort = UInt16 +alias c_int = Int32 +alias c_uint = UInt32 +alias c_long = Int64 +alias c_ulong = UInt64 +alias c_float = Float32 +alias c_double = Float64 + +# `Int` is known to be machine's width +alias c_size_t = Int +alias c_ssize_t = Int + +alias ptrdiff_t = Int64 +alias intptr_t = Int64 +alias uintptr_t = UInt64 diff --git a/external/gojo/syscall/file.mojo b/external/gojo/syscall/file.mojo new file mode 100644 index 00000000..77f05a99 --- /dev/null +++ b/external/gojo/syscall/file.mojo @@ -0,0 +1,62 @@ +from . import c_int, c_char, c_void, c_size_t, c_ssize_t + + +# --- ( File Related Syscalls & Structs )--------------------------------------- +alias O_NONBLOCK = 16384 +alias O_ACCMODE = 3 +alias O_CLOEXEC = 524288 + + +fn close(fildes: c_int) -> c_int: + """Libc POSIX `close` function + Reference: https://man7.org/linux/man-pages/man3/close.3p.html + Fn signature: int close(int fildes). + + Args: + fildes: A File Descriptor to close. + + Returns: + Upon successful completion, 0 shall be returned; otherwise, -1 + shall be returned and errno set to indicate the error. + """ + return external_call["close", c_int, c_int](fildes) + + +fn open[*T: AnyType](path: UnsafePointer[c_char], oflag: c_int) -> c_int: + """Libc POSIX `open` function + Reference: https://man7.org/linux/man-pages/man3/open.3p.html + Fn signature: int open(const char *path, int oflag, ...). + + Args: + path: A pointer to a C string containing the path to open. + oflag: The flags to open the file with. + Returns: + A File Descriptor or -1 in case of failure + """ + return external_call["open", c_int, UnsafePointer[c_char], c_int](path, oflag) # FnName, RetType # Args + + +fn read(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: + """Libc POSIX `read` function + Reference: https://man7.org/linux/man-pages/man3/read.3p.html + Fn signature: sssize_t read(int fildes, void *buf, size_t nbyte). + + Args: fildes: A File Descriptor. + buf: A pointer to a buffer to store the read data. + nbyte: The number of bytes to read. + Returns: The number of bytes read or -1 in case of failure. + """ + return external_call["read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte) + + +fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: + """Libc POSIX `write` function + Reference: https://man7.org/linux/man-pages/man3/write.3p.html + Fn signature: ssize_t write(int fildes, const void *buf, size_t nbyte). + + Args: fildes: A File Descriptor. + buf: A pointer to a buffer to write. + nbyte: The number of bytes to write. + Returns: The number of bytes written or -1 in case of failure. + """ + return external_call["write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte) diff --git a/external/gojo/syscall/net.mojo b/external/gojo/syscall/net.mojo new file mode 100644 index 00000000..2b0901af --- /dev/null +++ b/external/gojo/syscall/net.mojo @@ -0,0 +1,790 @@ +from . import c_char, c_int, c_ushort, c_uint, c_size_t, c_ssize_t +from .types import strlen +from .file import O_CLOEXEC, O_NONBLOCK +from utils.static_tuple import StaticTuple + +alias IPPROTO_IPV6 = 41 +alias IPV6_V6ONLY = 26 +alias EPROTONOSUPPORT = 93 + +# Adapted from https://github.com/gabrieldemarmiesse/mojo-stdlib-extensions/ . Huge thanks to Gabriel! + +alias FD_STDIN: c_int = 0 +alias FD_STDOUT: c_int = 1 +alias FD_STDERR: c_int = 2 + +alias SUCCESS = 0 +alias GRND_NONBLOCK: UInt8 = 1 + +alias char_pointer = DTypePointer[DType.uint8] + + +# --- ( error.h Constants )----------------------------------------------------- +alias EPERM = 1 +alias ENOENT = 2 +alias ESRCH = 3 +alias EINTR = 4 +alias EIO = 5 +alias ENXIO = 6 +alias E2BIG = 7 +alias ENOEXEC = 8 +alias EBADF = 9 +alias ECHILD = 10 +alias EAGAIN = 11 +alias ENOMEM = 12 +alias EACCES = 13 +alias EFAULT = 14 +alias ENOTBLK = 15 +alias EBUSY = 16 +alias EEXIST = 17 +alias EXDEV = 18 +alias ENODEV = 19 +alias ENOTDIR = 20 +alias EISDIR = 21 +alias EINVAL = 22 +alias ENFILE = 23 +alias EMFILE = 24 +alias ENOTTY = 25 +alias ETXTBSY = 26 +alias EFBIG = 27 +alias ENOSPC = 28 +alias ESPIPE = 29 +alias EROFS = 30 +alias EMLINK = 31 +alias EPIPE = 32 +alias EDOM = 33 +alias ERANGE = 34 +alias EWOULDBLOCK = EAGAIN + + +fn to_char_ptr(s: String) -> DTypePointer[DType.uint8]: + """Only ASCII-based strings.""" + var ptr = DTypePointer[DType.uint8]().alloc(len(s)) + for i in range(len(s)): + ptr.store(i, ord(s[i])) + return ptr + + +fn c_charptr_to_string(s: DTypePointer[DType.uint8]) -> String: + return String(s, strlen(s)) + + +fn cftob(val: c_int) -> Bool: + """Convert C-like failure (-1) to Bool.""" + return rebind[Bool](val > 0) + + +# --- ( Network Related Constants )--------------------------------------------- +alias sa_family_t = c_ushort +alias socklen_t = c_uint +alias in_addr_t = c_uint +alias in_port_t = c_ushort + +# Address Family Constants +alias AF_UNSPEC = 0 +alias AF_UNIX = 1 +alias AF_LOCAL = AF_UNIX +alias AF_INET = 2 +alias AF_AX25 = 3 +alias AF_IPX = 4 +alias AF_APPLETALK = 5 +alias AF_NETROM = 6 +alias AF_BRIDGE = 7 +alias AF_ATMPVC = 8 +alias AF_X25 = 9 +alias AF_INET6 = 10 +alias AF_ROSE = 11 +alias AF_DECnet = 12 +alias AF_NETBEUI = 13 +alias AF_SECURITY = 14 +alias AF_KEY = 15 +alias AF_NETLINK = 16 +alias AF_ROUTE = AF_NETLINK +alias AF_PACKET = 17 +alias AF_ASH = 18 +alias AF_ECONET = 19 +alias AF_ATMSVC = 20 +alias AF_RDS = 21 +alias AF_SNA = 22 +alias AF_IRDA = 23 +alias AF_PPPOX = 24 +alias AF_WANPIPE = 25 +alias AF_LLC = 26 +alias AF_CAN = 29 +alias AF_TIPC = 30 +alias AF_BLUETOOTH = 31 +alias AF_IUCV = 32 +alias AF_RXRPC = 33 +alias AF_ISDN = 34 +alias AF_PHONET = 35 +alias AF_IEEE802154 = 36 +alias AF_CAIF = 37 +alias AF_ALG = 38 +alias AF_NFC = 39 +alias AF_VSOCK = 40 +alias AF_KCM = 41 +alias AF_QIPCRTR = 42 +alias AF_MAX = 43 + +# Protocol family constants +alias PF_UNSPEC = AF_UNSPEC +alias PF_UNIX = AF_UNIX +alias PF_LOCAL = AF_LOCAL +alias PF_INET = AF_INET +alias PF_AX25 = AF_AX25 +alias PF_IPX = AF_IPX +alias PF_APPLETALK = AF_APPLETALK +alias PF_NETROM = AF_NETROM +alias PF_BRIDGE = AF_BRIDGE +alias PF_ATMPVC = AF_ATMPVC +alias PF_X25 = AF_X25 +alias PF_INET6 = AF_INET6 +alias PF_ROSE = AF_ROSE +alias PF_DECnet = AF_DECnet +alias PF_NETBEUI = AF_NETBEUI +alias PF_SECURITY = AF_SECURITY +alias PF_KEY = AF_KEY +alias PF_NETLINK = AF_NETLINK +alias PF_ROUTE = AF_ROUTE +alias PF_PACKET = AF_PACKET +alias PF_ASH = AF_ASH +alias PF_ECONET = AF_ECONET +alias PF_ATMSVC = AF_ATMSVC +alias PF_RDS = AF_RDS +alias PF_SNA = AF_SNA +alias PF_IRDA = AF_IRDA +alias PF_PPPOX = AF_PPPOX +alias PF_WANPIPE = AF_WANPIPE +alias PF_LLC = AF_LLC +alias PF_CAN = AF_CAN +alias PF_TIPC = AF_TIPC +alias PF_BLUETOOTH = AF_BLUETOOTH +alias PF_IUCV = AF_IUCV +alias PF_RXRPC = AF_RXRPC +alias PF_ISDN = AF_ISDN +alias PF_PHONET = AF_PHONET +alias PF_IEEE802154 = AF_IEEE802154 +alias PF_CAIF = AF_CAIF +alias PF_ALG = AF_ALG +alias PF_NFC = AF_NFC +alias PF_VSOCK = AF_VSOCK +alias PF_KCM = AF_KCM +alias PF_QIPCRTR = AF_QIPCRTR +alias PF_MAX = AF_MAX + +# Socket Type constants +alias SOCK_STREAM = 1 +alias SOCK_DGRAM = 2 +alias SOCK_RAW = 3 +alias SOCK_RDM = 4 +alias SOCK_SEQPACKET = 5 +alias SOCK_DCCP = 6 +alias SOCK_PACKET = 10 +alias SOCK_CLOEXEC = O_CLOEXEC +alias SOCK_NONBLOCK = O_NONBLOCK + +# Address Information +alias AI_PASSIVE = 1 +alias AI_CANONNAME = 2 +alias AI_NUMERICHOST = 4 +alias AI_V4MAPPED = 2048 +alias AI_ALL = 256 +alias AI_ADDRCONFIG = 1024 +alias AI_IDN = 64 + +alias INET_ADDRSTRLEN = 16 +alias INET6_ADDRSTRLEN = 46 + +alias SHUT_RD = 0 +alias SHUT_WR = 1 +alias SHUT_RDWR = 2 + +alias SOL_SOCKET = 65535 + +# Socket Options +alias SO_DEBUG = 1 +alias SO_REUSEADDR = 4 +alias SO_TYPE = 4104 +alias SO_ERROR = 4103 +alias SO_DONTROUTE = 16 +alias SO_BROADCAST = 32 +alias SO_SNDBUF = 4097 +alias SO_RCVBUF = 4098 +alias SO_KEEPALIVE = 8 +alias SO_OOBINLINE = 256 +alias SO_LINGER = 128 +alias SO_REUSEPORT = 512 +alias SO_RCVLOWAT = 4100 +alias SO_SNDLOWAT = 4099 +alias SO_RCVTIMEO = 4102 +alias SO_SNDTIMEO = 4101 +alias SO_RCVTIMEO_OLD = 4102 +alias SO_SNDTIMEO_OLD = 4101 +alias SO_ACCEPTCONN = 2 + +# unsure of these socket options, they weren't available via python +alias SO_NO_CHECK = 11 +alias SO_PRIORITY = 12 +alias SO_BSDCOMPAT = 14 +alias SO_PASSCRED = 16 +alias SO_PEERCRED = 17 +alias SO_SECURITY_AUTHENTICATION = 22 +alias SO_SECURITY_ENCRYPTION_TRANSPORT = 23 +alias SO_SECURITY_ENCRYPTION_NETWORK = 24 +alias SO_BINDTODEVICE = 25 +alias SO_ATTACH_FILTER = 26 +alias SO_DETACH_FILTER = 27 +alias SO_GET_FILTER = SO_ATTACH_FILTER +alias SO_PEERNAME = 28 +alias SO_TIMESTAMP = 29 +alias SO_TIMESTAMP_OLD = 29 +alias SO_PEERSEC = 31 +alias SO_SNDBUFFORCE = 32 +alias SO_RCVBUFFORCE = 33 +alias SO_PASSSEC = 34 +alias SO_TIMESTAMPNS = 35 +alias SO_TIMESTAMPNS_OLD = 35 +alias SO_MARK = 36 +alias SO_TIMESTAMPING = 37 +alias SO_TIMESTAMPING_OLD = 37 +alias SO_PROTOCOL = 38 +alias SO_DOMAIN = 39 +alias SO_RXQ_OVFL = 40 +alias SO_WIFI_STATUS = 41 +alias SCM_WIFI_STATUS = SO_WIFI_STATUS +alias SO_PEEK_OFF = 42 +alias SO_NOFCS = 43 +alias SO_LOCK_FILTER = 44 +alias SO_SELECT_ERR_QUEUE = 45 +alias SO_BUSY_POLL = 46 +alias SO_MAX_PACING_RATE = 47 +alias SO_BPF_EXTENSIONS = 48 +alias SO_INCOMING_CPU = 49 +alias SO_ATTACH_BPF = 50 +alias SO_DETACH_BPF = SO_DETACH_FILTER +alias SO_ATTACH_REUSEPORT_CBPF = 51 +alias SO_ATTACH_REUSEPORT_EBPF = 52 +alias SO_CNX_ADVICE = 53 +alias SCM_TIMESTAMPING_OPT_STATS = 54 +alias SO_MEMINFO = 55 +alias SO_INCOMING_NAPI_ID = 56 +alias SO_COOKIE = 57 +alias SCM_TIMESTAMPING_PKTINFO = 58 +alias SO_PEERGROUPS = 59 +alias SO_ZEROCOPY = 60 +alias SO_TXTIME = 61 +alias SCM_TXTIME = SO_TXTIME +alias SO_BINDTOIFINDEX = 62 +alias SO_TIMESTAMP_NEW = 63 +alias SO_TIMESTAMPNS_NEW = 64 +alias SO_TIMESTAMPING_NEW = 65 +alias SO_RCVTIMEO_NEW = 66 +alias SO_SNDTIMEO_NEW = 67 +alias SO_DETACH_REUSEPORT_BPF = 68 + + +# --- ( Network Related Structs )----------------------------------------------- +@value +@register_passable("trivial") +struct in_addr: + var s_addr: in_addr_t + + +@value +@register_passable("trivial") +struct in6_addr: + var s6_addr: StaticTuple[c_char, 16] + + +@value +@register_passable("trivial") +struct sockaddr: + var sa_family: sa_family_t + var sa_data: StaticTuple[c_char, 14] + + +@value +@register_passable("trivial") +struct sockaddr_in: + var sin_family: sa_family_t + var sin_port: in_port_t + var sin_addr: in_addr + var sin_zero: StaticTuple[c_char, 8] + + +@value +@register_passable("trivial") +struct sockaddr_in6: + var sin6_family: sa_family_t + var sin6_port: in_port_t + var sin6_flowinfo: c_uint + var sin6_addr: in6_addr + var sin6_scope_id: c_uint + + +@value +@register_passable("trivial") +struct addrinfo: + """Struct field ordering can vary based on platform. + For MacOS, I had to swap the order of ai_canonname and ai_addr. + https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_canonname: DTypePointer[DType.uint8] + var ai_addr: UnsafePointer[sockaddr] + var ai_next: UnsafePointer[addrinfo] + + fn __init__( + inout self, + ai_flags: c_int = 0, + ai_family: c_int = 0, + ai_socktype: c_int = 0, + ai_protocol: c_int = 0, + ai_addrlen: socklen_t = 0, + ai_canonname: DTypePointer[DType.uint8] = DTypePointer[DType.uint8](), + ai_addr: UnsafePointer[sockaddr] = UnsafePointer[sockaddr](), + ai_next: UnsafePointer[addrinfo] = UnsafePointer[addrinfo](), + ): + self.ai_flags = ai_flags + self.ai_family = ai_family + self.ai_socktype = ai_socktype + self.ai_protocol = ai_protocol + self.ai_addrlen = ai_addrlen + self.ai_canonname = ai_canonname + self.ai_addr = ai_addr + self.ai_next = ai_next + + # fn __init__() -> Self: + # return Self(0, 0, 0, 0, 0, DTypePointer[DType.uint8](), UnsafePointer[sockaddr](), UnsafePointer[addrinfo]()) + + +@value +@register_passable("trivial") +struct addrinfo_unix: + """Struct field ordering can vary based on platform. + For MacOS, I had to swap the order of ai_canonname and ai_addr. + https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. + """ + + var ai_flags: c_int + var ai_family: c_int + var ai_socktype: c_int + var ai_protocol: c_int + var ai_addrlen: socklen_t + var ai_addr: UnsafePointer[sockaddr] + var ai_canonname: DTypePointer[DType.uint8] + var ai_next: UnsafePointer[addrinfo] + + fn __init__( + inout self, + ai_flags: c_int = 0, + ai_family: c_int = 0, + ai_socktype: c_int = 0, + ai_protocol: c_int = 0, + ai_addrlen: socklen_t = 0, + ai_canonname: DTypePointer[DType.uint8] = DTypePointer[DType.uint8](), + ai_addr: UnsafePointer[sockaddr] = UnsafePointer[sockaddr](), + ai_next: UnsafePointer[addrinfo] = UnsafePointer[addrinfo](), + ): + self.ai_flags = ai_flags + self.ai_family = ai_family + self.ai_socktype = ai_socktype + self.ai_protocol = ai_protocol + self.ai_addrlen = ai_addrlen + self.ai_canonname = ai_canonname + self.ai_addr = ai_addr + self.ai_next = ai_next + + +# --- ( Network Related Syscalls & Structs )------------------------------------ + + +fn htonl(hostlong: c_uint) -> c_uint: + """Libc POSIX `htonl` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint32_t htonl(uint32_t hostlong). + + Args: hostlong: A 32-bit integer in host byte order. + Returns: The value provided in network byte order. + """ + return external_call["htonl", c_uint, c_uint](hostlong) + + +fn htons(hostshort: c_ushort) -> c_ushort: + """Libc POSIX `htons` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint16_t htons(uint16_t hostshort). + + Args: hostshort: A 16-bit integer in host byte order. + Returns: The value provided in network byte order. + """ + return external_call["htons", c_ushort, c_ushort](hostshort) + + +fn ntohl(netlong: c_uint) -> c_uint: + """Libc POSIX `ntohl` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint32_t ntohl(uint32_t netlong). + + Args: netlong: A 32-bit integer in network byte order. + Returns: The value provided in host byte order. + """ + return external_call["ntohl", c_uint, c_uint](netlong) + + +fn ntohs(netshort: c_ushort) -> c_ushort: + """Libc POSIX `ntohs` function + Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html + Fn signature: uint16_t ntohs(uint16_t netshort). + + Args: netshort: A 16-bit integer in network byte order. + Returns: The value provided in host byte order. + """ + return external_call["ntohs", c_ushort, c_ushort](netshort) + + +fn inet_ntop( + af: c_int, src: DTypePointer[DType.uint8], dst: DTypePointer[DType.uint8], size: socklen_t +) -> DTypePointer[DType.uint8]: + """Libc POSIX `inet_ntop` function + Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html. + Fn signature: const char *inet_ntop(int af, const void *restrict src, char *restrict dst, socklen_t size). + + Args: + af: Address Family see AF_ aliases. + src: A pointer to a binary address. + dst: A pointer to a buffer to store the result. + size: The size of the buffer. + + Returns: + A pointer to the buffer containing the result. + """ + return external_call[ + "inet_ntop", + DTypePointer[DType.uint8], # FnName, RetType + c_int, + DTypePointer[DType.uint8], + DTypePointer[DType.uint8], + socklen_t, # Args + ](af, src, dst, size) + + +fn inet_pton(af: c_int, src: DTypePointer[DType.uint8], dst: DTypePointer[DType.uint8]) -> c_int: + """Libc POSIX `inet_pton` function + Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html + Fn signature: int inet_pton(int af, const char *restrict src, void *restrict dst). + + Args: af: Address Family see AF_ aliases. + src: A pointer to a string containing the address. + dst: A pointer to a buffer to store the result. + Returns: 1 on success, 0 if the input is not a valid address, -1 on error. + """ + return external_call[ + "inet_pton", + c_int, # FnName, RetType + c_int, + DTypePointer[DType.uint8], + DTypePointer[DType.uint8], # Args + ](af, src, dst) + + +fn inet_addr(cp: DTypePointer[DType.uint8]) -> in_addr_t: + """Libc POSIX `inet_addr` function + Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html + Fn signature: in_addr_t inet_addr(const char *cp). + + Args: cp: A pointer to a string containing the address. + Returns: The address in network byte order. + """ + return external_call["inet_addr", in_addr_t, DTypePointer[DType.uint8]](cp) + + +fn inet_ntoa(addr: in_addr) -> DTypePointer[DType.uint8]: + """Libc POSIX `inet_ntoa` function + Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html + Fn signature: char *inet_ntoa(struct in_addr in). + + Args: in: A pointer to a string containing the address. + Returns: The address in network byte order. + """ + return external_call["inet_ntoa", DTypePointer[DType.uint8], in_addr](addr) + + +fn socket(domain: c_int, type: c_int, protocol: c_int) -> c_int: + """Libc POSIX `socket` function + Reference: https://man7.org/linux/man-pages/man3/socket.3p.html + Fn signature: int socket(int domain, int type, int protocol). + + Args: domain: Address Family see AF_ aliases. + type: Socket Type see SOCK_ aliases. + protocol: The protocol to use. + Returns: A File Descriptor or -1 in case of failure. + """ + return external_call["socket", c_int, c_int, c_int, c_int](domain, type, protocol) # FnName, RetType # Args + + +fn setsockopt( + socket: c_int, + level: c_int, + option_name: c_int, + option_value: DTypePointer[DType.uint8], + option_len: socklen_t, +) -> c_int: + """Libc POSIX `setsockopt` function + Reference: https://man7.org/linux/man-pages/man3/setsockopt.3p.html + Fn signature: int setsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len). + + Args: + socket: A File Descriptor. + level: The protocol level. + option_name: The option to set. + option_value: A pointer to the value to set. + option_len: The size of the value. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "setsockopt", + c_int, # FnName, RetType + c_int, + c_int, + c_int, + DTypePointer[DType.uint8], + socklen_t, # Args + ](socket, level, option_name, option_value, option_len) + + +fn getsockopt( + socket: c_int, + level: c_int, + option_name: c_int, + option_value: DTypePointer[DType.uint8], + option_len: UnsafePointer[socklen_t], +) -> c_int: + """Libc POSIX `getsockopt` function + Reference: https://man7.org/linux/man-pages/man3/getsockopt.3p.html + Fn signature: int getsockopt(int socket, int level, int option_name, void *restrict option_value, socklen_t *restrict option_len). + + Args: socket: A File Descriptor. + level: The protocol level. + option_name: The option to get. + option_value: A pointer to the value to get. + option_len: DTypePointer to the size of the value. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getsockopt", + c_int, # FnName, RetType + c_int, + c_int, + c_int, + DTypePointer[DType.uint8], + UnsafePointer[socklen_t], # Args + ](socket, level, option_name, option_value, option_len) + + +fn getsockname(socket: c_int, address: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t]) -> c_int: + """Libc POSIX `getsockname` function + Reference: https://man7.org/linux/man-pages/man3/getsockname.3p.html + Fn signature: int getsockname(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). + + Args: socket: A File Descriptor. + address: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getsockname", + c_int, # FnName, RetType + c_int, + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args + ](socket, address, address_len) + + +fn getpeername(sockfd: c_int, addr: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t]) -> c_int: + """Libc POSIX `getpeername` function + Reference: https://man7.org/linux/man-pages/man2/getpeername.2.html + Fn signature: int getpeername(int socket, struct sockaddr *restrict addr, socklen_t *restrict address_len). + + Args: sockfd: A File Descriptor. + addr: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: 0 on success, -1 on error. + """ + return external_call[ + "getpeername", + c_int, # FnName, RetType + c_int, + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args + ](sockfd, addr, address_len) + + +fn bind(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int: + """Libc POSIX `bind` function + Reference: https://man7.org/linux/man-pages/man3/bind.3p.html + Fn signature: int bind(int socket, const struct sockaddr *address, socklen_t address_len). + """ + return external_call["bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t]( # FnName, RetType # Args + socket, address, address_len + ) + + +fn listen(socket: c_int, backlog: c_int) -> c_int: + """Libc POSIX `listen` function + Reference: https://man7.org/linux/man-pages/man3/listen.3p.html + Fn signature: int listen(int socket, int backlog). + + Args: socket: A File Descriptor. + backlog: The maximum length of the queue of pending connections. + Returns: 0 on success, -1 on error. + """ + return external_call["listen", c_int, c_int, c_int](socket, backlog) + + +fn accept(socket: c_int, address: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t]) -> c_int: + """Libc POSIX `accept` function + Reference: https://man7.org/linux/man-pages/man3/accept.3p.html + Fn signature: int accept(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). + + Args: socket: A File Descriptor. + address: A pointer to a buffer to store the address of the peer. + address_len: A pointer to the size of the buffer. + Returns: A File Descriptor or -1 in case of failure. + """ + return external_call[ + "accept", + c_int, # FnName, RetType + c_int, + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args + ](socket, address, address_len) + + +fn connect(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int: + """Libc POSIX `connect` function + Reference: https://man7.org/linux/man-pages/man3/connect.3p.html + Fn signature: int connect(int socket, const struct sockaddr *address, socklen_t address_len). + + Args: socket: A File Descriptor. + address: A pointer to the address to connect to. + address_len: The size of the address. + Returns: 0 on success, -1 on error. + """ + return external_call["connect", c_int, c_int, UnsafePointer[sockaddr], socklen_t]( # FnName, RetType # Args + socket, address, address_len + ) + + +fn recv(socket: c_int, buffer: DTypePointer[DType.uint8], length: c_size_t, flags: c_int) -> c_ssize_t: + """Libc POSIX `recv` function + Reference: https://man7.org/linux/man-pages/man3/recv.3p.html + Fn signature: ssize_t recv(int socket, void *buffer, size_t length, int flags). + """ + return external_call[ + "recv", + c_ssize_t, # FnName, RetType + c_int, + DTypePointer[DType.uint8], + c_size_t, + c_int, # Args + ](socket, buffer, length, flags) + + +fn send(socket: c_int, buffer: DTypePointer[DType.uint8], length: c_size_t, flags: c_int) -> c_ssize_t: + """Libc POSIX `send` function + Reference: https://man7.org/linux/man-pages/man3/send.3p.html + Fn signature: ssize_t send(int socket, const void *buffer, size_t length, int flags). + + Args: socket: A File Descriptor. + buffer: A pointer to the buffer to send. + length: The size of the buffer. + flags: Flags to control the behaviour of the function. + Returns: The number of bytes sent or -1 in case of failure. + """ + return external_call[ + "send", + c_ssize_t, # FnName, RetType + c_int, + DTypePointer[DType.uint8], + c_size_t, + c_int, # Args + ](socket, buffer, length, flags) + + +fn shutdown(socket: c_int, how: c_int) -> c_int: + """Libc POSIX `shutdown` function + Reference: https://man7.org/linux/man-pages/man3/shutdown.3p.html + Fn signature: int shutdown(int socket, int how). + + Args: socket: A File Descriptor. + how: How to shutdown the socket. + Returns: 0 on success, -1 on error. + """ + return external_call["shutdown", c_int, c_int, c_int](socket, how) # FnName, RetType # Args + + +fn getaddrinfo( + nodename: DTypePointer[DType.uint8], + servname: DTypePointer[DType.uint8], + hints: UnsafePointer[addrinfo], + res: UnsafePointer[UnsafePointer[addrinfo]], +) -> c_int: + """Libc POSIX `getaddrinfo` function + Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + Fn signature: int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res). + """ + return external_call[ + "getaddrinfo", + c_int, # FnName, RetType + DTypePointer[DType.uint8], + DTypePointer[DType.uint8], + UnsafePointer[addrinfo], # Args + UnsafePointer[UnsafePointer[addrinfo]], # Args + ](nodename, servname, hints, res) + + +fn getaddrinfo_unix( + nodename: DTypePointer[DType.uint8], + servname: DTypePointer[DType.uint8], + hints: UnsafePointer[addrinfo_unix], + res: UnsafePointer[UnsafePointer[addrinfo_unix]], +) -> c_int: + """Libc POSIX `getaddrinfo` function + Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html + Fn signature: int getaddrinfo(const char *restrict nodename, const char *restrict servname, const struct addrinfo *restrict hints, struct addrinfo **restrict res). + """ + return external_call[ + "getaddrinfo", + c_int, # FnName, RetType + DTypePointer[DType.uint8], + DTypePointer[DType.uint8], + UnsafePointer[addrinfo_unix], # Args + UnsafePointer[UnsafePointer[addrinfo_unix]], # Args + ](nodename, servname, hints, res) + + +fn gai_strerror(ecode: c_int) -> DTypePointer[DType.uint8]: + """Libc POSIX `gai_strerror` function + Reference: https://man7.org/linux/man-pages/man3/gai_strerror.3p.html + Fn signature: const char *gai_strerror(int ecode). + + Args: ecode: The error code. + Returns: A pointer to a string describing the error. + """ + return external_call["gai_strerror", DTypePointer[DType.uint8], c_int](ecode) # FnName, RetType # Args + + +# fn inet_pton(address_family: Int, address: String) -> Int: +# var ip_buf_size = 4 +# if address_family == AF_INET6: +# ip_buf_size = 16 + +# var ip_buf = DTypePointer[DType.uint8].alloc(ip_buf_size) +# var conv_status = inet_pton(rebind[c_int](address_family), to_char_ptr(address), ip_buf) +# return int(ip_buf.bitcast[c_uint]().load()) diff --git a/external/gojo/syscall/types.mojo b/external/gojo/syscall/types.mojo new file mode 100644 index 00000000..6b2c49ad --- /dev/null +++ b/external/gojo/syscall/types.mojo @@ -0,0 +1,9 @@ +fn strlen(s: DTypePointer[DType.uint8]) -> c_size_t: + """Libc POSIX `strlen` function + Reference: https://man7.org/linux/man-pages/man3/strlen.3p.html + Fn signature: size_t strlen(const char *s). + + Args: s: A pointer to a C string. + Returns: The length of the string. + """ + return external_call["strlen", c_size_t, DTypePointer[DType.uint8]](s) diff --git a/lightbug_http/tests/__init__.mojo b/external/gojo/tests/__init__.mojo similarity index 100% rename from lightbug_http/tests/__init__.mojo rename to external/gojo/tests/__init__.mojo diff --git a/external/gojo/tests/wrapper.mojo b/external/gojo/tests/wrapper.mojo new file mode 100644 index 00000000..ee73f268 --- /dev/null +++ b/external/gojo/tests/wrapper.mojo @@ -0,0 +1,38 @@ +from testing import testing + + +@value +struct MojoTest: + """ + A utility struct for testing. + """ + + var test_name: String + + fn __init__(inout self, test_name: String): + self.test_name = test_name + print("# " + test_name) + + fn assert_true(self, cond: Bool, message: String = ""): + try: + if message == "": + testing.assert_true(cond) + else: + testing.assert_true(cond, message) + except e: + print(e) + + fn assert_false(self, cond: Bool, message: String = ""): + try: + if message == "": + testing.assert_false(cond) + else: + testing.assert_false(cond, message) + except e: + print(e) + + fn assert_equal[T: testing.Testable](self, left: T, right: T): + try: + testing.assert_equal(left, right) + except e: + print(e) \ No newline at end of file diff --git a/external/gojo/unicode/__init__.mojo b/external/gojo/unicode/__init__.mojo new file mode 100644 index 00000000..bd4cba63 --- /dev/null +++ b/external/gojo/unicode/__init__.mojo @@ -0,0 +1 @@ +from .utf8 import string_iterator, rune_count_in_string diff --git a/external/gojo/unicode/utf8/__init__.mojo b/external/gojo/unicode/utf8/__init__.mojo new file mode 100644 index 00000000..b8732ec8 --- /dev/null +++ b/external/gojo/unicode/utf8/__init__.mojo @@ -0,0 +1,4 @@ +"""Almost all of the actual implementation in this module was written by @mzaks (https://github.com/mzaks)! +This would not be possible without his help. +""" +from .runes import string_iterator, rune_count_in_string diff --git a/external/gojo/unicode/utf8/runes.mojo b/external/gojo/unicode/utf8/runes.mojo new file mode 100644 index 00000000..7346162b --- /dev/null +++ b/external/gojo/unicode/utf8/runes.mojo @@ -0,0 +1,334 @@ +"""Almost all of the actual implementation in this module was written by @mzaks (https://github.com/mzaks)! +This would not be possible without his help. +""" + +from ...builtins import Rune +from algorithm.functional import vectorize +from memory.unsafe import DTypePointer +from sys.info import simdwidthof +from bit import countl_zero + + +# The default lowest and highest continuation byte. +alias locb = 0b10000000 +alias hicb = 0b10111111 +alias RUNE_SELF = 0x80 # Characters below RuneSelf are represented as themselves in a single byte + + +# acceptRange gives the range of valid values for the second byte in a UTF-8 +# sequence. +@value +struct AcceptRange(CollectionElement): + var lo: UInt8 # lowest value for second byte. + var hi: UInt8 # highest value for second byte. + + +# ACCEPT_RANGES has size 16 to avoid bounds checks in the code that uses it. +alias ACCEPT_RANGES = List[AcceptRange]( + AcceptRange(locb, hicb), + AcceptRange(0xA0, hicb), + AcceptRange(locb, 0x9F), + AcceptRange(0x90, hicb), + AcceptRange(locb, 0x8F), +) + +# These names of these constants are chosen to give nice alignment in the +# table below. The first nibble is an index into acceptRanges or F for +# special one-byte cases. The second nibble is the Rune length or the +# Status for the special one-byte case. +alias xx = 0xF1 # invalid: size 1 +alias as1 = 0xF0 # ASCII: size 1 +alias s1 = 0x02 # accept 0, size 2 +alias s2 = 0x13 # accept 1, size 3 +alias s3 = 0x03 # accept 0, size 3 +alias s4 = 0x23 # accept 2, size 3 +alias s5 = 0x34 # accept 3, size 4 +alias s6 = 0x04 # accept 0, size 4 +alias s7 = 0x44 # accept 4, size 4 + + +# first is information about the first byte in a UTF-8 sequence. +var first = List[UInt8]( + # 1 2 3 4 5 6 7 8 9 A B C D E F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x00-0x0F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x10-0x1F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x20-0x2F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x30-0x3F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x40-0x4F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x50-0x5F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x60-0x6F + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, + as1, # 0x70-0x7F + # 1 2 3 4 5 6 7 8 9 A B C D E F + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, # 0x80-0x8F + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, # 0x90-0x9F + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, # 0xA0-0xAF + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, # 0xB0-0xBF + xx, + xx, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, # 0xC0-0xCF + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, + s1, # 0xD0-0xDF + s2, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s3, + s4, + s3, + s3, # 0xE0-0xEF + s5, + s6, + s6, + s6, + s7, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, + xx, # 0xF0-0xFF +) + + +alias simd_width_u8 = simdwidthof[DType.uint8]() + + +fn rune_count_in_string(s: String) -> Int: + """Count the number of runes in a string. + + Args: + s: The string to count runes in. + + Returns: + The number of runes in the string. + """ + var p = DTypePointer[DType.uint8](s.unsafe_uint8_ptr()) + var string_byte_length = len(s) + var result = 0 + + @parameter + fn count[simd_width: Int](offset: Int): + result += int(((p.load[width=simd_width](offset) >> 6) != 0b10).reduce_add()) + + vectorize[count, simd_width_u8](string_byte_length) + return result diff --git a/external/libc.mojo b/external/libc.mojo index 1e2eceb2..9e07b0e7 100644 --- a/external/libc.mojo +++ b/external/libc.mojo @@ -1,3 +1,4 @@ +from utils import StaticTuple from lightbug_http.io.bytes import Bytes alias IPPROTO_IPV6 = 41 @@ -13,7 +14,7 @@ alias FD_STDERR: c_int = 2 alias SUCCESS = 0 alias GRND_NONBLOCK: UInt8 = 1 -alias char_pointer = UnsafePointer[c_char] +alias char_UnsafePointer = UnsafePointer[c_char] # Adapted from https://github.com/crisadamo/mojo-Libc . Huge thanks to Cristian! # C types @@ -77,22 +78,21 @@ alias ERANGE = 34 alias EWOULDBLOCK = EAGAIN -fn to_char_ptr(s: String) -> Pointer[c_char]: +fn to_char_ptr(s: String) -> UnsafePointer[c_char]: """Only ASCII-based strings.""" - var ptr = Pointer[c_char]().alloc(len(s)) + var ptr = UnsafePointer[c_char]().alloc(len(s)) for i in range(len(s)): - ptr.store(i, ord(s[i])) + ptr[i] = ord(s[i]) return ptr - -fn to_char_ptr(s: Bytes) -> Pointer[c_char]: - var ptr = Pointer[c_char]().alloc(len(s)) +fn to_char_ptr(s: Bytes) -> UnsafePointer[c_char]: + var ptr = UnsafePointer[c_char]().alloc(len(s)) for i in range(len(s)): - ptr.store(i, int(s[i])) + ptr[i] = int(s[i]) return ptr -fn c_charptr_to_string(s: Pointer[c_char]) -> String: - return String(s.bitcast[Int8](), strlen(s)) +fn c_charptr_to_string(s: UnsafePointer[c_char]) -> String: + return String(s.bitcast[UInt8](), strlen(s)) fn cftob(val: c_int) -> Bool: @@ -352,26 +352,26 @@ struct addrinfo: var ai_socktype: c_int var ai_protocol: c_int var ai_addrlen: socklen_t - var ai_addr: Pointer[sockaddr] - var ai_canonname: Pointer[c_char] - # FIXME(cristian): This should be Pointer[addrinfo] - var ai_next: Pointer[c_void] + var ai_addr: UnsafePointer[sockaddr] + var ai_canonname: UnsafePointer[c_char] + # FIXME(cristian): This should be UnsafePointer[addrinfo] + var ai_next: UnsafePointer[c_void] fn __init__() -> Self: return Self( - 0, 0, 0, 0, 0, Pointer[sockaddr](), Pointer[c_char](), Pointer[c_void]() + 0, 0, 0, 0, 0, UnsafePointer[sockaddr](), UnsafePointer[c_char](), UnsafePointer[c_void]() ) -fn strlen(s: Pointer[c_char]) -> c_size_t: +fn strlen(s: UnsafePointer[c_char]) -> c_size_t: """Libc POSIX `strlen` function Reference: https://man7.org/linux/man-pages/man3/strlen.3p.html Fn signature: size_t strlen(const char *s). - Args: s: A pointer to a C string. + Args: s: A UnsafePointer to a C string. Returns: The length of the string. """ - return external_call["strlen", c_size_t, Pointer[c_char]](s) + return external_call["strlen", c_size_t, UnsafePointer[c_char]](s) # --- ( Network Related Syscalls & Structs )------------------------------------ @@ -422,70 +422,70 @@ fn ntohs(netshort: c_ushort) -> c_ushort: fn inet_ntop( - af: c_int, src: Pointer[c_void], dst: Pointer[c_char], size: socklen_t -) -> Pointer[c_char]: + af: c_int, src: UnsafePointer[c_void], dst: UnsafePointer[c_char], size: socklen_t +) -> UnsafePointer[c_char]: """Libc POSIX `inet_ntop` function Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html. Fn signature: const char *inet_ntop(int af, const void *restrict src, char *restrict dst, socklen_t size). Args: af: Address Family see AF_ aliases. - src: A pointer to a binary address. - dst: A pointer to a buffer to store the result. + src: A UnsafePointer to a binary address. + dst: A UnsafePointer to a buffer to store the result. size: The size of the buffer. Returns: - A pointer to the buffer containing the result. + A UnsafePointer to the buffer containing the result. """ return external_call[ "inet_ntop", - Pointer[c_char], # FnName, RetType + UnsafePointer[c_char], # FnName, RetType c_int, - Pointer[c_void], - Pointer[c_char], + UnsafePointer[c_void], + UnsafePointer[c_char], socklen_t, # Args ](af, src, dst, size) -fn inet_pton(af: c_int, src: Pointer[c_char], dst: Pointer[c_void]) -> c_int: +fn inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) -> c_int: """Libc POSIX `inet_pton` function Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html Fn signature: int inet_pton(int af, const char *restrict src, void *restrict dst). Args: af: Address Family see AF_ aliases. - src: A pointer to a string containing the address. - dst: A pointer to a buffer to store the result. + src: A UnsafePointer to a string containing the address. + dst: A UnsafePointer to a buffer to store the result. Returns: 1 on success, 0 if the input is not a valid address, -1 on error. """ return external_call[ "inet_pton", c_int, # FnName, RetType c_int, - Pointer[c_char], - Pointer[c_void], # Args + UnsafePointer[c_char], + UnsafePointer[c_void], # Args ](af, src, dst) -fn inet_addr(cp: Pointer[c_char]) -> in_addr_t: +fn inet_addr(cp: UnsafePointer[c_char]) -> in_addr_t: """Libc POSIX `inet_addr` function Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html Fn signature: in_addr_t inet_addr(const char *cp). - Args: cp: A pointer to a string containing the address. + Args: cp: A UnsafePointer to a string containing the address. Returns: The address in network byte order. """ - return external_call["inet_addr", in_addr_t, Pointer[c_char]](cp) + return external_call["inet_addr", in_addr_t, UnsafePointer[c_char]](cp) -fn inet_ntoa(addr: in_addr) -> Pointer[c_char]: +fn inet_ntoa(addr: in_addr) -> UnsafePointer[c_char]: """Libc POSIX `inet_ntoa` function Reference: https://man7.org/linux/man-pages/man3/inet_addr.3p.html Fn signature: char *inet_ntoa(struct in_addr in). - Args: in: A pointer to a string containing the address. + Args: in: A UnsafePointer to a string containing the address. Returns: The address in network byte order. """ - return external_call["inet_ntoa", Pointer[c_char], in_addr](addr) + return external_call["inet_ntoa", UnsafePointer[c_char], in_addr](addr) fn socket(domain: c_int, type: c_int, protocol: c_int) -> c_int: @@ -507,7 +507,7 @@ fn setsockopt( socket: c_int, level: c_int, option_name: c_int, - option_value: Pointer[c_void], + option_value: UnsafePointer[c_void], option_len: socklen_t, ) -> c_int: """Libc POSIX `setsockopt` function @@ -517,7 +517,7 @@ fn setsockopt( Args: socket: A File Descriptor. level: The protocol level. option_name: The option to set. - option_value: A pointer to the value to set. + option_value: A UnsafePointer to the value to set. option_len: The size of the value. Returns: 0 on success, -1 on error. """ @@ -527,60 +527,60 @@ fn setsockopt( c_int, c_int, c_int, - Pointer[c_void], + UnsafePointer[c_void], socklen_t, # Args ](socket, level, option_name, option_value, option_len) fn getsockname( - socket: c_int, address: Pointer[sockaddr], address_len: Pointer[socklen_t] + socket: c_int, address: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t] ) -> c_int: """Libc POSIX `getsockname` function Reference: https://man7.org/linux/man-pages/man3/getsockname.3p.html Fn signature: int getsockname(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). Args: socket: A File Descriptor. - address: A pointer to a buffer to store the address of the peer. - address_len: A pointer to the size of the buffer. + address: A UnsafePointer to a buffer to store the address of the peer. + address_len: A UnsafePointer to the size of the buffer. Returns: 0 on success, -1 on error. """ return external_call[ "getsockname", c_int, # FnName, RetType c_int, - Pointer[sockaddr], - Pointer[socklen_t], # Args + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args ](socket, address, address_len) fn getpeername( - sockfd: c_int, addr: Pointer[sockaddr], address_len: Pointer[socklen_t] + sockfd: c_int, addr: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t] ) -> c_int: """Libc POSIX `getpeername` function Reference: https://man7.org/linux/man-pages/man2/getpeername.2.html Fn signature: int getpeername(int socket, struct sockaddr *restrict addr, socklen_t *restrict address_len). Args: sockfd: A File Descriptor. - addr: A pointer to a buffer to store the address of the peer. - address_len: A pointer to the size of the buffer. + addr: A UnsafePointer to a buffer to store the address of the peer. + address_len: A UnsafePointer to the size of the buffer. Returns: 0 on success, -1 on error. """ return external_call[ "getpeername", c_int, # FnName, RetType c_int, - Pointer[sockaddr], - Pointer[socklen_t], # Args + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args ](sockfd, addr, address_len) -fn bind(socket: c_int, address: Pointer[sockaddr], address_len: socklen_t) -> c_int: +fn bind(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int: """Libc POSIX `bind` function Reference: https://man7.org/linux/man-pages/man3/bind.3p.html Fn signature: int bind(int socket, const struct sockaddr *address, socklen_t address_len). """ return external_call[ - "bind", c_int, c_int, Pointer[sockaddr], socklen_t # FnName, RetType # Args + "bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t # FnName, RetType # Args ](socket, address, address_len) @@ -597,43 +597,63 @@ fn listen(socket: c_int, backlog: c_int) -> c_int: fn accept( - socket: c_int, address: Pointer[sockaddr], address_len: Pointer[socklen_t] + socket: c_int, address: UnsafePointer[sockaddr], address_len: UnsafePointer[socklen_t] ) -> c_int: """Libc POSIX `accept` function Reference: https://man7.org/linux/man-pages/man3/accept.3p.html Fn signature: int accept(int socket, struct sockaddr *restrict address, socklen_t *restrict address_len). Args: socket: A File Descriptor. - address: A pointer to a buffer to store the address of the peer. - address_len: A pointer to the size of the buffer. + address: A UnsafePointer to a buffer to store the address of the peer. + address_len: A UnsafePointer to the size of the buffer. Returns: A File Descriptor or -1 in case of failure. """ return external_call[ "accept", c_int, # FnName, RetType c_int, - Pointer[sockaddr], - Pointer[socklen_t], # Args + UnsafePointer[sockaddr], + UnsafePointer[socklen_t], # Args ](socket, address, address_len) -fn connect(socket: c_int, address: Pointer[sockaddr], address_len: socklen_t) -> c_int: +fn connect(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int: """Libc POSIX `connect` function Reference: https://man7.org/linux/man-pages/man3/connect.3p.html Fn signature: int connect(int socket, const struct sockaddr *address, socklen_t address_len). Args: socket: A File Descriptor. - address: A pointer to the address to connect to. + address: A UnsafePointer to the address to connect to. address_len: The size of the address. Returns: 0 on success, -1 on error. """ return external_call[ - "connect", c_int, c_int, Pointer[sockaddr], socklen_t # FnName, RetType # Args + "connect", c_int, c_int, UnsafePointer[sockaddr], socklen_t # FnName, RetType # Args ](socket, address, address_len) +# fn recv( +# socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int +# ) -> c_ssize_t: +# """Libc POSIX `recv` function +# Reference: https://man7.org/linux/man-pages/man3/recv.3p.html +# Fn signature: ssize_t recv(int socket, void *buffer, size_t length, int flags). +# """ +# return external_call[ +# "recv", +# c_ssize_t, # FnName, RetType +# c_int, +# UnsafePointer[c_void], +# c_size_t, +# c_int, # Args +# ](socket, buffer, length, flags) + + fn recv( - socket: c_int, buffer: Pointer[c_void], length: c_size_t, flags: c_int + socket: c_int, + buffer: DTypePointer[DType.uint8], + length: c_size_t, + flags: c_int, ) -> c_ssize_t: """Libc POSIX `recv` function Reference: https://man7.org/linux/man-pages/man3/recv.3p.html @@ -643,21 +663,20 @@ fn recv( "recv", c_ssize_t, # FnName, RetType c_int, - Pointer[c_void], + DTypePointer[DType.uint8], c_size_t, c_int, # Args ](socket, buffer, length, flags) - fn send( - socket: c_int, buffer: Pointer[c_void], length: c_size_t, flags: c_int + socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int ) -> c_ssize_t: """Libc POSIX `send` function Reference: https://man7.org/linux/man-pages/man3/send.3p.html Fn signature: ssize_t send(int socket, const void *buffer, size_t length, int flags). Args: socket: A File Descriptor. - buffer: A pointer to the buffer to send. + buffer: A UnsafePointer to the buffer to send. length: The size of the buffer. flags: Flags to control the behaviour of the function. Returns: The number of bytes sent or -1 in case of failure. @@ -666,7 +685,7 @@ fn send( "send", c_ssize_t, # FnName, RetType c_int, - Pointer[c_void], + UnsafePointer[c_void], c_size_t, c_int, # Args ](socket, buffer, length, flags) @@ -687,10 +706,10 @@ fn shutdown(socket: c_int, how: c_int) -> c_int: fn getaddrinfo( - nodename: Pointer[c_char], - servname: Pointer[c_char], - hints: Pointer[addrinfo], - res: Pointer[Pointer[addrinfo]], + nodename: UnsafePointer[c_char], + servname: UnsafePointer[c_char], + hints: UnsafePointer[addrinfo], + res: UnsafePointer[UnsafePointer[addrinfo]], ) -> c_int: """Libc POSIX `getaddrinfo` function Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html @@ -699,23 +718,23 @@ fn getaddrinfo( return external_call[ "getaddrinfo", c_int, # FnName, RetType - Pointer[c_char], - Pointer[c_char], - Pointer[addrinfo], # Args - Pointer[Pointer[addrinfo]], # Args + UnsafePointer[c_char], + UnsafePointer[c_char], + UnsafePointer[addrinfo], # Args + UnsafePointer[UnsafePointer[addrinfo]], # Args ](nodename, servname, hints, res) -fn gai_strerror(ecode: c_int) -> Pointer[c_char]: +fn gai_strerror(ecode: c_int) -> UnsafePointer[c_char]: """Libc POSIX `gai_strerror` function Reference: https://man7.org/linux/man-pages/man3/gai_strerror.3p.html Fn signature: const char *gai_strerror(int ecode). Args: ecode: The error code. - Returns: A pointer to a string describing the error. + Returns: A UnsafePointer to a string describing the error. """ return external_call[ - "gai_strerror", Pointer[c_char], c_int # FnName, RetType # Args + "gai_strerror", UnsafePointer[c_char], c_int # FnName, RetType # Args ](ecode) @@ -724,11 +743,11 @@ fn inet_pton(address_family: Int, address: String) -> Int: if address_family == AF_INET6: ip_buf_size = 16 - var ip_buf = Pointer[c_void].alloc(ip_buf_size) + var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) var conv_status = inet_pton( rebind[c_int](address_family), to_char_ptr(address), ip_buf ) - return int(ip_buf.bitcast[c_uint]().load()) + return int(ip_buf.bitcast[c_uint]()) # --- ( File Related Syscalls & Structs )--------------------------------------- @@ -752,102 +771,102 @@ fn close(fildes: c_int) -> c_int: return external_call["close", c_int, c_int](fildes) -fn open[*T: AnyType](path: Pointer[c_char], oflag: c_int, *args: *T) -> c_int: +fn open[*T: AnyType](path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int: """Libc POSIX `open` function Reference: https://man7.org/linux/man-pages/man3/open.3p.html Fn signature: int open(const char *path, int oflag, ...). Args: - path: A pointer to a C string containing the path to open. + path: A UnsafePointer to a C string containing the path to open. oflag: The flags to open the file with. args: The optional arguments. Returns: A File Descriptor or -1 in case of failure """ return external_call[ - "open", c_int, Pointer[c_char], c_int # FnName, RetType # Args + "open", c_int, UnsafePointer[c_char], c_int # FnName, RetType # Args ](path, oflag, args) fn openat[ *T: AnyType -](fd: c_int, path: Pointer[c_char], oflag: c_int, *args: *T) -> c_int: +](fd: c_int, path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int: """Libc POSIX `open` function Reference: https://man7.org/linux/man-pages/man3/open.3p.html Fn signature: int openat(int fd, const char *path, int oflag, ...). Args: fd: A File Descriptor. - path: A pointer to a C string containing the path to open. + path: A UnsafePointer to a C string containing the path to open. oflag: The flags to open the file with. args: The optional arguments. Returns: A File Descriptor or -1 in case of failure """ return external_call[ - "openat", c_int, c_int, Pointer[c_char], c_int # FnName, RetType # Args + "openat", c_int, c_int, UnsafePointer[c_char], c_int # FnName, RetType # Args ](fd, path, oflag, args) -fn printf[*T: AnyType](format: Pointer[c_char], *args: *T) -> c_int: +fn printf[*T: AnyType](format: UnsafePointer[c_char], *args: *T) -> c_int: """Libc POSIX `printf` function Reference: https://man7.org/linux/man-pages/man3/fprintf.3p.html Fn signature: int printf(const char *restrict format, ...). - Args: format: A pointer to a C string containing the format. + Args: format: A UnsafePointer to a C string containing the format. args: The optional arguments. Returns: The number of bytes written or -1 in case of failure. """ return external_call[ "printf", c_int, # FnName, RetType - Pointer[c_char], # Args + UnsafePointer[c_char], # Args ](format, args) fn sprintf[ *T: AnyType -](s: Pointer[c_char], format: Pointer[c_char], *args: *T) -> c_int: +](s: UnsafePointer[c_char], format: UnsafePointer[c_char], *args: *T) -> c_int: """Libc POSIX `sprintf` function Reference: https://man7.org/linux/man-pages/man3/fprintf.3p.html Fn signature: int sprintf(char *restrict s, const char *restrict format, ...). - Args: s: A pointer to a buffer to store the result. - format: A pointer to a C string containing the format. + Args: s: A UnsafePointer to a buffer to store the result. + format: A UnsafePointer to a C string containing the format. args: The optional arguments. Returns: The number of bytes written or -1 in case of failure. """ return external_call[ - "sprintf", c_int, Pointer[c_char], Pointer[c_char] # FnName, RetType # Args + "sprintf", c_int, UnsafePointer[c_char], UnsafePointer[c_char] # FnName, RetType # Args ](s, format, args) -fn read(fildes: c_int, buf: Pointer[c_void], nbyte: c_size_t) -> c_int: +fn read(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: """Libc POSIX `read` function Reference: https://man7.org/linux/man-pages/man3/read.3p.html Fn signature: sssize_t read(int fildes, void *buf, size_t nbyte). Args: fildes: A File Descriptor. - buf: A pointer to a buffer to store the read data. + buf: A UnsafePointer to a buffer to store the read data. nbyte: The number of bytes to read. Returns: The number of bytes read or -1 in case of failure. """ - return external_call["read", c_ssize_t, c_int, Pointer[c_void], c_size_t]( + return external_call["read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t]( fildes, buf, nbyte ) -fn write(fildes: c_int, buf: Pointer[c_void], nbyte: c_size_t) -> c_int: +fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: """Libc POSIX `write` function Reference: https://man7.org/linux/man-pages/man3/write.3p.html Fn signature: ssize_t write(int fildes, const void *buf, size_t nbyte). Args: fildes: A File Descriptor. - buf: A pointer to a buffer to write. + buf: A UnsafePointer to a buffer to write. nbyte: The number of bytes to write. Returns: The number of bytes written or -1 in case of failure. """ - return external_call["write", c_ssize_t, c_int, Pointer[c_void], c_size_t]( + return external_call["write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t]( fildes, buf, nbyte ) @@ -859,8 +878,8 @@ fn __test_getaddrinfo__(): var ip_addr = "127.0.0.1" var port = 8083 - var servinfo = Pointer[addrinfo]().alloc(1) - servinfo.store(addrinfo()) + var servinfo = UnsafePointer[addrinfo]().alloc(1) + servinfo[0] = addrinfo() var hints = addrinfo() hints.ai_family = AF_INET @@ -870,139 +889,138 @@ fn __test_getaddrinfo__(): var status = getaddrinfo( to_char_ptr(ip_addr), - Pointer[UInt8](), - Pointer.address_of(hints), - Pointer.address_of(servinfo), + UnsafePointer[UInt8](), + UnsafePointer.address_of(hints), + UnsafePointer.address_of(servinfo), ) var msg_ptr = gai_strerror(c_int(status)) - _ = external_call["printf", c_int, Pointer[c_char], Pointer[c_char]]( + _ = external_call["printf", c_int, UnsafePointer[c_char], UnsafePointer[c_char]]( to_char_ptr("gai_strerror: %s"), msg_ptr ) var msg = c_charptr_to_string(msg_ptr) print("getaddrinfo satus: " + msg) -fn __test_socket_client__(): - var ip_addr = "127.0.0.1" # The server's hostname or IP address - var port = 8080 # The port used by the server - var address_family = AF_INET - - var ip_buf = Pointer[c_void].alloc(4) - var conv_status = inet_pton(address_family, to_char_ptr(ip_addr), ip_buf) - var raw_ip = ip_buf.bitcast[c_uint]().load() - - print("inet_pton: " + raw_ip.__str__() + " :: status: " + conv_status.__str__()) - - var bin_port = htons(UInt16(port)) - print("htons: " + "\n" + bin_port.__str__()) - - var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) - var ai_ptr = Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() - - var sockfd = socket(address_family, SOCK_STREAM, 0) - if sockfd == -1: - print("Socket creation error") - print("sockfd: " + "\n" + sockfd.__str__()) - - if connect(sockfd, ai_ptr, sizeof[sockaddr_in]()) == -1: - _ = shutdown(sockfd, SHUT_RDWR) - print("Connection error") - return # Ensure to exit if connection fails - - var msg = to_char_ptr("Hello, world Server") - var bytes_sent = send(sockfd, msg, strlen(msg), 0) - if bytes_sent == -1: - print("Failed to send message") - else: - print("Message sent") - var buf_size = 1024 - var buf = Pointer[UInt8]().alloc(buf_size) - var bytes_recv = recv(sockfd, buf, buf_size, 0) - if bytes_recv == -1: - print("Failed to receive message") - else: - print("Received Message: ") - print(String(buf.bitcast[Int8](), bytes_recv)) - - _ = shutdown(sockfd, SHUT_RDWR) - var close_status = close(sockfd) - if close_status == -1: - print("Failed to close socket") - - -fn __test_socket_server__() raises: - var ip_addr = "127.0.0.1" - var port = 8083 - - var address_family = AF_INET - var ip_buf_size = 4 - if address_family == AF_INET6: - ip_buf_size = 16 - - var ip_buf = Pointer[c_void].alloc(ip_buf_size) - var conv_status = inet_pton(address_family, to_char_ptr(ip_addr), ip_buf) - var raw_ip = ip_buf.bitcast[c_uint]().load() - - print("inet_pton: " + raw_ip.__str__() + " :: status: " + conv_status.__str__()) - - var bin_port = htons(UInt16(port)) - print("htons: " + "\n" + bin_port.__str__()) - - var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) - var ai_ptr = Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() - - var sockfd = socket(address_family, SOCK_STREAM, 0) - if sockfd == -1: - print("Socket creation error") - print("sockfd: " + "\n" + sockfd.__str__()) - - var yes: Int = 1 - if ( - setsockopt( - sockfd, - SOL_SOCKET, - SO_REUSEADDR, - Pointer[Int].address_of(yes).bitcast[c_void](), - sizeof[Int](), - ) - == -1 - ): - print("set socket options failed") - - if bind(sockfd, ai_ptr, sizeof[sockaddr_in]()) == -1: - # close(sockfd) - _ = shutdown(sockfd, SHUT_RDWR) - print("Binding socket failed. Wait a few seconds and try again?") - - if listen(sockfd, c_int(128)) == -1: - print("Listen failed.\n on sockfd " + sockfd.__str__()) - - print( - "server: started at " - + ip_addr - + ":" - + port.__str__() - + " on sockfd " - + sockfd.__str__() - + "Waiting for connections..." - ) - - var their_addr_ptr = Pointer[sockaddr].alloc(1) - var sin_size = socklen_t(sizeof[socklen_t]()) - var new_sockfd = accept( - sockfd, their_addr_ptr, Pointer[socklen_t].address_of(sin_size) - ) - if new_sockfd == -1: - print("Accept failed") - # close(sockfd) - _ = shutdown(sockfd, SHUT_RDWR) - - var msg = "Hello, Mojo!" - if send(new_sockfd, to_char_ptr(msg).bitcast[c_void](), len(msg), 0) == -1: - print("Failed to send response") - print("Message sent succesfully") - _ = shutdown(sockfd, SHUT_RDWR) - - var close_status = close(new_sockfd) - if close_status == -1: - print("Failed to close new_sockfd") +# fn __test_socket_client__(): +# var ip_addr = "127.0.0.1" # The server's hostname or IP address +# var port = 8080 # The port used by the server +# var address_family = AF_INET + +# var ip_buf = UnsafePointer[c_void].alloc(4) +# var conv_status = inet_pton(address_family, to_char_ptr(ip_addr), ip_buf) +# var raw_ip = ip_buf.bitcast[c_uint]() + +# print("inet_pton: " + raw_ip.__str__() + " :: status: " + conv_status.__str__()) + +# var bin_port = htons(UInt16(port)) +# print("htons: " + "\n" + bin_port.__str__()) + +# var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) +# var ai_ptr = UnsafePointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() + +# var sockfd = socket(address_family, SOCK_STREAM, 0) +# if sockfd == -1: +# print("Socket creation error") +# print("sockfd: " + "\n" + sockfd.__str__()) + +# if connect(sockfd, ai_ptr, sizeof[sockaddr_in]()) == -1: +# _ = shutdown(sockfd, SHUT_RDWR) +# print("Connection error") +# return # Ensure to exit if connection fails + +# var msg = to_char_ptr("Hello, world Server") +# var bytes_sent = send(sockfd, msg, strlen(msg), 0) +# if bytes_sent == -1: +# print("Failed to send message") +# else: +# print("Message sent") +# var buf_size = 1024 +# var buf = UnsafePointer[UInt8]().alloc(buf_size) +# var bytes_recv = recv(sockfd, buf, buf_size, 0) +# if bytes_recv == -1: +# print("Failed to receive message") +# else: +# print("Received Message: ") +# print(String(buf.bitcast[UInt8](), bytes_recv)) + +# _ = shutdown(sockfd, SHUT_RDWR) +# var close_status = close(sockfd) +# if close_status == -1: +# print("Failed to close socket") + + +# fn __test_socket_server__() raises: +# var ip_addr = "127.0.0.1" +# var port = 8083 + +# var address_family = AF_INET +# var ip_buf_size = 4 +# if address_family == AF_INET6: +# ip_buf_size = 16 + +# var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) +# var conv_status = inet_pton(address_family, to_char_ptr(ip_addr), ip_buf) +# var raw_ip = ip_buf.bitcast[c_uint]() +# print("inet_pton: " + raw_ip.__str__() + " :: status: " + conv_status.__str__()) + +# var bin_port = htons(UInt16(port)) +# print("htons: " + "\n" + bin_port.__str__()) + +# var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) +# var ai_ptr = UnsafePointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() + +# var sockfd = socket(address_family, SOCK_STREAM, 0) +# if sockfd == -1: +# print("Socket creation error") +# print("sockfd: " + "\n" + sockfd.__str__()) + +# var yes: Int = 1 +# if ( +# setsockopt( +# sockfd, +# SOL_SOCKET, +# SO_REUSEADDR, +# UnsafePointer[Int].address_of(yes).bitcast[c_void](), +# sizeof[Int](), +# ) +# == -1 +# ): +# print("set socket options failed") + +# if bind(sockfd, ai_ptr, sizeof[sockaddr_in]()) == -1: +# # close(sockfd) +# _ = shutdown(sockfd, SHUT_RDWR) +# print("Binding socket failed. Wait a few seconds and try again?") + +# if listen(sockfd, c_int(128)) == -1: +# print("Listen failed.\n on sockfd " + sockfd.__str__()) + +# print( +# "server: started at " +# + ip_addr +# + ":" +# + port.__str__() +# + " on sockfd " +# + sockfd.__str__() +# + "Waiting for connections..." +# ) + +# var their_addr_ptr = UnsafePointer[sockaddr].alloc(1) +# var sin_size = socklen_t(sizeof[socklen_t]()) +# var new_sockfd = accept( +# sockfd, their_addr_ptr, UnsafePointer[socklen_t].address_of(sin_size) +# ) +# if new_sockfd == -1: +# print("Accept failed") +# # close(sockfd) +# _ = shutdown(sockfd, SHUT_RDWR) + +# var msg = "Hello, Mojo!" +# if send(new_sockfd, to_char_ptr(msg).bitcast[c_void](), len(msg), 0) == -1: +# print("Failed to send response") +# print("Message sent succesfully") +# _ = shutdown(sockfd, SHUT_RDWR) + +# var close_status = close(new_sockfd) +# if close_status == -1: +# print("Failed to close new_sockfd") diff --git a/external/morrow.mojo b/external/morrow.mojo index 52bb8fd2..1c040acc 100644 --- a/external/morrow.mojo +++ b/external/morrow.mojo @@ -293,7 +293,7 @@ def normalize_timestamp(timestamp: Float64) -> Float64: timestamp /= 1_000_000 else: raise Error( - "The specified timestamp " + String(timestamp) + "is too large." + "The specified timestamp " + timestamp.__str__() + "is too large." ) return timestamp @@ -311,4 +311,4 @@ fn rjust(string: String, width: Int, fillchar: String = " ") -> String: fn rjust(string: Int, width: Int, fillchar: String = " ") -> String: - return rjust(String(string), width, fillchar) + return rjust(string.__str__(), width, fillchar) diff --git a/lightbug_http/error.mojo b/lightbug_http/error.mojo index e19e87d0..ab9091d7 100644 --- a/lightbug_http/error.mojo +++ b/lightbug_http/error.mojo @@ -1,9 +1,9 @@ from lightbug_http.http import HTTPResponse from lightbug_http.header import ResponseHeader - +from lightbug_http.io.bytes import bytes # TODO: Custom error handlers provided by the user @value struct ErrorHandler: fn Error(self) -> HTTPResponse: - return HTTPResponse(ResponseHeader(), String("TODO")._buffer) \ No newline at end of file + return HTTPResponse(ResponseHeader(), bytes("TODO")) \ No newline at end of file diff --git a/lightbug_http/header.mojo b/lightbug_http/header.mojo index bfc3fff8..43bcf0e9 100644 --- a/lightbug_http/header.mojo +++ b/lightbug_http/header.mojo @@ -1,17 +1,19 @@ +from external.gojo.bufio import Reader from lightbug_http.strings import ( - next_line, strHttp11, strHttp10, strSlash, strMethodGet, rChar, nChar, + colonChar, + whitespace, + tab ) -from lightbug_http.io.bytes import Bytes, bytes_equal +from lightbug_http.io.bytes import Bytes, Byte, BytesView, bytes_equal, bytes, index_byte, compare_case_insensitive, next_line, last_index_byte alias statusOK = 200 - @value struct RequestHeader: var disable_normalization: Bool @@ -25,6 +27,7 @@ struct RequestHeader: var __host: Bytes var __content_type: Bytes var __user_agent: Bytes + var __transfer_encoding: Bytes var raw_headers: Bytes var __trailer: Bytes @@ -40,6 +43,7 @@ struct RequestHeader: self.__host = Bytes() self.__content_type = Bytes() self.__user_agent = Bytes() + self.__transfer_encoding = Bytes() self.raw_headers = Bytes() self.__trailer = Bytes() @@ -52,9 +56,10 @@ struct RequestHeader: self.__method = Bytes() self.__request_uri = Bytes() self.proto = Bytes() - self.__host = host._buffer + self.__host = bytes(host) self.__content_type = Bytes() self.__user_agent = Bytes() + self.__transfer_encoding = Bytes() self.raw_headers = Bytes() self.__trailer = Bytes() @@ -70,6 +75,7 @@ struct RequestHeader: self.__host = Bytes() self.__content_type = Bytes() self.__user_agent = Bytes() + self.__transfer_encoding = Bytes() self.raw_headers = rawheaders self.__trailer = Bytes() @@ -86,6 +92,7 @@ struct RequestHeader: host: Bytes, content_type: Bytes, user_agent: Bytes, + transfer_encoding: Bytes, raw_headers: Bytes, trailer: Bytes, ) -> None: @@ -100,69 +107,75 @@ struct RequestHeader: self.__host = host self.__content_type = content_type self.__user_agent = user_agent + self.__transfer_encoding = transfer_encoding self.raw_headers = raw_headers self.__trailer = trailer fn set_content_type(inout self, content_type: String) -> Self: - self.__content_type = content_type._buffer + self.__content_type = bytes(content_type) return self fn set_content_type_bytes(inout self, content_type: Bytes) -> Self: self.__content_type = content_type return self - fn content_type(self) -> Bytes: - return self.__content_type + fn content_type(self) -> BytesView: + return BytesView(unsafe_ptr=self.__content_type.unsafe_ptr(), len=self.__content_type.size) fn set_host(inout self, host: String) -> Self: - self.__host = host._buffer + self.__host = bytes(host) return self fn set_host_bytes(inout self, host: Bytes) -> Self: self.__host = host return self - fn host(self) -> Bytes: - return self.__host + fn host(self) -> BytesView: + return BytesView(unsafe_ptr=self.__host.unsafe_ptr(), len=self.__host.size) fn set_user_agent(inout self, user_agent: String) -> Self: - self.__user_agent = user_agent._buffer + self.__user_agent = bytes(user_agent) return self fn set_user_agent_bytes(inout self, user_agent: Bytes) -> Self: self.__user_agent = user_agent return self - fn user_agent(self) -> Bytes: - return self.__user_agent + fn user_agent(self) -> BytesView: + return BytesView(unsafe_ptr=self.__user_agent.unsafe_ptr(), len=self.__user_agent.size) fn set_method(inout self, method: String) -> Self: - self.__method = method._buffer + self.__method = bytes(method) return self fn set_method_bytes(inout self, method: Bytes) -> Self: self.__method = method return self - fn method(self) -> Bytes: + fn method(self) -> BytesView: if len(self.__method) == 0: - return strMethodGet - return self.__method - - fn set_protocol(inout self, method: String) -> Self: - self.no_http_1_1 = bytes_equal(method._buffer, strHttp11) - self.proto = method._buffer + return strMethodGet.as_bytes_slice() + return BytesView(unsafe_ptr=self.__method.unsafe_ptr(), len=self.__method.size) + + fn set_protocol(inout self, proto: String) -> Self: + self.no_http_1_1 = False # hardcoded until HTTP/2 is supported + self.proto = bytes(proto) return self - fn set_protocol_bytes(inout self, method: Bytes) -> Self: - self.no_http_1_1 = bytes_equal(method, strHttp11) - self.proto = method + fn set_protocol_bytes(inout self, proto: Bytes) -> Self: + self.no_http_1_1 = False # hardcoded until HTTP/2 is supported + self.proto = proto return self - fn protocol(self) -> Bytes: + fn protocol_str(self) -> String: if len(self.proto) == 0: return strHttp11 - return self.proto + return String(self.proto) + + fn protocol(self) -> BytesView: + if len(self.proto) == 0: + return strHttp11.as_bytes_slice() + return BytesView(unsafe_ptr=self.proto.unsafe_ptr(), len=self.proto.size) fn content_length(self) -> Int: return self.__content_length @@ -176,25 +189,42 @@ struct RequestHeader: return self fn set_request_uri(inout self, request_uri: String) -> Self: - self.__request_uri = request_uri.as_bytes() + self.__request_uri = request_uri.as_bytes_slice() return self fn set_request_uri_bytes(inout self, request_uri: Bytes) -> Self: self.__request_uri = request_uri return self - fn request_uri(self) -> Bytes: - if len(self.__request_uri) == 0: - return strSlash - return self.__request_uri + fn request_uri(self) -> BytesView: + if len(self.__request_uri) <= 1: + return BytesView(unsafe_ptr=strSlash.as_bytes_slice().unsafe_ptr(), len=2) + return BytesView(unsafe_ptr=self.__request_uri.unsafe_ptr(), len=self.__request_uri.size) + + fn set_transfer_encoding(inout self, transfer_encoding: String) -> Self: + self.__transfer_encoding = bytes(transfer_encoding) + return self + + fn set_transfer_encoding_bytes(inout self, transfer_encoding: Bytes) -> Self: + self.__transfer_encoding = transfer_encoding + return self + + fn transfer_encoding(self) -> BytesView: + return BytesView(unsafe_ptr=self.__transfer_encoding.unsafe_ptr(), len=self.__transfer_encoding.size) fn set_trailer(inout self, trailer: String) -> Self: - self.__trailer = trailer._buffer + self.__trailer = bytes(trailer) return self fn set_trailer_bytes(inout self, trailer: Bytes) -> Self: self.__trailer = trailer return self + + fn trailer(self) -> BytesView: + return BytesView(unsafe_ptr=self.__trailer.unsafe_ptr(), len=self.__trailer.size) + + fn trailer_str(self) -> String: + return String(self.__trailer) fn set_connection_close(inout self) -> Self: self.__connection_close = True @@ -213,89 +243,129 @@ struct RequestHeader: fn headers(self) -> String: return String(self.raw_headers) - fn parse(inout self, request_line: String) raises -> None: - var headers = self.raw_headers - - var n = request_line.find(" ") - if n <= 0: - raise Error("Cannot find HTTP request method in the request") - - var method = request_line[:n] - var rest_of_request_line = request_line[n + 1 :] - - # Defaults to HTTP/1.1 - var proto_str = String(strHttp11) - - # Parse requestURI - n = rest_of_request_line.rfind(" ") - if n < 0: - n = len(rest_of_request_line) - proto_str = strHttp10 - elif n == 0: - raise Error("Request URI cannot be empty") - else: - var proto = rest_of_request_line[n + 1 :] - if proto != strHttp11: - proto_str = proto - - var request_uri = rest_of_request_line[:n + 1] + fn parse_raw(inout self, inout r: Reader) raises -> Int: + var first_byte = r.peek(1) + if len(first_byte) == 0: + raise Error("Failed to read first byte from request header") + + var buf: Bytes + var e: Error + + buf, e = r.peek(r.buffered()) + if e: + raise Error("Failed to read request header: " + e.__str__()) + if len(buf) == 0: + raise Error("Failed to read request header, empty buffer") + + var end_of_first_line = self.parse_first_line(buf) - _ = self.set_method(method) - _ = self.set_protocol(proto_str) - _ = self.set_request_uri(request_uri) + var header_len = self.read_raw_headers(buf[end_of_first_line:]) - # Now process the rest of the headers + self.parse_headers(buf[end_of_first_line:]) + + return end_of_first_line + header_len + + fn parse_first_line(inout self, buf: Bytes) raises -> Int: + var b_next = buf + var b = Bytes() + while len(b) == 0: + try: + b, b_next = next_line(b_next) + except e: + raise Error("Failed to read first line from request, " + e.__str__()) + + var first_whitespace = index_byte(b, bytes(whitespace, pop=False)[0]) + if first_whitespace <= 0: + raise Error("Could not find HTTP request method in request line: " + String(b)) + + _ = self.set_method_bytes(b[:first_whitespace]) + + var last_whitespace = last_index_byte(b, bytes(whitespace, pop=False)[0]) + 1 + + if last_whitespace < 0: + raise Error("Could not find request target or HTTP version in request line: " + String(b)) + elif last_whitespace == 0: + raise Error("Request URI is empty: " + String(b)) + var proto = b[last_whitespace :] + if len(proto) != len(bytes(strHttp11, pop=False)): + raise Error("Invalid protocol, HTTP version not supported: " + String(proto)) + _ = self.set_protocol_bytes(proto) + _ = self.set_request_uri_bytes(b[first_whitespace+1:last_whitespace]) + + return len(buf) - len(b_next) + + fn parse_headers(inout self, buf: Bytes) raises -> None: _ = self.set_content_length(-2) - var s = headerScanner() - s.b = headers - s.disable_normalization = self.disable_normalization + s.set_b(buf) while s.next(): - # The below is based on the code from Golang's FastHTTP library - if len(s.key) > 0: - # Spaces between the header key and colon are not allowed. - # See RFC 7230, Section 3.2.4. - if s.key.find(" ") != -1 or s.key.find("\t") != -1: - raise Error("Invalid header key") - - if s.key[0] == "h" or s.key[0] == "H": - if s.key.lower() == "host": - _ = self.set_host(s.value) - continue - elif s.key[0] == "u" or s.key[0] == "U": - if s.key.lower() == "user-agent": - _ = self.set_user_agent(s.value) - continue - elif s.key[0] == "c" or s.key[0] == "C": - if s.key.lower() == "content-type": - _ = self.set_content_type(s.value) - continue - if s.key.lower() == "content-length": - if self.content_length() != -1: - var content_length = s.value - _ = self.set_content_length(atol(content_length)) - _ = self.set_content_length_bytes(content_length._buffer) - continue - if s.key.lower() == "connection": - if s.value == "close": - _ = self.set_connection_close() - else: - _ = self.reset_connection_close() - # _ = self.appendargbytes(s.key, s.value) - continue - elif s.key[0] == "t" or s.key[0] == "T": - if s.key.lower() == "transfer-encoding": - if s.value != "identity": - _ = self.set_content_length(-1) - # _ = self.setargbytes(s.key, strChunked) - continue - if s.key.lower() == "trailer": - _ = self.set_trailer(s.value) - - # close connection for non-http/1.1 request unless 'Connection: keep-alive' is set. - # if self.no_http_1_1 and not self.__connection_close: - # self.__connection_close = not has_header_value(v, strKeepAlive) + if len(s.key()) > 0: + self.parse_header(s.key(), s.value()) + + fn parse_header(inout self, key: Bytes, value: Bytes) raises -> None: + if index_byte(key, bytes(colonChar, pop=False)[0]) == -1 or index_byte(key, bytes(tab, pop=False)[0]) != -1: + raise Error("Invalid header key: " + String(key)) + + var key_first = key[0].__xor__(0x20) + + if key_first == bytes("h", pop=False)[0] or key_first == bytes("H", pop=False)[0]: + if compare_case_insensitive(key, bytes("host", pop=False)): + _ = self.set_host_bytes(bytes(value)) + return + elif key_first == bytes("u", pop=False)[0] or key_first == bytes("U", pop=False)[0]: + if compare_case_insensitive(key, bytes("user-agent", pop=False)): + _ = self.set_user_agent_bytes(bytes(value)) + return + elif key_first == bytes("c", pop=False)[0] or key_first == bytes("C", pop=False)[0]: + if compare_case_insensitive(key, bytes("content-type", pop=False)): + _ = self.set_content_type_bytes(bytes(value)) + return + if compare_case_insensitive(key, bytes("content-length", pop=False)): + if self.content_length() != -1: + _ = self.set_content_length(atol(value)) + return + if compare_case_insensitive(key, bytes("connection", pop=False)): + if compare_case_insensitive(bytes(value), bytes("close", pop=False)): + _ = self.set_connection_close() + else: + _ = self.reset_connection_close() + return + elif key_first == bytes("t", pop=False)[0] or key_first == bytes("T", pop=False)[0]: + if compare_case_insensitive(key, bytes("transfer-encoding", pop=False)): + _ = self.set_transfer_encoding_bytes(bytes(value, pop=False)) + return + if compare_case_insensitive(key, bytes("trailer", pop=False)): + _ = self.set_trailer_bytes(bytes(value, pop=False)) + return + if self.content_length() < 0: + _ = self.set_content_length(0) + return + + fn read_raw_headers(inout self, buf: Bytes) raises -> Int: + var n = index_byte(buf, bytes(nChar, pop=False)[0]) + if n == -1: + self.raw_headers = self.raw_headers[:0] + raise Error("Failed to find a newline in headers") + + if n == 0 or (n == 1 and (buf[0] == bytes(rChar, pop=False)[0])): + # empty line -> end of headers + return n + 1 + + n += 1 + var b = buf + var m = n + while True: + b = b[m:] + m = index_byte(b, bytes(nChar, pop=False)[0]) + if m == -1: + raise Error("Failed to find a newline in headers") + m += 1 + n += m + if m == 2 and (b[0] == bytes(rChar, pop=False)[0]) or m == 1: + self.raw_headers = self.raw_headers + buf[:n] + return n + @value @@ -452,26 +522,29 @@ struct ResponseHeader: fn set_status_message(inout self, message: Bytes) -> Self: self.__status_message = message return self + + fn status_message(self) -> BytesView: + return BytesView(unsafe_ptr=self.__status_message.unsafe_ptr(), len=self.__status_message.size) + + fn status_message_str(self) -> String: + return String(self.status_message()) - fn status_message(self) -> Bytes: - return self.__status_message - - fn content_type(self) -> Bytes: - return self.__content_type + fn content_type(self) -> BytesView: + return BytesView(unsafe_ptr=self.__content_type.unsafe_ptr(), len=self.__content_type.size) fn set_content_type(inout self, content_type: String) -> Self: - self.__content_type = content_type._buffer + self.__content_type = bytes(content_type) return self fn set_content_type_bytes(inout self, content_type: Bytes) -> Self: self.__content_type = content_type return self - fn content_encoding(self) -> Bytes: - return self.__content_encoding + fn content_encoding(self) -> BytesView: + return BytesView(unsafe_ptr=self.__content_encoding.unsafe_ptr(), len=self.__content_encoding.size) fn set_content_encoding(inout self, content_encoding: String) -> Self: - self.__content_encoding = content_encoding._buffer + self.__content_encoding = bytes(content_encoding) return self fn set_content_encoding_bytes(inout self, content_encoding: Bytes) -> Self: @@ -489,34 +562,51 @@ struct ResponseHeader: self.__content_length_bytes = content_length return self - fn server(self) -> Bytes: - return self.__server + fn server(self) -> BytesView: + return BytesView(unsafe_ptr=self.__server.unsafe_ptr(), len=self.__server.size) fn set_server(inout self, server: String) -> Self: - self.__server = server._buffer + self.__server = bytes(server) return self fn set_server_bytes(inout self, server: Bytes) -> Self: self.__server = server return self - fn set_protocol(inout self, protocol: Bytes) -> Self: + fn set_protocol(inout self, proto: String) -> Self: + self.no_http_1_1 = False # hardcoded until HTTP/2 is supported + self.__protocol = bytes(proto) + return self + + fn set_protocol_bytes(inout self, protocol: Bytes) -> Self: + self.no_http_1_1 = False # hardcoded until HTTP/2 is supported self.__protocol = protocol return self - fn protocol(self) -> Bytes: + fn protocol_str(self) -> String: if len(self.__protocol) == 0: return strHttp11 - return self.__protocol + return String(self.__protocol) + + fn protocol(self) -> BytesView: + if len(self.__protocol) == 0: + return BytesView(unsafe_ptr=strHttp11.as_bytes_slice().unsafe_ptr(), len=8) + return BytesView(unsafe_ptr=self.__protocol.unsafe_ptr(), len=self.__protocol.size) fn set_trailer(inout self, trailer: String) -> Self: - self.__trailer = trailer._buffer + self.__trailer = bytes(trailer) return self fn set_trailer_bytes(inout self, trailer: Bytes) -> Self: self.__trailer = trailer return self + + fn trailer(self) -> BytesView: + return BytesView(unsafe_ptr=self.__trailer.unsafe_ptr(), len=self.__trailer.size) + fn trailer_str(self) -> String: + return String(self.trailer()) + fn set_connection_close(inout self) -> Self: self.__connection_close = True return self @@ -534,122 +624,247 @@ struct ResponseHeader: fn headers(self) -> String: return String(self.raw_headers) - fn parse(inout self, first_line: String) raises -> None: - var headers = self.raw_headers + fn parse_raw(inout self, inout r: Reader) raises -> Int: + var first_byte = r.peek(1) + if len(first_byte) == 0: + raise Error("Failed to read first byte from response header") + + var buf: Bytes + var e: Error + + buf, e = r.peek(r.buffered()) + if e: + raise Error("Failed to read response header: " + e.__str__()) + if len(buf) == 0: + raise Error("Failed to read response header, empty buffer") - # Defaults to HTTP/1.1 - var proto_str = String(strHttp11) + var end_of_first_line = self.parse_first_line(buf) - var n = first_line.find(" ") - var proto = first_line[:n] - if proto != strHttp11: - proto_str = proto + var header_len = self.read_raw_headers(buf[end_of_first_line:]) - var rest_of_response_line = first_line[n + 1 :] - var status_code = atol(rest_of_response_line[:3]) - var message = rest_of_response_line[4:] + self.parse_headers(buf[end_of_first_line:]) + + return end_of_first_line + header_len + + fn parse_first_line(inout self, buf: Bytes) raises -> Int: + var b_next = buf + var b = Bytes() + while len(b) == 0: + try: + b, b_next = next_line(b_next) + except e: + raise Error("Failed to read first line from response, " + e.__str__()) + + var first_whitespace = index_byte(b, bytes(whitespace, pop=False)[0]) + if first_whitespace <= 0: + raise Error("Could not find HTTP version in response line: " + String(b)) + + _ = self.set_protocol(b[:first_whitespace+2]) + + var end_of_status_code = first_whitespace+5 # status code is always 3 digits, this calculation includes null terminator - _ = self.set_protocol(proto_str._buffer) + var status_code = atol(b[first_whitespace+1:end_of_status_code]) _ = self.set_status_code(status_code) - _ = self.set_status_message(message._buffer) - _ = self.set_content_length(-2) + var status_text = b[end_of_status_code :] + if len(status_text) > 1: + _ = self.set_status_message(status_text) + + return len(buf) - len(b_next) + + fn parse_headers(inout self, buf: Bytes) raises -> None: + _ = self.set_content_length(-2) var s = headerScanner() - s.b = headers - s.disable_normalization = self.disable_normalization + s.set_b(buf) while s.next(): - if len(s.key) > 0: - # Spaces between header key and colon not allowed (RFC 7230, 3.2.4) - if s.key.find(" ") != -1 or s.key.find("\t") != -1: - raise Error("Invalid header key") - elif s.key[0] == "c" or s.key[0] == "C": - if s.key.lower() == "content-type": - _ = self.set_content_type(s.value) - continue - if s.key.lower() == "content-encoding": - _ = self.set_content_encoding(s.value) - continue - if s.key.lower() == "content-length": - if self.content_length() != -1: - var content_length = s.value - _ = self.set_content_length(atol(content_length)) - _ = self.set_content_length_bytes(content_length._buffer) - continue - if s.key.lower() == "connection": - if s.value == "close": - _ = self.set_connection_close() - else: - _ = self.reset_connection_close() - continue - elif s.key[0] == "s" or s.key[0] == "S": - if s.key.lower() == "server": - _ = self.set_server(s.value) - continue - elif s.key[0] == "t" or s.key[0] == "T": - if s.key.lower() == "transfer-encoding": - if s.value != "identity": - _ = self.set_content_length(-1) - continue - if s.key.lower() == "trailer": - _ = self.set_trailer(s.value) - + if len(s.key()) > 0: + self.parse_header(s.key(), s.value()) + + fn parse_header(inout self, key: Bytes, value: Bytes) raises -> None: + if index_byte(key, bytes(colonChar, pop=False)[0]) == -1 or index_byte(key, bytes(tab, pop=False)[0]) != -1: + raise Error("Invalid header key: " + String(key)) + + var key_first = key[0].__xor__(0x20) + + if key_first == bytes("c", pop=False)[0] or key_first == bytes("C", pop=False)[0]: + if compare_case_insensitive(key, bytes("content-type", pop=False)): + _ = self.set_content_type_bytes(bytes(value)) + return + if compare_case_insensitive(key, bytes("content-encoding", pop=False)): + _ = self.set_content_encoding_bytes(bytes(value)) + return + if compare_case_insensitive(key, bytes("content-length", pop=False)): + if self.content_length() != -1: + var content_length = value + _ = self.set_content_length(atol(content_length)) + _ = self.set_content_length_bytes(bytes(content_length)) + return + if compare_case_insensitive(key, bytes("connection", pop=False)): + if compare_case_insensitive(bytes(value), bytes("close", pop=False)): + _ = self.set_connection_close() + else: + _ = self.reset_connection_close() + return + elif key_first == bytes("s", pop=False)[0] or key_first == bytes("S", pop=False)[0]: + if compare_case_insensitive(key, bytes("server", pop=False)): + _ = self.set_server_bytes(bytes(value)) + return + elif key_first == bytes("t", pop=False)[0] or key_first == bytes("T", pop=False)[0]: + if compare_case_insensitive(key, bytes("transfer-encoding", pop=False)): + if not compare_case_insensitive(value, bytes("identity", pop=False)): + _ = self.set_content_length(-1) + return + if compare_case_insensitive(key, bytes("trailer", pop=False)): + _ = self.set_trailer_bytes(bytes(value)) + + fn read_raw_headers(inout self, buf: Bytes) raises -> Int: + var n = index_byte(buf, bytes(nChar, pop=False)[0]) + + if n == -1: + self.raw_headers = self.raw_headers[:0] + raise Error("Failed to find a newline in headers") + + if n == 0 or (n == 1 and (buf[0] == bytes(rChar, pop=False)[0])): + # empty line -> end of headers + return n + 1 + + n += 1 + var b = buf + var m = n + while True: + b = b[m:] + m = index_byte(b, bytes(nChar, pop=False)[0]) + if m == -1: + raise Error("Failed to find a newline in headers") + m += 1 + n += m + if m == 2 and (b[0] == bytes(rChar, pop=False)[0]) or m == 1: + self.raw_headers = self.raw_headers + buf[:n] + return n struct headerScanner: - var b: String # string for now until we have a better way to subset Bytes - var key: String - var value: String - var err: Error - var subslice_len: Int + var __b: Bytes + var __key: Bytes + var __value: Bytes + var __subslice_len: Int var disable_normalization: Bool - var next_colon: Int - var next_line: Int - var initialized: Bool + var __next_colon: Int + var __next_line: Int + var __initialized: Bool fn __init__(inout self) -> None: - self.b = "" - self.key = "" - self.value = "" - self.err = Error() - self.subslice_len = 0 + self.__b = Bytes() + self.__key = Bytes() + self.__value = Bytes() + self.__subslice_len = 0 self.disable_normalization = False - self.next_colon = 0 - self.next_line = 0 - self.initialized = False + self.__next_colon = 0 + self.__next_line = 0 + self.__initialized = False + + fn b(self) -> Bytes: + return self.__b + + fn set_b(inout self, b: Bytes) -> None: + self.__b = b + + fn key(self) -> Bytes: + return self.__key - fn next(inout self) -> Bool: - if not self.initialized: - self.initialized = True + fn set_key(inout self, key: Bytes) -> None: + self.__key = key - if self.b.startswith('\r\n\r\n'): - self.b = self.b[2:] - return False + fn value(self) -> Bytes: + return self.__value + + fn set_value(inout self, value: Bytes) -> None: + self.__value = value + + fn subslice_len(self) -> Int: + return self.__subslice_len + + fn set_subslice_len(inout self, n: Int) -> None: + self.__subslice_len = n - if self.b.startswith('\r\n'): - self.b = self.b[1:] - return False + fn next_colon(self) -> Int: + return self.__next_colon - var n = self.b.find(':') - var x = self.b.find('\r\n') - if x != -1 and x < n: - return False + fn set_next_colon(inout self, n: Int) -> None: + self.__next_colon = n + + fn next_line(self) -> Int: + return self.__next_line + + fn set_next_line(inout self, n: Int) -> None: + self.__next_line = n + + fn initialized(self) -> Bool: + return self.__initialized - if n == -1: - # If we don't find a colon, assume we have reached the end + fn set_initialized(inout self) -> None: + self.__initialized = True + + fn next(inout self) raises -> Bool: + if not self.initialized(): + self.set_next_colon(-1) + self.set_next_line(-1) + self.set_initialized() + + var b_len = len(self.b()) + + if b_len >= 2 and (self.b()[0] == bytes(rChar, pop=False)[0]) and (self.b()[1] == bytes(nChar, pop=False)[0]): + self.set_b(self.b()[2:]) + self.set_subslice_len(2) + return False + + if b_len >= 1 and (self.b()[0] == bytes(nChar, pop=False)[0]): + self.set_b(self.b()[1:]) + self.set_subslice_len(self.subslice_len() + 1) return False + + var colon: Int + if self.next_colon() >= 0: + colon = self.next_colon() + self.set_next_colon(-1) + else: + colon = index_byte(self.b(), bytes(colonChar, pop=False)[0]) + var newline = index_byte(self.b(), bytes(nChar, pop=False)[0]) + if newline < 0: + raise Error("Invalid header, did not find a newline at the end of the header") + if newline < colon: + raise Error("Invalid header, found a newline before the colon") + if colon < 0: + raise Error("Invalid header, did not find a colon") + + var jump_to = colon + 1 + self.set_key(self.b()[:jump_to]) - self.key = self.b[:n].strip() - self.b = self.b[n+1:].strip() + while len(self.b()) > jump_to and (self.b()[jump_to] == bytes(whitespace, pop=False)[0]): + jump_to += 1 + self.set_next_line(self.next_line() - 1) + + self.set_subslice_len(self.subslice_len() + jump_to) + self.set_b(self.b()[jump_to:]) - x = self.b.find('\r\n') - if x == -1: - if len(self.b) == 0: - return False - self.value = self.b.strip() - self.b = '' + if self.next_line() >= 0: + jump_to = self.next_line() + self.set_next_line(-1) else: - self.value = self.b[:x].strip() - self.b = self.b[x+1:] + jump_to = index_byte(self.b(), bytes(nChar, pop=False)[0]) + if jump_to < 0: + raise Error("Invalid header, did not find a newline") + + jump_to += 1 + self.set_value(self.b()[:jump_to]) + self.set_subslice_len(self.subslice_len() + jump_to) + self.set_b(self.b()[jump_to:]) + + if jump_to > 0 and (self.value()[jump_to-1] == bytes(rChar, pop=False)[0]): + jump_to -= 1 + while jump_to > 0 and (self.value()[jump_to-1] == bytes(whitespace, pop=False)[0]): + jump_to -= 1 + self.set_value(self.value()[:jump_to]) return True diff --git a/lightbug_http/http.mojo b/lightbug_http/http.mojo index 4734806a..af2e08dd 100644 --- a/lightbug_http/http.mojo +++ b/lightbug_http/http.mojo @@ -1,13 +1,13 @@ from time import now from external.morrow import Morrow -from external.gojo.strings import StringBuilder +from external.gojo.strings.builder import StringBuilder +from external.gojo.bufio import Reader from lightbug_http.uri import URI -from lightbug_http.io.bytes import Bytes, bytes +from lightbug_http.io.bytes import Bytes, BytesView, bytes from lightbug_http.header import RequestHeader, ResponseHeader from lightbug_http.io.sync import Duration from lightbug_http.net import Addr, TCPAddr -from lightbug_http.strings import strHttp11, strHttp - +from lightbug_http.strings import strHttp11, strHttp, strSlash, whitespace, rChar, nChar trait Request: fn __init__(inout self, uri: URI): @@ -43,7 +43,7 @@ trait Request: fn request_uri(inout self) -> String: ... - fn set_connection_close(inout self, connection_close: Bool) -> Self: + fn set_connection_close(inout self) -> Self: ... fn connection_close(self) -> Bool: @@ -60,7 +60,7 @@ trait Response: fn status_code(self) -> Int: ... - fn set_connection_close(inout self, connection_close: Bool) -> Self: + fn set_connection_close(inout self) -> Self: ... fn connection_close(self) -> Bool: @@ -79,7 +79,7 @@ struct HTTPRequest(Request): var disable_redirect_path_normalization: Bool fn __init__(inout self, uri: URI): - self.header = RequestHeader(String("127.0.0.1")) + self.header = RequestHeader("127.0.0.1") self.__uri = uri self.body_raw = Bytes() self.parsed_uri = False @@ -88,7 +88,7 @@ struct HTTPRequest(Request): self.disable_redirect_path_normalization = False fn __init__(inout self, uri: URI, headers: RequestHeader): - self.header = RequestHeader() + self.header = headers self.__uri = uri self.body_raw = Bytes() self.parsed_uri = False @@ -123,8 +123,12 @@ struct HTTPRequest(Request): self.timeout = timeout self.disable_redirect_path_normalization = disable_redirect_path_normalization - fn get_body(self) -> Bytes: - return self.body_raw + fn get_body_bytes(self) -> BytesView: + return BytesView(unsafe_ptr=self.body_raw.unsafe_ptr(), len=self.body_raw.size) + + fn set_body_bytes(inout self, body: Bytes) -> Self: + self.body_raw = body + return self fn set_host(inout self, host: String) -> Self: _ = self.__uri.set_host(host) @@ -135,7 +139,7 @@ struct HTTPRequest(Request): return self fn host(self) -> String: - return self.__uri.host() + return self.__uri.host_str() fn set_request_uri(inout self, request_uri: String) -> Self: _ = self.header.set_request_uri(request_uri.as_bytes()) @@ -154,13 +158,23 @@ struct HTTPRequest(Request): fn uri(self) -> URI: return self.__uri - fn set_connection_close(inout self, connection_close: Bool) -> Self: + fn set_connection_close(inout self) -> Self: _ = self.header.set_connection_close() return self fn connection_close(self) -> Bool: return self.header.connection_close() + + fn read_body(inout self, inout r: Reader, content_length: Int, header_len: Int, max_body_size: Int) raises -> None: + if content_length > max_body_size: + raise Error("Request body too large") + _ = r.discard(header_len) + + var body_buf: Bytes + body_buf, _ = r.peek(r.buffered()) + + _ = self.set_body_bytes(bytes(body_buf)) @value struct HTTPResponse(Response): @@ -176,7 +190,7 @@ struct HTTPResponse(Response): self.header = ResponseHeader( 200, bytes("OK"), - bytes("Content-Type: application/octet-stream\r\n"), + bytes("application/octet-stream"), ) self.stream_immediate_header_flush = False self.stream_body = False @@ -193,10 +207,17 @@ struct HTTPResponse(Response): self.skip_reading_writing_body = False self.raddr = TCPAddr() self.laddr = TCPAddr() + + fn get_body_bytes(self) -> BytesView: + return BytesView(unsafe_ptr=self.body_raw.unsafe_ptr(), len=self.body_raw.size) fn get_body(self) -> Bytes: return self.body_raw + fn set_body_bytes(inout self, body: Bytes) -> Self: + self.body_raw = body + return self + fn set_status_code(inout self, status_code: Int) -> Self: _ = self.header.set_status_code(status_code) return self @@ -204,16 +225,24 @@ struct HTTPResponse(Response): fn status_code(self) -> Int: return self.header.status_code() - fn set_connection_close(inout self, connection_close: Bool) -> Self: + fn set_connection_close(inout self) -> Self: _ = self.header.set_connection_close() return self fn connection_close(self) -> Bool: return self.header.connection_close() + + fn read_body(inout self, inout r: Reader, header_len: Int) raises -> None: + _ = r.discard(header_len) + + var body_buf: Bytes + body_buf, _ = r.peek(r.buffered()) + + _ = self.set_body_bytes(bytes(body_buf)) fn OK(body: StringLiteral) -> HTTPResponse: return HTTPResponse( - ResponseHeader(200, bytes("OK"), bytes("Content-Type: text/plain")), bytes(body), + ResponseHeader(200, bytes("OK"), bytes("text/plain")), bytes(body), ) fn OK(body: StringLiteral, content_type: String) -> HTTPResponse: @@ -223,7 +252,7 @@ fn OK(body: StringLiteral, content_type: String) -> HTTPResponse: fn OK(body: String) -> HTTPResponse: return HTTPResponse( - ResponseHeader(200, bytes("OK"), bytes("Content-Type: text/plain")), bytes(body), + ResponseHeader(200, bytes("OK"), bytes("text/plain")), bytes(body), ) fn OK(body: String, content_type: String) -> HTTPResponse: @@ -233,7 +262,7 @@ fn OK(body: String, content_type: String) -> HTTPResponse: fn OK(body: Bytes) -> HTTPResponse: return HTTPResponse( - ResponseHeader(200, bytes("OK"), bytes("Content-Type: text/plain")), body, + ResponseHeader(200, bytes("OK"), bytes("text/plain")), body, ) fn OK(body: Bytes, content_type: String) -> HTTPResponse: @@ -251,54 +280,58 @@ fn NotFound(path: String) -> HTTPResponse: ResponseHeader(404, bytes("Not Found"), bytes("text/plain")), bytes("path " + path + " not found"), ) -fn encode(req: HTTPRequest, uri: URI) raises -> Bytes: - var res_str = String() - var protocol = strHttp11 - var current_time = String() - +fn encode(req: HTTPRequest) raises -> StringSlice[False, ImmutableStaticLifetime]: var builder = StringBuilder() _ = builder.write(req.header.method()) - _ = builder.write_string(String(" ")) - if len(uri.request_uri()) > 1: - _ = builder.write_string(uri.request_uri()) + _ = builder.write_string(whitespace) + if len(req.uri().path_bytes()) > 1: + _ = builder.write_string(req.uri().path()) else: - _ = builder.write_string("/") - _ = builder.write_string(String(" ")) - _ = builder.write(protocol) - _ = builder.write_string(String("\r\n")) - - _ = builder.write_string(String("Host: " + String(uri.host()))) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string(strSlash) + _ = builder.write_string(whitespace) + + _ = builder.write(req.header.protocol()) - if len(req.body_raw) > 0: - _ = builder.write_string(String("Content-Type: ")) - _ = builder.write(req.header.content_type()) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) - _ = builder.write_string(String("Content-Length: ")) - _ = builder.write_string(String(len(req.body_raw))) - _ = builder.write_string(String("\r\n")) + if len(req.header.host()) > 0: + _ = builder.write_string("Host: ") + _ = builder.write(req.header.host()) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) - _ = builder.write_string(String("Connection: ")) + if len(req.body_raw) > 0: + if len(req.header.content_type()) > 0: + _ = builder.write_string("Content-Type: ") + _ = builder.write(req.header.content_type()) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + + _ = builder.write_string("Content-Length: ") + _ = builder.write_string(len(req.body_raw).__str__()) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + + _ = builder.write_string("Connection: ") if req.connection_close(): - _ = builder.write_string(String("close")) + _ = builder.write_string("close") else: - _ = builder.write_string(String("keep-alive")) + _ = builder.write_string("keep-alive") - _ = builder.write_string(String("\r\n")) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) if len(req.body_raw) > 0: - _ = builder.write_string(String("\r\n")) - _ = builder.write(req.body_raw) - - return builder.get_bytes() + _ = builder.write(req.get_body_bytes()) + + return StringSlice[False, ImmutableStaticLifetime](unsafe_from_utf8_ptr=builder.render().unsafe_ptr(), len=builder.size) fn encode(res: HTTPResponse) raises -> Bytes: - var res_str = String() - var protocol = strHttp11 var current_time = String() try: current_time = Morrow.utcnow().__str__() @@ -308,43 +341,90 @@ fn encode(res: HTTPResponse) raises -> Bytes: var builder = StringBuilder() - _ = builder.write(protocol) - _ = builder.write_string(String(" ")) - _ = builder.write_string(String(res.header.status_code())) - _ = builder.write_string(String(" ")) + _ = builder.write(res.header.protocol()) + _ = builder.write_string(whitespace) + _ = builder.write_string(res.header.status_code().__str__()) + _ = builder.write_string(whitespace) _ = builder.write(res.header.status_message()) - _ = builder.write_string(String("\r\n")) - _ = builder.write_string(String("Server: lightbug_http")) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + + _ = builder.write_string("Server: lightbug_http") + + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) - _ = builder.write_string(String("Content-Type: ")) + _ = builder.write_string("Content-Type: ") _ = builder.write(res.header.content_type()) - _ = builder.write_string(String("\r\n")) + + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) if len(res.header.content_encoding()) > 0: - _ = builder.write_string(String("Content-Encoding: ")) + _ = builder.write_string("Content-Encoding: ") _ = builder.write(res.header.content_encoding()) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) if len(res.body_raw) > 0: - _ = builder.write_string(String("Content-Length: ")) - _ = builder.write_string(String(len(res.body_raw))) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string("Content-Length: ") + _ = builder.write_string(len(res.body_raw).__str__()) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + else: + _ = builder.write_string("Content-Length: 0") + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) - _ = builder.write_string(String("Connection: ")) + _ = builder.write_string("Connection: ") if res.connection_close(): - _ = builder.write_string(String("close")) + _ = builder.write_string("close") else: - _ = builder.write_string(String("keep-alive")) - _ = builder.write_string(String("\r\n")) + _ = builder.write_string("keep-alive") + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) - _ = builder.write_string(String("Date: ")) - _ = builder.write_string(String(current_time)) + _ = builder.write_string("Date: ") + _ = builder.write_string(current_time) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + _ = builder.write_string(rChar) + _ = builder.write_string(nChar) + if len(res.body_raw) > 0: - _ = builder.write_string(String("\r\n")) - _ = builder.write_string(String("\r\n")) - _ = builder.write(res.body_raw) + _ = builder.write(res.get_body_bytes()) + + return builder.render().as_bytes_slice() + +fn split_http_string(buf: Bytes) raises -> (String, String, String): + var request = String(buf) + + var request_first_line_headers_body = request.split("\r\n\r\n") + + if len(request_first_line_headers_body) == 0: + raise Error("Invalid HTTP string, did not find a double newline") + + var request_first_line_headers = request_first_line_headers_body[0] + + var request_body = String() + + if len(request_first_line_headers_body) > 1: + request_body = request_first_line_headers_body[1] + + var request_first_line_headers_list = request_first_line_headers.split("\r\n", 1) + + var request_first_line = String() + var request_headers = String() + + if len(request_first_line_headers_list) == 0: + raise Error("Invalid HTTP string, did not find a newline in the first line") + + if len(request_first_line_headers_list) == 1: + request_first_line = request_first_line_headers_list[0] + else: + request_first_line = request_first_line_headers_list[0] + request_headers = request_first_line_headers_list[1] - return builder.get_bytes() + return (request_first_line, request_headers, request_body) \ No newline at end of file diff --git a/lightbug_http/io/bytes.mojo b/lightbug_http/io/bytes.mojo index bdee734d..ab97a17a 100644 --- a/lightbug_http/io/bytes.mojo +++ b/lightbug_http/io/bytes.mojo @@ -1,36 +1,73 @@ from python import PythonObject +from lightbug_http.strings import nChar, rChar -alias Bytes = List[Int8] +alias Byte = UInt8 +alias Bytes = List[Byte] +alias BytesView = Span[is_mutable=False, T=Byte, lifetime=ImmutableStaticLifetime] -fn bytes(s: StringLiteral) -> Bytes: +fn bytes(s: StringLiteral, pop: Bool = True) -> Bytes: # This is currently null-terminated, which we don't want in HTTP responses var buf = String(s)._buffer - _ = buf.pop() + if pop: + _ = buf.pop() return buf -fn bytes(s: String) -> Bytes: +fn bytes(s: String, pop: Bool = True) -> Bytes: # This is currently null-terminated, which we don't want in HTTP responses var buf = s._buffer - _ = buf.pop() + if pop: + _ = buf.pop() return buf +fn bytes_equal(a: Bytes, b: Bytes) -> Bool: + return String(a) == String(b) + +fn index_byte(buf: Bytes, c: Byte) -> Int: + for i in range(len(buf)): + if buf[i] == c: + return i + return -1 + +fn last_index_byte(buf: Bytes, c: Byte) -> Int: + for i in range(len(buf)-1, -1, -1): + if buf[i] == c: + return i + return -1 + +fn compare_case_insensitive(a: Bytes, b: Bytes) -> Bool: + if len(a) != len(b): + return False + for i in range(len(a) - 1): + if (a[i] | 0x20) != (b[i] | 0x20): + return False + return True + +fn next_line(b: Bytes) raises -> (Bytes, Bytes): + var n_next = index_byte(b, bytes(nChar, pop=False)[0]) + if n_next < 0: + raise Error("next_line: newline not found") + var n = n_next + if n > 0 and (b[n-1] == bytes(rChar, pop=False)[0]): + n -= 1 + return (b[:n+1], b[n_next+1:]) + @value @register_passable("trivial") struct UnsafeString: - var data: Pointer[Int8] + var data: Pointer[UInt8] var len: Int fn __init__(str: StringLiteral) -> UnsafeString: var l = str.__len__() var s = String(str) - var p = Pointer[Int8].alloc(l) + var p = Pointer[UInt8].alloc(l) for i in range(l): p.store(i, s._buffer[i]) return UnsafeString(p, l) fn __init__(str: String) -> UnsafeString: var l = str.__len__() - var p = Pointer[Int8].alloc(l) + var p = Pointer[UInt8].alloc(l) for i in range(l): p.store(i, str._buffer[i]) return UnsafeString(p, l) @@ -39,6 +76,3 @@ struct UnsafeString: var s = String(self.data, self.len) return s - -fn bytes_equal(a: Bytes, b: Bytes) -> Bool: - return String(a) == String(b) diff --git a/lightbug_http/net.mojo b/lightbug_http/net.mojo index d3736970..f7ab1a26 100644 --- a/lightbug_http/net.mojo +++ b/lightbug_http/net.mojo @@ -119,8 +119,8 @@ struct TCPAddr(Addr): fn string(self) -> String: if self.zone != "": - return join_host_port(String(self.ip) + "%" + self.zone, self.port) - return join_host_port(self.ip, self.port) + return join_host_port(self.ip + "%" + self.zone, self.port.__str__()) + return join_host_port(self.ip, self.port.__str__()) fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: @@ -233,8 +233,8 @@ fn convert_binary_ip_to_string( """ # It seems like the len of the buffer depends on the length of the string IP. # Allocating 10 works for localhost (127.0.0.1) which I suspect is 9 bytes + 1 null terminator byte. So max should be 16 (15 + 1). - var ip_buffer = Pointer[c_void].alloc(16) - var ip_address_ptr = Pointer.address_of(ip_address).bitcast[c_void]() + var ip_buffer = UnsafePointer[c_void].alloc(16) + var ip_address_ptr = UnsafePointer.address_of(ip_address).bitcast[c_void]() _ = inet_ntop(address_family, ip_address_ptr, ip_buffer, 16) var string_buf = ip_buffer.bitcast[Int8]() @@ -249,39 +249,40 @@ fn convert_binary_ip_to_string( fn get_sock_name(fd: Int32) raises -> HostPort: """Return the address of the socket.""" - var local_address_ptr = Pointer[sockaddr].alloc(1) + var local_address_ptr = UnsafePointer[sockaddr].alloc(1) var local_address_ptr_size = socklen_t(sizeof[sockaddr]()) var status = getsockname( fd, local_address_ptr, - Pointer[socklen_t].address_of(local_address_ptr_size), + UnsafePointer[socklen_t].address_of(local_address_ptr_size), ) if status == -1: raise Error("get_sock_name: Failed to get address of local socket.") - var addr_in = local_address_ptr.bitcast[sockaddr_in]().load() + var addr_in = local_address_ptr.bitcast[sockaddr_in]()[] return HostPort( host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), - port=convert_binary_port_to_int(addr_in.sin_port), + port=convert_binary_port_to_int(addr_in.sin_port).__str__(), ) fn get_peer_name(fd: Int32) raises -> HostPort: """Return the address of the peer connected to the socket.""" - var remote_address_ptr = Pointer[sockaddr].alloc(1) + var remote_address_ptr = UnsafePointer[sockaddr].alloc(1) var remote_address_ptr_size = socklen_t(sizeof[sockaddr]()) + var status = getpeername( fd, remote_address_ptr, - Pointer[socklen_t].address_of(remote_address_ptr_size), + UnsafePointer[socklen_t].address_of(remote_address_ptr_size), ) if status == -1: raise Error("get_peer_name: Failed to get address of remote socket.") # Cast sockaddr struct to sockaddr_in to convert binary IP to string. - var addr_in = remote_address_ptr.bitcast[sockaddr_in]().load() + var addr_in = remote_address_ptr.bitcast[sockaddr_in]()[] return HostPort( host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, 16), - port=convert_binary_port_to_int(addr_in.sin_port), + port=convert_binary_port_to_int(addr_in.sin_port).__str__(), ) diff --git a/lightbug_http/python/client.mojo b/lightbug_http/python/client.mojo index ec3e6d6f..81ae43db 100644 --- a/lightbug_http/python/client.mojo +++ b/lightbug_http/python/client.mojo @@ -1,7 +1,7 @@ from lightbug_http.client import Client from lightbug_http.http import HTTPRequest, HTTPResponse from lightbug_http.python import Modules -from lightbug_http.io.bytes import Bytes, UnsafeString +from lightbug_http.io.bytes import Bytes, UnsafeString, bytes from lightbug_http.strings import CharSet @@ -57,4 +57,4 @@ struct PythonClient(Client): var res = self.socket.recv(1024).decode() _ = self.socket.close() - return HTTPResponse(res.__str__()._buffer) + return HTTPResponse(bytes(res)) diff --git a/lightbug_http/python/net.mojo b/lightbug_http/python/net.mojo index 66e917a4..2174c511 100644 --- a/lightbug_http/python/net.mojo +++ b/lightbug_http/python/net.mojo @@ -1,5 +1,5 @@ from lightbug_http.python import Modules -from lightbug_http.io.bytes import Bytes, UnsafeString +from lightbug_http.io.bytes import Bytes, UnsafeString, bytes from lightbug_http.io.sync import Duration from lightbug_http.net import ( Net, @@ -98,8 +98,8 @@ struct PythonConnection(Connection): fn __init__(inout self, laddr: TCPAddr, raddr: TCPAddr) raises: self.conn = None - self.raddr = PythonObject(raddr.ip + ":" + raddr.port) - self.laddr = PythonObject(laddr.ip + ":" + laddr.port) + self.raddr = PythonObject(raddr.ip + ":" + raddr.port.__str__()) + self.laddr = PythonObject(laddr.ip + ":" + laddr.port.__str__()) self.pymodules = Modules().builtins fn __init__(inout self, pymodules: PythonObject, py_conn_addr: PythonObject) raises: @@ -110,9 +110,9 @@ struct PythonConnection(Connection): fn read(self, inout buf: Bytes) raises -> Int: var data = self.conn.recv(default_buffer_size) - buf = String( + buf = bytes( self.pymodules.bytes.decode(data, CharSet.utf8.value).__str__() - )._buffer + ) return len(buf) fn write(self, buf: Bytes) raises -> Int: diff --git a/lightbug_http/python/server.mojo b/lightbug_http/python/server.mojo index db9088cd..eef0ba11 100644 --- a/lightbug_http/python/server.mojo +++ b/lightbug_http/python/server.mojo @@ -1,6 +1,6 @@ from lightbug_http.server import DefaultConcurrency from lightbug_http.net import Listener -from lightbug_http.http import HTTPRequest, encode +from lightbug_http.http import HTTPRequest, encode, split_http_string from lightbug_http.uri import URI from lightbug_http.header import RequestHeader from lightbug_http.python.net import ( @@ -13,7 +13,7 @@ from lightbug_http.service import HTTPService from lightbug_http.io.sync import Duration from lightbug_http.io.bytes import Bytes from lightbug_http.error import ErrorHandler -from lightbug_http.strings import next_line, NetworkType +from lightbug_http.strings import NetworkType struct PythonServer: @@ -70,20 +70,23 @@ struct PythonServer: if read_len == 0: conn.close() break - var first_line_and_headers = next_line(buf) - var request_line = first_line_and_headers.first_line - var rest_of_headers = first_line_and_headers.rest - - var uri = URI(request_line) + + var request_first_line: String + var request_headers: String + var request_body: String + + request_first_line, request_headers, request_body = split_http_string(buf) + + var uri = URI(request_first_line) try: uri.parse() except: conn.close() raise Error("Failed to parse request line") - var header = RequestHeader(buf) + var header = RequestHeader(request_headers.as_bytes()) try: - header.parse(request_line) + header.parse_raw(request_first_line) except: conn.close() raise Error("Failed to parse request header") @@ -96,5 +99,5 @@ struct PythonServer: ) ) var res_encoded = encode(res) - _ = conn.write(res_encoded) + _ = conn.write(res_encoded.as_bytes_slice()) conn.close() diff --git a/lightbug_http/service.mojo b/lightbug_http/service.mojo index 908feeab..375cd2d3 100644 --- a/lightbug_http/service.mojo +++ b/lightbug_http/service.mojo @@ -56,7 +56,7 @@ struct ExampleRouter(HTTPService): @value struct TechEmpowerRouter(HTTPService): fn func(self, req: HTTPRequest) raises -> HTTPResponse: - var body = req.body_raw + # var body = req.body_raw var uri = req.uri() if uri.path() == "/plaintext": diff --git a/lightbug_http/strings.mojo b/lightbug_http/strings.mojo index f74cdde9..372537e1 100644 --- a/lightbug_http/strings.mojo +++ b/lightbug_http/strings.mojo @@ -1,39 +1,22 @@ from lightbug_http.io.bytes import Bytes -alias strSlash = String("/").as_bytes() -alias strHttp = String("http").as_bytes() -alias http = String("http") -alias strHttps = String("https").as_bytes() -alias https = String("https") -alias strHttp11 = String("HTTP/1.1").as_bytes() -alias strHttp10 = String("HTTP/1.0").as_bytes() +alias strSlash = "/" +alias strHttp = "http" +alias http = "http" +alias strHttps = "https" +alias https = "https" +alias strHttp11 = "HTTP/1.1" +alias strHttp10 = "HTTP/1.0" -alias strMethodGet = String("GET").as_bytes() +alias strMethodGet = "GET" -alias rChar = String("\r").as_bytes() -alias nChar = String("\n").as_bytes() - - -# This is temporary due to no string support in tuples in Mojo, to be removed -@value -struct TwoLines: - var first_line: String - var rest: String - - fn __init__(inout self, first_line: String, rest: String) -> None: - self.first_line = first_line - self.rest = rest - - -# Helper function to split a string into two lines by delimiter -fn next_line(s: String, delimiter: String = "\n") raises -> TwoLines: - var first_newline = s.find(delimiter) - if first_newline == -1: - return TwoLines(s, String()) - var before_newline = s[0:first_newline] - var after_newline = s[first_newline + 1 :] - return TwoLines(before_newline.strip(), after_newline) +alias rChar = "\r" +alias nChar = "\n" +alias colonChar = ":" +alias empty_string = "" +alias whitespace = " " +alias tab = "\t" @value struct NetworkType: @@ -51,7 +34,6 @@ struct NetworkType: alias ip6 = NetworkType("ip6") alias unix = NetworkType("unix") - @value struct ConnType: var value: String @@ -60,7 +42,6 @@ struct ConnType: alias http = ConnType("http") alias websocket = ConnType("websocket") - @value struct RequestMethod: var value: String @@ -73,14 +54,12 @@ struct RequestMethod: alias patch = RequestMethod("PATCH") alias options = RequestMethod("OPTIONS") - @value struct CharSet: var value: String alias utf8 = CharSet("utf-8") - @value struct MediaType: var value: String @@ -89,7 +68,6 @@ struct MediaType: alias plain = MediaType("text/plain") alias json = MediaType("application/json") - @value struct Message: var type: String diff --git a/lightbug_http/sys/client.mojo b/lightbug_http/sys/client.mojo index de93d1d8..0e981e1f 100644 --- a/lightbug_http/sys/client.mojo +++ b/lightbug_http/sys/client.mojo @@ -1,9 +1,5 @@ -from lightbug_http.client import Client -from lightbug_http.http import HTTPRequest, HTTPResponse, encode -from lightbug_http.header import ResponseHeader -from lightbug_http.sys.net import create_connection -from lightbug_http.io.bytes import Bytes -from lightbug_http.strings import next_line +from external.gojo.bufio import Reader, Scanner, scan_words, scan_bytes +from external.gojo.bytes import buffer from external.libc import ( c_int, AF_INET, @@ -14,6 +10,12 @@ from external.libc import ( recv, close, ) +from lightbug_http.client import Client +from lightbug_http.net import default_buffer_size +from lightbug_http.http import HTTPRequest, HTTPResponse, encode, split_http_string +from lightbug_http.header import ResponseHeader +from lightbug_http.sys.net import create_connection +from lightbug_http.io.bytes import Bytes struct MojoClient(Client): @@ -63,11 +65,6 @@ struct MojoClient(Client): If there is a failure in sending or receiving the message. """ var uri = req.uri() - try: - _ = uri.parse() - except e: - print("error parsing uri: " + e.__str__()) - var host = String(uri.host()) if host == "": @@ -92,48 +89,51 @@ struct MojoClient(Client): port = 80 var conn = create_connection(self.fd, host_str, port) + + var req_encoded = encode(req) - var req_encoded = encode(req, uri) var bytes_sent = conn.write(req_encoded) if bytes_sent == -1: raise Error("Failed to send message") - var new_buf = Bytes() - + var new_buf = Bytes(capacity=default_buffer_size) var bytes_recv = conn.read(new_buf) + if bytes_recv == 0: conn.close() - - var response_first_line_headers_and_body = next_line(new_buf, "\r\n\r\n") - var response_first_line_headers = response_first_line_headers_and_body.first_line - var response_body = response_first_line_headers_and_body.rest - var response_first_line_and_headers = next_line(response_first_line_headers, "\r\n") - var response_first_line = response_first_line_and_headers.first_line - var response_headers = response_first_line_and_headers.rest + var buf = buffer.new_buffer(new_buf^) + var reader = Reader(buf^) - # Ugly hack for now in case the default buffer is too large and we read additional responses from the server - var newline_in_body = response_body.find("\r\n") - if newline_in_body != -1: - response_body = response_body[:newline_in_body] + var error = Error() - var header = ResponseHeader(response_headers._buffer) + # # Ugly hack for now in case the default buffer is too large and we read additional responses from the server + # var newline_in_body = response_body.find("\r\n") + # if newline_in_body != -1: + # response_body = response_body[:newline_in_body] + var header = ResponseHeader() + var first_line_and_headers_len = 0 try: - header.parse(response_first_line) + first_line_and_headers_len = header.parse_raw(reader) except e: conn.close() - raise Error("Failed to parse response header: " + e.__str__()) + error = Error("Failed to parse response headers: " + e.__str__()) - var total_recv = bytes_recv + var response = HTTPResponse(header, Bytes()) - while header.content_length() > total_recv: - if header.content_length() != 0 and header.content_length() != -2: - var remaining_body = Bytes() - var read_len = conn.read(remaining_body) - response_body += remaining_body - total_recv += read_len + try: + response.read_body(reader, first_line_and_headers_len,) + except e: + error = Error("Failed to read request body: " + e.__str__()) + # var total_recv = bytes_recv + # while header.content_length() > total_recv: + # if header.content_length() != 0 and header.content_length() != -2: + # var remaining_body = Bytes() + # var read_len = conn.read(remaining_body) + # response_body += remaining_body + # total_recv += read_len conn.close() - return HTTPResponse(header, response_body._buffer) + return response diff --git a/lightbug_http/sys/net.mojo b/lightbug_http/sys/net.mojo index 2a0b7409..c3226a1a 100644 --- a/lightbug_http/sys/net.mojo +++ b/lightbug_http/sys/net.mojo @@ -1,3 +1,4 @@ +from utils import StaticTuple from lightbug_http.net import ( Listener, ListenConfig, @@ -10,7 +11,7 @@ from lightbug_http.net import ( get_peer_name, ) from lightbug_http.strings import NetworkType -from lightbug_http.io.bytes import Bytes +from lightbug_http.io.bytes import Bytes, bytes from lightbug_http.io.sync import Duration from external.libc import ( c_void, @@ -57,10 +58,10 @@ trait AnAddrInfo: fn getaddrinfo[ T: AnAddrInfo ]( - nodename: Pointer[c_char], - servname: Pointer[c_char], - hints: Pointer[T], - res: Pointer[Pointer[T]], + nodename: UnsafePointer[c_char], + servname: UnsafePointer[c_char], + hints: UnsafePointer[T], + res: UnsafePointer[UnsafePointer[T]], ) -> c_int: """ Overwrites the existing libc `getaddrinfo` function to use the AnAddrInfo trait. @@ -72,10 +73,10 @@ fn getaddrinfo[ return external_call[ "getaddrinfo", c_int, # FnName, RetType - Pointer[c_char], - Pointer[c_char], - Pointer[T], # Args - Pointer[Pointer[T]], # Args + UnsafePointer[c_char], + UnsafePointer[c_char], + UnsafePointer[T], # Args + UnsafePointer[UnsafePointer[T]], # Args ](nodename, servname, hints, res) @@ -101,13 +102,13 @@ struct SysListener: self.fd = fd fn accept(self) raises -> SysConnection: - var their_addr_ptr = Pointer[sockaddr].alloc(1) + var their_addr_ptr = UnsafePointer[sockaddr].alloc(1) var sin_size = socklen_t(sizeof[socklen_t]()) var new_sockfd = accept( - self.fd, their_addr_ptr, Pointer[socklen_t].address_of(sin_size) + self.fd, their_addr_ptr, UnsafePointer[socklen_t].address_of(sin_size) ) if new_sockfd == -1: - print("Failed to accept connection") + print("Failed to accept connection, system accept() returned an error.") var peer = get_peer_name(new_sockfd) return SysConnection( @@ -140,14 +141,14 @@ struct SysListenConfig(ListenConfig): if address_family == AF_INET6: ip_buf_size = 16 - var ip_buf = Pointer[c_void].alloc(ip_buf_size) + var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) var conv_status = inet_pton(address_family, to_char_ptr(addr.ip), ip_buf) - var raw_ip = ip_buf.bitcast[c_uint]().load() + var raw_ip = ip_buf.bitcast[c_uint]()[] var bin_port = htons(UInt16(addr.port)) var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) - var ai_ptr = Pointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() + var ai_ptr = UnsafePointer[sockaddr_in].address_of(ai).bitcast[sockaddr]() var sockfd = socket(address_family, SOCK_STREAM, 0) if sockfd == -1: @@ -158,7 +159,7 @@ struct SysListenConfig(ListenConfig): sockfd, SOL_SOCKET, SO_REUSEADDR, - Pointer[Int].address_of(yes).bitcast[c_void](), + UnsafePointer[Int].address_of(yes).bitcast[c_void](), sizeof[Int](), ) @@ -216,14 +217,14 @@ struct SysConnection(Connection): self.fd = fd fn read(self, inout buf: Bytes) raises -> Int: - var new_buf = Pointer[UInt8]().alloc(default_buffer_size) - var bytes_recv = recv(self.fd, new_buf, default_buffer_size, 0) + var bytes_recv = recv(self.fd, DTypePointer[DType.uint8](buf.unsafe_ptr()).offset(buf.size), buf.capacity - buf.size, 0) if bytes_recv == -1: return 0 + buf.size += bytes_recv if bytes_recv == 0: return 0 - var bytes_str = String(new_buf.bitcast[Int8](), bytes_recv) - buf = bytes_str._buffer + if bytes_recv < buf.capacity: + return bytes_recv return bytes_recv fn write(self, msg: String) raises -> Int: @@ -275,13 +276,13 @@ struct addrinfo_macos(AnAddrInfo): var ai_socktype: c_int var ai_protocol: c_int var ai_addrlen: socklen_t - var ai_canonname: Pointer[c_char] - var ai_addr: Pointer[sockaddr] - var ai_next: Pointer[c_void] + var ai_canonname: UnsafePointer[c_char] + var ai_addr: UnsafePointer[sockaddr] + var ai_next: UnsafePointer[c_void] fn __init__() -> Self: return Self( - 0, 0, 0, 0, 0, Pointer[c_char](), Pointer[sockaddr](), Pointer[c_void]() + 0, 0, 0, 0, 0, UnsafePointer[c_char](), UnsafePointer[sockaddr](), UnsafePointer[c_void]() ) fn get_ip_address(self, host: String) raises -> in_addr: @@ -296,8 +297,8 @@ struct addrinfo_macos(AnAddrInfo): UInt32 - The IP address. """ var host_ptr = to_char_ptr(host) - var servinfo = Pointer[Self]().alloc(1) - servinfo.store(Self()) + var servinfo = UnsafePointer[Self]().alloc(1) + initialize_pointee_move(servinfo, Self()) var hints = Self() hints.ai_family = AF_INET @@ -306,15 +307,15 @@ struct addrinfo_macos(AnAddrInfo): var error = getaddrinfo[Self]( host_ptr, - Pointer[UInt8](), - Pointer.address_of(hints), - Pointer.address_of(servinfo), + UnsafePointer[UInt8](), + UnsafePointer.address_of(hints), + UnsafePointer.address_of(servinfo), ) if error != 0: print("getaddrinfo failed") raise Error("Failed to get IP address. getaddrinfo failed.") - var addrinfo = servinfo.load() + var addrinfo = servinfo[] var ai_addr = addrinfo.ai_addr if not ai_addr: @@ -324,7 +325,7 @@ struct addrinfo_macos(AnAddrInfo): " ai_addr is null." ) - var addr_in = ai_addr.bitcast[sockaddr_in]().load() + var addr_in = ai_addr.bitcast[sockaddr_in]()[] return addr_in.sin_addr @@ -341,13 +342,13 @@ struct addrinfo_unix(AnAddrInfo): var ai_socktype: c_int var ai_protocol: c_int var ai_addrlen: socklen_t - var ai_addr: Pointer[sockaddr] - var ai_canonname: Pointer[c_char] - var ai_next: Pointer[c_void] + var ai_addr: UnsafePointer[sockaddr] + var ai_canonname: UnsafePointer[c_char] + var ai_next: UnsafePointer[c_void] fn __init__() -> Self: return Self( - 0, 0, 0, 0, 0, Pointer[sockaddr](), Pointer[c_char](), Pointer[c_void]() + 0, 0, 0, 0, 0, UnsafePointer[sockaddr](), UnsafePointer[c_char](), UnsafePointer[c_void]() ) fn get_ip_address(self, host: String) raises -> in_addr: @@ -362,8 +363,8 @@ struct addrinfo_unix(AnAddrInfo): UInt32 - The IP address. """ var host_ptr = to_char_ptr(String(host)) - var servinfo = Pointer[Self]().alloc(1) - servinfo.store(Self()) + var servinfo = UnsafePointer[Self]().alloc(1) + initialize_pointee_move(servinfo, Self()) var hints = Self() hints.ai_family = AF_INET @@ -372,15 +373,15 @@ struct addrinfo_unix(AnAddrInfo): var error = getaddrinfo[Self]( host_ptr, - Pointer[UInt8](), - Pointer.address_of(hints), - Pointer.address_of(servinfo), + UnsafePointer[UInt8](), + UnsafePointer.address_of(hints), + UnsafePointer.address_of(servinfo), ) if error != 0: print("getaddrinfo failed") raise Error("Failed to get IP address. getaddrinfo failed.") - var addrinfo = servinfo.load() + var addrinfo = servinfo[] var ai_addr = addrinfo.ai_addr if not ai_addr: @@ -390,7 +391,7 @@ struct addrinfo_unix(AnAddrInfo): " ai_addr is null." ) - var addr_in = ai_addr.bitcast[sockaddr_in]().load() + var addr_in = ai_addr.bitcast[sockaddr_in]()[] return addr_in.sin_addr @@ -417,7 +418,7 @@ fn create_connection(sock: c_int, host: String, port: UInt16) raises -> SysConne var addr: sockaddr_in = sockaddr_in( AF_INET, htons(port), ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0) ) - var addr_ptr = Pointer[sockaddr_in].address_of(addr).bitcast[sockaddr]() + var addr_ptr = UnsafePointer[sockaddr_in].address_of(addr).bitcast[sockaddr]() if connect(sock, addr_ptr, sizeof[sockaddr_in]()) == -1: _ = shutdown(sock, SHUT_RDWR) diff --git a/lightbug_http/sys/server.mojo b/lightbug_http/sys/server.mojo index c407f60f..8ba64f69 100644 --- a/lightbug_http/sys/server.mojo +++ b/lightbug_http/sys/server.mojo @@ -1,14 +1,18 @@ +from external.gojo.bufio import Reader, Scanner, scan_words, scan_bytes +from external.gojo.bytes import buffer from lightbug_http.server import DefaultConcurrency -from lightbug_http.net import Listener -from lightbug_http.http import HTTPRequest, encode +from lightbug_http.net import Listener, default_buffer_size +from lightbug_http.http import HTTPRequest, encode, split_http_string from lightbug_http.uri import URI from lightbug_http.header import RequestHeader from lightbug_http.sys.net import SysListener, SysConnection, SysNet from lightbug_http.service import HTTPService from lightbug_http.io.sync import Duration -from lightbug_http.io.bytes import Bytes +from lightbug_http.io.bytes import Bytes, bytes from lightbug_http.error import ErrorHandler -from lightbug_http.strings import next_line, NetworkType +from lightbug_http.strings import NetworkType + +alias default_max_request_body_size = 4 * 1024 * 1024 # 4MB @value struct SysServer: @@ -23,7 +27,7 @@ struct SysServer: var max_concurrent_connections: Int var max_requests_per_connection: Int - var max_request_body_size: Int + var __max_request_body_size: Int var tcp_keep_alive: Bool var ln: SysListener @@ -34,17 +38,27 @@ struct SysServer: self.__address = "127.0.0.1" self.max_concurrent_connections = 1000 self.max_requests_per_connection = 0 - self.max_request_body_size = 0 + self.__max_request_body_size = default_max_request_body_size self.tcp_keep_alive = False self.ln = SysListener() + fn __init__(inout self, tcp_keep_alive: Bool) raises: + self.error_handler = ErrorHandler() + self.name = "lightbug_http" + self.__address = "127.0.0.1" + self.max_concurrent_connections = 1000 + self.max_requests_per_connection = 0 + self.__max_request_body_size = default_max_request_body_size + self.tcp_keep_alive = tcp_keep_alive + self.ln = SysListener() + fn __init__(inout self, own_address: String) raises: self.error_handler = ErrorHandler() self.name = "lightbug_http" self.__address = own_address self.max_concurrent_connections = 1000 self.max_requests_per_connection = 0 - self.max_request_body_size = 0 + self.__max_request_body_size = default_max_request_body_size self.tcp_keep_alive = False self.ln = SysListener() @@ -54,10 +68,30 @@ struct SysServer: self.__address = "127.0.0.1" self.max_concurrent_connections = 1000 self.max_requests_per_connection = 0 - self.max_request_body_size = 0 + self.__max_request_body_size = default_max_request_body_size self.tcp_keep_alive = False self.ln = SysListener() + fn __init__(inout self, max_request_body_size: Int) raises: + self.error_handler = ErrorHandler() + self.name = "lightbug_http" + self.__address = "127.0.0.1" + self.max_concurrent_connections = 1000 + self.max_requests_per_connection = 0 + self.__max_request_body_size = max_request_body_size + self.tcp_keep_alive = False + self.ln = SysListener() + + fn __init__(inout self, max_request_body_size: Int, tcp_keep_alive: Bool) raises: + self.error_handler = ErrorHandler() + self.name = "lightbug_http" + self.__address = "127.0.0.1" + self.max_concurrent_connections = 1000 + self.max_requests_per_connection = 0 + self.__max_request_body_size = max_request_body_size + self.tcp_keep_alive = tcp_keep_alive + self.ln = SysListener() + fn address(self) -> String: return self.__address @@ -65,6 +99,13 @@ struct SysServer: self.__address = own_address return self + fn max_request_body_size(self) -> Int: + return self.__max_request_body_size + + fn set_max_request_body_size(inout self, size: Int) -> Self: + self.__max_request_body_size = size + return self + fn get_concurrency(self) -> Int: """ Retrieve the concurrency level which is either @@ -108,56 +149,85 @@ struct SysServer: while True: var conn = self.ln.accept() - var buf = Bytes() - var read_len = conn.read(buf) - - if read_len == 0: - conn.close() - break - - var request_first_line_headers_and_body = next_line(buf, "\r\n\r\n") - var request_first_line_headers = request_first_line_headers_and_body.first_line - var request_body = request_first_line_headers_and_body.rest - - var request_first_line_headers_split = next_line(request_first_line_headers, "\r\n") - var request_first_line = request_first_line_headers_split.first_line - var request_headers = request_first_line_headers_split.rest + self.serve_connection(conn, handler) + + fn serve_connection[T: HTTPService](inout self, conn: SysConnection, handler: T) raises -> None: + """ + Serve a single connection. - var header = RequestHeader(request_headers._buffer) + Args: + conn : SysConnection - A connection object that represents a client connection. + handler : HTTPService - An object that handles incoming HTTP requests. + Raises: + If there is an error while serving the connection. + """ + var b = Bytes(capacity=default_buffer_size) + var bytes_recv = conn.read(b) + if bytes_recv == 0: + conn.close() + return + + var buf = buffer.new_buffer(b^) + var reader = Reader(buf^) + + var error = Error() + + var max_request_body_size = self.max_request_body_size() + if max_request_body_size <= 0: + max_request_body_size = default_max_request_body_size + + var req_number = 0 + + while True: + req_number += 1 + + if req_number > 1: + var b = Bytes(capacity=default_buffer_size) + var bytes_recv = conn.read(b) + if bytes_recv == 0: + conn.close() + break + buf = buffer.new_buffer(b^) + reader = Reader(buf^) + + var header = RequestHeader() + var first_line_and_headers_len = 0 try: - header.parse(request_first_line) + first_line_and_headers_len = header.parse_raw(reader) except e: - conn.close() - raise Error("Failed to parse request header: " + e.__str__()) - + error = Error("Failed to parse request headers: " + e.__str__()) + var uri = URI(self.address() + String(header.request_uri())) try: uri.parse() except e: - conn.close() - raise Error("Failed to parse request line:" + e.__str__()) - - if header.content_length() != 0 and header.content_length() != (len(request_body) + 1): - var remaining_body = Bytes() - var remaining_len = header.content_length() - len(request_body + 1) - while remaining_len > 0: - var read_len = conn.read(remaining_body) - buf.extend(remaining_body) - remaining_len -= read_len - - var res = handler.func( - HTTPRequest( + error = Error("Failed to parse request line:" + e.__str__()) + + if header.content_length() > 0: + if max_request_body_size > 0 and header.content_length() > max_request_body_size: + error = Error("Request body too large") + + var request = HTTPRequest( uri, - buf, + Bytes(), header, ) - ) - # Always close the connection as long as we don't support concurrency - _ = res.set_connection_close(True) + try: + request.read_body(reader, header.content_length(), first_line_and_headers_len, max_request_body_size) + except e: + error = Error("Failed to read request body: " + e.__str__()) + + var res = handler.func(request) + + if not self.tcp_keep_alive: + _ = res.set_connection_close() var res_encoded = encode(res) + _ = conn.write(res_encoded) - conn.close() + if not self.tcp_keep_alive: + conn.close() + return diff --git a/lightbug_http/tests/run.mojo b/lightbug_http/tests/run.mojo deleted file mode 100644 index b1533c0d..00000000 --- a/lightbug_http/tests/run.mojo +++ /dev/null @@ -1,19 +0,0 @@ -from lightbug_http.python.client import PythonClient -from lightbug_http.sys.client import MojoClient -from lightbug_http.tests.test_client import ( - test_python_client_lightbug, - test_mojo_client_lightbug, - test_mojo_client_lightbug_external_req, -) - - -fn run_tests() raises: - run_client_tests() - - -fn run_client_tests() raises: - var py_client = PythonClient() - var mojo_client = MojoClient() - # test_mojo_client_lightbug_external_req(mojo_client) - test_mojo_client_lightbug(mojo_client) - test_python_client_lightbug(py_client) diff --git a/lightbug_http/tests/test_client.mojo b/lightbug_http/tests/test_client.mojo deleted file mode 100644 index 3eaa8e66..00000000 --- a/lightbug_http/tests/test_client.mojo +++ /dev/null @@ -1,212 +0,0 @@ -import testing -from lightbug_http.python.client import PythonClient -from lightbug_http.sys.client import MojoClient -from lightbug_http.http import HTTPRequest, encode -from lightbug_http.uri import URI -from lightbug_http.header import RequestHeader -from external.morrow import Morrow -from lightbug_http.tests.utils import ( - default_server_conn_string, - getRequest, -) - - -fn test_mojo_client_lightbug(client: MojoClient) raises: - var res = client.do( - HTTPRequest( - URI(default_server_conn_string), - String("Hello world!")._buffer, - RequestHeader(getRequest), - ) - ) - testing.assert_equal( - String(res.body_raw[0:112]), - String( - "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type:" - " text/plain\r\nContent-Length: 12\r\nConnection: close\r\nDate: " - ), - ) - - -fn test_mojo_client_lightbug_external_req(client: MojoClient) raises: - var req = HTTPRequest( - URI("http://grandinnerastoundingspell.neverssl.com/online/"), - ) - try: - var res = client.do(req) - testing.assert_equal(res.header.status_code(), 200) - except e: - print(e) - - -fn test_python_client_lightbug(client: PythonClient) raises: - var res = client.do( - HTTPRequest( - URI(default_server_conn_string), - String("Hello world!")._buffer, - RequestHeader(getRequest), - ) - ) - testing.assert_equal( - String(res.body_raw[0:112]), - String( - "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type:" - " text/plain\r\nContent-Length: 12\r\nConnection: close\r\nDate: " - ), - ) - - -fn test_request_simple_url(inout client: PythonClient) raises -> None: - """ - Test making a simple GET request without parameters. - Validate that we get a 200 OK response. - """ - var uri = URI("http", "localhost", "/123") - var response = client.do(HTTPRequest(uri)) - testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_simple_url_with_parameters(inout client: PythonClient) raises -> None: - """ - WIP: Test making a simple GET request with query parameters. - Validate that we get a 200 OK response and that server can parse the query parameters. - """ - # This test is a WIP - var uri = URI("http", "localhost", "/123") - # uri.add_query_parameter("foo", "bar") - # uri.add_query_parameter("baz", "qux") - var response = client.do(HTTPRequest(uri)) - testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_simple_url_with_headers(inout client: PythonClient) raises -> None: - """ - WIP: Test making a simple GET request with headers. - Validate that we get a 200 OK response and that server can parse the headers. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.header.add("foo", "bar") - var response = client.do(request) - testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_post_plain_text(inout client: PythonClient) raises -> None: - """ - WIP: Test making a POST request with PLAIN TEXT body. - Validate that request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "Hello World" - # var response = client.post(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_post_json(inout client: PythonClient) raises -> None: - """ - WIP: Test making a POST request with JSON body. - Validate that the request is properly received and the server can parse the JSON. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "{\"foo\": \"bar\"}" - # var response = client.post(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_post_form(inout client: PythonClient) raises -> None: - """ - WIP: Test making a POST request with a FORM body. - Validate that the request is properly received and the server can parse the form. - Include URL encoded strings in test cases. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.post(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_post_file(inout client: PythonClient) raises -> None: - """ - WIP: Test making a POST request with a FILE body. - Validate that the request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.post(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_post_stream(inout client: PythonClient) raises -> None: - """ - WIP: Test making a POST request with a stream body. - Validate that the request is properly received and the server can parse the body. - Try stream only, stream then body, and body then stream. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.post(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_put(inout client: PythonClient) raises -> None: - """ - WIP: Test making a PUT request. - Validate that the PUT request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.put(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_patch(inout client: PythonClient) raises -> None: - """ - WIP: Test making a PATCH request. - Validate that the PATCH request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.patch(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_options(inout client: PythonClient) raises -> None: - """ - WIP: Test making an OPTIONS request. - Validate that the OPTIONS request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.options(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_delete(inout client: PythonClient) raises -> None: - """ - WIP: Test making a DELETE request. - Validate that the DELETE request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # request.body = "foo=bar&baz=qux" - # var response = client.delete(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_head(inout client: PythonClient) raises -> None: - """ - WIP: Test making a HEAD request. - Validate that the HEAD request is properly received and the server can parse the body. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # var response = client.head(request) - # testing.assert_equal(response.header.status_code(), 200) diff --git a/lightbug_http/tests/test_connection.mojo b/lightbug_http/tests/test_connection.mojo deleted file mode 100644 index 86cca84a..00000000 --- a/lightbug_http/tests/test_connection.mojo +++ /dev/null @@ -1,54 +0,0 @@ -import testing -from lightbug_http.client import Client -from lightbug_http.uri import URI -from lightbug_http.http import HTTPRequest, HTTPResponse - - -fn test_multiple_connections[T: Client](inout client: T) raises -> None: - """ - WIP: Test making multiple simultaneous connections. - Validate that the server can handle multiple simultaneous connections without dropping or mixing up data. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # var response1 = client.get(request) - # var response2 = client.get(request) - # var response3 = client.get(request) - # testing.assert_equal(response1.header.status_code(), 200) - # testing.assert_equal(response2.header.status_code(), 200) - # testing.assert_equal(response3.header.status_code(), 200) - - -fn test_idle_connections[T: Client](inout client: T) raises -> None: - """ - WIP: Test idle connections. - Establish a connection and then remain idle for longer than the server’s timeout setting to ensure the server properly closes the connection. - """ - var uri = URI("http", "localhost", "/123") - var request = HTTPRequest(uri) - # var response = client.get(request) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_keep_alive[T: Client](inout client: T) raises -> None: - """ - WIP: Test Keep-Alive. - Validate that the server keeps the connection active during periods of inactivity as expected. - """ - ... - - -fn test_keep_alive_timeout[T: Client](inout client: T) raises -> None: - """ - WIP: Test Keep-Alive Timeout. - Validate that the server closes the connection after the configured timeout period. - """ - ... - - -fn test_port_reuse[T: Client](inout client: T) raises -> None: - """ - WIP: Test Port Reusability. - After the server is stopped, ensure that the TCP port it was using can be immediately reused. Validate that the server is not leaving the port in a TIME_WAIT state. - """ - ... diff --git a/lightbug_http/tests/test_cookies.mojo b/lightbug_http/tests/test_cookies.mojo deleted file mode 100644 index 44d2545f..00000000 --- a/lightbug_http/tests/test_cookies.mojo +++ /dev/null @@ -1,31 +0,0 @@ -import testing -from lightbug_http.client import Client -from lightbug_http.uri import URI -from lightbug_http.http import HTTPRequest, HTTPResponse - - -fn test_request_with_cookies[T: Client](inout client: T) raises -> None: - """ - WIP: Test making a simple GET request with cookies. - Validate that the cookies are parsed correctly. - """ - var uri = URI("http", "localhost", "/123") - # var cookies = [Cookie("foo", "bar")] - # var response = client.get(HTTPRequest(uri, cookies=cookies)) - # testing.assert_equal(response.header.status_code(), 200) - - -fn test_request_with_invalid_cookies[T: Client](inout client: T) raises -> None: - """ - WIP: We should be able to parse invalid or non-spec conformant cookies, such as the ones set by Okta (see below). - From Starlette (https://github.com/encode/starlette/blob/master/tests/test_requests.py). - """ - var uri = URI("http", "localhost", "/123") - # var cookies = [ - # Cookie("importantCookie", "importantValue"), - # Cookie("okta-oauth-redirect-params", '{"responseType":"code","state":"somestate","nonce":"somenonce","scopes":["openid","profile","email","phone"],"urls":{"issuer":"https://subdomain.okta.com/oauth2/authServer","authorizeUrl":"https://subdomain.okta.com/oauth2/authServer/v1/authorize","userinfoUrl":"https://subdomain.okta.com/oauth2/authServer/v1/userinfo"}}'), - # Cookie("provider-oauth-nonce", "validAsciiblabla"), - # Cookie("sessionCookie", "importantSessionValue"), - # ] - # var response = client.get(HTTPRequest(uri, cookies=cookies)) - # testing.assert_equal(response.header.status_code(), 200) diff --git a/lightbug_http/tests/test_io.mojo b/lightbug_http/tests/test_io.mojo deleted file mode 100644 index 59121df7..00000000 --- a/lightbug_http/tests/test_io.mojo +++ /dev/null @@ -1,9 +0,0 @@ -import testing -from lightbug_http.io.bytes import Bytes, bytes_equal - - -fn test_bytes_equal() raises: - var test1 = String("test")._buffer - var test2 = String("test")._buffer - var equal = bytes_equal(test1, test2) - testing.assert_true(equal) diff --git a/lightbug_http/tests/test_server.mojo b/lightbug_http/tests/test_server.mojo deleted file mode 100644 index 88673ddb..00000000 --- a/lightbug_http/tests/test_server.mojo +++ /dev/null @@ -1,56 +0,0 @@ -import testing -from lightbug_http.python.server import PythonServer, PythonTCPListener -from lightbug_http.python.client import PythonClient -from lightbug_http.python.net import PythonListenConfig -from lightbug_http.http import HTTPRequest -from lightbug_http.net import Listener -from lightbug_http.client import Client -from lightbug_http.uri import URI -from lightbug_http.header import RequestHeader -from lightbug_http.tests.utils import ( - getRequest, - default_server_conn_string, - defaultExpectedGetResponse, -) - - -fn test_python_server[C: Client](client: C, ln: PythonListenConfig) raises -> None: - """ - Run a server listening on a port. - Validate that the server is listening on the provided port. - """ - ... - # var conn = ln.accept() - # var res = client.do( - # HTTPRequest( - # URI(default_server_conn_string), - # String("Hello world!")._buffer, - # RequestHeader(getRequest), - # ) - # ) - # testing.assert_equal( - # String(res.body_raw), - # defaultExpectedGetResponse, - # ) - - -fn test_server_busy_port() raises -> None: - """ - Test that we get an error if we try to run a server on a port that is already in use. - """ - ... - - -fn test_server_invalid_host() raises -> None: - """ - Test that we get an error if we try to run a server on an invalid host. - """ - ... - - -fn test_tls() raises -> None: - """ - TLS Support. - Validate that the server supports TLS. - """ - ... diff --git a/lightbug_http/uri.mojo b/lightbug_http/uri.mojo index 221012d7..ca31a0cb 100644 --- a/lightbug_http/uri.mojo +++ b/lightbug_http/uri.mojo @@ -1,4 +1,4 @@ -from lightbug_http.io.bytes import Bytes, bytes_equal +from lightbug_http.io.bytes import Bytes, BytesView, bytes_equal, bytes from lightbug_http.strings import ( strSlash, strHttp11, @@ -37,10 +37,28 @@ struct URI: self.__path = Bytes() self.__query_string = Bytes() self.__hash = Bytes() - self.__host = String("127.0.0.1")._buffer + self.__host = Bytes() self.__http_version = Bytes() self.disable_path_normalization = False - self.__full_uri = full_uri._buffer + self.__full_uri = bytes(full_uri, pop=False) + self.__request_uri = Bytes() + self.__username = Bytes() + self.__password = Bytes() + + fn __init__( + inout self, + full_uri: String, + host: String + ) -> None: + self.__path_original = Bytes() + self.__scheme = Bytes() + self.__path = Bytes() + self.__query_string = Bytes() + self.__hash = Bytes() + self.__host = bytes(host) + self.__http_version = Bytes() + self.disable_path_normalization = False + self.__full_uri = bytes(full_uri) self.__request_uri = Bytes() self.__username = Bytes() self.__password = Bytes() @@ -51,12 +69,12 @@ struct URI: host: String, path: String, ) -> None: - self.__path_original = path._buffer - self.__scheme = scheme._buffer - self.__path = normalise_path(path._buffer, self.__path_original) + self.__path_original = bytes(path) + self.__scheme = scheme.as_bytes() + self.__path = normalise_path(bytes(path), self.__path_original) self.__query_string = Bytes() self.__hash = Bytes() - self.__host = host._buffer + self.__host = bytes(host) self.__http_version = Bytes() self.disable_path_normalization = False self.__full_uri = Bytes() @@ -92,102 +110,146 @@ struct URI: self.__username = username self.__password = password - fn path_original(self) -> Bytes: - return self.__path_original + fn path_original(self) -> BytesView: + return BytesView(unsafe_ptr=self.__path_original.unsafe_ptr(), len=self.__path_original.size) fn set_path(inout self, path: String) -> Self: - self.__path = normalise_path(path._buffer, self.__path_original) + self.__path = normalise_path(bytes(path), self.__path_original) return self - fn set_path_sbytes(inout self, path: Bytes) -> Self: + fn set_path_bytes(inout self, path: Bytes) -> Self: self.__path = normalise_path(path, self.__path_original) return self fn path(self) -> String: - var processed_path = self.__path - if len(processed_path) == 0: - processed_path = strSlash - return String(processed_path) + if len(self.__path) == 0: + return strSlash + return String(self.__path) + + fn path_bytes(self) -> BytesView: + if len(self.__path) == 0: + return BytesView(unsafe_ptr=strSlash.as_bytes_slice().unsafe_ptr(), len=2) + return BytesView(unsafe_ptr=self.__path.unsafe_ptr(), len=self.__path.size) fn set_scheme(inout self, scheme: String) -> Self: - self.__scheme = scheme._buffer + self.__scheme = bytes(scheme) return self fn set_scheme_bytes(inout self, scheme: Bytes) -> Self: self.__scheme = scheme return self - fn scheme(self) -> Bytes: - var processed_scheme = self.__scheme - if len(processed_scheme) == 0: - processed_scheme = strHttp - return processed_scheme + fn scheme(self) -> BytesView: + if len(self.__scheme) == 0: + return BytesView(unsafe_ptr=strHttp.as_bytes_slice().unsafe_ptr(), len=5) + return BytesView(unsafe_ptr=self.__scheme.unsafe_ptr(), len=self.__scheme.size) + + fn http_version(self) -> BytesView: + if len(self.__http_version) == 0: + return BytesView(unsafe_ptr=strHttp11.as_bytes_slice().unsafe_ptr(), len=9) + return BytesView(unsafe_ptr=self.__http_version.unsafe_ptr(), len=self.__http_version.size) - fn http_version(self) -> Bytes: + fn http_version_str(self) -> String: return self.__http_version fn set_http_version(inout self, http_version: String) -> Self: - self.__http_version = http_version._buffer + self.__http_version = bytes(http_version) + return self + + fn set_http_version_bytes(inout self, http_version: Bytes) -> Self: + self.__http_version = http_version return self fn is_http_1_1(self) -> Bool: - return bytes_equal(self.__http_version, strHttp11) + return bytes_equal(self.http_version(), bytes(strHttp11, pop=False)) fn is_http_1_0(self) -> Bool: - return bytes_equal(self.__http_version, strHttp10) + return bytes_equal(self.http_version(), bytes(strHttp10, pop=False)) fn is_https(self) -> Bool: - return bytes_equal(self.__scheme, https._buffer) + return bytes_equal(self.__scheme, bytes(https, pop=False)) fn is_http(self) -> Bool: - return bytes_equal(self.__scheme, http._buffer) or len(self.__scheme) == 0 + return bytes_equal(self.__scheme, bytes(http, pop=False)) or len(self.__scheme) == 0 fn set_request_uri(inout self, request_uri: String) -> Self: - self.__request_uri = request_uri._buffer + self.__request_uri = bytes(request_uri) return self fn set_request_uri_bytes(inout self, request_uri: Bytes) -> Self: self.__request_uri = request_uri return self + + fn request_uri(self) -> BytesView: + return BytesView(unsafe_ptr=self.__request_uri.unsafe_ptr(), len=self.__request_uri.size) fn set_query_string(inout self, query_string: String) -> Self: - self.__query_string = query_string._buffer + self.__query_string = bytes(query_string) return self fn set_query_string_bytes(inout self, query_string: Bytes) -> Self: self.__query_string = query_string return self + + fn query_string(self) -> BytesView: + return BytesView(unsafe_ptr=self.__query_string.unsafe_ptr(), len=self.__query_string.size) fn set_hash(inout self, hash: String) -> Self: - self.__hash = hash._buffer + self.__hash = bytes(hash) return self fn set_hash_bytes(inout self, hash: Bytes) -> Self: self.__hash = hash return self - fn hash(self) -> Bytes: - return self.__hash + fn hash(self) -> BytesView: + return BytesView(unsafe_ptr=self.__hash.unsafe_ptr(), len=self.__hash.size) fn set_host(inout self, host: String) -> Self: - self.__host = host._buffer + self.__host = bytes(host) return self fn set_host_bytes(inout self, host: Bytes) -> Self: self.__host = host return self - fn host(self) -> Bytes: + fn host(self) -> BytesView: + return BytesView(unsafe_ptr=self.__host.unsafe_ptr(), len=self.__host.size) + + fn host_str(self) -> String: return self.__host + fn full_uri(self) -> BytesView: + return BytesView(unsafe_ptr=self.__full_uri.unsafe_ptr(), len=self.__full_uri.size) + + fn set_username(inout self, username: String) -> Self: + self.__username = bytes(username) + return self + + fn set_username_bytes(inout self, username: Bytes) -> Self: + self.__username = username + return self + + fn username(self) -> BytesView: + return BytesView(unsafe_ptr=self.__username.unsafe_ptr(), len=self.__username.size) + + fn set_password(inout self, password: String) -> Self: + self.__password = bytes(password) + return self + + fn set_password_bytes(inout self, password: Bytes) -> Self: + self.__password = password + return self + + fn password(self) -> BytesView: + return BytesView(unsafe_ptr=self.__password.unsafe_ptr(), len=self.__password.size) + fn parse(inout self) raises -> None: var raw_uri = String(self.__full_uri) - # Defaults to HTTP/1.1 var proto_str = String(strHttp11) var is_https = False - # Parse the protocol var proto_end = raw_uri.find("://") var remainder_uri: String if proto_end >= 0: @@ -197,55 +259,36 @@ struct URI: remainder_uri = raw_uri[proto_end + 3:] else: remainder_uri = raw_uri - # Parse the host and optional port + + _ = self.set_scheme_bytes(proto_str.as_bytes_slice()) + var path_start = remainder_uri.find("/") var host_and_port: String var request_uri: String if path_start >= 0: host_and_port = remainder_uri[:path_start] request_uri = remainder_uri[path_start:] - self.__host = host_and_port[:path_start]._buffer + _ = self.set_host_bytes(bytes(host_and_port[:path_start], pop=False)) else: host_and_port = remainder_uri request_uri = strSlash - self.__host = host_and_port._buffer + _ = self.set_host_bytes(bytes(host_and_port, pop=False)) if is_https: - _ = self.set_scheme(https) + _ = self.set_scheme_bytes(bytes(https, pop=False)) else: - _ = self.set_scheme(http) + _ = self.set_scheme_bytes(bytes(http, pop=False)) - # Parse path var n = request_uri.find("?") if n >= 0: - self.__path_original = request_uri[:n]._buffer - self.__query_string = request_uri[n + 1 :]._buffer + self.__path_original = bytes(request_uri[:n], pop=False) + self.__query_string = bytes(request_uri[n + 1 :], pop=False) else: - self.__path_original = request_uri._buffer + self.__path_original = bytes(request_uri, pop=False) self.__query_string = Bytes() - self.__path = normalise_path(self.__path_original, self.__path_original) - - _ = self.set_request_uri(request_uri) - - fn request_uri(self) -> Bytes: - return self.__request_uri - - fn set_username(inout self, username: String) -> Self: - self.__username = username._buffer - return self - - fn set_username_bytes(inout self, username: Bytes) -> Self: - self.__username = username - return self - - fn set_password(inout self, password: String) -> Self: - self.__password = password._buffer - return self - - fn set_password_bytes(inout self, password: Bytes) -> Self: - self.__password = password - return self + _ = self.set_path_bytes(normalise_path(self.__path_original, self.__path_original)) + _ = self.set_request_uri_bytes(bytes(request_uri, pop=False)) fn normalise_path(path: Bytes, path_original: Bytes) -> Bytes: diff --git a/run_tests.mojo b/run_tests.mojo new file mode 100644 index 00000000..5d3411c3 --- /dev/null +++ b/run_tests.mojo @@ -0,0 +1,13 @@ +from tests.test_io import test_io +from tests.test_http import test_http +from tests.test_header import test_header +from tests.test_uri import test_uri +# from lightbug_http.test.test_client import test_client + +fn main() raises: + test_io() + test_http() + test_header() + test_uri() + # test_client() + diff --git a/tests/__init__.mojo b/tests/__init__.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_client.mojo b/tests/test_client.mojo new file mode 100644 index 00000000..47c8059b --- /dev/null +++ b/tests/test_client.mojo @@ -0,0 +1,67 @@ +from external.gojo.tests.wrapper import MojoTest +from external.morrow import Morrow +from tests.utils import ( + default_server_conn_string, + getRequest, +) +from lightbug_http.python.client import PythonClient +from lightbug_http.sys.client import MojoClient +from lightbug_http.http import HTTPRequest, encode +from lightbug_http.uri import URI +from lightbug_http.header import RequestHeader +from lightbug_http.io.bytes import bytes + + +def test_client(): + var mojo_client = MojoClient() + var py_client = PythonClient() + test_mojo_client_lightbug_external_req(mojo_client) + test_python_client_lightbug(py_client) + + +fn test_mojo_client_lightbug(client: MojoClient) raises: + var test = MojoTest("test_mojo_client_lightbug") + var res = client.do( + HTTPRequest( + URI(default_server_conn_string), + bytes("Hello world!"), + RequestHeader(getRequest), + ) + ) + test.assert_equal( + String(res.body_raw[0:112]), + String( + "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type:" + " text/plain\r\nContent-Length: 12\r\nConnection: close\r\nDate: " + ), + ) + + +fn test_mojo_client_lightbug_external_req(client: MojoClient) raises: + var test = MojoTest("test_mojo_client_lightbug_external_req") + var req = HTTPRequest( + URI("http://grandinnerastoundingspell.neverssl.com/online/"), + ) + try: + var res = client.do(req) + test.assert_equal(res.header.status_code(), 200) + except e: + print(e) + + +fn test_python_client_lightbug(client: PythonClient) raises: + var test = MojoTest("test_python_client_lightbug") + var res = client.do( + HTTPRequest( + URI(default_server_conn_string), + bytes("Hello world!"), + RequestHeader(getRequest), + ) + ) + test.assert_equal( + String(res.body_raw[0:112]), + String( + "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type:" + " text/plain\r\nContent-Length: 12\r\nConnection: close\r\nDate: " + ), + ) diff --git a/tests/test_header.mojo b/tests/test_header.mojo new file mode 100644 index 00000000..57c265f7 --- /dev/null +++ b/tests/test_header.mojo @@ -0,0 +1,47 @@ +from external.gojo.tests.wrapper import MojoTest +from external.gojo.bytes import buffer +from external.gojo.bufio import Reader +from lightbug_http.header import RequestHeader, ResponseHeader +from lightbug_http.io.bytes import Bytes, bytes +from lightbug_http.strings import empty_string +from lightbug_http.net import default_buffer_size + +def test_header(): + test_parse_request_header() + test_parse_response_header() + +def test_parse_request_header(): + var test = MojoTest("test_parse_request_header") + var headers_str = bytes('''GET /index.html HTTP/1.1\r\nHost: example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\n''') + var header = RequestHeader() + var b = Bytes(headers_str) + var buf = buffer.new_buffer(b^) + var reader = Reader(buf^) + _ = header.parse_raw(reader) + test.assert_equal(String(header.request_uri()), "/index.html") + test.assert_equal(String(header.protocol()), "HTTP/1.1") + test.assert_equal(header.no_http_1_1, False) + test.assert_equal(String(header.host()), String("example.com")) + test.assert_equal(String(header.user_agent()), "Mozilla/5.0") + test.assert_equal(String(header.content_type()), "text/html") + test.assert_equal(header.content_length(), 1234) + test.assert_equal(header.connection_close(), True) + +def test_parse_response_header(): + var test = MojoTest("test_parse_response_header") + var headers_str = bytes('''HTTP/1.1 200 OK\r\nServer: example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Encoding: gzip\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\n''') + var header = ResponseHeader() + var b = Bytes(headers_str) + var buf = buffer.new_buffer(b^) + var reader = Reader(buf^) + _ = header.parse_raw(reader) + test.assert_equal(String(header.protocol()), "HTTP/1.1") + test.assert_equal(header.no_http_1_1, False) + test.assert_equal(header.status_code(), 200) + test.assert_equal(String(header.status_message()), "OK") + test.assert_equal(String(header.server()), "example.com") + test.assert_equal(String(header.content_type()), "text/html") + test.assert_equal(String(header.content_encoding()), "gzip") + test.assert_equal(header.content_length(), 1234) + test.assert_equal(header.connection_close(), True) + test.assert_equal(header.trailer_str(), "end-of-message") diff --git a/tests/test_http.mojo b/tests/test_http.mojo new file mode 100644 index 00000000..f5f85400 --- /dev/null +++ b/tests/test_http.mojo @@ -0,0 +1,65 @@ +from external.gojo.tests.wrapper import MojoTest +from lightbug_http.io.bytes import Bytes, bytes +from lightbug_http.http import HTTPRequest, HTTPResponse, split_http_string, encode +from lightbug_http.header import RequestHeader +from lightbug_http.uri import URI +from tests.utils import ( + default_server_conn_string, + getRequest, +) + +def test_http(): + test_split_http_string() + test_encode_http_request() + test_encode_http_response() + +def test_split_http_string(): + var test = MojoTest("test_split_http_string") + var cases = Dict[StringLiteral, List[StringLiteral]]() + + cases["GET /index.html HTTP/1.1\r\nHost: www.example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\nHello, World!\0"] = + List("GET /index.html HTTP/1.1", + "Host: www.example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message", + "Hello, World!") + + for c in cases.items(): + var buf = bytes((c[].key)) + request_first_line, request_headers, request_body = split_http_string(buf) + test.assert_equal(request_first_line, c[].value[0]) + test.assert_equal(request_headers, String(c[].value[1])) + test.assert_equal(request_body, c[].value[2]) + +def test_encode_http_request(): + var test = MojoTest("test_encode_http_request") + var uri = URI(default_server_conn_string) + var req = HTTPRequest( + uri, + String("Hello world!").as_bytes(), + RequestHeader(getRequest), + ) + + var req_encoded = encode(req) + test.assert_equal(String(req_encoded), "GET / HTTP/1.1\r\nContent-Length: 12\r\nConnection: keep-alive\r\n\r\nHello world!") + +def test_encode_http_response(): + var test = MojoTest("test_encode_http_response") + var res = HTTPResponse( + bytes("Hello, World!"), + ) + + var res_encoded = encode(res) + var res_str = String(res_encoded) + + # Since we cannot compare the exact date, we will only compare the headers until the date and the body + var expected_full = "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type: application/octet-stream\r\nContent-Length: 13\r\nConnection: keep-alive\r\nDate: 2024-06-02T13:41:50.766880+00:00\r\n\r\nHello, World!" + + var expected_headers_len = 124 + var hello_world_len = len(String("Hello, World!")) - 1 # -1 for the null terminator + var date_header_len = len(String("Date: 2024-06-02T13:41:50.766880+00:00")) + + var expected_split = String(expected_full).split("\r\n\r\n") + var expected_headers = expected_split[0] + var expected_body = expected_split[1] + + test.assert_equal(res_str[:expected_headers_len], expected_headers[:len(expected_headers) - date_header_len]) + test.assert_equal(res_str[(len(res_str) - hello_world_len):len(res_str) + 1], expected_body) \ No newline at end of file diff --git a/tests/test_io.mojo b/tests/test_io.mojo new file mode 100644 index 00000000..93363c34 --- /dev/null +++ b/tests/test_io.mojo @@ -0,0 +1,31 @@ +from external.gojo.tests.wrapper import MojoTest +from lightbug_http.io.bytes import Bytes, bytes_equal, bytes + +def test_io(): + test_string_literal_to_bytes() + +fn test_string_literal_to_bytes() raises: + var test = MojoTest("test_string_to_bytes") + var cases = Dict[StringLiteral, Bytes]() + cases[""] = Bytes() + cases["Hello world!"] = List[UInt8](72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 33) + cases["\0"] = List[UInt8](0) + cases["\0\0\0\0"] = List[UInt8](0, 0, 0, 0) + cases["OK"] = List[UInt8](79, 75) + cases["HTTP/1.1 200 OK"] = List[UInt8](72, 84, 84, 80, 47, 49, 46, 49, 32, 50, 48, 48, 32, 79, 75) + + for c in cases.items(): + test.assert_true(bytes_equal(bytes(c[].key), c[].value)) + +fn test_string_to_bytes() raises: + var test = MojoTest("test_string_to_bytes") + var cases = Dict[String, Bytes]() + cases[String("")] = Bytes() + cases[String("Hello world!")] = List[UInt8](72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 33) + cases[String("\0")] = List[UInt8](0) + cases[String("\0\0\0\0")] = List[UInt8](0, 0, 0, 0) + cases[String("OK")] = List[UInt8](79, 75) + cases[String("HTTP/1.1 200 OK")] = List[UInt8](72, 84, 84, 80, 47, 49, 46, 49, 32, 50, 48, 48, 32, 79, 75) + + for c in cases.items(): + test.assert_true(bytes_equal(bytes(c[].key), c[].value)) \ No newline at end of file diff --git a/tests/test_net.mojo b/tests/test_net.mojo new file mode 100644 index 00000000..23b91c3f --- /dev/null +++ b/tests/test_net.mojo @@ -0,0 +1,6 @@ + +def test_net(): + test_split_host_port() + +def test_split_host_port(): + ... \ No newline at end of file diff --git a/tests/test_uri.mojo b/tests/test_uri.mojo new file mode 100644 index 00000000..3af3df30 --- /dev/null +++ b/tests/test_uri.mojo @@ -0,0 +1,113 @@ +from external.gojo.tests.wrapper import MojoTest +from lightbug_http.uri import URI +from lightbug_http.strings import empty_string +from lightbug_http.io.bytes import Bytes + +def test_uri(): + test_uri_no_parse_defaults() + test_uri_parse_http_with_port() + test_uri_parse_https_with_port() + test_uri_parse_http_with_path() + test_uri_parse_https_with_path() + test_uri_parse_http_basic() + test_uri_parse_http_basic_www() + test_uri_parse_http_with_query_string() + test_uri_parse_http_with_hash() + test_uri_parse_http_with_query_string_and_hash() + +def test_uri_no_parse_defaults(): + var test = MojoTest("test_uri_no_parse_defaults") + var uri = URI("http://example.com") + test.assert_equal(String(uri.full_uri()), "http://example.com") + test.assert_equal(String(uri.scheme()), "http") + test.assert_equal(String(uri.path()), "/") + +def test_uri_parse_http_with_port(): + var test = MojoTest("test_uri_parse_http_with_port") + var uri = URI("http://example.com:8080/index.html") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "http") + test.assert_equal(String(uri.host()), "example.com:8080") + test.assert_equal(String(uri.path()), "/index.html") + test.assert_equal(String(uri.path_original()), "/index.html") + test.assert_equal(String(uri.request_uri()), "/index.html") + test.assert_equal(String(uri.http_version()), "HTTP/1.1") + test.assert_equal(uri.is_http_1_0(), False) + test.assert_equal(uri.is_http_1_1(), True) + test.assert_equal(uri.is_https(), False) + test.assert_equal(uri.is_http(), True) + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_https_with_port(): + var test = MojoTest("test_uri_parse_https_with_port") + var uri = URI("https://example.com:8080/index.html") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "https") + test.assert_equal(String(uri.host()), "example.com:8080") + test.assert_equal(String(uri.path()), "/index.html") + test.assert_equal(String(uri.path_original()), "/index.html") + test.assert_equal(String(uri.request_uri()), "/index.html") + test.assert_equal(uri.is_https(), True) + test.assert_equal(uri.is_http(), False) + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_http_with_path(): + var test = MojoTest("test_uri_parse_http_with_path") + uri = URI("http://example.com/index.html") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "http") + test.assert_equal(String(uri.host()), "example.com") + test.assert_equal(String(uri.path()), "/index.html") + test.assert_equal(String(uri.path_original()), "/index.html") + test.assert_equal(String(uri.request_uri()), "/index.html") + test.assert_equal(uri.is_https(), False) + test.assert_equal(uri.is_http(), True) + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_https_with_path(): + var test = MojoTest("test_uri_parse_https_with_path") + uri = URI("https://example.com/index.html") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "https") + test.assert_equal(String(uri.host()), "example.com") + test.assert_equal(String(uri.path()), "/index.html") + test.assert_equal(String(uri.path_original()), "/index.html") + test.assert_equal(String(uri.request_uri()), "/index.html") + test.assert_equal(uri.is_https(), True) + test.assert_equal(uri.is_http(), False) + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_http_basic(): + var test = MojoTest("test_uri_parse_http_basic") + uri = URI("http://example.com") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "http") + test.assert_equal(String(uri.host()), "example.com") + test.assert_equal(String(uri.path()), "/") + test.assert_equal(String(uri.path_original()), "/") + test.assert_equal(String(uri.http_version()), "HTTP/1.1") + test.assert_equal(String(uri.request_uri()), "/") + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_http_basic_www(): + var test = MojoTest("test_uri_parse_http_basic_www") + uri = URI("http://www.example.com") + _ = uri.parse() + test.assert_equal(String(uri.scheme()), "http") + test.assert_equal(String(uri.host()), "www.example.com") + test.assert_equal(String(uri.path()), "/") + test.assert_equal(String(uri.path_original()), "/") + test.assert_equal(String(uri.request_uri()), "/") + test.assert_equal(String(uri.http_version()), "HTTP/1.1") + test.assert_equal(String(uri.query_string()), String(empty_string.as_bytes_slice())) + +def test_uri_parse_http_with_query_string(): + ... + +def test_uri_parse_http_with_hash(): + ... + +def test_uri_parse_http_with_query_string_and_hash(): + ... + + diff --git a/lightbug_http/tests/utils.mojo b/tests/utils.mojo similarity index 96% rename from lightbug_http/tests/utils.mojo rename to tests/utils.mojo index 8daec149..b255f315 100644 --- a/lightbug_http/tests/utils.mojo +++ b/tests/utils.mojo @@ -7,24 +7,23 @@ from lightbug_http.net import Listener, Addr, Connection, TCPAddr from lightbug_http.service import HTTPService, OK from lightbug_http.server import ServerTrait from lightbug_http.client import Client - +from lightbug_http.io.bytes import bytes alias default_server_conn_string = "http://localhost:8080" -alias getRequest = String( +alias getRequest = bytes( "GET /foobar?baz HTTP/1.1\r\nHost: google.com\r\nUser-Agent: aaa/bbb/ccc/ddd/eee" " Firefox Chrome MSIE Opera\r\n" + "Referer: http://example.com/aaa?bbb=ccc\r\nCookie: foo=bar; baz=baraz;" " aa=aakslsdweriwereowriewroire\r\n\r\n" -)._buffer +) -alias defaultExpectedGetResponse = String( +alias defaultExpectedGetResponse = bytes( "HTTP/1.1 200 OK\r\nServer: lightbug_http\r\nContent-Type:" " text/plain\r\nContent-Length: 12\r\nConnection: close\r\nDate: \r\n\r\nHello" " world!" ) - @parameter fn new_httpx_client() -> PythonObject: try: @@ -34,11 +33,9 @@ fn new_httpx_client() -> PythonObject: print("Could not set up httpx client: " + e.__str__()) return None - fn new_fake_listener(request_count: Int, request: Bytes) -> FakeListener: return FakeListener(request_count, request) - struct ReqInfo: var full_uri: URI var host: String @@ -49,7 +46,6 @@ struct ReqInfo: self.host = host self.is_tls = is_tls - struct FakeClient(Client): """FakeClient doesn't actually send any requests, but it extracts useful information from the input. """ @@ -78,7 +74,7 @@ struct FakeClient(Client): self.req_is_tls = False fn do(self, req: HTTPRequest) raises -> HTTPResponse: - return OK(String(defaultExpectedGetResponse)._buffer) + return OK(String(defaultExpectedGetResponse)) fn extract(inout self, req: HTTPRequest) raises -> ReqInfo: var full_uri = req.uri() @@ -101,7 +97,6 @@ struct FakeClient(Client): return ReqInfo(full_uri, host, is_tls) - struct FakeServer(ServerTrait): var __listener: FakeListener var __handler: FakeResponder @@ -132,15 +127,13 @@ struct FakeServer(ServerTrait): fn serve(self, ln: Listener, handler: HTTPService) raises -> None: ... - @value struct FakeResponder(HTTPService): fn func(self, req: HTTPRequest) raises -> HTTPResponse: var method = String(req.header.method()) if method != "GET": raise Error("Did not expect a non-GET request! Got: " + method) - return OK(String("Hello, world!")._buffer) - + return OK(bytes("Hello, world!")) @value struct FakeConnection(Connection): @@ -165,7 +158,6 @@ struct FakeConnection(Connection): fn remote_addr(self) raises -> TCPAddr: return TCPAddr() - @value struct FakeListener: var request_count: Int @@ -197,7 +189,6 @@ struct FakeListener: fn addr(self) -> TCPAddr: return TCPAddr() - @value struct TestStruct: var a: String @@ -209,7 +200,7 @@ struct TestStruct: fn __init__(inout self, a: String, b: String) -> None: self.a = a self.b = b - self.c = String("c")._buffer + self.c = bytes("c") self.d = 1 self.e = TestStructNested("a", 1) @@ -220,7 +211,6 @@ struct TestStruct: fn set_a_copy(self, a: String) -> Self: return Self(a, self.b) - @value struct TestStructNested: var a: String