From 772bda49629614d78214069b12d3c0faf64569a7 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Wed, 27 Mar 2024 12:55:43 +0000 Subject: [PATCH] Do not allow operations on mismatched devices --- Cargo.toml | 2 +- src/data_repr.rs | 36 +++++++++++++++++--- src/data_traits.rs | 11 +++++- src/impl_methods.rs | 14 ++++++-- src/impl_ops.rs | 83 ++++++++++++++++++++++++++------------------- 5 files changed, 104 insertions(+), 42 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 36cceffed..feec26d33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ num-complex = { version = "0.4", default-features = false } # Use via the `opencl` crate feature! hasty_ = { version = "0.2", optional = true, package = "hasty", default-features = false } -#hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false } +# hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false } # Use via the `rayon` crate feature! rayon_ = { version = "1.0.3", optional = true, package = "rayon" } diff --git a/src/data_repr.rs b/src/data_repr.rs index 311e4653b..781fc9350 100644 --- a/src/data_repr.rs +++ b/src/data_repr.rs @@ -53,9 +53,13 @@ impl OwnedRepr { } } + pub(crate) fn device(&self) -> Device { + self.device + } + /// Move this storage object to a specified device. #[allow(clippy::unnecessary_wraps)] - pub(crate) fn copy_to_device(self, device: Device) -> Option { + pub(crate) fn move_to_device(self, device: Device) -> Option { // println!("Copying to {device:?}"); // let mut self_ = ManuallyDrop::new(self); // self_.device = device; @@ -209,7 +213,8 @@ impl OwnedRepr { /// on the host device. pub(crate) fn as_slice(&self) -> &[A] { // Cannot create a slice of a device pointer - assert_eq!(self.device, Device::Host); + debug_assert_eq!(self.device, Device::Host, "Cannot create a slice of a device pointer"); + unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } } @@ -337,8 +342,31 @@ where A: Clone #[cfg(feature = "opencl")] Device::OpenCL => { println!("Performing OpenCL Clone"); - // todo: OpenCL clone - Self::from(self.as_slice().to_owned()) + unsafe { + // Allocate new buffer + let bytes = std::mem::size_of::() * self.len(); + + match hasty_::opencl::opencl_allocate(bytes, hasty_::opencl::OpenCLMemoryType::ReadWrite) { + Ok(buffer_ptr) => { + if let Err(err_code) = + hasty_::opencl::opencl_copy(buffer_ptr, self.as_ptr() as *const std::ffi::c_void, bytes) + { + panic!("Failed to copy to OpenCL buffer. Exited with status: {:?}", err_code); + } + + Self { + ptr: NonNull::new(buffer_ptr as *mut A).unwrap(), + len: self.len, + capacity: self.capacity, + device: self.device, + } + } + + Err(err_code) => { + panic!("Failed to clone OpenCL buffer. Exited with status: {:?}", err_code); + } + } + } } #[cfg(feature = "cuda")] diff --git a/src/data_traits.rs b/src/data_traits.rs index 42b01ed0e..b06163081 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,7 +17,7 @@ use std::mem::MaybeUninit; use std::mem::{self, size_of}; use std::ptr::NonNull; -use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; +use crate::{ArcArray, Array, ArrayBase, CowRepr, Device, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -41,6 +41,11 @@ pub unsafe trait RawData: Sized { #[doc(hidden)] fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool; + #[doc(hidden)] + fn _device(&self) -> Option { + None + } + private_decl! {} } @@ -330,6 +335,10 @@ unsafe impl RawData for OwnedRepr { self_ptr >= ptr && self_ptr <= end } + fn _device(&self) -> Option { + Some(self.device()) + } + private_impl! {} } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index e9f9a01ec..1faa3436b 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2964,6 +2964,12 @@ where f(&*prev, &mut *curr) }); } + + pub fn device(&self) -> Device { + // If a device is returned, use that. Otherwise, it's fairly safe to + // assume that the data is on the host. + self.data._device().unwrap_or(Device::Host) + } } /// Transmute from A to B. @@ -2986,10 +2992,14 @@ type DimMaxOf = >::Output; impl ArrayBase, D> where A: std::fmt::Debug { - pub fn copy_to_device(self, device: Device) -> Option { + // pub fn device(&self) -> Device { + // self.data.device() + // } + + pub fn move_to_device(self, device: Device) -> Option { let dim = self.dim; let strides = self.strides; - let data = self.data.copy_to_device(device)?; + let data = self.data.move_to_device(device)?; let ptr = std::ptr::NonNull::new(data.as_ptr() as *mut A).unwrap(); Some(Self { diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 8d02364d1..9475ebba3 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -50,10 +50,18 @@ impl ScalarOperand for f64 {} impl ScalarOperand for Complex {} impl ScalarOperand for Complex {} +macro_rules! device_check_assert( + ($self:expr, $rhs:expr) => { + debug_assert_eq!($self.device(), $rhs.device(), + "Cannot perform operation on arrays on different devices. \ + Please move them to the same device first."); + } +); + macro_rules! impl_binary_op( - ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => ( + ($rs_trait:ident, $operator:tt, $math_op:ident, $inplace_op:tt, $docstring:expr) => ( /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between `self` and `rhs`, /// and return the result. /// @@ -62,9 +70,9 @@ macro_rules! impl_binary_op( /// If their shapes disagree, `self` is broadcast to their broadcast shape. /// /// **Panics** if broadcasting isn’t possible. -impl $trt> for ArrayBase +impl $rs_trait> for ArrayBase where - A: Clone + $trt, + A: Clone + $rs_trait, B: Clone, S: DataOwned + DataMut, S2: Data, @@ -73,14 +81,15 @@ where { type Output = ArrayBase>::Output>; #[track_caller] - fn $mth(self, rhs: ArrayBase) -> Self::Output + fn $math_op(self, rhs: ArrayBase) -> Self::Output { - self.$mth(&rhs) + device_check_assert!(self, rhs); + self.$math_op(&rhs) } } /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between `self` and reference `rhs`, /// and return the result. /// @@ -90,9 +99,9 @@ where /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. -impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase +impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase> for ArrayBase where - A: Clone + $trt, + A: Clone + $rs_trait, B: Clone, S: DataOwned + DataMut, S2: Data, @@ -101,27 +110,29 @@ where { type Output = ArrayBase>::Output>; #[track_caller] - fn $mth(self, rhs: &ArrayBase) -> Self::Output + fn $math_op(self, rhs: &ArrayBase) -> Self::Output { + device_check_assert!(self, rhs); + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); + out.zip_mut_with_same_shape(rhs, clone_iopf(A::$math_op)); out } else { let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap(); if lhs_view.shape() == self.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth)); + out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$math_op)); out } else { - Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth)) + Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op)) } } } } /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between reference `self` and `rhs`, /// and return the result. /// @@ -131,9 +142,9 @@ where /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. -impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase +impl<'a, A, B, S, S2, D, E> $rs_trait> for &'a ArrayBase where - A: Clone + $trt, + A: Clone + $rs_trait, B: Clone, S: Data, S2: DataOwned + DataMut, @@ -142,28 +153,30 @@ where { type Output = ArrayBase>::Output>; #[track_caller] - fn $mth(self, rhs: ArrayBase) -> Self::Output - where + fn $math_op(self, rhs: ArrayBase) -> Self::Output + // where { + device_check_assert!(self, rhs); + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); + out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$math_op)); out } else { let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap(); if rhs_view.shape() == rhs.shape() { let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); - out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth)); + out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$math_op)); out } else { - Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth)) + Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op)) } } } } /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between references `self` and `rhs`, /// and return the result as a new `Array`. /// @@ -171,9 +184,9 @@ where /// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. -impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase +impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase> for &'a ArrayBase where - A: Clone + $trt, + A: Clone + $rs_trait, B: Clone, S: Data, S2: Data, @@ -182,7 +195,9 @@ where { type Output = Array>::Output>; #[track_caller] - fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + fn $math_op(self, rhs: &'a ArrayBase) -> Self::Output { + device_check_assert!(self, rhs); + let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let lhs = self.view().into_dimensionality::<>::Output>().unwrap(); let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap(); @@ -190,24 +205,24 @@ where } else { self.broadcast_with(rhs).unwrap() }; - Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth)) + Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$math_op)) } } /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between `self` and the scalar `x`, /// and return the result (based on `self`). /// /// `self` must be an `Array` or `ArcArray`. -impl $trt for ArrayBase - where A: Clone + $trt, +impl $rs_trait for ArrayBase + where A: Clone + $rs_trait, S: DataOwned + DataMut, D: Dimension, B: ScalarOperand, { type Output = ArrayBase; - fn $mth(mut self, x: B) -> ArrayBase { + fn $math_op(mut self, x: B) -> ArrayBase { self.map_inplace(move |elt| { *elt = elt.clone() $operator x.clone(); }); @@ -216,17 +231,17 @@ impl $trt for ArrayBase } /// Perform elementwise -#[doc=$doc] +#[doc=$docstring] /// between the reference `self` and the scalar `x`, /// and return the result as a new `Array`. -impl<'a, A, S, D, B> $trt for &'a ArrayBase - where A: Clone + $trt, +impl<'a, A, S, D, B> $rs_trait for &'a ArrayBase + where A: Clone + $rs_trait, S: Data, D: Dimension, B: ScalarOperand, { type Output = Array; - fn $mth(self, x: B) -> Self::Output { + fn $math_op(self, x: B) -> Self::Output { self.map(move |elt| elt.clone() $operator x.clone()) } }