Skip to content

Commit

Permalink
Do not allow operations on mismatched devices
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Mar 27, 2024
1 parent 8dd39d1 commit 772bda4
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ num-complex = { version = "0.4", default-features = false }

# Use via the `opencl` crate feature!
hasty_ = { version = "0.2", optional = true, package = "hasty", default-features = false }
#hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false }
# hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false }

# Use via the `rayon` crate feature!
rayon_ = { version = "1.0.3", optional = true, package = "rayon" }
Expand Down
36 changes: 32 additions & 4 deletions src/data_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ impl<A> OwnedRepr<A> {
}
}

pub(crate) fn device(&self) -> Device {
self.device
}

/// Move this storage object to a specified device.
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn copy_to_device(self, device: Device) -> Option<Self> {
pub(crate) fn move_to_device(self, device: Device) -> Option<Self> {
// println!("Copying to {device:?}");
// let mut self_ = ManuallyDrop::new(self);
// self_.device = device;
Expand Down Expand Up @@ -209,7 +213,8 @@ impl<A> OwnedRepr<A> {
/// on the host device.
pub(crate) fn as_slice(&self) -> &[A] {
// Cannot create a slice of a device pointer
assert_eq!(self.device, Device::Host);
debug_assert_eq!(self.device, Device::Host, "Cannot create a slice of a device pointer");

unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}

Expand Down Expand Up @@ -337,8 +342,31 @@ where A: Clone
#[cfg(feature = "opencl")]
Device::OpenCL => {
println!("Performing OpenCL Clone");
// todo: OpenCL clone
Self::from(self.as_slice().to_owned())
unsafe {
// Allocate new buffer
let bytes = std::mem::size_of::<A>() * self.len();

match hasty_::opencl::opencl_allocate(bytes, hasty_::opencl::OpenCLMemoryType::ReadWrite) {
Ok(buffer_ptr) => {
if let Err(err_code) =
hasty_::opencl::opencl_copy(buffer_ptr, self.as_ptr() as *const std::ffi::c_void, bytes)
{
panic!("Failed to copy to OpenCL buffer. Exited with status: {:?}", err_code);
}

Self {
ptr: NonNull::new(buffer_ptr as *mut A).unwrap(),
len: self.len,
capacity: self.capacity,
device: self.device,
}
}

Err(err_code) => {
panic!("Failed to clone OpenCL buffer. Exited with status: {:?}", err_code);
}
}
}
}

#[cfg(feature = "cuda")]
Expand Down
11 changes: 10 additions & 1 deletion src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::mem::MaybeUninit;
use std::mem::{self, size_of};
use std::ptr::NonNull;

use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};
use crate::{ArcArray, Array, ArrayBase, CowRepr, Device, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};

/// Array representation trait.
///
Expand All @@ -41,6 +41,11 @@ pub unsafe trait RawData: Sized {
#[doc(hidden)]
fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool;

#[doc(hidden)]
fn _device(&self) -> Option<Device> {
None
}

private_decl! {}
}

Expand Down Expand Up @@ -330,6 +335,10 @@ unsafe impl<A> RawData for OwnedRepr<A> {
self_ptr >= ptr && self_ptr <= end
}

fn _device(&self) -> Option<Device> {
Some(self.device())
}

private_impl! {}
}

Expand Down
14 changes: 12 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2964,6 +2964,12 @@ where
f(&*prev, &mut *curr)
});
}

pub fn device(&self) -> Device {
// If a device is returned, use that. Otherwise, it's fairly safe to
// assume that the data is on the host.
self.data._device().unwrap_or(Device::Host)
}
}

/// Transmute from A to B.
Expand All @@ -2986,10 +2992,14 @@ type DimMaxOf<A, B> = <A as DimMax<B>>::Output;
impl<A, D> ArrayBase<OwnedRepr<A>, D>
where A: std::fmt::Debug
{
pub fn copy_to_device(self, device: Device) -> Option<Self> {
// pub fn device(&self) -> Device {
// self.data.device()
// }

pub fn move_to_device(self, device: Device) -> Option<Self> {
let dim = self.dim;
let strides = self.strides;
let data = self.data.copy_to_device(device)?;
let data = self.data.move_to_device(device)?;
let ptr = std::ptr::NonNull::new(data.as_ptr() as *mut A).unwrap();

Some(Self {
Expand Down
83 changes: 49 additions & 34 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@ impl ScalarOperand for f64 {}
impl ScalarOperand for Complex<f32> {}
impl ScalarOperand for Complex<f64> {}

macro_rules! device_check_assert(
($self:expr, $rhs:expr) => {
debug_assert_eq!($self.device(), $rhs.device(),
"Cannot perform operation on arrays on different devices. \
Please move them to the same device first.");
}
);

macro_rules! impl_binary_op(
($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
($rs_trait:ident, $operator:tt, $math_op:ident, $inplace_op:tt, $docstring:expr) => (
/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between `self` and `rhs`,
/// and return the result.
///
Expand All @@ -62,9 +70,9 @@ macro_rules! impl_binary_op(
/// If their shapes disagree, `self` is broadcast to their broadcast shape.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
impl<A, B, S, S2, D, E> $rs_trait<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
A: Clone + $rs_trait<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
Expand All @@ -73,14 +81,15 @@ where
{
type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
#[track_caller]
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
fn $math_op(self, rhs: ArrayBase<S2, E>) -> Self::Output
{
self.$mth(&rhs)
device_check_assert!(self, rhs);
self.$math_op(&rhs)
}
}

/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between `self` and reference `rhs`,
/// and return the result.
///
Expand All @@ -90,9 +99,9 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
A: Clone + $rs_trait<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S2: Data<Elem=B>,
Expand All @@ -101,27 +110,29 @@ where
{
type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
#[track_caller]
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
fn $math_op(self, rhs: &ArrayBase<S2, E>) -> Self::Output
{
device_check_assert!(self, rhs);

if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$math_op));
out
} else {
let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
if lhs_view.shape() == self.shape() {
let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$math_op));
out
} else {
Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op))
}
}
}
}

/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between reference `self` and `rhs`,
/// and return the result.
///
Expand All @@ -131,9 +142,9 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
impl<'a, A, B, S, S2, D, E> $rs_trait<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=B>,
A: Clone + $rs_trait<B, Output=B>,
B: Clone,
S: Data<Elem=A>,
S2: DataOwned<Elem=B> + DataMut,
Expand All @@ -142,38 +153,40 @@ where
{
type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
#[track_caller]
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
where
fn $math_op(self, rhs: ArrayBase<S2, E>) -> Self::Output
// where
{
device_check_assert!(self, rhs);

if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$math_op));
out
} else {
let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
if rhs_view.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$math_op));
out
} else {
Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op))
}
}
}
}

/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
A: Clone + $rs_trait<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
Expand All @@ -182,32 +195,34 @@ where
{
type Output = Array<A, <D as DimMax<E>>::Output>;
#[track_caller]
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
fn $math_op(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
device_check_assert!(self, rhs);

let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
(lhs, rhs)
} else {
self.broadcast_with(rhs).unwrap()
};
Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$math_op))
}
}

/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between `self` and the scalar `x`,
/// and return the result (based on `self`).
///
/// `self` must be an `Array` or `ArcArray`.
impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
where A: Clone + $trt<B, Output=A>,
impl<A, S, D, B> $rs_trait<B> for ArrayBase<S, D>
where A: Clone + $rs_trait<B, Output=A>,
S: DataOwned<Elem=A> + DataMut,
D: Dimension,
B: ScalarOperand,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
fn $math_op(mut self, x: B) -> ArrayBase<S, D> {
self.map_inplace(move |elt| {
*elt = elt.clone() $operator x.clone();
});
Expand All @@ -216,17 +231,17 @@ impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
}

/// Perform elementwise
#[doc=$doc]
#[doc=$docstring]
/// between the reference `self` and the scalar `x`,
/// and return the result as a new `Array`.
impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
where A: Clone + $trt<B, Output=A>,
impl<'a, A, S, D, B> $rs_trait<B> for &'a ArrayBase<S, D>
where A: Clone + $rs_trait<B, Output=A>,
S: Data<Elem=A>,
D: Dimension,
B: ScalarOperand,
{
type Output = Array<A, D>;
fn $mth(self, x: B) -> Self::Output {
fn $math_op(self, x: B) -> Self::Output {
self.map(move |elt| elt.clone() $operator x.clone())
}
}
Expand Down

0 comments on commit 772bda4

Please sign in to comment.