Skip to content

Commit

Permalink
Merge pull request #9 from JamesHallowell/type-rework
Browse files Browse the repository at this point in the history
Add separate primitive enum
  • Loading branch information
JamesHallowell committed Mar 16, 2024
2 parents 63d61bc + 02f8f12 commit 00937df
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 86 deletions.
27 changes: 18 additions & 9 deletions src/engine/program_details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use {
Endpoint, EndpointDirection, EndpointId, EventEndpoint, StreamEndpoint, ValueEndpoint,
},
engine::program_details::ParseEndpointError::UnsupportedType,
value::types::{Array, Object, Type},
value::types::{Array, Object, Primitive, Type},
},
serde::{
de::{value::MapAccessDeserializer, Visitor},
Expand Down Expand Up @@ -161,12 +161,12 @@ impl TryFrom<&EndpointDataType> for Type {
}: &EndpointDataType,
) -> Result<Self, Self::Error> {
match *ty {
ValueType::Void => Ok(Type::Void),
ValueType::Bool => Ok(Type::Bool),
ValueType::Int32 => Ok(Type::Int32),
ValueType::Int64 => Ok(Type::Int64),
ValueType::Float32 => Ok(Type::Float32),
ValueType::Float64 => Ok(Type::Float64),
ValueType::Void => Ok(Type::Primitive(Primitive::Void)),
ValueType::Bool => Ok(Type::Primitive(Primitive::Bool)),
ValueType::Int32 => Ok(Type::Primitive(Primitive::Int32)),
ValueType::Int64 => Ok(Type::Primitive(Primitive::Int64)),
ValueType::Float32 => Ok(Type::Primitive(Primitive::Float32)),
ValueType::Float64 => Ok(Type::Primitive(Primitive::Float64)),
ValueType::Object => {
let mut object = Object::new();
for (name, value) in members.as_ref().ok_or(Self::Error::StructHasNoMembers)? {
Expand Down Expand Up @@ -287,7 +287,10 @@ mod test {

assert_eq!(details.id.as_ref(), "out");
assert_eq!(details.endpoint_type, EndpointType::Stream);
assert_eq!(details.value_type, vec![Type::Float32]);
assert_eq!(
details.value_type,
vec![Type::Primitive(Primitive::Float32)]
);
}

#[test]
Expand All @@ -311,6 +314,12 @@ mod test {

assert_eq!(details.id.as_ref(), "out");
assert_eq!(details.endpoint_type, EndpointType::Event);
assert_eq!(details.value_type, vec![Type::Float32, Type::Int32]);
assert_eq!(
details.value_type,
vec![
Type::Primitive(Primitive::Float32),
Type::Primitive(Primitive::Int32)
]
);
}
}
6 changes: 3 additions & 3 deletions src/performer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use {
crate::{
endpoint::{Endpoint, EndpointDirection, EndpointHandle, Endpoints},
ffi::PerformerPtr,
value::{Value, ValueRef},
value::{types::IsScalar, Value, ValueRef},
},
std::sync::Arc,
};
Expand Down Expand Up @@ -115,7 +115,7 @@ impl Performer {
/// given slice.
pub unsafe fn read_stream_unchecked<T>(&mut self, handle: EndpointHandle, frames: &mut [T])
where
T: Copy,
T: Copy + IsScalar,
{
self.inner.copy_output_frames(handle, frames);
}
Expand All @@ -131,7 +131,7 @@ impl Performer {
/// given slice.
pub unsafe fn write_stream_unchecked<T>(&mut self, handle: EndpointHandle, frames: &[T])
where
T: Copy,
T: Copy + IsScalar,
{
self.inner.set_input_frames(handle, frames);
}
Expand Down
86 changes: 47 additions & 39 deletions src/value/types.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
//! Types of Cmajor values.

use smallvec::SmallVec;
use {crate::value::types::sealed::Sealed, smallvec::SmallVec};

/// The type of a Cmajor value.
/// A Cmajor type.
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
/// A void type.
Void,

/// A boolean type.
Bool,

/// A 32-bit signed integer type.
Int32,

/// A 64-bit signed integer type.
Int64,

/// A 32-bit floating-point type.
Float32,

/// A 64-bit floating-point type.
Float64,
/// A primitive type.
Primitive(Primitive),

/// An array type.
Array(Box<Array>),
Expand All @@ -30,9 +15,9 @@ pub enum Type {
Object(Box<Object>),
}

/// A reference to a Cmajor [`Type`].
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum TypeRef<'a> {
/// A Cmajor primitive.
pub enum Primitive {
/// A void type.
Void,

Expand All @@ -50,6 +35,13 @@ pub enum TypeRef<'a> {

/// A 64-bit floating-point type.
Float64,
}

/// A reference to a Cmajor [`Type`].
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum TypeRef<'a> {
/// A primitive type.
Primitive(Primitive),

/// An array type.
Array(&'a Array),
Expand Down Expand Up @@ -87,12 +79,7 @@ impl Type {
/// Get a reference to the type.
pub fn as_ref(&self) -> TypeRef<'_> {
match self {
Type::Void => TypeRef::Void,
Type::Bool => TypeRef::Bool,
Type::Int32 => TypeRef::Int32,
Type::Int64 => TypeRef::Int64,
Type::Float32 => TypeRef::Float32,
Type::Float64 => TypeRef::Float64,
Type::Primitive(primitive) => TypeRef::Primitive(*primitive),
Type::Array(array) => TypeRef::Array(array.as_ref()),
Type::Object(object) => TypeRef::Object(object.as_ref()),
}
Expand All @@ -103,12 +90,12 @@ impl TypeRef<'_> {
/// The size of the type in bytes.
pub fn size(&self) -> usize {
match self {
TypeRef::Void => 0,
TypeRef::Bool => 4,
TypeRef::Int32 => 4,
TypeRef::Int64 => 8,
TypeRef::Float32 => 4,
TypeRef::Float64 => 8,
TypeRef::Primitive(Primitive::Void) => 0,
TypeRef::Primitive(Primitive::Bool) => 4,
TypeRef::Primitive(Primitive::Int32) => 4,
TypeRef::Primitive(Primitive::Int64) => 8,
TypeRef::Primitive(Primitive::Float32) => 4,
TypeRef::Primitive(Primitive::Float64) => 8,
TypeRef::Array(array) => array.size(),
TypeRef::Object(object) => object.size(),
}
Expand All @@ -117,12 +104,7 @@ impl TypeRef<'_> {
/// Convert the type reference into an owned [`Type`].
pub fn to_owned(&self) -> Type {
match *self {
TypeRef::Void => Type::Void,
TypeRef::Bool => Type::Bool,
TypeRef::Int32 => Type::Int32,
TypeRef::Int64 => Type::Int64,
TypeRef::Float32 => Type::Float32,
TypeRef::Float64 => Type::Float64,
TypeRef::Primitive(primitive) => Type::Primitive(primitive),
TypeRef::Array(array) => Type::Array(Box::new(array.clone())),
TypeRef::Object(object) => Type::Object(Box::new(object.clone())),
}
Expand Down Expand Up @@ -213,3 +195,29 @@ impl Field {
&self.ty
}
}

/// Implemented for primitive types.
pub trait IsPrimitive: Sealed {}

impl IsPrimitive for bool {}
impl IsPrimitive for i32 {}
impl IsPrimitive for i64 {}
impl IsPrimitive for f32 {}
impl IsPrimitive for f64 {}

/// Implemented for scalar types.
pub trait IsScalar: Sealed {}

impl IsScalar for i32 {}
impl IsScalar for i64 {}
impl IsScalar for f32 {}
impl IsScalar for f64 {}

mod sealed {
pub trait Sealed {}
impl Sealed for bool {}
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
}
63 changes: 33 additions & 30 deletions src/value/values.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use {
crate::value::types::{Array, Object, Type, TypeRef},
crate::value::types::{Array, Object, Primitive, Type, TypeRef},
bytes::Buf,
smallvec::SmallVec,
};
Expand Down Expand Up @@ -92,12 +92,12 @@ impl Value {
/// Get the type of the value.
pub fn ty(&self) -> TypeRef<'_> {
match self {
Self::Void => TypeRef::Void,
Self::Bool(_) => TypeRef::Bool,
Self::Int32(_) => TypeRef::Int32,
Self::Int64(_) => TypeRef::Int64,
Self::Float32(_) => TypeRef::Float32,
Self::Float64(_) => TypeRef::Float64,
Self::Void => TypeRef::Primitive(Primitive::Void),
Self::Bool(_) => TypeRef::Primitive(Primitive::Bool),
Self::Int32(_) => TypeRef::Primitive(Primitive::Int32),
Self::Int64(_) => TypeRef::Primitive(Primitive::Int64),
Self::Float32(_) => TypeRef::Primitive(Primitive::Float32),
Self::Float64(_) => TypeRef::Primitive(Primitive::Float64),
Self::Array(array) => TypeRef::Array(&array.ty),
Self::Object(object) => TypeRef::Object(&object.ty),
}
Expand Down Expand Up @@ -128,12 +128,12 @@ impl<'a> ValueRef<'a> {
'b: 'a,
{
match ty {
TypeRef::Void => Self::Void,
TypeRef::Bool => Self::Bool(data.get_u32_ne() != 0),
TypeRef::Int32 => Self::Int32(data.get_i32_ne()),
TypeRef::Int64 => Self::Int64(data.get_i64_ne()),
TypeRef::Float32 => Self::Float32(data.get_f32_ne()),
TypeRef::Float64 => Self::Float64(data.get_f64_ne()),
TypeRef::Primitive(Primitive::Void) => Self::Void,
TypeRef::Primitive(Primitive::Bool) => Self::Bool(data.get_u32_ne() != 0),
TypeRef::Primitive(Primitive::Int32) => Self::Int32(data.get_i32_ne()),
TypeRef::Primitive(Primitive::Int64) => Self::Int64(data.get_i64_ne()),
TypeRef::Primitive(Primitive::Float32) => Self::Float32(data.get_f32_ne()),
TypeRef::Primitive(Primitive::Float64) => Self::Float64(data.get_f64_ne()),
TypeRef::Array(array) => Self::Array(ArrayValueRef::new_from_slice(array, data)),
TypeRef::Object(object) => Self::Object(ObjectValueRef::new_from_slice(object, data)),
}
Expand All @@ -158,12 +158,12 @@ impl<'a> ValueRef<'a> {
/// Get the type of the value.
pub fn ty(&self) -> TypeRef<'_> {
match self {
Self::Void => TypeRef::Void,
Self::Bool(_) => TypeRef::Bool,
Self::Int32(_) => TypeRef::Int32,
Self::Int64(_) => TypeRef::Int64,
Self::Float32(_) => TypeRef::Float32,
Self::Float64(_) => TypeRef::Float64,
Self::Void => TypeRef::Primitive(Primitive::Void),
Self::Bool(_) => TypeRef::Primitive(Primitive::Bool),
Self::Int32(_) => TypeRef::Primitive(Primitive::Int32),
Self::Int64(_) => TypeRef::Primitive(Primitive::Int64),
Self::Float32(_) => TypeRef::Primitive(Primitive::Float32),
Self::Float64(_) => TypeRef::Primitive(Primitive::Float64),
Self::Array(array) => TypeRef::Array(array.ty),
Self::Object(object) => TypeRef::Object(object.ty),
}
Expand Down Expand Up @@ -264,11 +264,11 @@ impl<'a> ArrayValueRef<'a> {
/// # Example
///
/// ```
/// # use cmajor::value::{ArrayValue, types::Type};
/// # use cmajor::value::{ArrayValue, types::{Type, Primitive}};
/// let array: ArrayValue = [1, 2, 3].into();
/// let array_ref = array.as_ref();
///
/// assert_eq!(array_ref.elem_ty(), &Type::Int32);
/// assert_eq!(array_ref.elem_ty(), &Type::Primitive(Primitive::Int32));
pub fn elem_ty(&self) -> &Type {
self.ty.elem_ty()
}
Expand Down Expand Up @@ -427,8 +427,8 @@ pub struct Complex64 {
impl From<Complex32> for Value {
fn from(value: Complex32) -> Self {
let object = Object::new()
.with_field("imag", Type::Float32)
.with_field("real", Type::Float32);
.with_field("imag", Type::Primitive(Primitive::Float32))
.with_field("real", Type::Primitive(Primitive::Float32));

let mut data = SmallVec::new();
data.extend_from_slice(&value.imag.to_ne_bytes());
Expand Down Expand Up @@ -457,8 +457,8 @@ impl TryFrom<ValueRef<'_>> for Complex32 {
impl From<Complex64> for Value {
fn from(value: Complex64) -> Self {
let object = Object::new()
.with_field("imag", Type::Float64)
.with_field("real", Type::Float64);
.with_field("imag", Type::Primitive(Primitive::Float64))
.with_field("real", Type::Primitive(Primitive::Float64));

let mut data = SmallVec::new();
data.extend_from_slice(&value.imag.to_ne_bytes());
Expand Down Expand Up @@ -564,7 +564,7 @@ mod test {

#[test]
fn array_as_value() {
let array: Type = Array::new(Type::Int32, 3).into();
let array: Type = Array::new(Type::Primitive(Primitive::Int32), 3).into();
assert_eq!(array.size(), 12);

let values = [5, 6, 7];
Expand All @@ -586,7 +586,7 @@ mod test {

#[test]
fn multi_dimensional_array_as_value() {
let array: Type = Array::new(Array::new(Type::Int32, 3), 2).into();
let array: Type = Array::new(Array::new(Type::Primitive(Primitive::Int32), 3), 2).into();
assert_eq!(array.size(), 24);

let multi_dimensional_array = [[5, 6, 7], [8, 9, 10]];
Expand Down Expand Up @@ -623,9 +623,12 @@ mod test {
#[test]
fn object_as_value() {
let ty = Object::new()
.with_field("a", Type::Int32)
.with_field("b", Type::Int64)
.with_field("c", Object::new().with_field("d", Type::Bool));
.with_field("a", Type::Primitive(Primitive::Int32))
.with_field("b", Type::Primitive(Primitive::Int64))
.with_field(
"c",
Object::new().with_field("d", Type::Primitive(Primitive::Bool)),
);

let mut data = Vec::new();
data.extend_from_slice(&5_i32.to_ne_bytes());
Expand Down
12 changes: 7 additions & 5 deletions tests/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cmajor::{
endpoint::Endpoint,
performer::{EndpointError, Performer},
value::{
types::{Object, Type},
types::{Object, Primitive, Type},
Complex32, Complex64, ValueRef,
},
Cmajor,
Expand Down Expand Up @@ -423,15 +423,15 @@ fn can_query_endpoint_information() {
};

assert_eq!(a.id(), "a");
assert_eq!(a.ty(), &Type::Int32);
assert_eq!(a.ty(), &Type::Primitive(Primitive::Int32));

let b = match performer.endpoints().get_by_id("b").unwrap() {
(_, Endpoint::Value(endpoint)) => endpoint,
_ => panic!("expected a value"),
};

assert_eq!(b.id(), "b");
assert_eq!(b.ty(), &Type::Float32);
assert_eq!(b.ty(), &Type::Primitive(Primitive::Float32));

let c = match performer.endpoints().get_by_id("c").unwrap() {
(_, Endpoint::Event(endpoint)) => endpoint,
Expand All @@ -442,8 +442,10 @@ fn can_query_endpoint_information() {
assert_eq!(
c.types(),
vec![
Type::Int32,
Object::new().with_field("d", Type::Bool).into()
Type::Primitive(Primitive::Int32),
Object::new()
.with_field("d", Type::Primitive(Primitive::Bool))
.into()
]
);
}
Expand Down

0 comments on commit 00937df

Please sign in to comment.