-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6903445
commit e95f31a
Showing
4 changed files
with
954 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
use candle_lora::LoraConfig; | ||
use candle_lora_transformers::bert::{BertModel, Config, DTYPE}; | ||
|
||
use anyhow::{Error as E, Result}; | ||
use candle_core::{Tensor, Var}; | ||
use candle_nn::{VarBuilder, VarMap}; | ||
use clap::Parser; | ||
use hf_hub::{api::sync::Api, Repo, RepoType}; | ||
use tokenizers::{PaddingParams, Tokenizer}; | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending | ||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long)] | ||
revision: Option<String>, | ||
|
||
/// When set, compute embeddings for this prompt. | ||
#[arg(long)] | ||
prompt: Option<String>, | ||
|
||
/// Use the pytorch weights rather than the safetensors ones | ||
#[arg(long)] | ||
use_pth: bool, | ||
|
||
/// The number of times to run the prompt. | ||
#[arg(long, default_value = "1")] | ||
n: usize, | ||
|
||
/// L2 normalization for embeddings. | ||
#[arg(long, default_value = "true")] | ||
normalize_embeddings: bool, | ||
} | ||
|
||
impl Args { | ||
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { | ||
let device = candle_examples::device(self.cpu)?; | ||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); | ||
let default_revision = "refs/pr/21".to_string(); | ||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { | ||
(Some(model_id), Some(revision)) => (model_id, revision), | ||
(Some(model_id), None) => (model_id, "main".to_string()), | ||
(None, Some(revision)) => (default_model, revision), | ||
(None, None) => (default_model, default_revision), | ||
}; | ||
|
||
let repo = Repo::with_revision(model_id, RepoType::Model, revision); | ||
let (config_filename, tokenizer_filename, weights_filename) = { | ||
let api = Api::new()?; | ||
let api = api.repo(repo); | ||
let config = api.get("config.json")?; | ||
let tokenizer = api.get("tokenizer.json")?; | ||
let weights = if self.use_pth { | ||
api.get("pytorch_model.bin")? | ||
} else { | ||
api.get("model.safetensors")? | ||
}; | ||
(config, tokenizer, weights) | ||
}; | ||
let config = std::fs::read_to_string(config_filename)?; | ||
let config: Config = serde_json::from_str(&config)?; | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
|
||
let map = VarMap::new(); | ||
if self.use_pth { | ||
let mut ws = map.data().lock().unwrap(); | ||
|
||
let tensors = candle_core::pickle::PthTensors::new(&weights_filename)?; | ||
for (name, _) in tensors.tensor_infos() { | ||
let tensor = tensors | ||
.get(&name)? | ||
.expect("Tensor not found") | ||
.to_device(&device)? | ||
.to_dtype(DTYPE)?; | ||
ws.insert(name.to_string(), Var::from_tensor(&tensor)?); | ||
} | ||
} else { | ||
let mut ws = map.data().lock().unwrap(); | ||
|
||
let tensors = | ||
unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[weights_filename])? }; | ||
for (name, _) in tensors.tensors() { | ||
let tensor = tensors | ||
.load(&name, &device)? | ||
.to_device(&device)? | ||
.to_dtype(DTYPE)?; | ||
ws.insert(name, Var::from_tensor(&tensor)?); | ||
} | ||
}; | ||
|
||
let vb = VarBuilder::from_varmap(&map, DTYPE, &device); | ||
|
||
let loraconfig = LoraConfig::new(1, 1., None); | ||
let model = BertModel::load(vb, &config, true, loraconfig)?; | ||
Ok((model, tokenizer)) | ||
} | ||
} | ||
|
||
fn main() -> Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
println!("tracing..."); | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
let start = std::time::Instant::now(); | ||
|
||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?; | ||
let device = &model.device; | ||
|
||
if let Some(prompt) = args.prompt { | ||
let tokenizer = tokenizer | ||
.with_padding(None) | ||
.with_truncation(None) | ||
.map_err(E::msg)?; | ||
let tokens = tokenizer | ||
.encode(prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; | ||
let token_type_ids = token_ids.zeros_like()?; | ||
println!("Loaded and encoded {:?}", start.elapsed()); | ||
for idx in 0..args.n { | ||
let start = std::time::Instant::now(); | ||
let ys = model.forward(&token_ids, &token_type_ids)?; | ||
if idx == 0 { | ||
println!("{ys}"); | ||
} | ||
println!("Took {:?}", start.elapsed()); | ||
} | ||
} else { | ||
let sentences = [ | ||
"The cat sits outside", | ||
"A man is playing guitar", | ||
"I love pasta", | ||
"The new movie is awesome", | ||
"The cat plays in the garden", | ||
"A woman watches TV", | ||
"The new movie is so great", | ||
"Do you like pizza?", | ||
]; | ||
let n_sentences = sentences.len(); | ||
if let Some(pp) = tokenizer.get_padding_mut() { | ||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest | ||
} else { | ||
let pp = PaddingParams { | ||
strategy: tokenizers::PaddingStrategy::BatchLongest, | ||
..Default::default() | ||
}; | ||
tokenizer.with_padding(Some(pp)); | ||
} | ||
let tokens = tokenizer | ||
.encode_batch(sentences.to_vec(), true) | ||
.map_err(E::msg)?; | ||
let token_ids = tokens | ||
.iter() | ||
.map(|tokens| { | ||
let tokens = tokens.get_ids().to_vec(); | ||
Ok(Tensor::new(tokens.as_slice(), device)?) | ||
}) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
let token_ids = Tensor::stack(&token_ids, 0)?; | ||
let token_type_ids = token_ids.zeros_like()?; | ||
println!("running inference on batch {:?}", token_ids.shape()); | ||
let embeddings = model.forward(&token_ids, &token_type_ids)?; | ||
println!("generated embeddings {:?}", embeddings.shape()); | ||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; | ||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; | ||
let embeddings = if args.normalize_embeddings { | ||
normalize_l2(&embeddings)? | ||
} else { | ||
embeddings | ||
}; | ||
println!("pooled embeddings {:?}", embeddings.shape()); | ||
|
||
let mut similarities = vec![]; | ||
for i in 0..n_sentences { | ||
let e_i = embeddings.get(i)?; | ||
for j in (i + 1)..n_sentences { | ||
let e_j = embeddings.get(j)?; | ||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; | ||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; | ||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; | ||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); | ||
similarities.push((cosine_similarity, i, j)) | ||
} | ||
} | ||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); | ||
for &(score, i, j) in similarities[..5].iter() { | ||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { | ||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) | ||
} |
Oops, something went wrong.