Skip to content

Commit

Permalink
Merge pull request #11 from JamesHallowell/typesafe-endpoints
Browse files Browse the repository at this point in the history
Add type safety to endpoints
  • Loading branch information
JamesHallowell committed Mar 28, 2024
2 parents 730c488 + 61dae51 commit c434c2d
Show file tree
Hide file tree
Showing 17 changed files with 1,399 additions and 295 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
*.dylib
/cmaj
.env
cmake-build-*
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ static = ["dep:cmake"]
bytes = "1.5.0"
dotenvy = "0.15.7"
libloading = "0.8.0"
real-time = { git = "https://github.com/JamesHallowell/real-time", branch = "master" }
sealed = "0.5"
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
smallvec = "1.11.1"
smallvec = { version = "1.11.1", features = ["serde"] }
thiserror = "1.0.48"

[dev-dependencies]
Expand Down
63 changes: 43 additions & 20 deletions src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl PartialEq<str> for EndpointId {

/// A handle used to reference an endpoint.
#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
pub struct EndpointHandle(u32);
pub struct EndpointHandle(pub(crate) u32);

impl From<u32> for EndpointHandle {
fn from(handle: u32) -> Self {
Expand All @@ -49,7 +49,7 @@ impl From<EndpointHandle> for u32 {

/// An endpoint.
#[derive(Debug)]
pub enum Endpoint {
pub enum EndpointType {
/// A stream endpoint.
Stream(StreamEndpoint),

Expand Down Expand Up @@ -79,7 +79,7 @@ pub struct StreamEndpoint {
annotation: Annotation,
}

impl From<StreamEndpoint> for Endpoint {
impl From<StreamEndpoint> for EndpointType {
fn from(endpoint: StreamEndpoint) -> Self {
Self::Stream(endpoint)
}
Expand All @@ -94,7 +94,7 @@ pub struct EventEndpoint {
annotation: Annotation,
}

impl From<EventEndpoint> for Endpoint {
impl From<EventEndpoint> for EndpointType {
fn from(endpoint: EventEndpoint) -> Self {
Self::Event(endpoint)
}
Expand All @@ -109,13 +109,13 @@ pub struct ValueEndpoint {
annotation: Annotation,
}

impl From<ValueEndpoint> for Endpoint {
impl From<ValueEndpoint> for EndpointType {
fn from(endpoint: ValueEndpoint) -> Self {
Self::Value(endpoint)
}
}

impl Endpoint {
impl EndpointType {
/// The endpoint's identifier (or name).
pub fn id(&self) -> &EndpointId {
match self {
Expand All @@ -142,6 +142,30 @@ impl Endpoint {
Self::Value(endpoint) => &endpoint.annotation,
}
}

/// Get the endpoint as a value endpoint.
pub fn as_stream(&self) -> Option<&StreamEndpoint> {
match self {
Self::Stream(endpoint) => Some(endpoint),
_ => None,
}
}

/// Get the endpoint as an event endpoint.
pub fn as_event(&self) -> Option<&EventEndpoint> {
match self {
Self::Event(endpoint) => Some(endpoint),
_ => None,
}
}

/// Get the endpoint as a value endpoint.
pub fn as_value(&self) -> Option<&ValueEndpoint> {
match self {
Self::Value(endpoint) => Some(endpoint),
_ => None,
}
}
}

impl ValueEndpoint {
Expand Down Expand Up @@ -257,41 +281,40 @@ impl EventEndpoint {
self.ty
.iter()
.position(|t| t.as_ref() == ty)
.map(|index| index as u32)
.map(EndpointTypeIndex::from)
}

/// The type at the given index in the endpoint's type list.
pub fn get_type(&self, index: EndpointTypeIndex) -> Option<&Type> {
self.ty.get(index.0 as usize)
self.ty.get(usize::from(index))
}
}

/// An index into an event endpoint's type list.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct EndpointTypeIndex(u32);
pub struct EndpointTypeIndex(usize);

impl From<u32> for EndpointTypeIndex {
fn from(index: u32) -> Self {
impl From<usize> for EndpointTypeIndex {
fn from(index: usize) -> Self {
Self(index)
}
}

impl From<EndpointTypeIndex> for u32 {
impl From<EndpointTypeIndex> for usize {
fn from(index: EndpointTypeIndex) -> Self {
index.0
}
}

/// A collection of endpoints.
#[derive(Debug)]
pub struct Endpoints {
endpoints: HashMap<EndpointHandle, Endpoint>,
pub struct ProgramEndpoints {
endpoints: HashMap<EndpointHandle, EndpointType>,
ids: HashMap<EndpointId, EndpointHandle>,
}

impl Endpoints {
pub(crate) fn new(endpoints: impl IntoIterator<Item = (EndpointHandle, Endpoint)>) -> Self {
impl ProgramEndpoints {
pub(crate) fn new(endpoints: impl IntoIterator<Item = (EndpointHandle, EndpointType)>) -> Self {
let endpoints: HashMap<_, _> = endpoints.into_iter().collect();
let ids = endpoints
.iter()
Expand All @@ -302,26 +325,26 @@ impl Endpoints {
}

/// Get an iterator over the input endpoints.
pub fn inputs(&self) -> impl Iterator<Item = &Endpoint> {
pub fn inputs(&self) -> impl Iterator<Item = &EndpointType> {
self.endpoints
.values()
.filter(|endpoint| endpoint.direction() == EndpointDirection::Input)
}

/// Get an interator over the output endpoints.
pub fn outputs(&self) -> impl Iterator<Item = &Endpoint> {
pub fn outputs(&self) -> impl Iterator<Item = &EndpointType> {
self.endpoints
.values()
.filter(|endpoint| endpoint.direction() == EndpointDirection::Output)
}

/// Get an endpoint by its handle.
pub fn get(&self, handle: EndpointHandle) -> Option<&Endpoint> {
pub fn get(&self, handle: EndpointHandle) -> Option<&EndpointType> {
self.endpoints.get(&handle)
}

/// Get an endpoint by its ID.
pub fn get_by_id(&self, id: impl AsRef<str>) -> Option<(EndpointHandle, &Endpoint)> {
pub fn get_by_id(&self, id: impl AsRef<str>) -> Option<(EndpointHandle, &EndpointType)> {
let handle = self.ids.get(id.as_ref()).copied()?;
self.endpoints
.get(&handle)
Expand Down
6 changes: 3 additions & 3 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod program_details;
pub use annotation::Annotation;
use {
crate::{
endpoint::{EndpointHandle, Endpoints},
endpoint::{EndpointHandle, ProgramEndpoints},
engine::program_details::ProgramDetails,
ffi::EnginePtr,
performer::Performer,
Expand Down Expand Up @@ -131,7 +131,7 @@ pub struct Loaded;
#[doc(hidden)]
#[derive(Debug)]
pub struct Linked {
endpoints: Arc<Endpoints>,
endpoints: Arc<ProgramEndpoints>,
}

impl Engine<Idle> {
Expand Down Expand Up @@ -188,7 +188,7 @@ impl Engine<Loaded> {
.map(|handle| (handle, endpoint))
});

let endpoints = Endpoints::new(endpoints);
let endpoints = ProgramEndpoints::new(endpoints);

match self.inner.link() {
Ok(_) => {
Expand Down
21 changes: 11 additions & 10 deletions src/engine/program_details.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use {
crate::{
endpoint::{
Endpoint, EndpointDirection, EndpointId, EventEndpoint, StreamEndpoint, ValueEndpoint,
EndpointDirection, EndpointId, EndpointType, EventEndpoint, StreamEndpoint,
ValueEndpoint,
},
engine::program_details::ParseEndpointError::UnsupportedType,
value::types::{Array, Object, Primitive, Type},
Expand All @@ -23,7 +24,7 @@ pub struct ProgramDetails {
}

impl ProgramDetails {
pub fn endpoints(&self) -> impl Iterator<Item = Endpoint> + '_ {
pub fn endpoints(&self) -> impl Iterator<Item = EndpointType> + '_ {
let inputs = self.inputs.iter().zip(repeat(EndpointDirection::Input));
let outputs = self.outputs.iter().zip(repeat(EndpointDirection::Output));

Expand All @@ -45,7 +46,7 @@ struct EndpointDetails {
id: EndpointId,

#[serde(rename = "endpointType")]
endpoint_type: EndpointType,
endpoint_type: EndpointVariant,

#[serde(
rename = "dataType",
Expand All @@ -62,7 +63,7 @@ struct EndpointDetails {
}

#[derive(Debug, Copy, Clone, Deserialize, PartialEq)]
enum EndpointType {
enum EndpointVariant {
#[serde(rename = "stream")]
Stream,

Expand Down Expand Up @@ -201,21 +202,21 @@ fn try_make_endpoint(
..
}: &EndpointDetails,
direction: EndpointDirection,
) -> Result<Endpoint, ParseEndpointError> {
) -> Result<EndpointType, ParseEndpointError> {
let annotation = annotation.clone().unwrap_or_default().into();

Ok(match endpoint_type {
EndpointType::Stream => {
EndpointVariant::Stream => {
if value_type.len() != 1 {
return Err(ParseEndpointError::UnexpectedNumberOfTypes);
}

StreamEndpoint::new(id.clone(), direction, value_type[0].clone(), annotation).into()
}
EndpointType::Event => {
EndpointVariant::Event => {
EventEndpoint::new(id.clone(), direction, value_type.clone(), annotation).into()
}
EndpointType::Value => {
EndpointVariant::Value => {
if value_type.len() != 1 {
return Err(ParseEndpointError::UnexpectedNumberOfTypes);
}
Expand Down Expand Up @@ -286,7 +287,7 @@ mod test {
let details: EndpointDetails = serde_json::from_str(json).unwrap();

assert_eq!(details.id.as_ref(), "out");
assert_eq!(details.endpoint_type, EndpointType::Stream);
assert_eq!(details.endpoint_type, EndpointVariant::Stream);
assert_eq!(
details.value_type,
vec![Type::Primitive(Primitive::Float32)]
Expand All @@ -313,7 +314,7 @@ mod test {
let details: EndpointDetails = serde_json::from_str(json).unwrap();

assert_eq!(details.id.as_ref(), "out");
assert_eq!(details.endpoint_type, EndpointType::Event);
assert_eq!(details.endpoint_type, EndpointVariant::Event);
assert_eq!(
details.value_type,
vec![
Expand Down
4 changes: 2 additions & 2 deletions src/ffi/performer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl PerformerPtr {
((*(*self.performer).vtable).add_input_event)(
self.performer,
handle.into(),
type_index.into(),
usize::from(type_index) as u32,
data_ptr,
)
};
Expand Down Expand Up @@ -158,7 +158,7 @@ impl PerformerPtr {
(*callback)(
frame_offset as usize,
endpoint.into(),
type_index.into(),
(type_index as usize).into(),
data,
);
});
Expand Down
27 changes: 27 additions & 0 deletions src/performer/atomic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};

#[derive(Default)]
pub struct AtomicF32(AtomicU32);

impl AtomicF32 {
pub fn load(&self, order: Ordering) -> f32 {
f32::from_bits(self.0.load(order))
}

pub fn store(&self, value: f32, order: Ordering) {
self.0.store(value.to_bits(), order);
}
}

#[derive(Default)]
pub struct AtomicF64(AtomicU64);

impl AtomicF64 {
pub fn load(&self, order: Ordering) -> f64 {
f64::from_bits(self.0.load(order))
}

pub fn store(&self, value: f64, order: Ordering) {
self.0.store(value.to_bits(), order);
}
}

0 comments on commit c434c2d

Please sign in to comment.