Skip to content

Commit

Permalink
Remove causal masks cache (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jun 9, 2024
1 parent f257423 commit dfb9dc5
Showing 1 changed file with 20 additions and 69 deletions.
89 changes: 20 additions & 69 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{collections::HashMap, ops::Add, sync::Mutex};
use std::ops::Add;

use candle_core::{DType, Device, Result, Tensor, WithDType};
use once_cell::sync::Lazy;

// (bs, tgt_len, past_kv_len)
type MaskKey = (usize, usize, usize);
static MASKS: Lazy<Mutex<HashMap<MaskKey, Tensor>>> = Lazy::new(|| Mutex::new(HashMap::new()));

// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
pub struct CausalMasker;
Expand Down Expand Up @@ -93,14 +88,8 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
let causal_mask = if let Some(mask) = res {
return Ok(Some(mask));
} else {

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
Expand All @@ -119,10 +108,6 @@ impl CausalMasker {
f32::NEG_INFINITY,
)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Expand Down Expand Up @@ -150,14 +135,8 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
let causal_mask = if let Some(mask) = res {
return Ok(Some(mask));
} else {

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
Expand All @@ -180,10 +159,6 @@ impl CausalMasker {
f32::NEG_INFINITY,
)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Expand All @@ -209,25 +184,13 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
if let Some(mask) = res {
Ok(Some(mask))
} else {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(Some(mask))
}
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Ok(Some(mask))
}

#[deprecated(
Expand All @@ -251,28 +214,16 @@ impl CausalMasker {
if tgt_len == 1 {
return Ok(None);
}
let res = MASKS
.lock()
.unwrap()
.get(&(b_sz, tgt_len, past_kv_len))
.cloned();
if let Some(mask) = res {
Ok(Some(mask))
} else {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

MASKS
.lock()
.unwrap()
.insert((b_sz, tgt_len, past_kv_len), mask.clone());
Ok(Some(mask))
}
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Ok(Some(mask))
}

pub fn apply_mask_one_and_zero(
Expand Down

0 comments on commit dfb9dc5

Please sign in to comment.