Skip to content

Commit

Permalink
docs: docs.rs content
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Oct 2, 2023
1 parent 9ddf708 commit 31d60a9
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fastembed"
version = "0.4.0"
version = "0.5.0"
edition = "2021"
description = "Rust implementation of https://github.com/qdrant/fastembed"
license = "MIT"
Expand Down
152 changes: 117 additions & 35 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,80 @@
//! [FastEmbed](https://github.com/Anush008/fastembed-rs) - Fast, light, accurate library built for retrieval embedding generation.
//!
//! The library provides the FlagEmbedding struct to interface with the Flag embedding models.
//!
//! ### Instantiating [FlagEmbedding](crate::FlagEmbedding)
//! ```
//! use fastembed::{FlagEmbedding, InitOptions, EmbeddingModel, EmbeddingBase};
//!
//!# fn model_demo() -> anyhow::Result<()> {
//! // With default InitOptions
//! let model: FlagEmbedding = FlagEmbedding::try_new(Default::default())?;
//!
//! // List all supported models
//! dbg!(FlagEmbedding::list_supported_models());
//!
//! // With custom InitOptions
//! let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
//! model_name: EmbeddingModel::BGEBaseEN,
//! show_download_message: false,
//! ..Default::default()
//! })?;
//! # Ok(())
//! # }
//! ```
//! Find more info about the available options in the [InitOptions](crate::InitOptions) documentation.
//!
//! ### Embeddings generation
//!```
//!# use fastembed::{FlagEmbedding, InitOptions, EmbeddingModel, EmbeddingBase};
//!# fn embedding_demo() -> anyhow::Result<()> {
//!# let model: FlagEmbedding = FlagEmbedding::try_new(Default::default())?;
//! let documents = vec![
//! "passage: Hello, World!",
//! "query: Hello, World!",
//! "passage: This is an example passage.",
//! // You can leave out the prefix but it's recommended
//! "fastembed-rs is licensed under MIT"
//! ];
//!
//! // Generate embeddings with the default batch size, 256
//! let embeddings = model.embed(documents, None)?;
//!
//! println!("Embeddings length: {}", embeddings.len()); // -> Embeddings length: 4
//! # Ok(())
//! # }
//! ```
//!
//! ### Generate query and passage embeddings
//!```
//!# use fastembed::{FlagEmbedding, InitOptions, EmbeddingModel, EmbeddingBase};
//!# fn query_passage_demo() -> anyhow::Result<()> {
//!# let model: FlagEmbedding = FlagEmbedding::try_new(Default::default())?;
//! let passages = vec![
//! "This is the first passage. It contains provides more context for retrieval.",
//! "Here's the second passage, which is longer than the first one. It includes additional information.",
//! "And this is the third passage, the longest of all. It contains several sentences and is meant for more extensive testing."
//! ];
//!
//! // Generate embeddings for the passages
//! // The texts are prefixed with "passage" for better results
//! // The batch size is set to 1 for demonstration purposes
//! let embeddings = model.passage_embed(passages, Some(1))?;
//!
//! println!("Passage embeddings length: {}", embeddings.len()); // -> Embeddings length: 3
//!
//! let query = "What is the answer to this generic question?";
//!
//! // Generate embeddings for the query
//! // The text is prefixed with "query" for better retrieval
//! let query_embedding = model.query_embed(query)?;
//!
//! println!("Query embedding dimension: {}", query_embedding.len()); // -> Query embedding dimension: 768
//! # Ok(())
//! # }
//! ```
//!

use std::{
path::{Path, PathBuf},
thread::available_parallelism,
Expand All @@ -20,7 +97,9 @@ const DEFAULT_MAX_LENGTH: usize = 512;
const DEFAULT_CACHE_DIR: &str = "local_cache";
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallEN;

/// Type alias for the embedding vector
pub type Embedding = Vec<f32>;

type Tokenizer = tokenizers::TokenizerImpl<
tokenizers::ModelWrapper,
tokenizers::NormalizerWrapper,
Expand All @@ -29,6 +108,7 @@ type Tokenizer = tokenizers::TokenizerImpl<
tokenizers::DecoderWrapper,
>;

/// Enum for the available models
#[derive(Debug, Clone)]
pub enum EmbeddingModel {
/// Sentence Transformer model, MiniLM-L6-v2
Expand All @@ -52,6 +132,7 @@ impl ToString for EmbeddingModel {
}
}

/// Options for initializing the FlagEmbedding model
#[derive(Debug, Clone)]
pub struct InitOptions {
pub model_name: EmbeddingModel,
Expand All @@ -73,15 +154,20 @@ impl Default for InitOptions {
}
}

/// Data struct about the available models
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub model: EmbeddingModel,
pub dim: usize,
pub description: String,
}

/// Base class for implemnting an embedding model
pub trait EmbeddingBase<S: AsRef<str>> {
/// The base embedding method for generating senytence embeddings
/// The base embedding method for generating sentence embeddings
fn embed(&self, texts: Vec<S>, batch_size: Option<usize>) -> Result<Vec<Embedding>>;

/// List the supported models by fastembed-rs
fn list_supported_models() -> Vec<ModelInfo>;

/// Generate sentence embeddings for passages, prefixed with "passage"
/// Generate sentence embeddings for passages, pre-fixed with "passage"
fn passage_embed(&self, texts: Vec<S>, batch_size: Option<usize>) -> Result<Vec<Embedding>>;

/// Generate embeddings for user queries pre-fixed with "query"
Expand Down Expand Up @@ -191,9 +277,35 @@ impl FlagEmbedding {

Ok(output_path)
}

/// Retrieve a list of supported modelsc
pub fn list_supported_models() -> Vec<ModelInfo> {
vec![ModelInfo {
model: EmbeddingModel::AllMiniLML6V2,
dim: 384,
description: String::from("Sentence Transformer model, MiniLM-L6-v2"),
},
ModelInfo {
model: EmbeddingModel::BGEBaseEN,
dim: 768,
description: String::from("Base English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallEN,
dim: 384,
description: String::from("Fast and Default English model"),
},
ModelInfo {
model: EmbeddingModel::MLE5Large,
dim: 1024,
description: String::from("Multilingual model, e5-large. Recommend using this model for non-English languages."),
}
]
}
}

/// EmbeddingBase implementation for FlagEmbedding
///
/// Generic type to accept String, &str, OsString, &OsStr
impl<S: AsRef<str> + Send + Sync> EmbeddingBase<S> for FlagEmbedding {
// Method to generate sentence embeddings for a Vec of str refs
Expand Down Expand Up @@ -294,36 +406,6 @@ impl<S: AsRef<str> + Send + Sync> EmbeddingBase<S> for FlagEmbedding {
let query_embedding = self.embed(vec![&query], None);
Ok(query_embedding?[0].to_owned())
}

fn list_supported_models() -> Vec<ModelInfo> {
vec![ModelInfo {
model: EmbeddingModel::AllMiniLML6V2,
dim: 384,
description: String::from("Sentence Transformer model, MiniLM-L6-v2"),
},
ModelInfo {
model: EmbeddingModel::BGEBaseEN,
dim: 768,
description: String::from("Base English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallEN,
dim: 384,
description: String::from("Fast and Default English model"),
},
ModelInfo {
model: EmbeddingModel::MLE5Large,
dim: 1024,
description: String::from("Multilingual model, e5-large. Recommend using this model for non-English languages."),
}
]
}
}

pub struct ModelInfo {
pub model: EmbeddingModel,
pub dim: usize,
pub description: String,
}

fn normalize(v: &mut [f32]) -> Vec<f32> {
Expand Down

0 comments on commit 31d60a9

Please sign in to comment.