In [1]:
import pdb
import os
import sys
from ctypes import (
    CDLL,
    POINTER,
    Structure,
    byref,
    string_at,
    c_char_p,
    c_int32,
    c_int64,
    c_uint64,
    c_ubyte,
)

from ctypes.util import find_library
from typing import Optional, Union

LIB: CDLL = None

class FfiByteBuffer(Structure):
    """A byte buffer allocated by python."""
    _fields_ = [
        ("length", c_int64),
        ("data", POINTER(c_ubyte)),
    ]


class FfiError(Structure):
    """An error allocated by python."""
    _fields_ = [
        ("code", c_int32),
        ("message", c_char_p),
    ]


def _decode_bytes(arg: Optional[Union[str, bytes, FfiByteBuffer]]) -> bytes:
    if isinstance(arg, FfiByteBuffer):
        return string_at(arg.data, arg.length)
    if isinstance(arg, memoryview):
        return string_at(arg.obj, arg.nbytes)
    if isinstance(arg, bytearray):
        return arg
    if arg is not None:
        if isinstance(arg, str):
            return arg.encode("utf-8")
    return bytearray()


def _encode_bytes(arg: Optional[Union[str, bytes, FfiByteBuffer]]) -> FfiByteBuffer:
    if isinstance(arg, FfiByteBuffer):
        return arg
    buf = FfiByteBuffer()
    if isinstance(arg, memoryview):
        buf.length = arg.nbytes
        if arg.contiguous and not arg.readonly:
            buf.data = (c_ubyte * buf.length).from_buffer(arg.obj)
        else:
            buf.data = (c_ubyte * buf.length).from_buffer_copy(arg.obj)
    elif isinstance(arg, bytearray):
        buf.length = len(arg)
        if buf.length > 0:
            buf.data = (c_ubyte * buf.length).from_buffer(arg)
    elif arg is not None:
        if isinstance(arg, str):
            arg = arg.encode("utf-8")
        buf.length = len(arg)
        if buf.length > 0:
            buf.data = (c_ubyte * buf.length).from_buffer_copy(arg)
    return buf


def _load_library(lib_name: str) -> CDLL:
    lib_prefix_mapping = {"win32": ""}
    lib_suffix_mapping = {"darwin": ".dylib", "win32": ".dll"}
    try:
        os_name = sys.platform
        lib_prefix = lib_prefix_mapping.get(os_name, "lib")
        lib_suffix = lib_suffix_mapping.get(os_name, ".so")
        lib_path = os.path.join(
            os.path.dirname(os.getcwd()), f"agora-allosaurus-rs/target/release/{lib_prefix}{lib_name}{lib_suffix}"
        )
        return CDLL(lib_path)
    except KeyError:
        print ("Unknown platform for shared library")
    except OSError:
        print ("Library not loaded from python package")

    lib_path = find_library(lib_name)
    if not lib_path:
        if sys.platform == "darwin":
            ld = os.getenv("DYLD_LIBRARY_PATH")
            lib_path = os.path.join(ld, "liboberon.dylib")
            if os.path.exists(lib_path):
                return CDLL(lib_path)

            ld = os.getenv("DYLD_FALLBACK_LIBRARY_PATH")
            lib_path = os.path.join(ld, "liboberon.dylib")
            if os.path.exists(lib_path):
                return CDLL(lib_path)
        elif sys.platform != "win32":
            ld = os.getenv("LD_LIBRARY_PATH")
            lib_path = os.path.join(ld, "liboberon.so")
            if os.path.exists(lib_path):
                return CDLL(lib_path)

        raise Exception(f"Error loading library: {lib_name}")
    try:
        return CDLL(lib_path)
    except OSError as e:
        raise Exception(f"Error loading library: {lib_name}")


def _get_library() -> CDLL:
    global LIB
    if LIB is None:
        LIB = _load_library("agora_allosaurus_rs")

    return LIB

def _get_func(fn_name: str):
    return getattr(_get_library(), fn_name)

def _free_buffer(buffer: FfiByteBuffer):
    lib_fn = _get_func("allosaurus_byte_buffer_free")
    lib_fn(byref(buffer))


def _free_string(err: FfiError):
    lib_fn = _get_func("allosaurus_string_free")
    lib_fn(byref(err))


def _free_handle(handle: c_int64, err: FfiError):
    lib_fn = _get_func("allosaurus_create_proof_free")
    lib_fn(handle, byref(err))

In [2]:
def new_server() -> c_int64:
    err = FfiError()
    lib_fn = _get_func("allosaurus_new_server")
    lib_fn.restype = c_uint64

    handle = lib_fn(byref(err))
    if handle == 0:
        message = string_at(err.message)
        raise Exception(message)
    handle = c_uint64(handle)
    return handle
server = new_server()
server

c_ulong(4707134412381224962)

In [10]:
def new_user(server) -> c_int64:
    buffer = FfiByteBuffer()
    err = FfiError()
    lib_fn = _get_func("allosaurus_new_user")
    lib_fn(server, byref(buffer), byref(err))

    if not buffer:
        message = string_at(err.message)
        raise Exception(message)
    buffer = _decode_bytes(buffer)
    return buffer
user = new_user(server)
user

b"@k_3\xd2\x12'N\xf2/\x02\xdb\xa8d\tt\xceO\x92\x17\xa0\nf\x80\x97\xe4)\x1b\xff\x049\xac\x00\x89\x1f\xb2\x98\xfd\xb0\xf1&J\x1f-\x0eD5\xe5\x03\x9bE\x99\xc4\xedku\xb5\xbeU\xc6\xd5C\x18]_0\xc8\x81j\x1c\xc5s\x95r\xdf\x1bRC$\x98\x00\x99\x90\xbeZ\x12\x05\xcd\x18s\x97!\xf3\xb0\xc1;\x04\xa1|?\x894|@5\xde\xc5\x1c\x10`\x87d-\xf5\xbc<\xa2\x9d\x83\xa1\xd4l8\xb4m\x87}\xc6\xe1\x18\x106\x0f\xe3]\x93\xde\xf5\x82\x8d\xd8/Q\xf7\xc2\xd2uHJ\x01\xc4\x83v\xa5}/2\x81\x7fY\xaa\xaah~G\xe1\xee\xc0\x05\x0e\xa3\x8b\x95Ls\xde\x11\xb7\xe4)\xe1\xebi\x0co\xdd\xfc\xf9l,\x10\xb4v\xab@\xd9\x1c\x81\xadd \x18\xfe\xd7\x16\xb2\xa8\x91}\xf1\xa5z\xbe)\x84\xb5\xe8\x86k[\xb1\xdf7>\xb6\x16\xbdqe\x16\x0f\xac\x81\x99\x93\xd0\xc6\x80\xfdX?\xa9}\x88\x88:\xe4\x89\x88\xc9\xc8\x9e\x08\xb6\xfa*T\x9f\x90\xfc\xfe\x0c\x87\xe8j.~\xbe9\x82\x11,W\x02"

In [11]:
def server_add(server, user) -> c_int64:
    buffer = FfiByteBuffer()
    err = FfiError()
    lib_fn = _get_func("allosaurus_server_add")
    lib_fn(server, _encode_bytes(user), byref(buffer), byref(err))
    if not buffer:
        message = string_at(err.message)
        raise Exception(message)
    buffer = _decode_bytes(buffer)
    return buffer
server_add(server, user)

b'\xa6\xfc\xe3\xf6N d,w\xf2\xe4\xfd\xc7\xd2\xb9\xc0\x1a\xffn;\xcf:t\x9b\x9b\x1d\xc9\xc6\xbb\xe3\tx3\xbbnx\xfb\xe4XZ\xc5\xe3\x03\xc8\x19\xf1\x9cS'

In [13]:
def server_delete(server, user) -> c_int64:
    buffer = FfiByteBuffer()
    err = FfiError()
    lib_fn = _get_func("allosaurus_server_delete")
    lib_fn(server, _encode_bytes(user), byref(buffer), byref(err))
    if not buffer:
        message = string_at(err.message)
        raise Exception(message)
    buffer = _decode_bytes(buffer)
    return buffer
server_delete(server, user)

b'\xa6\xfc\xe3\xf6N d,w\xf2\xe4\xfd\xc7\xd2\xb9\xc0\x1a\xffn;\xcf:t\x9b\x9b\x1d\xc9\xc6\xbb\xe3\tx3\xbbnx\xfb\xe4XZ\xc5\xe3\x03\xc8\x19\xf1\x9cS'

In [14]:
def server_get_epoch(server) -> int:
    err = FfiError()
    lib_fn = _get_func("allosaurus_server_get_epoch")
    return lib_fn(server, byref(err))
server_get_epoch(server)

3