diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index e71d8bf50..45606c974 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -12,7 +12,7 @@ use std::{ use candle_core::{ quantized::{gguf_file, QMatMul, QTensor}, - DType, Device, IndexOp, Result, Tensor, + DType, Device, IndexOp, Result, Shape, Tensor, D, }; use candle_nn::{Linear, Module, VarBuilder}; use either::Either; @@ -398,6 +398,57 @@ impl ScaledDotProductAttention { } } +/// Linear layer with fused bias matmul. +#[derive(Debug)] +pub struct FusedBiasLinear { + pub(crate) w: Tensor, + pub(crate) b: Tensor, +} + +impl TryFrom for FusedBiasLinear { + type Error = candle_core::Error; + + fn try_from(x: Linear) -> Result { + if let Some(bias) = x.bias() { + Ok(Self { + w: x.weight().clone(), + b: bias.clone(), + }) + } else { + candle_core::bail!("`FusedBiasLinear` expects a Linear layer with bias.") + } + } +} + +impl Module for FusedBiasLinear { + fn forward(&self, x: &Tensor) -> Result { + let w = match *x.dims() { + [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?, + [bsize, _, _] => self.w.broadcast_left(bsize)?, + _ => self.w.clone(), + }; + let mut tgt_shape = x.dims().to_vec(); + tgt_shape[x.dims().len() - 1] = w.dim(D::Minus2)?; + let b = self.b.broadcast_as(Shape::from_dims(&tgt_shape))?; + + if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), *CUBLASLT_HANDLE.lock().unwrap()) { + cublaslt + .batch_matmul( + x, + &w, + Some(&b.t()?.contiguous()?), + None, + Some(1.0), + None, + None, + )? + .t() + } else { + x.matmul(&w.t()?)? + b + } + } +} + #[derive(Debug, Clone)] pub struct QLinear { inner: QMatMul, @@ -493,3 +544,78 @@ impl Module for QLinear { } } } + +mod tests { + + #[test] + fn fused_bias_linear() { + use candle_core::{DType, Device, IndexOp, Tensor}; + use candle_nn::{Linear, Module}; + + use crate::cublaslt::setup_cublas_lt_wrapper; + use crate::layers::FusedBiasLinear; + + const IN: usize = 1921; + const OUT: usize = 4096; + const INNER: usize = 1024; + + let dev = Device::cuda_if_available(0).unwrap(); + setup_cublas_lt_wrapper(); + + let inner_dtype = if dev.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + + let w = Tensor::arange(0f32, (OUT * IN) as f32, &dev) + .unwrap() + .to_dtype(inner_dtype) + .unwrap() + .reshape((OUT, IN)) + .unwrap(); + let b = Tensor::arange(0f32, OUT as f32, &dev) + .unwrap() + .to_dtype(inner_dtype) + .unwrap() + .reshape((OUT,)) + .unwrap(); + + let xs = Tensor::arange(0f32, (INNER * IN) as f32, &dev) + .unwrap() + .to_dtype(inner_dtype) + .unwrap() + .reshape((1, INNER, IN)) + .unwrap(); + + let lin = Linear::new(w.clone(), Some(b.clone())); + let truth_out = lin.forward(&xs).unwrap(); + let truth_y = truth_out + .to_dtype(DType::F32) + .unwrap() + .to_vec3::() + .unwrap(); + + let fused = FusedBiasLinear { w, b }; + let fused_out = fused.forward(&xs).unwrap(); + let fused_y = fused_out + .to_dtype(DType::F32) + .unwrap() + .to_vec3::() + .unwrap(); + + assert_eq!(truth_out.shape(), fused_out.shape()); + if truth_y != fused_y { + panic!( + "Truth does not match fused kernel. Diff fused - truth:\n{:#?}", + &(&fused_out - &truth_out) + .unwrap() + .i((0, 5..10, 0..5)) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .to_vec2::() + ) + } + } +}