Skip to content

Commit

Permalink
Merge pull request #12 from JamesHallowell/typed-stream-endpoints
Browse files Browse the repository at this point in the history
Add typed stream endpoints
  • Loading branch information
JamesHallowell committed Mar 28, 2024
2 parents c434c2d + 581612d commit 5bc14a5
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 23 deletions.
4 changes: 2 additions & 2 deletions examples/hello_world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

performer.set_block_size(BLOCK_SIZE);

let output = performer.endpoints().get_handle("out").unwrap();
let mut performer = performer.with_output_stream::<f32>("out")?;

let stream = cpal::default_host()
.default_output_device()
Expand All @@ -89,7 +89,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
},
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
performer.advance();
unsafe { performer.read_stream_unchecked(output, data) };
performer.read_stream(data);
},
|err| eprintln!("an error occurred on stream: {}", err),
None,
Expand Down
5 changes: 4 additions & 1 deletion src/performer/endpoints/input_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ impl Endpoint<InputEvent<Value>> {

#[sealed]
impl PerformerEndpoint for InputEvent<Value> {
fn make(id: &str, performer: &mut Performer) -> Result<Endpoint<Self>, EndpointError> {
fn make<Streams>(
id: &str,
performer: &mut Performer<Streams>,
) -> Result<Endpoint<Self>, EndpointError> {
let (handle, endpoint) = performer
.endpoints
.get_by_id(id)
Expand Down
5 changes: 4 additions & 1 deletion src/performer/endpoints/input_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ impl<T> PerformerEndpoint for InputValue<T>
where
T: 'static,
{
fn make(id: &str, performer: &mut Performer) -> Result<Endpoint<Self>, EndpointError> {
fn make<Streams>(
id: &str,
performer: &mut Performer<Streams>,
) -> Result<Endpoint<Self>, EndpointError> {
let (handle, endpoint) = performer
.endpoints
.get_by_id(id)
Expand Down
1 change: 1 addition & 0 deletions src/performer/endpoints/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod input_event;
mod input_value;
mod output_value;
mod stream;

/// An endpoint.
pub struct Endpoint<T> {
Expand Down
5 changes: 4 additions & 1 deletion src/performer/endpoints/output_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ impl<T> PerformerEndpoint for OutputValue<T>
where
T: 'static,
{
fn make(id: &str, performer: &mut Performer) -> Result<Endpoint<Self>, EndpointError> {
fn make<Streams>(
id: &str,
performer: &mut Performer<Streams>,
) -> Result<Endpoint<Self>, EndpointError> {
let (handle, endpoint) = performer
.endpoints
.get_by_id(id)
Expand Down
204 changes: 204 additions & 0 deletions src/performer/endpoints/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
use {
crate::{
endpoint::{EndpointDirection, EndpointHandle},
performer::{EndpointError, Performer},
value::types::{Primitive, Type},
},
sealed::sealed,
std::marker::PhantomData,
};

impl<Input> Performer<(Input, ())> {
/// Bind an output stream to the performer.
pub fn with_output_stream<T>(
self,
id: impl AsRef<str>,
) -> Result<Performer<(Input, OutputStream<T>)>, EndpointError>
where
T: StreamType,
{
let (handle, endpoint) = self
.endpoints
.get_by_id(id)
.ok_or(EndpointError::EndpointDoesNotExist)?;

if endpoint.direction() != EndpointDirection::Output {
return Err(EndpointError::DirectionMismatch);
}

let stream_endpoint = endpoint
.as_stream()
.ok_or(EndpointError::EndpointTypeMismatch)?;

match stream_endpoint.ty() {
Type::Primitive(primitive) => {
if T::EXTENT != 1 || &T::ELEMENT != primitive {
return Err(EndpointError::DataTypeMismatch);
}
}
Type::Array(array) => {
if T::EXTENT != array.len() || &Type::Primitive(T::ELEMENT) != array.elem_ty() {
return Err(EndpointError::DataTypeMismatch);
}
}
_ => return Err(EndpointError::EndpointTypeMismatch),
}

let Self {
inner,
endpoints,
inputs,
outputs,
cached_input_values,
streams: (input, ()),
} = self;

Ok(Performer {
inner,
endpoints,
inputs,
outputs,
cached_input_values,
streams: (
input,
OutputStream {
handle,
_marker: PhantomData,
},
),
})
}
}

impl<Output> Performer<((), Output)> {
/// Bind an input stream to the performer.
pub fn with_input_stream<T>(
self,
id: impl AsRef<str>,
) -> Result<Performer<(InputStream<T>, Output)>, EndpointError>
where
T: StreamType,
{
let (handle, endpoint) = self
.endpoints
.get_by_id(id)
.ok_or(EndpointError::EndpointDoesNotExist)?;

if endpoint.direction() != EndpointDirection::Input {
return Err(EndpointError::DirectionMismatch);
}

let stream_endpoint = endpoint
.as_stream()
.ok_or(EndpointError::EndpointTypeMismatch)?;

match stream_endpoint.ty() {
Type::Primitive(primitive) => {
if T::EXTENT != 1 || &T::ELEMENT != primitive {
return Err(EndpointError::DataTypeMismatch);
}
}
Type::Array(array) => {
if T::EXTENT != array.len() || &Type::Primitive(T::ELEMENT) != array.elem_ty() {
return Err(EndpointError::DataTypeMismatch);
}
}
_ => return Err(EndpointError::EndpointTypeMismatch),
}

let Self {
inner,
endpoints,
inputs,
outputs,
cached_input_values,
streams: ((), output),
} = self;

Ok(Performer {
inner,
endpoints,
inputs,
outputs,
cached_input_values,
streams: (
InputStream {
handle,
_marker: PhantomData,
},
output,
),
})
}
}

/// An input stream.
pub struct InputStream<T> {
handle: EndpointHandle,
_marker: PhantomData<T>,
}

/// An output stream.
pub struct OutputStream<T> {
handle: EndpointHandle,
_marker: PhantomData<T>,
}

impl<T, Output> Performer<(InputStream<T>, Output)>
where
T: StreamType,
{
/// Write to the performers input stream.
pub fn write_stream(&mut self, frames: &[T]) {
unsafe { self.write_stream_unchecked(self.streams.0.handle, frames) }
}
}

impl<T, Input> Performer<(Input, OutputStream<T>)>
where
T: StreamType,
{
/// Read from the performers output stream.
pub fn read_stream(&mut self, frames: &mut [T]) {
unsafe { self.read_stream_unchecked(self.streams.1.handle, frames) }
}
}

#[sealed]
pub trait StreamType: Copy {
const ELEMENT: Primitive;
const EXTENT: usize;
}

#[sealed]
impl StreamType for i32 {
const ELEMENT: Primitive = Primitive::Int32;
const EXTENT: usize = 1;
}

#[sealed]
impl StreamType for i64 {
const ELEMENT: Primitive = Primitive::Int64;
const EXTENT: usize = 1;
}

#[sealed]
impl StreamType for f32 {
const ELEMENT: Primitive = Primitive::Float32;
const EXTENT: usize = 1;
}

#[sealed]
impl StreamType for f64 {
const ELEMENT: Primitive = Primitive::Float64;
const EXTENT: usize = 1;
}

#[sealed]
impl<T, const EXTENT: usize> StreamType for [T; EXTENT]
where
T: StreamType,
{
const ELEMENT: Primitive = T::ELEMENT;
const EXTENT: usize = EXTENT;
}
19 changes: 13 additions & 6 deletions src/performer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,39 @@ use {
crate::{
endpoint::{EndpointDirection, EndpointHandle, EndpointType, ProgramEndpoints},
ffi::PerformerPtr,
value::{types::IsScalar, ValueRef},
value::ValueRef,
},
endpoints::CachedInputValues,
sealed::sealed,
std::sync::Arc,
};

/// A Cmajor performer.
pub struct Performer {
pub struct Performer<Streams = ((), ())> {
inner: PerformerPtr,
endpoints: Arc<ProgramEndpoints>,
inputs: Vec<EndpointHandler>,
outputs: Vec<EndpointHandler>,
cached_input_values: CachedInputValues,
streams: Streams,
}

pub(crate) type EndpointHandler = Box<dyn FnMut(&mut PerformerPtr) + Send>;

impl Performer {
impl Performer<((), ())> {
pub(crate) fn new(performer: PerformerPtr, endpoints: Arc<ProgramEndpoints>) -> Self {
Performer {
inner: performer,
endpoints: Arc::clone(&endpoints),
inputs: vec![],
outputs: vec![],
cached_input_values: CachedInputValues::default(),
streams: ((), ()),
}
}
}

impl<Streams> Performer<Streams> {
/// Returns an endpoint of the performer.
pub fn endpoint<T>(&mut self, id: impl AsRef<str>) -> Result<Endpoint<T>, EndpointError>
where
Expand Down Expand Up @@ -86,7 +90,7 @@ impl Performer {
/// given slice.
pub unsafe fn read_stream_unchecked<T>(&mut self, handle: EndpointHandle, frames: &mut [T])
where
T: Copy + IsScalar,
T: Copy,
{
self.inner.copy_output_frames(handle, frames);
}
Expand All @@ -102,7 +106,7 @@ impl Performer {
/// given slice.
pub unsafe fn write_stream_unchecked<T>(&mut self, handle: EndpointHandle, frames: &[T])
where
T: Copy + IsScalar,
T: Copy,
{
self.inner.set_input_frames(handle, frames);
}
Expand Down Expand Up @@ -190,7 +194,10 @@ pub enum EndpointError {
#[doc(hidden)]
#[sealed(pub(crate))]
pub trait PerformerEndpoint {
fn make(id: &str, performer: &mut Performer) -> Result<Endpoint<Self>, EndpointError>
fn make<Streams>(
id: &str,
performer: &mut Performer<Streams>,
) -> Result<Endpoint<Self>, EndpointError>
where
Self: Sized;
}
2 changes: 1 addition & 1 deletion src/value/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ impl_is_primitive!(bool, i32, i64, f32, f64);

/// Implemented for scalar types.
#[sealed]
pub trait IsScalar {}
pub trait IsScalar: IsPrimitive {}

macro_rules! impl_is_scalar {
($($ty:ty),*) => {
Expand Down

0 comments on commit 5bc14a5

Please sign in to comment.