Skip to content

Commit

Permalink
feat: New FlagEmbedding models (#5)
Browse files Browse the repository at this point in the history
* feat: new models

* chore: formatting

* chore: default to fast-bge-small-en-v1.5

* docs: Default model README.md

* chore: update list_supported_models()
  • Loading branch information
Anush008 committed Oct 18, 2023
1 parent d4dd141 commit 335937b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🤖 Models

- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en)
- [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5)
- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en)
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5)
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large)

Expand Down
158 changes: 74 additions & 84 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, TruncationParams};
const DEFAULT_BATCH_SIZE: usize = 256;
const DEFAULT_MAX_LENGTH: usize = 512;
const DEFAULT_CACHE_DIR: &str = "local_cache";
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallEN;
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15;

/// Type alias for the embedding vector
pub type Embedding = Vec<f32>;
Expand All @@ -116,8 +116,14 @@ pub enum EmbeddingModel {
AllMiniLML6V2,
/// Base English model
BGEBaseEN,
/// v1.5 release of the Base English model
BGEBaseENV15,
/// Fast and Default English model
BGESmallEN,
/// v1.5 release of the BGESmallEN model
BGESmallENV15,
/// v1.5 release of the Fast Chinese model
BGESmallZH,
/// Multilingual model, e5-large. Recommend using this model for non-English languages.
MLE5Large,
}
Expand All @@ -127,7 +133,10 @@ impl ToString for EmbeddingModel {
match self {
EmbeddingModel::AllMiniLML6V2 => String::from("fast-all-MiniLM-L6-v2"),
EmbeddingModel::BGEBaseEN => String::from("fast-bge-base-en"),
EmbeddingModel::BGEBaseENV15 => String::from("fast-bge-base-en-v1.5"),
EmbeddingModel::BGESmallEN => String::from("fast-bge-small-en"),
EmbeddingModel::BGESmallENV15 => String::from("fast-bge-small-en-v1.5"),
EmbeddingModel::BGESmallZH => String::from("fast-bge-small-zh-v1.5"),
EmbeddingModel::MLE5Large => String::from("fast-multilingual-e5-large"),
}
}
Expand Down Expand Up @@ -352,10 +361,25 @@ impl FlagEmbedding {
dim: 768,
description: String::from("Base English model"),
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15,
dim: 768,
description: String::from("v1.5 release of the base English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallEN,
dim: 384,
description: String::from("Fast and Default English model"),
description: String::from("Fast English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15,
dim: 384,
description: String::from("v1.5 release of the fast and default English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallZH,
dim: 512,
description: String::from("v1.5 release of the fast and Chinese model"),
},
ModelInfo {
model: EmbeddingModel::MLE5Large,
Expand Down Expand Up @@ -504,93 +528,59 @@ mod tests {
const EPSILON: f32 = 1e-4;

#[test]
fn test_bgesmall() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.022123, -0.01472, 0.039255,
0.034447, 0.004598,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_bgebase() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGEBaseEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.0114, 0.03722, 0.02941, 0.0123, 0.03451, 0.00876, 0.02356, 0.05414, -0.0294, -0.0547,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_allminilm() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::AllMiniLML6V2,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.02591, 0.00573, 0.01147, 0.03796, -0.0232, -0.0549, 0.01404, -0.0107, -0.0244,
-0.01822,
fn test_embeddings() {
let models_and_expected_values = vec![
(
EmbeddingModel::BGESmallEN,
vec![-0.02313, -0.02552, 0.017357, -0.06393, -0.00061],
),
(
EmbeddingModel::BGEBaseEN,
vec![0.0114, 0.03722, 0.02941, 0.0123, 0.03451],
),
(
EmbeddingModel::AllMiniLML6V2,
vec![0.02591, 0.00573, 0.01147, 0.03796, -0.0232],
),
(
EmbeddingModel::MLE5Large,
vec![0.00961, 0.00443, 0.00658, -0.03532, 0.00703],
),
(
EmbeddingModel::BGEBaseENV15,
vec![0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045],
),
(
EmbeddingModel::BGESmallENV15,
vec![0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434],
),
(
EmbeddingModel::BGESmallZH,
vec![-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762],
),
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}
for (model_name, expected) in models_and_expected_values {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: model_name.clone(),
..Default::default()
})
.unwrap();

#[test]
fn test_mle5large() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::MLE5Large,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.00961, 0.00443, 0.00658, -0.03532, 0.00703, -0.02878, -0.03671, 0.03482, 0.06343,
-0.04731,
];
let documents = vec!["hello world"];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();
// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(
difference < EPSILON,
"Difference for {}: {}",
model_name.to_string(),
difference
)
}
}
}
}

0 comments on commit 335937b

Please sign in to comment.