Skip to content

Commit

Permalink
Merge branch 'master' into idefics2
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed May 23, 2024
2 parents 5eb7f7a + 05f6ab8 commit 9585476
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 60 deletions.
12 changes: 2 additions & 10 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use candle_core::Device;
use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use mistralrs_core::{
Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelKind,
Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder,
ModelSelected, NormalRequest, Request, RequestMessage, Response, SamplingParams,
SchedulerMethod, TokenSource, Usage,
};
Expand Down Expand Up @@ -313,15 +313,7 @@ fn main() -> anyhow::Result<()> {
if use_flash_attn {
info!("Using flash attention.");
}
if use_flash_attn
&& matches!(
loader.get_kind(),
ModelKind::QuantizedGGML
| ModelKind::QuantizedGGUF
| ModelKind::XLoraGGML
| ModelKind::XLoraGGUF
)
{
if use_flash_attn && loader.get_kind().is_quantized() {
warn!("Using flash attention with a quantized model has no effect!")
}
info!("Model kind is: {}", loader.get_kind().to_string());
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ once_cell.workspace = true
toml = "0.8.12"
strum = { version = "0.26", features = ["derive"] }
image = "0.25.1"
derive_more = { version = "0.99.17", default-features = false, features = ["from"] }

[features]
pyo3_macros = ["pyo3"]
Expand Down
158 changes: 145 additions & 13 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,23 +241,155 @@ pub enum ModelKind {
},
}

impl Display for ModelKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// TODO: Future replacement for `ModelKind` above:
#[derive(Default, derive_more::From, strum::Display)]
pub enum ModelKindB {
#[default]
#[strum(to_string = "normal (no quant, no adapters)")]
Plain,

#[strum(to_string = "quantized from {quant} (no adapters)")]
Quantized { quant: QuantizationKind },

#[strum(to_string = "{adapter}, (no quant)")]
Adapter { adapter: AdapterKind },

#[strum(to_string = "{adapter}, quantized from {quant}")]
AdapterQuantized {
adapter: AdapterKind,
quant: QuantizationKind,
},

// TODO: This would need to be later changed to reference `Self`, but this current way
// avoids having to handle the conversion logic with `ModelKind`.
#[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
Speculative {
target: Box<ModelKind>,
draft: Box<ModelKind>,
},
}

#[derive(Clone, Copy, strum::Display, strum::EnumIs)]
#[strum(serialize_all = "kebab-case")]
pub enum QuantizationKind {
Ggml,
Gguf,
}

#[derive(Clone, Copy, strum::Display, strum::EnumIs)]
#[strum(serialize_all = "kebab-case")]
pub enum AdapterKind {
Lora,
XLora,
}

impl ModelKindB {
// Quantized helpers:
pub fn is_quantized(&self) -> bool {
self.quantized_kind().iter().any(|q| q.is_some())
}

pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
}

pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
use ModelKindB::*;

match self {
ModelKind::Normal => write!(f, "normal (no quant, no adapters)"),
ModelKind::QuantizedGGML => write!(f, "quantized from ggml (no adapters)"),
ModelKind::QuantizedGGUF => write!(f, "quantized from gguf (no adapters)"),
ModelKind::XLoraNormal => write!(f, "x-lora (no quant)"),
ModelKind::XLoraGGML => write!(f, "x-lora, quantized from ggml"),
ModelKind::XLoraGGUF => write!(f, "x-lora, quantized from gguf"),
ModelKind::LoraGGUF => write!(f, "lora, quantized from gguf"),
ModelKind::LoraGGML => write!(f, "lora, quantized from ggml"),
ModelKind::LoraNormal => write!(f, "lora (no quant)"),
ModelKind::Speculative { target, draft } => {
write!(f, "speculative: target: `{target}`, draft: `{draft}`")
Plain | Adapter { .. } => vec![None],
Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)],
Speculative { target, draft } => {
let t = ModelKindB::from(*target.clone());
let d = ModelKindB::from(*draft.clone());

[t.quantized_kind(), d.quantized_kind()].concat()
}
}
}

// Adapter helpers:
pub fn is_adapted(&self) -> bool {
self.adapted_kind().iter().any(|a| a.is_some())
}

pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
}

pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
use ModelKindB::*;

match self {
Plain | Quantized { .. } => vec![None],
Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)],
Speculative { target, draft } => {
let t = ModelKindB::from(*target.clone());
let d = ModelKindB::from(*draft.clone());

[t.adapted_kind(), d.adapted_kind()].concat()
}
}
}
}

// TODO: Temporary compatibility layers follow (until a future PR follow-up introduces a breaking change)
impl Display for ModelKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", ModelKindB::from(self.clone()))
}
}

// Delegate to `ModelKindB` methods:
impl ModelKind {
// Quantized helpers:
pub fn is_quantized(&self) -> bool {
let k = ModelKindB::from(self.clone());
k.is_quantized()
}

pub fn is_quantized_and(&self, f: impl FnMut(QuantizationKind) -> bool) -> bool {
let k = ModelKindB::from(self.clone());
k.is_quantized_and(f)
}

pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
let k = ModelKindB::from(self.clone());
k.quantized_kind()
}

// Adapter helpers:
pub fn is_adapted(&self) -> bool {
let k = ModelKindB::from(self.clone());
k.is_adapted()
}

pub fn is_adapted_and(&self, f: impl FnMut(AdapterKind) -> bool) -> bool {
let k = ModelKindB::from(self.clone());
k.is_adapted_and(f)
}

pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
let k = ModelKindB::from(self.clone());
k.adapted_kind()
}
}

impl From<ModelKind> for ModelKindB {
fn from(kind: ModelKind) -> Self {
match kind {
ModelKind::Normal => ModelKindB::Plain,
ModelKind::QuantizedGGML => (QuantizationKind::Ggml).into(),
ModelKind::QuantizedGGUF => (QuantizationKind::Gguf).into(),
ModelKind::XLoraNormal => (AdapterKind::XLora).into(),
ModelKind::XLoraGGML => (AdapterKind::XLora, QuantizationKind::Ggml).into(),
ModelKind::XLoraGGUF => (AdapterKind::XLora, QuantizationKind::Gguf).into(),
ModelKind::LoraNormal => (AdapterKind::Lora).into(),
ModelKind::LoraGGML => (AdapterKind::Lora, QuantizationKind::Ggml).into(),
ModelKind::LoraGGUF => (AdapterKind::Lora, QuantizationKind::Gguf).into(),
ModelKind::Speculative { target, draft } => (target, draft).into(),
}
}
}

/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
Expand Down
49 changes: 22 additions & 27 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,9 @@ impl Loader for NormalLoader {
Device::Cpu
};

let mut is_lora = false;
let is_lora = self.kind.is_adapted_and(|a| a.is_lora());

let mut model = match self.kind {
ModelKind::QuantizedGGUF => unreachable!(),
ModelKind::QuantizedGGML => unreachable!(),
ModelKind::Normal => normal_model_loader!(
paths,
dtype,
Expand All @@ -265,30 +264,26 @@ impl Loader for NormalLoader {
in_situ_quant.is_some(),
device.clone()
),
ModelKind::LoraNormal => {
is_lora = true;
lora_model_loader!(
paths,
dtype,
default_dtype,
&load_device,
config,
self.inner,
self.config.use_flash_attn,
silent,
mapper,
in_situ_quant.is_some(),
device.clone()
)
}
ModelKind::XLoraGGUF => unreachable!(),
ModelKind::XLoraGGML => unreachable!(),
ModelKind::LoraGGUF => unreachable!(),
ModelKind::LoraGGML => unreachable!(),
ModelKind::Speculative {
target: _,
draft: _,
} => unreachable!(),
ModelKind::LoraNormal => lora_model_loader!(
paths,
dtype,
default_dtype,
&load_device,
config,
self.inner,
self.config.use_flash_attn,
silent,
mapper,
in_situ_quant.is_some(),
device.clone()
),
ModelKind::QuantizedGGUF
| ModelKind::QuantizedGGML
| ModelKind::XLoraGGUF
| ModelKind::XLoraGGML
| ModelKind::LoraGGUF
| ModelKind::LoraGGML
| ModelKind::Speculative { .. } => unreachable!(),
};

let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?;
Expand Down
12 changes: 2 additions & 10 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use candle_core::{quantized::GgmlDType, Device};
use clap::Parser;
use mistralrs_core::{
get_tgt_non_granular_index, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs,
MistralRsBuilder, ModelKind, ModelSelected, Request, SchedulerMethod, TokenSource,
MistralRsBuilder, ModelSelected, Request, SchedulerMethod, TokenSource,
};
use openai::{ChatCompletionRequest, Message, ModelObjects, StopTokens};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -269,15 +269,7 @@ async fn main() -> Result<()> {
if use_flash_attn {
info!("Using flash attention.");
}
if use_flash_attn
&& matches!(
loader.get_kind(),
ModelKind::QuantizedGGML
| ModelKind::QuantizedGGUF
| ModelKind::XLoraGGML
| ModelKind::XLoraGGUF
)
{
if use_flash_attn && loader.get_kind().is_quantized() {
warn!("Using flash attention with a quantized model has no effect!")
}
info!("Model kind is: {}", loader.get_kind().to_string());
Expand Down

0 comments on commit 9585476

Please sign in to comment.