Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove causal masks cache #412

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 20 additions & 69 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{collections::HashMap, ops::Add, sync::Mutex};
use std::ops::Add;

use candle_core::{DType, Device, Result, Tensor, WithDType};
use once_cell::sync::Lazy;

// (bs, tgt_len, past_kv_len)
type MaskKey = (usize, usize, usize);
static MASKS: Lazy<Mutex<HashMap<MaskKey, Tensor>>> = Lazy::new(|| Mutex::new(HashMap::new()));

// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
pub struct CausalMasker;
Expand Down Expand Up @@ -93,14 +88,8 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
let causal_mask = if let Some(mask) = res {
return Ok(Some(mask));
} else {

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
Expand All @@ -119,10 +108,6 @@ impl CausalMasker {
f32::NEG_INFINITY,
)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Expand Down Expand Up @@ -150,14 +135,8 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
let causal_mask = if let Some(mask) = res {
return Ok(Some(mask));
} else {

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
Expand All @@ -180,10 +159,6 @@ impl CausalMasker {
f32::NEG_INFINITY,
)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Expand All @@ -209,25 +184,13 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
if let Some(mask) = res {
Ok(Some(mask))
} else {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(Some(mask))
}
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Ok(Some(mask))
}

#[deprecated(
Expand All @@ -251,28 +214,16 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
if let Some(mask) = res {
Ok(Some(mask))
} else {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(Some(mask))
}
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Ok(Some(mask))
}

pub fn apply_mask_one_and_zero(
Expand Down
Loading