diff --git a/README.md b/README.md index abd261e..7d896f1 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ converted: - `llama` - `mistral` - `falcon` +- `bert` ```rust use candle_core::{DType, Device, Module, Result, Tensor}; diff --git a/candle-lora-transformers/examples/bert.rs b/candle-lora-transformers/examples/bert.rs new file mode 100644 index 0000000..e1d32ee --- /dev/null +++ b/candle-lora-transformers/examples/bert.rs @@ -0,0 +1,223 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use candle_lora::LoraConfig; +use candle_lora_transformers::bert::{BertModel, Config, DTYPE}; + +use anyhow::{Error as E, Result}; +use candle_core::{Tensor, Var}; +use candle_nn::{VarBuilder, VarMap}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, + + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let api = api.repo(repo); + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = if self.use_pth { + api.get("pytorch_model.bin")? + } else { + api.get("model.safetensors")? + }; + (config, tokenizer, weights) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let map = VarMap::new(); + if self.use_pth { + let mut ws = map.data().lock().unwrap(); + + let tensors = candle_core::pickle::PthTensors::new(&weights_filename)?; + for (name, _) in tensors.tensor_infos() { + let tensor = tensors + .get(&name)? + .expect("Tensor not found") + .to_device(&device)? + .to_dtype(DTYPE)?; + ws.insert(name.to_string(), Var::from_tensor(&tensor)?); + } + } else { + let mut ws = map.data().lock().unwrap(); + + let tensors = + unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[weights_filename])? }; + for (name, _) in tensors.tensors() { + let tensor = tensors + .load(&name, &device)? + .to_device(&device)? + .to_dtype(DTYPE)?; + ws.insert(name, Var::from_tensor(&tensor)?); + } + }; + + let vb = VarBuilder::from_varmap(&map, DTYPE, &device); + + let loraconfig = LoraConfig::new(1, 1., None); + let model = BertModel::load(vb, &config, true, loraconfig)?; + Ok((model, tokenizer)) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); + + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + + if let Some(prompt) = args.prompt { + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let token_type_ids = token_ids.zeros_like()?; + println!("Loaded and encoded {:?}", start.elapsed()); + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&token_ids, &token_type_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + } else { + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Ok(Tensor::new(tokens.as_slice(), device)?) + }) + .collect::>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + let token_type_ids = token_ids.zeros_like()?; + println!("running inference on batch {:?}", token_ids.shape()); + let embeddings = model.forward(&token_ids, &token_type_ids)?; + println!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; + println!("pooled embeddings {:?}", embeddings.shape()); + + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = embeddings.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = embeddings.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + } + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-lora-transformers/src/bert.rs b/candle-lora-transformers/src/bert.rs new file mode 100644 index 0000000..4594f62 --- /dev/null +++ b/candle-lora-transformers/src/bert.rs @@ -0,0 +1,729 @@ +use candle_core::{DType, Device, Result, Tensor}; +use candle_lora::{ + EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig, +}; +use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; +use candle_nn::{Embedding, Linear, VarBuilder}; +use serde::Deserialize; +use std::ops::Deref; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } + + fn forward(&self, xs: &Tensor) -> candle_core::Result { + let _enter = self.span.enter(); + match self.act { + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug)] +#[replace_layer_fields] +#[derive(AutoLoraConvert)] +pub struct BertLinear { + inner: Linear, + span: tracing::Span, +} + +impl BertLinear { + pub fn new( + vb: VarBuilder, + weight: Tensor, + bias: Option, + merge: bool, + lora_config: LoraConfig, + ) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let dims = weight.dims2().unwrap(); + let linear_config = LoraLinearConfig::new(dims.1, dims.0); + let mut this = Self { + inner: Box::new(Linear::new(weight, bias)), + span, + }; + + if merge { + this.get_merged_lora_model( + lora_config, + &vb.pp("lora_linear"), + Some(linear_config), + None, + None, + None, + ) + } else { + this.get_lora_model( + lora_config, + &vb.pp("lora_linear"), + Some(linear_config), + None, + None, + None, + ) + } + + this + } + + pub fn forward(&self, x: &Tensor) -> candle_core::Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + span: tracing::Span, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Self { + weight, + bias, + eps, + span, + } + } + + pub fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: HiddenAct, + hidden_dropout_prob: f64, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + classifier_dropout: Option, + model_type: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +impl Config { + fn _all_mini_lm_l6_v2() -> Self { + // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json + Self { + vocab_size: 30522, + hidden_size: 384, + num_hidden_layers: 6, + num_attention_heads: 12, + intermediate_size: 1536, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear( + size1: usize, + size2: usize, + vb: VarBuilder, + merge: bool, + lora_config: LoraConfig, +) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(BertLinear::new( + vb.pp("lora_linear"), + weight, + Some(bias), + merge, + lora_config, + )) +} + +struct Dropout { + #[allow(dead_code)] + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +#[replace_layer_fields] +#[derive(AutoLoraConvert)] +struct BertEmbedding { + inner: Embedding, +} + +impl Deref for BertEmbedding { + type Target = Box; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl BertEmbedding { + fn new( + vb: VarBuilder, + vocab_size: usize, + hidden_size: usize, + merge: bool, + lora_config: LoraConfig, + ) -> Result { + let mut this = Self { + inner: Box::new(embedding(vocab_size, hidden_size, vb.clone())?), + }; + + let embed_config = LoraEmbeddingConfig::new(vocab_size, hidden_size); + + if merge { + this.get_merged_lora_model( + lora_config, + &vb.pp("lora_embed"), + None, + None, + None, + Some(embed_config), + ) + } else { + this.get_lora_model( + lora_config, + &vb.pp("lora_embed"), + None, + None, + None, + Some(embed_config), + ) + } + + Ok(this) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 +struct BertEmbeddings { + word_embeddings: BertEmbedding, + position_embeddings: Option, + token_type_embeddings: BertEmbedding, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertEmbeddings { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let word_embeddings = BertEmbedding::new( + vb.pp("word_embeddings"), + config.vocab_size, + config.hidden_size, + merge, + lora_config.clone(), + )?; + let position_embeddings = BertEmbedding::new( + vb.pp("position_embeddings"), + config.max_position_embeddings, + config.hidden_size, + merge, + lora_config.clone(), + )?; + let token_type_embeddings = BertEmbedding::new( + vb.pp("token_type_embeddings"), + config.type_vocab_size, + config.hidden_size, + merge, + lora_config.clone(), + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + + Ok(Self { + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, + dropout: Dropout::new(config.hidden_dropout_prob), + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + // TODO: Proper absolute positions? + let position_ids = (0..seq_len as u32).collect::>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? + } + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct BertSelfAttention { + query: BertLinear, + key: BertLinear, + value: BertLinear, + dropout: Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl BertSelfAttention { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = linear( + hidden_size, + all_head_size, + vb.pp("query"), + merge, + lora_config.clone(), + )?; + let value = linear( + hidden_size, + all_head_size, + vb.pp("value"), + merge, + lora_config.clone(), + )?; + let key = linear( + hidden_size, + all_head_size, + vb.pp("key"), + merge, + lora_config.clone(), + )?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)? + }; + let attention_probs = self.dropout.forward(&attention_probs)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?; + Ok(context_layer) + } +} + +struct BertSelfOutput { + dense: BertLinear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertSelfOutput { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let dense = linear( + config.hidden_size, + config.hidden_size, + vb.pp("dense"), + merge, + lora_config, + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, + span: tracing::Span, +} + +impl BertAttention { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let self_attention = + BertSelfAttention::load(vb.pp("self"), config, merge, lora_config.clone())?; + let self_output = BertSelfOutput::load(vb.pp("output"), config, merge, lora_config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +struct BertIntermediate { + dense: BertLinear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl BertIntermediate { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let dense = linear( + config.hidden_size, + config.intermediate_size, + vb.pp("dense"), + merge, + lora_config, + )?; + Ok(Self { + dense, + intermediate_act: HiddenActLayer::new(config.hidden_act), + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +struct BertOutput { + dense: BertLinear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertOutput { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let dense = linear( + config.intermediate_size, + config.hidden_size, + vb.pp("dense"), + merge, + lora_config, + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +struct BertLayer { + attention: BertAttention, + intermediate: BertIntermediate, + output: BertOutput, + span: tracing::Span, +} + +impl BertLayer { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let attention = + BertAttention::load(vb.pp("attention"), config, merge, lora_config.clone())?; + let intermediate = + BertIntermediate::load(vb.pp("intermediate"), config, merge, lora_config.clone())?; + let output = BertOutput::load(vb.pp("output"), config, merge, lora_config.clone())?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct BertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl BertEncoder { + fn load(vb: VarBuilder, config: &Config, merge: bool, lora_config: LoraConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| { + BertLayer::load( + vb.pp(&format!("layer.{index}")), + config, + merge, + lora_config.clone(), + ) + }) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states)? + } + Ok(hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +pub struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pub device: Device, + span: tracing::Span, +} + +impl BertModel { + pub fn load( + vb: VarBuilder, + config: &Config, + merge: bool, + lora_config: LoraConfig, + ) -> Result { + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config, merge, lora_config.clone()), + BertEncoder::load(vb.pp("encoder"), config, merge, lora_config.clone()), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load( + vb.pp(&format!("{model_type}.embeddings")), + config, + merge, + lora_config.clone(), + ), + BertEncoder::load( + vb.pp(&format!("{model_type}.encoder")), + config, + merge, + lora_config, + ), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) + } +} diff --git a/candle-lora-transformers/src/lib.rs b/candle-lora-transformers/src/lib.rs index bfd9447..6051643 100644 --- a/candle-lora-transformers/src/lib.rs +++ b/candle-lora-transformers/src/lib.rs @@ -1,5 +1,6 @@ pub mod with_tracing; +pub mod bert; pub mod falcon; pub mod llama; pub mod mistral;