Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cudnn/bindgen.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
bindgen "/usr/include/cudnn.h" \
bindgen "${HOME}/local/include/cudnn.h" \
--size_t-is-usize \
--allowlist-type "cudnn.*" \
--allowlist-function "cudnn.*" \
Expand Down
141 changes: 141 additions & 0 deletions crates/cudnn/src/backend/conv_bwd_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use crate::{
backend::{ConvCfg, Descriptor, FloatDataType, Operation, Real, Tensor},
sys, CudnnError, DataType, IntoResult,
};

pub struct ConvBwdDataBuilder {
cfg: Option<ConvCfg>,
alpha: Option<Real>,
beta: Option<Real>,
w: Option<Tensor>,
dx: Option<Tensor>,
dy: Option<Tensor>,
}

impl ConvBwdDataBuilder {
pub fn set_cfg(mut self, cfg: ConvCfg) -> Self {
self.cfg = Some(cfg);
self
}

pub fn set_alpha<T>(mut self, alpha: T) -> Self
where
T: FloatDataType,
{
self.alpha = Some(alpha.wrap());
self
}

pub fn set_beta<T>(mut self, beta: T) -> Self
where
T: FloatDataType,
{
self.beta = Some(beta.wrap());
self
}

pub fn set_w(mut self, w: Tensor) -> Self {
self.w = Some(w);
self
}

pub fn set_dx(mut self, dx: Tensor) -> Self {
self.dx = Some(dx);
self
}

pub fn set_dy(mut self, dy: Tensor) -> Self {
self.dy = Some(dy);
self
}

pub fn build(self) -> Result<Operation, CudnnError> {
let cfg = self.cfg.expect("convolution configuration is required.");

let w = self.w.expect("w tensor is required");
let dx = self.dx.expect("dx tensor is required.");
let dy = self.dy.expect("dy tensor is required.");

let alpha = self.alpha.unwrap_or(Real::Float(1.0));
let beta = self.beta.unwrap_or(Real::Float(0.0));

unsafe {
let mut raw = Descriptor::new(
sys::cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&cfg.raw.inner(),
)
?;

match self.alpha {
Some(Real::Float(ref alpha)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
1,
alpha,
)?,
Some(Real::Double(ref alpha)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
1,
alpha,
)?,
None => (),
}

match self.beta {
Some(Real::Float(ref beta)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
1,
beta,
)?,
Some(Real::Double(ref beta)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
1,
beta,
)?,
None => (),
}

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&w.raw.inner(),
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&dx.raw.inner(),
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&dy.raw.inner(),
)?;

raw.finalize()?;

Ok(Operation::ConvBwdData {
raw,
cfg,
alpha,
beta,
w,
dx,
dy,
})
}
}
}
139 changes: 139 additions & 0 deletions crates/cudnn/src/backend/conv_bwd_filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use crate::{
backend::{ConvCfg, Descriptor, FloatDataType, Operation, Real, Tensor},
sys, CudnnError, DataType, IntoResult,
};

pub struct ConvBwdFilterBuilder {
cfg: Option<ConvCfg>,
alpha: Option<Real>,
beta: Option<Real>,
dw: Option<Tensor>,
x: Option<Tensor>,
dy: Option<Tensor>,
}

impl ConvBwdFilterBuilder {
pub fn set_cfg(mut self, cfg: ConvCfg) -> Self {
self.cfg = Some(cfg);
self
}

pub fn set_alpha<T>(mut self, alpha: T) -> Self
where
T: FloatDataType,
{
self.alpha = Some(alpha.wrap());
self
}

pub fn set_beta<T>(mut self, beta: T) -> Self
where
T: FloatDataType,
{
self.beta = Some(beta.wrap());
self
}

pub fn set_dw(mut self, dw: Tensor) -> Self {
self.dw = Some(dw);
self
}

pub fn set_dx(mut self, x: Tensor) -> Self {
self.x = Some(x);
self
}

pub fn set_dy(mut self, dy: Tensor) -> Self {
self.dy = Some(dy);
self
}

pub fn build(self) -> Result<Operation, CudnnError> {
let cfg = self.cfg.expect("convolution configuration is required.");
let dw = self.dw.expect("dw tensor is required");
let x = self.x.expect("x tensor is required.");
let dy = self.dy.expect("dy tensor is required.");

let alpha = self.alpha.unwrap_or(Real::Float(1.0));
let beta = self.beta.unwrap_or(Real::Float(0.0));

unsafe {
let mut raw = Descriptor::new(
sys::cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&cfg.raw.inner(),
)?;

match self.alpha {
Some(Real::Float(ref alpha)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
1,
alpha,
)?,
Some(Real::Double(ref alpha)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
1,
alpha,
)?,
None => (),
}

match self.beta {
Some(Real::Float(ref beta)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT,
1,
beta,
)?,
Some(Real::Double(ref beta)) => raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE,
1,
beta,
)?,
None => (),
}

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&dw.raw.inner(),
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&x.raw.inner(),
)?;

raw.set_attribute(
sys::cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
sys::cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&dy.raw.inner(),
)?;

raw.finalize()?;

Ok(Operation::ConvBwdFilter {
raw,
cfg,
alpha,
beta,
dw,
x,
dy,
})
}
}
}
Loading