Skip to content

Commit

Permalink
Merge pull request #317 from EricLBuehler/matmul_f16
Browse files Browse the repository at this point in the history
Matmul via f16 when possible
  • Loading branch information
EricLBuehler committed May 16, 2024
2 parents d78313c + 9b7d5ca commit c3e176f
Show file tree
Hide file tree
Showing 35 changed files with 507 additions and 494 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
members = [
"mistralrs-server",
"mistralrs-core",
"mistralrs-lora",
"mistralrs-pyo3",
"mistralrs",
"mistralrs-bench",
Expand Down
1 change: 0 additions & 1 deletion mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ tokenizers = "0.15.2"
tqdm = "0.7.0"
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }
chrono = "0.4.34"
mistralrs-lora = { version = "0.1.7", path = "../mistralrs-lora" }
minijinja = "1.0.12"
either.workspace = true
indexmap.workspace = true
Expand Down
159 changes: 155 additions & 4 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{collections::HashMap, ops::Mul, str::FromStr, sync::Mutex};
use std::{
collections::HashMap,
ops::Mul,
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};

use candle_core::{quantized::QTensor, DType, Device, IndexOp, Result, Tensor, WithDType};
use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
DType, Device, IndexOp, Result, Tensor, WithDType,
};
use candle_nn::{
layer_norm::{RmsNormNonQuantized, RmsNormQuantized},
Module, VarBuilder,
Linear, Module, VarBuilder,
};
use once_cell::sync::Lazy;

static MASKS: Lazy<Mutex<HashMap<(usize, usize), Tensor>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

use crate::models::phi3;
use crate::{models::phi3, INHIBIT_GEMM_F16};

#[derive(Debug, Clone)]
pub struct RmsNorm {
Expand Down Expand Up @@ -366,6 +377,146 @@ impl CausalMasker {
}
}

/// Matrix multiplcation, configurable to be via f16 (to use the faster GEMM kernels) optionally.
pub struct MatMul;

/// Set the matmuls to go via f16
pub(crate) static USE_MATMUL_VIA_F16: AtomicBool = AtomicBool::new(false);

pub(crate) fn set_use_matmul_via_f16(via_f16: bool) {
if !INHIBIT_GEMM_F16.load(Ordering::Relaxed) {
USE_MATMUL_VIA_F16.store(via_f16, Ordering::Relaxed)
}
}
pub fn get_use_matmul_via_f16() -> bool {
USE_MATMUL_VIA_F16.load(Ordering::Relaxed)
}

impl MatMul {
/// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
if !get_use_matmul_via_f16() {
return a.matmul(b);
}
let original_dtype = a.dtype();
a.to_dtype(DType::F16)?
.matmul(&b.to_dtype(DType::F16)?)?
.to_dtype(original_dtype)
}

/// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// The result will be divided by the `scale` parameter in an affine division.
pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
// TODO(EricLBuehler): Optimize this by using the gemm parameter
self.matmul(a, b)? / scale
}

/// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_f16(x)
} else {
matmul.forward(x)
}
}
}

#[derive(Debug, Clone)]
pub struct QLinear {
inner: QMatMul,
bias: Option<Tensor>,
dtype: DType,
}

impl QLinear {
pub fn new<R: std::io::Read + std::io::Seek>(
ct: &gguf_file::Content,
r: &mut R,
name: &str,
device: &Device,
) -> Result<Self> {
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
let b = ct.tensor(r, &format!("{name}.bias"), device)?;
let inner = QMatMul::from_qtensor(w)?;
let bias = b.dequantize(device)?;
Ok(Self {
inner,
bias: Some(bias),
dtype: DType::F32,
})
}

pub fn from_linear(linear: Linear) -> Self {
Self {
inner: QMatMul::Tensor(linear.weight().clone()),
bias: linear.bias().cloned(),
dtype: if linear.weight().device().is_cuda() {
DType::BF16
} else {
DType::F32
},
}
}

pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
let dtype = if w.device().is_cuda() {
DType::BF16
} else {
DType::F32
};
Self {
inner: QMatMul::Tensor(w),
bias: b,
dtype,
}
}

pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
if let Some(ref b) = b {
assert_eq!(b.dtype(), DType::F32);
}
Self {
inner: QMatMul::QTensor(Arc::new(w)),
bias: b,
dtype: DType::F32,
}
}

pub fn inner(&mut self) -> &mut QMatMul {
&mut self.inner
}

pub fn is_quant(&self) -> bool {
matches!(self.inner, QMatMul::QTensor(_))
}

pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}

impl Module for QLinear {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = if self.is_quant() {
xs.to_dtype(DType::F32)?
} else {
xs.clone()
};
let forward_fn = if get_use_matmul_via_f16() {
QMatMul::forward
} else {
QMatMul::forward_via_f16
};
if let Some(bias) = &self.bias {
forward_fn(&self.inner, &xs)?
.broadcast_add(bias)?
.to_dtype(self.dtype)
} else {
forward_fn(&self.inner, &xs)?.to_dtype(self.dtype)
}
}
}

#[cfg(feature = "flash-attn")]
pub fn flash_attn(
q: &Tensor,
Expand Down
43 changes: 40 additions & 3 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ use std::{
error::Error,
fs::OpenOptions,
io::Write,
sync::{Arc, Mutex},
sync::{atomic::AtomicBool, Arc, Mutex},
thread,
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc::{channel, Sender};

use engine::Engine;
pub use engine::TERMINATE_ALL_NEXT_STEP;
pub use mistralrs_lora::Ordering;
pub use lora::Ordering;
pub use pipeline::Pipeline;

mod aici;
mod device_map;
mod engine;
mod lora;
mod model_loader;
pub use model_loader::{get_tgt_non_granular_index, LoaderBuilder};
mod model_selected;
Expand Down Expand Up @@ -133,10 +134,46 @@ impl MistralRsBuilder {
}
}

pub(crate) static INHIBIT_GEMM_F16: AtomicBool = AtomicBool::new(false);

#[cfg(feature = "cuda")]
fn set_gemm_reduced_precision_f16() {
candle_core::cuda::set_gemm_reduced_precision_f16(true);
use candle_core::{DType, Device, Tensor};

// NOTE(EricLBuehler): When we support multi-GPU inference, we should check for each gpu here
let a = Tensor::zeros((2, 2), DType::BF16, &Device::new_cuda(0).unwrap()).unwrap();
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
match a.matmul(&a) {
Ok(_) => (),
Err(e) => match e {
candle_core::Error::Cuda(e) => {
let x = e.downcast::<candle_core::cuda::cudarc::cublas::result::CublasError>();
if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in BF16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
_ => (),
},
}

let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap();
candle_core::cuda::set_gemm_reduced_precision_f16(true);
match a.matmul(&a) {
Ok(_) => (),
Err(e) => match e {
candle_core::Error::Cuda(e) => {
let x = e.downcast::<candle_core::cuda::cudarc::cublas::result::CublasError>();
if format!("{x:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in F16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_f16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
_ => (),
},
}
}

#[cfg(not(feature = "cuda"))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ use candle_core::{
use candle_nn::{Linear, VarBuilder};
use either::Either;

use crate::{
apply_scalings_to_x, get_maybe_topk_scalings, layer::QLinear, make_adapter, Adapter,
AdapterSwapper, LinearLayerLike, LoraConfig, LoraLinearConfig, Merge,
use crate::layers::QLinear;

use super::{
apply_scalings_to_x, get_maybe_topk_scalings, make_adapter, Adapter, AdapterSwapper,
LinearLayerLike, LoraConfig, LoraLinearConfig, Merge,
};

#[derive(Debug)]
Expand All @@ -22,7 +24,6 @@ pub struct LoraLinear {
layer_n: usize,
merged: bool,
adapters: HashMap<String, Adapter>,
linear_config: LoraLinearConfig,
}

impl LoraLinear {
Expand Down Expand Up @@ -112,7 +113,6 @@ impl LoraLinear {
layer_n,
merged: false,
adapters,
linear_config: linear_config.clone(),
})
} else {
Ok(LoraLinear {
Expand All @@ -123,7 +123,6 @@ impl LoraLinear {
layer_n,
merged: false,
adapters,
linear_config: linear_config.clone(),
})
}
}
Expand Down Expand Up @@ -158,22 +157,6 @@ impl AdapterSwapper for LoraLinear {
}
Ok(())
}
fn has_adapter(&self, adapter: String) -> bool {
self.adapters.contains_key(&adapter)
}
fn load_new_adapter(
&mut self,
name: String,
vb: VarBuilder,
cfg: &LoraConfig,
module_prefix: String,
) -> Result<()> {
let a_vb = vb.set_prefix(&module_prefix).pp("lora_A".to_string());
let b_vb = vb.set_prefix(&module_prefix).pp("lora_B".to_string());
let adapter = make_adapter(a_vb, b_vb, cfg, &self.linear_config)?;
self.adapters.insert(name.clone(), adapter);
Ok(())
}
fn can_load(&self) -> bool {
true
}
Expand Down
Loading

0 comments on commit c3e176f

Please sign in to comment.