Skip to content

Commit

Permalink
Reorganized Python interface and added Jax (#96)
Browse files Browse the repository at this point in the history
* Reorganized Python interface and added Jax
* Allowed string parameter and added tests
* Fixed import
* Improved docs
  • Loading branch information
Jegp committed Jan 8, 2024
1 parent 2498223 commit 2902b01
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 130 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v3
# - uses: Jimver/cuda-toolkit@v0.2.8
# id: cuda-toolkit
# with:
# cuda: '11.7.0'
- name: Install dependencies
run: |
sudo apt install libsdl2-dev
Expand Down
16 changes: 7 additions & 9 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import sys

extensions = [
"sphinx.ext.autodoc", # Imports modules and docs
"sphinx.ext.intersphinx", # Links to external libs docs
"sphinx.ext.napoleon", # Converts docs to rst format
"sphinx.ext.autodoc", # Imports modules and docs
"sphinx.ext.intersphinx", # Links to external libs docs
"sphinx.ext.napoleon", # Converts docs to rst format
"sphinx.ext.autosummary",
"sphinx_copybutton",
"myst_parser",
"sphinx.ext.napoleon"
"sphinx.ext.napoleon",
]

intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
intersphinx_mapping = {"python": ("https://docs.python.org/3", None)}

autosummary_generate = True

Expand All @@ -33,7 +33,7 @@
# The short X.Y version.
version = "0.6"
# The full version, including alpha/beta/rc tags.
release = "0.6.1"
release = "0.6.2"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand All @@ -42,12 +42,10 @@
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"


# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False


html_theme = "furo"
html_theme = 'furo'

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
Expand Down
11 changes: 11 additions & 0 deletions docs/python_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,14 @@ with SpeckInput() as stream:
```

> Example: [Visualize events from the Speck chip](https://github.com/aestream/blob/main/example/speck_video.py)
## Using different backends

AEStream is built with [Nanobind](https://nanobind.readthedocs.io/) that can directly [expose arrays in various memory formats](https://nanobind.readthedocs.io/), including [PyTorch](https://pytorch.org), [Numpy](https://numpy.org), and [Jax](https:/jax.readthedocs.io).
You can directly decide which backend to use by passing a `backend` argument to the `read` function:

```python
with FileInput(...) as stream:
...
stream.read(backend="torch")
```
19 changes: 5 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
[build-system]
requires = [
"scikit-build-core",
"nanobind>=1.2",
]
requires = ["scikit-build-core", "nanobind>=1.2"]
build-backend = "scikit_build_core.build"

[project]
name = "aestream"
version = "0.6.1"
version = "0.6.2"
description = "Streaming library for Address-Event Representation (AER) data"
readme = "README.md"
requires-python = ">=3.8"
authors = [
{ name = "Jens E. Pedersen", email = "jens@jepedersen.dk" },
]
authors = [{ name = "Jens E. Pedersen", email = "jens@jepedersen.dk" }]
license = { text = "MIT" }
classifiers = [
"License :: OSI Approved :: MIT License",
Expand All @@ -22,11 +17,7 @@ classifiers = [
"Topic :: Software Development :: Libraries",
"Topic :: System :: Hardware :: Universal Serial Bus (USB)",
]
dependencies = [
"numpy",
"nanobind>=1.2",
"pysdl2-dll"
]
dependencies = ["numpy", "nanobind>=1.2", "pysdl2-dll"]


[project.urls]
Expand All @@ -43,4 +34,4 @@ build-verbosity = 1

# Needed for full C++17 support
[tool.cibuildwheel.macos.environment]
MACOSX_DEPLOYMENT_TARGET = "10.14"
MACOSX_DEPLOYMENT_TARGET = "10.14"
7 changes: 1 addition & 6 deletions src/aestream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@
"""
import logging

try:
import torch
except ImportError:
logging.debug("Failed to import Torch: AEStream is running in Numpy mode")

# Import AEStream modules
from aestream.aestream_ext import Event
from aestream.aestream_ext import Backend, Event
from aestream._input import FileInput, UDPInput

modules = []
Expand Down
135 changes: 61 additions & 74 deletions src/aestream/_input.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Optional, Union
from aestream import aestream_ext as ext

# Set Numpy Event dtype
Expand All @@ -20,106 +21,92 @@
except ImportError as e:
raise ImportError("Numpy is required but could not be imported", e)

try:
import torch

USE_TORCH = True
except ImportError:
USE_TORCH = False
def _convert_parameter_to_backend(backend: Union[ext.Backend, str]):
if isinstance(backend, ext.Backend):
return backend
elif isinstance(backend, str):
return getattr(ext.Backend, backend)
else:
raise TypeError("backend must be either ext.Backend or str")


def _read_backend(obj: Any, backend: ext.Backend, population: Optional[Any]):
backend = _convert_parameter_to_backend(backend)
if backend == ext.Backend.GeNN:
obj.read_genn(population.extra_global_params["input"].view)
population.push_extra_global_param_to_device("input")
return population.extra_global_params["input"].view
elif backend == ext.Backend.Jax:
t = obj.read_buffer()
return t.to_jax()
elif backend == ext.Backend.Torch:
t = obj.read_buffer()
return t.to_torch()
else:
t = obj.read_buffer()
return t.to_numpy()


class FileInput(ext.FileInput):
"""
Reads events from a file.
Parameters:
filename (str): Path to file.
shape (tuple): Shape of the camera surface in pixels (X, Y).
device (str): Device name. Defaults to "cpu"
ignore_time (bool): Whether to ignore the timestamps for the events when
streaming. If set to True, the events will be streamed as fast as possible.
Defaults to False.
"""

def load(self):
buffer = self.load_all()
return np.frombuffer(buffer.data, NUMPY_EVENT_DTYPE)

def read(self):
t = self.read_buffer()
if USE_TORCH:
return t.to_torch()
else:
return t.to_numpy()

def read_genn(self, population):
# **YUCK** I would like to implement this with a mixin
# to reduce copy-paste but seemingly nanobind doesn't like this
# Read from stream into GeNN-owned memory
super().read_genn(population.extra_global_params["input"].view)

# Copy data to device
# **NOTE** this may be a NOP if CPU backend is used
population.push_extra_global_param_to_device("input")

return population.extra_global_params["input"].view
def read(self, backend: ext.Backend = ext.Backend.Numpy):
return _read_backend(self, backend, None)


class UDPInput(ext.UDPInput):
def read(self):
t = self.read_buffer()
if USE_TORCH:
return t.to_torch()
else:
return t.to_numpy()

def read_genn(self, population):
# **YUCK** I would like to implement this with a mixin
# to reduce copy-paste but seemingly nanobind doesn't like this
# Read from stream into GeNN-owned memory
super().read_genn(population.extra_global_params["input"].view)

# Copy data to device
# **NOTE** this may be a NOP if CPU backend is used
population.push_extra_global_param_to_device("input")
"""
Reads events from a UDP socket.
return population.extra_global_params["input"].view
Parameters:
shape (tuple): Shape of the camera surface in pixels (X, Y).
device (str): Device name. Defaults to "cpu"
port (int): Port to listen on. Defaults to 3333.
"""

def read(self, backend: ext.Backend = ext.Backend.Numpy):
return _read_backend(self, backend, None)


try:

class USBInput(ext.USBInput):
def read(self):
t = self.read_buffer()
if USE_TORCH:
return t.to_torch()
else:
return t.to_numpy()

def read_genn(self, population):
# **YUCK** I would like to implement this with a mixin
# to reduce copy-paste but seemingly nanobind doesn't like this
# Read from stream into GeNN-owned memory
super().read_genn(population.extra_global_params["input"].view)
"""
Reads events from a USB camera.
# Copy data to device
# **NOTE** this may be a NOP if CPU backend is used
population.push_extra_global_param_to_device("input")
Parameters:
shape (tuple): Shape of the camera surface in pixels (X, Y).
device (str): Device name. Defaults to "cpu"
device_id (int): Device ID. Defaults to 0.
device_address (int): Device address, typically on the bus. Defaults to 0.
"""

return population.extra_global_params["input"].view
def read(self, backend: ext.Backend = ext.Backend.Numpy):
return _read_backend(self, backend, None)

except:
pass # Ignore if drivers are not installed

try:

class SpeckInput(ext.SpeckInput):
def read(self):
t = self.read_buffer()
if USE_TORCH:
return t.to_torch()
else:
return t.to_numpy()

def read_genn(self, population):
# **YUCK** I would like to implement this with a mixin
# to reduce copy-paste but seemingly nanobind doesn't like this
# Read from stream into GeNN-owned memory
super().read_genn(population.extra_global_params["input"].view)

# Copy data to device
# **NOTE** this may be a NOP if CPU backend is used
population.push_extra_global_param_to_device("input")

return population.extra_global_params["input"].view
def read(self, backend: ext.Backend = ext.Backend.Numpy):
return _read_backend(self, backend, None)

except Exception as e:
pass # Ignore if Speck/ZMQ isn't installed
5 changes: 1 addition & 4 deletions src/python/file.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "../cpp/aer.hpp"
#include "../cpp/generator.hpp"
#include "../cpp/input/file.hpp"
#include "types.hpp"

#include "tensor_buffer.hpp"
#include "tensor_iterator.hpp"
Expand Down Expand Up @@ -43,10 +44,6 @@ class FileInput {

nb::ndarray<nb::numpy, uint8_t, nb::shape<1, nb::any>> load();

// py::array_t<AER::Event> events_co();

// Generator<py::array_t<AER::Event>> parts_co(size_t n_events_per_part);

FileInput *start_stream();

bool stop_stream(nb::object &a, nb::object &b, nb::object &c);
Expand Down
7 changes: 0 additions & 7 deletions src/python/input.hpp

This file was deleted.

7 changes: 7 additions & 0 deletions src/python/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ NB_MODULE(aestream_ext, m) {
.def_rw("y", &AER::Event::y)
.def_rw("polarity", &AER::Event::polarity);

nb::enum_<Backend>(m, "Backend")
.value("GeNN", Backend::GeNN)
.value("Jax", Backend::Jax)
.value("Numpy", Backend::Numpy)
.value("Torch", Backend::Torch);

nb::class_<BufferPointer>(m, "BufferPointer")
.def("to_jax", &BufferPointer::to_jax)
.def("to_numpy", &BufferPointer::to_numpy)
.def("to_torch", &BufferPointer::to_torch);

Expand Down
30 changes: 19 additions & 11 deletions src/python/tensor_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,8 @@ BufferPointer::BufferPointer(buffer_t data, const std::vector<size_t> &shape,
const std::string &device)
: data(std::move(data)), shape(shape), device(device) {}

tensor_numpy BufferPointer::to_numpy() {
const size_t s[2] = {shape[0], shape[1]};
float *ptr = data.release();
nb::capsule owner(ptr, [](void *p) noexcept { delete[](float *) p; });
return tensor_numpy(ptr, 2, s, owner);
}

tensor_torch BufferPointer::to_torch() {
template <typename tensor_type>
inline tensor_type BufferPointer::to_tensor_type() {
const size_t s[2] = {shape[0], shape[1]};
float *ptr = data.release();
nb::capsule owner;
Expand All @@ -169,7 +163,21 @@ tensor_torch BufferPointer::to_torch() {

int32_t device_type =
device == "cuda" ? nb::device::cuda::value : nb::device::cpu::value;
return tensor_torch(ptr, 2, s, owner, /* owner */
nullptr, /* strides */
nanobind::dtype<float>(), device_type);
return tensor_type(ptr, 2, s, owner, /* owner */
nullptr, /* strides */
nanobind::dtype<float>(), device_type);
}

tensor_numpy BufferPointer::to_numpy() {
// const size_t s[2] = {shape[0], shape[1]};
// float *ptr = data.release();
// nb::capsule owner(ptr, [](void *p) noexcept { delete[](float *) p; });
// return tensor_numpy(ptr, 2, s, owner);
return to_tensor_type<tensor_numpy>();
}

tensor_jax BufferPointer::to_jax() { return to_tensor_type<tensor_jax>(); }

tensor_torch BufferPointer::to_torch() {
return to_tensor_type<tensor_torch>();
}
Loading

0 comments on commit 2902b01

Please sign in to comment.