From 1decc55cb24cf81b8eb210624e89e36060a9c2d7 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 9 Jun 2024 05:56:03 -0400 Subject: [PATCH] Remove causal masks cache --- mistralrs-core/src/layers_masker.rs | 89 +++++++---------------------- 1 file changed, 20 insertions(+), 69 deletions(-) diff --git a/mistralrs-core/src/layers_masker.rs b/mistralrs-core/src/layers_masker.rs index 71f21081d..586d67b9f 100644 --- a/mistralrs-core/src/layers_masker.rs +++ b/mistralrs-core/src/layers_masker.rs @@ -1,13 +1,8 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use std::{collections::HashMap, ops::Add, sync::Mutex}; +use std::ops::Add; use candle_core::{DType, Device, Result, Tensor, WithDType}; -use once_cell::sync::Lazy; - -// (bs, tgt_len, past_kv_len) -type MaskKey = (usize, usize, usize); -static MASKS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); // https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py pub struct CausalMasker; @@ -93,14 +88,8 @@ impl CausalMasker { if tgt_len == 1 { return Ok(None); } - let res = MASKS - .lock() - .unwrap() - .get(&(b_sz, tgt_len, past_kv_len)) - .cloned(); - let causal_mask = if let Some(mask) = res { - return Ok(Some(mask)); - } else { + + let causal_mask = { let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; let mask = mask .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? @@ -119,10 +108,6 @@ impl CausalMasker { f32::NEG_INFINITY, )?; - MASKS - .lock() - .unwrap() - .insert((b_sz, tgt_len, past_kv_len), mask.clone()); Ok(mask) }); let mask: Option = if let Some(mask) = causal_mask { @@ -150,14 +135,8 @@ impl CausalMasker { if tgt_len == 1 { return Ok(None); } - let res = MASKS - .lock() - .unwrap() - .get(&(b_sz, tgt_len, past_kv_len)) - .cloned(); - let causal_mask = if let Some(mask) = res { - return Ok(Some(mask)); - } else { + + let causal_mask = { let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; let diagonal = past_kv_len as isize - sliding_window as isize - 1; let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; @@ -180,10 +159,6 @@ impl CausalMasker { f32::NEG_INFINITY, )?; - MASKS - .lock() - .unwrap() - .insert((b_sz, tgt_len, past_kv_len), mask.clone()); Ok(mask) }); let mask: Option = if let Some(mask) = causal_mask { @@ -209,25 +184,13 @@ impl CausalMasker { if tgt_len == 1 { return Ok(None); } - let res = MASKS - .lock() - .unwrap() - .get(&(b_sz, tgt_len, past_kv_len)) - .cloned(); - if let Some(mask) = res { - Ok(Some(mask)) - } else { - let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? - .to_dtype(DType::U8)?; - MASKS - .lock() - .unwrap() - .insert((b_sz, tgt_len, past_kv_len), mask.clone()); - Ok(Some(mask)) - } + let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; + let mask = mask + .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + .to_dtype(DType::U8)?; + + Ok(Some(mask)) } #[deprecated( @@ -251,28 +214,16 @@ impl CausalMasker { if tgt_len == 1 { return Ok(None); } - let res = MASKS - .lock() - .unwrap() - .get(&(b_sz, tgt_len, past_kv_len)) - .cloned(); - if let Some(mask) = res { - Ok(Some(mask)) - } else { - let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; - let diagonal = past_kv_len as isize - sliding_window as isize - 1; - let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; - let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?; - let mask = mask - .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? - .to_dtype(DType::U8)?; - MASKS - .lock() - .unwrap() - .insert((b_sz, tgt_len, past_kv_len), mask.clone()); - Ok(Some(mask)) - } + let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; + let diagonal = past_kv_len as isize - sliding_window as isize - 1; + let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; + let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?; + let mask = mask + .expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))? + .to_dtype(DType::U8)?; + + Ok(Some(mask)) } pub fn apply_mask_one_and_zero(