Skip to content

Commit

Permalink
Add fused bias linear layer with cublaslt (#400)
Browse files Browse the repository at this point in the history
* Add fused bias linear layer with cublaslt

* Setup cublaslt

* Format and fix bs dim

* Improve test case

* Add arange test
  • Loading branch information
EricLBuehler authored Jun 7, 2024
1 parent 69f8626 commit 8c01795
Showing 1 changed file with 127 additions and 1 deletion.
128 changes: 127 additions & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{

use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
DType, Device, IndexOp, Result, Tensor,
DType, Device, IndexOp, Result, Shape, Tensor, D,
};
use candle_nn::{Linear, Module, VarBuilder};
use either::Either;
Expand Down Expand Up @@ -398,6 +398,57 @@ impl ScaledDotProductAttention {
}
}

/// Linear layer with fused bias matmul.
#[derive(Debug)]
pub struct FusedBiasLinear {
pub(crate) w: Tensor,
pub(crate) b: Tensor,
}

impl TryFrom<Linear> for FusedBiasLinear {
type Error = candle_core::Error;

fn try_from(x: Linear) -> Result<Self> {
if let Some(bias) = x.bias() {
Ok(Self {
w: x.weight().clone(),
b: bias.clone(),
})
} else {
candle_core::bail!("`FusedBiasLinear` expects a Linear layer with bias.")
}
}
}

impl Module for FusedBiasLinear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let w = match *x.dims() {
[b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
[bsize, _, _] => self.w.broadcast_left(bsize)?,
_ => self.w.clone(),
};
let mut tgt_shape = x.dims().to_vec();
tgt_shape[x.dims().len() - 1] = w.dim(D::Minus2)?;
let b = self.b.broadcast_as(Shape::from_dims(&tgt_shape))?;

if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
cublaslt
.batch_matmul(
x,
&w,
Some(&b.t()?.contiguous()?),
None,
Some(1.0),
None,
None,
)?
.t()
} else {
x.matmul(&w.t()?)? + b
}
}
}

#[derive(Debug, Clone)]
pub struct QLinear {
inner: QMatMul,
Expand Down Expand Up @@ -493,3 +544,78 @@ impl Module for QLinear {
}
}
}

mod tests {

#[test]
fn fused_bias_linear() {
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::{Linear, Module};

use crate::cublaslt::setup_cublas_lt_wrapper;
use crate::layers::FusedBiasLinear;

const IN: usize = 1921;
const OUT: usize = 4096;
const INNER: usize = 1024;

let dev = Device::cuda_if_available(0).unwrap();
setup_cublas_lt_wrapper();

let inner_dtype = if dev.is_cuda() {
DType::BF16
} else {
DType::F32
};

let w = Tensor::arange(0f32, (OUT * IN) as f32, &dev)
.unwrap()
.to_dtype(inner_dtype)
.unwrap()
.reshape((OUT, IN))
.unwrap();
let b = Tensor::arange(0f32, OUT as f32, &dev)
.unwrap()
.to_dtype(inner_dtype)
.unwrap()
.reshape((OUT,))
.unwrap();

let xs = Tensor::arange(0f32, (INNER * IN) as f32, &dev)
.unwrap()
.to_dtype(inner_dtype)
.unwrap()
.reshape((1, INNER, IN))
.unwrap();

let lin = Linear::new(w.clone(), Some(b.clone()));
let truth_out = lin.forward(&xs).unwrap();
let truth_y = truth_out
.to_dtype(DType::F32)
.unwrap()
.to_vec3::<f32>()
.unwrap();

let fused = FusedBiasLinear { w, b };
let fused_out = fused.forward(&xs).unwrap();
let fused_y = fused_out
.to_dtype(DType::F32)
.unwrap()
.to_vec3::<f32>()
.unwrap();

assert_eq!(truth_out.shape(), fused_out.shape());
if truth_y != fused_y {
panic!(
"Truth does not match fused kernel. Diff fused - truth:\n{:#?}",
&(&fused_out - &truth_out)
.unwrap()
.i((0, 5..10, 0..5))
.unwrap()
.to_dtype(DType::F32)
.unwrap()
.to_vec2::<f32>()
)
}
}
}

0 comments on commit 8c01795

Please sign in to comment.