Skip to content

Commit

Permalink
Remove some unnecessary &muts (#473)
Browse files Browse the repository at this point in the history
* Remove some &muts from the cache traits

* Remove from model forward methods
  • Loading branch information
EricLBuehler committed Jun 24, 2024
1 parent 3febef2 commit a1e41aa
Show file tree
Hide file tree
Showing 29 changed files with 156 additions and 164 deletions.
12 changes: 6 additions & 6 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl Attention {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -277,7 +277,7 @@ impl DecoderLayer {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -374,7 +374,7 @@ impl Model {
}

pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -389,7 +389,7 @@ impl Model {
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
Expand Down Expand Up @@ -430,7 +430,7 @@ impl IsqModel for Model {

impl NormalModel for Model {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -445,7 +445,7 @@ impl NormalModel for Model {
)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ pub struct Llama {

impl Llama {
pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand Down Expand Up @@ -374,7 +374,7 @@ impl IsqModel for Llama {

impl NormalModel for Llama {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -389,7 +389,7 @@ impl NormalModel for Llama {
)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
14 changes: 7 additions & 7 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl Attention {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -238,7 +238,7 @@ impl DecoderLayer {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -356,7 +356,7 @@ impl Model {
}

pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -372,7 +372,7 @@ impl Model {
}

pub fn forward_embeds(
&mut self,
&self,
input_ids: &Tensor,
input_embeds: Tensor,
seqlen_offsets: &[usize],
Expand All @@ -388,7 +388,7 @@ impl Model {
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
Expand Down Expand Up @@ -429,7 +429,7 @@ impl IsqModel for Model {

impl NormalModel for Model {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -444,7 +444,7 @@ impl NormalModel for Model {
)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
12 changes: 6 additions & 6 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Attention {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -339,7 +339,7 @@ impl DecoderLayer {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -443,7 +443,7 @@ impl Model {
}

pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -458,7 +458,7 @@ impl Model {
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
Expand Down Expand Up @@ -502,7 +502,7 @@ impl IsqModel for Model {

impl NormalModel for Model {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -517,7 +517,7 @@ impl NormalModel for Model {
)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
12 changes: 6 additions & 6 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Attention {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -253,7 +253,7 @@ impl DecoderLayer {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -347,7 +347,7 @@ impl Model {
}

pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -361,7 +361,7 @@ impl Model {
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
Expand Down Expand Up @@ -400,7 +400,7 @@ impl IsqModel for Model {

impl NormalModel for Model {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -415,7 +415,7 @@ impl NormalModel for Model {
)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
12 changes: 6 additions & 6 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl Attention {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -254,7 +254,7 @@ impl DecoderLayer {
}

fn forward(
&mut self,
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -348,7 +348,7 @@ impl Model {
}

pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
position_ids: &[usize],
Expand All @@ -364,7 +364,7 @@ impl Model {
self.layers[0].self_attn.num_heads,
)?;

for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl IsqModel for Model {

impl NormalModel for Model {
fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
_start_offsets_kernel: Tensor,
Expand All @@ -412,7 +412,7 @@ impl NormalModel for Model {
self.forward(input_ids, seqlen_offsets, &position_ids, context_lens)
}
fn xlora_forward(
&mut self,
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct LayerWeights {

impl LayerWeights {
fn forward_attn(
&mut self,
&self,
x: &Tensor,
mask: Option<&Tensor>,
start_offsets: &[usize],
Expand Down Expand Up @@ -484,7 +484,7 @@ impl ModelConfig::FromGGUF for ModelWeights {

impl ModelWeights {
pub fn forward(
&mut self,
&self,
x: &Tensor,
start_offsets: &[usize],
start_offsets_kernel: Tensor,
Expand All @@ -498,7 +498,7 @@ impl ModelWeights {
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
layer_in = mapper.map(layer_in, i)?;
}
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl LayerWeights {
}

fn forward_attn(
&mut self,
&self,
x: &Tensor,
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -268,7 +268,7 @@ impl ModelConfig::FromGGUF for ModelWeights {

impl ModelWeights {
pub fn forward(
&mut self,
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
Expand All @@ -281,7 +281,7 @@ impl ModelWeights {
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
let residual = &xs;
let xs_norm = xs.apply(&layer.attn_norm)?;
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl LayerWeights {
}

fn forward_attn(
&mut self,
&self,
x: &Tensor,
mask: Option<&Tensor>,
seqlen_offsets: &[usize],
Expand Down Expand Up @@ -303,7 +303,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
}

impl ModelWeights {
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
pub fn forward(&self, input_ids: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
let (_b_sz, seq_len) = input_ids.dims2()?;
let mut xs = self.tok_embeddings.forward(input_ids)?;
let mut cache = self.cache.lock();
Expand All @@ -314,7 +314,7 @@ impl ModelWeights {
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
xs = mapper.map(xs, i)?;
}
Expand Down
Loading

0 comments on commit a1e41aa

Please sign in to comment.