diff --git a/crates/cudnn/src/attention/mod.rs b/crates/cudnn/src/attention/mod.rs index 9ed25cb6..00317436 100644 --- a/crates/cudnn/src/attention/mod.rs +++ b/crates/cudnn/src/attention/mod.rs @@ -180,9 +180,7 @@ impl CudnnContext { reserve_space_size, reserve_space_ptr, ) - .into_result()?; - - Ok(()) + .into_result() } } } diff --git a/crates/cudnn/src/lib.rs b/crates/cudnn/src/lib.rs index 70a01c51..41bab0a1 100644 --- a/crates/cudnn/src/lib.rs +++ b/crates/cudnn/src/lib.rs @@ -13,6 +13,7 @@ mod math_type; mod nan_propagation; mod op_tensor; mod rnn; +mod softmax; mod tensor; mod w_grad_mode; @@ -27,6 +28,7 @@ pub use math_type::*; pub use nan_propagation::*; pub use op_tensor::*; pub use rnn::*; +pub use softmax::*; pub use tensor::*; pub use w_grad_mode::*; diff --git a/crates/cudnn/src/softmax/mod.rs b/crates/cudnn/src/softmax/mod.rs new file mode 100644 index 00000000..dd9a8fd0 --- /dev/null +++ b/crates/cudnn/src/softmax/mod.rs @@ -0,0 +1,142 @@ +mod softmax_algo; +mod softmax_mode; + +pub use softmax_algo::*; +pub use softmax_mode::*; + +use crate::{sys, CudnnContext, CudnnError, DataType, IntoResult, SupportedOp, TensorDescriptor}; +use cust::memory::GpuBuffer; + +impl CudnnContext { + /// Computes the softmax function. + /// + /// # Arguments + /// + /// * `algo` - softmax algorithm to compute. + /// + /// * `mode` - specifies the softmax mode. + /// + /// * `alpha` - scaling factor for the result. Must be stored in host memory. + /// + /// * `x_desc` - tensor descriptor for the operand. + /// + /// * `x` - operand data in device memory. + /// + /// * `beta` - scaling factor for the destination tensor. + /// + /// * `y_desc` - tensor descriptor for the result. + /// + /// * `y` - output data in device memory. + /// + /// # Errors + /// + /// Returns errors if the configuration in input is not supported, the tensor shapes differ or + /// the data types of the input and destination tensor are not the same. + pub fn softmax_forward( + &self, + algo: SoftmaxAlgo, + mode: SoftmaxMode, + alpha: CompT, + x_desc: &TensorDescriptor, + x: impl GpuBuffer, + beta: CompT, + y_desc: &TensorDescriptor, + y: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: DataType, + CompT: SupportedOp, + { + let alpha_ptr = &alpha as *const CompT as *const _; + let x_ptr = x.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const CompT as *const _; + let y_ptr = y.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnSoftmaxForward( + self.raw, + algo.into(), + mode.into(), + alpha_ptr, + x_desc.raw, + x_ptr, + beta_ptr, + y_desc.raw, + y_ptr, + ) + .into_result() + } + } + + /// Computes the gradient of the softmax function + /// + /// # Arguments + /// + /// * `algo` - softmax algorithm to compute the gradient of. + /// + /// * `mode` - specifies the softmax mode to compute the gradient of. + /// + /// * `alpha` - scaling factor for the result. Must be stored in host memory. + /// + /// * `y_desc` - tensor descriptor for the operand. + /// + /// * `y` - operand data in device memory. + /// + /// * `dy_desc` - tensor descriptor for the result. + /// + /// * `dy` - output data in device memory. + /// + /// * `beta` - scaling factor for the differential tensor. + /// + /// * `dx_desc` - differential tensor descriptor. + /// + /// * `dx` - differential data in device memory. + /// + /// # Errors + /// + /// Returns errors if the configuration in input is not supported, the tensor shapes differ or + /// the data types of the input and differential tensor are not the same. + pub fn softmax_backward( + &self, + algo: SoftmaxAlgo, + mode: SoftmaxMode, + alpha: CompT, + y_desc: &TensorDescriptor, + y: impl GpuBuffer, + dy_desc: &TensorDescriptor, + dy: &impl GpuBuffer, + beta: CompT, + dx_desc: &TensorDescriptor, + dx: &mut impl GpuBuffer, + ) -> Result<(), CudnnError> + where + T: DataType, + CompT: SupportedOp, + { + let alpha_ptr = &alpha as *const CompT as *const _; + let y_ptr = y.as_device_ptr().as_ptr() as *const _; + + let beta_ptr = &beta as *const CompT as *const _; + let dy_ptr = dy.as_device_ptr().as_ptr() as *const _; + + let dx_ptr = dx.as_device_ptr().as_mut_ptr() as *mut _; + + unsafe { + sys::cudnnSoftmaxBackward( + self.raw, + algo.into(), + mode.into(), + alpha_ptr, + y_desc.raw, + y_ptr, + dy_desc.raw, + dy_ptr, + beta_ptr, + dx_desc.raw, + dx_ptr, + ) + .into_result() + } + } +} diff --git a/crates/cudnn/src/softmax/softmax_algo.rs b/crates/cudnn/src/softmax/softmax_algo.rs new file mode 100644 index 00000000..b3bb8337 --- /dev/null +++ b/crates/cudnn/src/softmax/softmax_algo.rs @@ -0,0 +1,24 @@ +use crate::sys; + +/// Specifies the implementation of the softmax function. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SoftmaxAlgo { + /// This implementation applies the straightforward softmax operation. + Fast, + /// This implementation scales each point of the softmax input domain by its maximum value + /// to avoid potential floating point overflows in the softmax evaluation. + Accurate, + /// This entry performs the log softmax operation, avoiding overflows by scaling each point in + /// the input domain as in the accurate version. + Log, +} + +impl From for sys::cudnnSoftmaxAlgorithm_t { + fn from(algo: SoftmaxAlgo) -> Self { + match algo { + SoftmaxAlgo::Fast => Self::CUDNN_SOFTMAX_FAST, + SoftmaxAlgo::Accurate => Self::CUDNN_SOFTMAX_ACCURATE, + SoftmaxAlgo::Log => Self::CUDNN_SOFTMAX_ACCURATE, + } + } +} diff --git a/crates/cudnn/src/softmax/softmax_mode.rs b/crates/cudnn/src/softmax/softmax_mode.rs new file mode 100644 index 00000000..8c4e60b0 --- /dev/null +++ b/crates/cudnn/src/softmax/softmax_mode.rs @@ -0,0 +1,20 @@ +use crate::{sys, SoftmaxAlgo}; + +/// Specifies how the softmax input must be processed. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SoftmaxMode { + /// The softmax operation is computed per image (N) across the dimensions C,H,W. + Instance, + /// The softmax operation is computed per spatial location (H,W) per image (N) across + /// dimension C. + Channel, +} + +impl From for sys::cudnnSoftmaxMode_t { + fn from(mode: SoftmaxMode) -> Self { + match mode { + SoftmaxMode::Channel => Self::CUDNN_SOFTMAX_MODE_CHANNEL, + SoftmaxMode::Instance => Self::CUDNN_SOFTMAX_MODE_INSTANCE, + } + } +}