diff --git a/Cargo.toml b/Cargo.toml
index 36cceffed..feec26d33 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -34,7 +34,7 @@ num-complex = { version = "0.4", default-features = false }
# Use via the `opencl` crate feature!
hasty_ = { version = "0.2", optional = true, package = "hasty", default-features = false }
-#hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false }
+# hasty_ = { path = "../../hasty_dev/hasty", optional = true, package = "hasty", default-features = false }
# Use via the `rayon` crate feature!
rayon_ = { version = "1.0.3", optional = true, package = "rayon" }
diff --git a/src/data_repr.rs b/src/data_repr.rs
index 311e4653b..781fc9350 100644
--- a/src/data_repr.rs
+++ b/src/data_repr.rs
@@ -53,9 +53,13 @@ impl OwnedRepr {
}
}
+ pub(crate) fn device(&self) -> Device {
+ self.device
+ }
+
/// Move this storage object to a specified device.
#[allow(clippy::unnecessary_wraps)]
- pub(crate) fn copy_to_device(self, device: Device) -> Option {
+ pub(crate) fn move_to_device(self, device: Device) -> Option {
// println!("Copying to {device:?}");
// let mut self_ = ManuallyDrop::new(self);
// self_.device = device;
@@ -209,7 +213,8 @@ impl OwnedRepr {
/// on the host device.
pub(crate) fn as_slice(&self) -> &[A] {
// Cannot create a slice of a device pointer
- assert_eq!(self.device, Device::Host);
+ debug_assert_eq!(self.device, Device::Host, "Cannot create a slice of a device pointer");
+
unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
@@ -337,8 +342,31 @@ where A: Clone
#[cfg(feature = "opencl")]
Device::OpenCL => {
println!("Performing OpenCL Clone");
- // todo: OpenCL clone
- Self::from(self.as_slice().to_owned())
+ unsafe {
+ // Allocate new buffer
+ let bytes = std::mem::size_of::() * self.len();
+
+ match hasty_::opencl::opencl_allocate(bytes, hasty_::opencl::OpenCLMemoryType::ReadWrite) {
+ Ok(buffer_ptr) => {
+ if let Err(err_code) =
+ hasty_::opencl::opencl_copy(buffer_ptr, self.as_ptr() as *const std::ffi::c_void, bytes)
+ {
+ panic!("Failed to copy to OpenCL buffer. Exited with status: {:?}", err_code);
+ }
+
+ Self {
+ ptr: NonNull::new(buffer_ptr as *mut A).unwrap(),
+ len: self.len,
+ capacity: self.capacity,
+ device: self.device,
+ }
+ }
+
+ Err(err_code) => {
+ panic!("Failed to clone OpenCL buffer. Exited with status: {:?}", err_code);
+ }
+ }
+ }
}
#[cfg(feature = "cuda")]
diff --git a/src/data_traits.rs b/src/data_traits.rs
index 42b01ed0e..b06163081 100644
--- a/src/data_traits.rs
+++ b/src/data_traits.rs
@@ -17,7 +17,7 @@ use std::mem::MaybeUninit;
use std::mem::{self, size_of};
use std::ptr::NonNull;
-use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};
+use crate::{ArcArray, Array, ArrayBase, CowRepr, Device, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr};
/// Array representation trait.
///
@@ -41,6 +41,11 @@ pub unsafe trait RawData: Sized {
#[doc(hidden)]
fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool;
+ #[doc(hidden)]
+ fn _device(&self) -> Option {
+ None
+ }
+
private_decl! {}
}
@@ -330,6 +335,10 @@ unsafe impl RawData for OwnedRepr {
self_ptr >= ptr && self_ptr <= end
}
+ fn _device(&self) -> Option {
+ Some(self.device())
+ }
+
private_impl! {}
}
diff --git a/src/impl_methods.rs b/src/impl_methods.rs
index e9f9a01ec..1faa3436b 100644
--- a/src/impl_methods.rs
+++ b/src/impl_methods.rs
@@ -2964,6 +2964,12 @@ where
f(&*prev, &mut *curr)
});
}
+
+ pub fn device(&self) -> Device {
+ // If a device is returned, use that. Otherwise, it's fairly safe to
+ // assume that the data is on the host.
+ self.data._device().unwrap_or(Device::Host)
+ }
}
/// Transmute from A to B.
@@ -2986,10 +2992,14 @@ type DimMaxOf = >::Output;
impl ArrayBase, D>
where A: std::fmt::Debug
{
- pub fn copy_to_device(self, device: Device) -> Option {
+ // pub fn device(&self) -> Device {
+ // self.data.device()
+ // }
+
+ pub fn move_to_device(self, device: Device) -> Option {
let dim = self.dim;
let strides = self.strides;
- let data = self.data.copy_to_device(device)?;
+ let data = self.data.move_to_device(device)?;
let ptr = std::ptr::NonNull::new(data.as_ptr() as *mut A).unwrap();
Some(Self {
diff --git a/src/impl_ops.rs b/src/impl_ops.rs
index 8d02364d1..9475ebba3 100644
--- a/src/impl_ops.rs
+++ b/src/impl_ops.rs
@@ -50,10 +50,18 @@ impl ScalarOperand for f64 {}
impl ScalarOperand for Complex {}
impl ScalarOperand for Complex {}
+macro_rules! device_check_assert(
+ ($self:expr, $rhs:expr) => {
+ debug_assert_eq!($self.device(), $rhs.device(),
+ "Cannot perform operation on arrays on different devices. \
+ Please move them to the same device first.");
+ }
+);
+
macro_rules! impl_binary_op(
- ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
+ ($rs_trait:ident, $operator:tt, $math_op:ident, $inplace_op:tt, $docstring:expr) => (
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between `self` and `rhs`,
/// and return the result.
///
@@ -62,9 +70,9 @@ macro_rules! impl_binary_op(
/// If their shapes disagree, `self` is broadcast to their broadcast shape.
///
/// **Panics** if broadcasting isn’t possible.
-impl $trt> for ArrayBase
+impl $rs_trait> for ArrayBase
where
- A: Clone + $trt,
+ A: Clone + $rs_trait,
B: Clone,
S: DataOwned + DataMut,
S2: Data,
@@ -73,14 +81,15 @@ where
{
type Output = ArrayBase>::Output>;
#[track_caller]
- fn $mth(self, rhs: ArrayBase) -> Self::Output
+ fn $math_op(self, rhs: ArrayBase) -> Self::Output
{
- self.$mth(&rhs)
+ device_check_assert!(self, rhs);
+ self.$math_op(&rhs)
}
}
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between `self` and reference `rhs`,
/// and return the result.
///
@@ -90,9 +99,9 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
-impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase
+impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase> for ArrayBase
where
- A: Clone + $trt,
+ A: Clone + $rs_trait,
B: Clone,
S: DataOwned + DataMut,
S2: Data,
@@ -101,27 +110,29 @@ where
{
type Output = ArrayBase>::Output>;
#[track_caller]
- fn $mth(self, rhs: &ArrayBase) -> Self::Output
+ fn $math_op(self, rhs: &ArrayBase) -> Self::Output
{
+ device_check_assert!(self, rhs);
+
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = self.into_dimensionality::<>::Output>().unwrap();
- out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
+ out.zip_mut_with_same_shape(rhs, clone_iopf(A::$math_op));
out
} else {
let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
if lhs_view.shape() == self.shape() {
let mut out = self.into_dimensionality::<>::Output>().unwrap();
- out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
+ out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$math_op));
out
} else {
- Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
+ Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op))
}
}
}
}
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between reference `self` and `rhs`,
/// and return the result.
///
@@ -131,9 +142,9 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
-impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase
+impl<'a, A, B, S, S2, D, E> $rs_trait> for &'a ArrayBase
where
- A: Clone + $trt,
+ A: Clone + $rs_trait,
B: Clone,
S: Data,
S2: DataOwned + DataMut,
@@ -142,28 +153,30 @@ where
{
type Output = ArrayBase>::Output>;
#[track_caller]
- fn $mth(self, rhs: ArrayBase) -> Self::Output
- where
+ fn $math_op(self, rhs: ArrayBase) -> Self::Output
+ // where
{
+ device_check_assert!(self, rhs);
+
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<>::Output>().unwrap();
- out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
+ out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$math_op));
out
} else {
let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
if rhs_view.shape() == rhs.shape() {
let mut out = rhs.into_dimensionality::<>::Output>().unwrap();
- out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
+ out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$math_op));
out
} else {
- Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
+ Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$math_op))
}
}
}
}
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
@@ -171,9 +184,9 @@ where
/// cloning the data if needed.
///
/// **Panics** if broadcasting isn’t possible.
-impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase
+impl<'a, A, B, S, S2, D, E> $rs_trait<&'a ArrayBase> for &'a ArrayBase
where
- A: Clone + $trt,
+ A: Clone + $rs_trait,
B: Clone,
S: Data,
S2: Data,
@@ -182,7 +195,9 @@ where
{
type Output = Array>::Output>;
#[track_caller]
- fn $mth(self, rhs: &'a ArrayBase) -> Self::Output {
+ fn $math_op(self, rhs: &'a ArrayBase) -> Self::Output {
+ device_check_assert!(self, rhs);
+
let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
let lhs = self.view().into_dimensionality::<>::Output>().unwrap();
let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap();
@@ -190,24 +205,24 @@ where
} else {
self.broadcast_with(rhs).unwrap()
};
- Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
+ Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$math_op))
}
}
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between `self` and the scalar `x`,
/// and return the result (based on `self`).
///
/// `self` must be an `Array` or `ArcArray`.
-impl $trt for ArrayBase
- where A: Clone + $trt,
+impl $rs_trait for ArrayBase
+ where A: Clone + $rs_trait,
S: DataOwned + DataMut,
D: Dimension,
B: ScalarOperand,
{
type Output = ArrayBase;
- fn $mth(mut self, x: B) -> ArrayBase {
+ fn $math_op(mut self, x: B) -> ArrayBase {
self.map_inplace(move |elt| {
*elt = elt.clone() $operator x.clone();
});
@@ -216,17 +231,17 @@ impl $trt for ArrayBase
}
/// Perform elementwise
-#[doc=$doc]
+#[doc=$docstring]
/// between the reference `self` and the scalar `x`,
/// and return the result as a new `Array`.
-impl<'a, A, S, D, B> $trt for &'a ArrayBase
- where A: Clone + $trt,
+impl<'a, A, S, D, B> $rs_trait for &'a ArrayBase
+ where A: Clone + $rs_trait,
S: Data,
D: Dimension,
B: ScalarOperand,
{
type Output = Array;
- fn $mth(self, x: B) -> Self::Output {
+ fn $math_op(self, x: B) -> Self::Output {
self.map(move |elt| elt.clone() $operator x.clone())
}
}