Skip to content

v2.7.0 - CachedGISTEmbedLoss, easy Matryoshka inference & evaluation, CrossEncoder, Intel Gaudi2

Latest
Compare
Choose a tag to compare
@tomaarsen tomaarsen released this 17 Apr 13:16
· 8 commits to master since this release

This release introduces a new promising loss function, easier inference for Matryoshka models, new functionality for CrossEncoders and Inference on Intel Gaudi2, along much more.

Install this version with

pip install sentence-transformers==2.7.0

New loss function: CachedGISTEmbedLoss (#2592)

For a number of years, MultipleNegativesRankingLoss (also known as SimCSE, InfoNCE, in-batch negatives loss) has been the state of the art in embedding model training. Notably, this loss function performs better with a larger batch size.

Recently, various improvements have been introduced:

  1. CachedMultipleNegativesRankingLoss was introduced, which allows you to pick much higher batch sizes (e.g. 65536) with constant memory.
  2. GISTEmbedLoss takes a guide model to guide the in-batch negative sample selection. This prevents false negatives, resulting in a stronger training signal.

Now, @JacksonCakes has combined these two approaches to produce the best of both worlds: CachedGISTEmbedLoss. This loss function allows for high batch sizes with constant memory usage, while also using a guide model to assist with the in-batch negative sample selection.

As can be seen in our Loss Overview, this model should be used with (anchor, positive) pairs or (anchor, positive, negative) triplets, much like MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss, and GISTEmbedLoss. In short, any example using those loss functions can be updated to use CachedGISTEmbedLoss! Feel free to experiment, e.g. with this training script.

Automatic Matryoshka model truncation (#2573)

Sentence Transformers v2.4.0 introduced Matryoshka models: models whose embeddings are still useful after truncation. Since then, many useful Matryoshka models have been trained.

As of this release, the truncation for these Matryoshka embedding models can be done automatically via a new truncate_dim constructor argument:

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

matryoshka_dim = 64
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, truncate_dim=matryoshka_dim)

embeddings = model.encode(
    [
        "search_query: What is TSNE?",
        "search_document: t-distributed stochastic neighbor embedding (t-SNE) is a statistical method for visualizing high-dimensional data by giving each datapoint a location in a two or three-dimensional map.",
        "search_document: Amelia Mary Earhart was an American aviation pioneer and writer.",
    ]
)
print(embeddings.shape)
# => [3, 64]

similarities = cos_sim(embeddings[0], embeddings[1:])
# => tensor([[0.7839, 0.4933]])

Extra information:

Model truncation in all evaluators (#2582)

Alongside easier inference with Matryoshka models, evaluating them is now also much easier. You can also pass truncate_dim to any Evaluator. This way you can easily check the performance of any Sentence Transformer model at various truncated dimensions (even if the model was not trained with MatryoshkaLoss!)

from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers import SentenceTransformer
import datasets

model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka")

stsb = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

for dim in [768, 512, 256, 128, 64, 32, 16, 8, 4]:
    evaluator = EmbeddingSimilarityEvaluator(
        stsb["sentence1"],
        stsb["sentence2"],
        [score / 5 for score in stsb["score"]],
        name=f"sts-test-{dim}",
        truncate_dim=dim,
    )
    print(f"dim={dim:<3}: {evaluator(model) * 100:.2f} Spearman Correlation")
dim=768: 86.81 Spearman Correlation
dim=512: 86.76 Spearman Correlation
dim=256: 86.66 Spearman Correlation
dim=128: 86.20 Spearman Correlation
dim=64 : 85.40 Spearman Correlation
dim=32 : 82.42 Spearman Correlation
dim=16 : 79.31 Spearman Correlation
dim=8  : 72.82 Spearman Correlation
dim=4  : 63.44 Spearman Correlation

Here are some example training scripts that use this new truncate_dim option to assist with training Matryoshka models:

CrossEncoder improvements

This release improves the support for CrossEncoder reranker models.

push_to_hub (#2524)

You can now push trained CrossEncoder models to the 🤗 Hugging Face Hub!

from sentence_transformers import CrossEncoder

...

model = CrossEncoder("distilroberta-base")

# Train the model
model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
)

model.push_to_hub("tomaarsen/distilroberta-base-stsb-cross-encoder")

trust_remote_code for custom models (#2595)

You can now load custom models from the Hugging Face Hub, i.e. models that have custom modelling code that require trust_remote_code to load.

from sentence_transformers import CrossEncoder

# Note: this model does not require `trust_remote_code=True` - there are currently no models that require it yet.
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", trust_remote_code=True)

# We want to compute the similarity between the query sentence
query = "A man is eating pasta."

# With all sentences in the corpus
corpus = [
    "A man is eating food.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "A man is riding a horse.",
    "A woman is playing violin.",
    "Two men pushed carts through the woods.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
    "A cheetah is running behind its prey.",
]

# We rank all sentences in the corpus for the query
ranks = model.rank(query, corpus)

# Print the scores
print("Query:", query)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")

Inference on Intel Gaudi2 (#2557)

From this release onwards, you will be able to perform inference on Intel Gaudi2 accelerators. No modifications are needed, as the library will automatically detect the hpu device and configure the model accordingly. Thanks to Intel Habana for the support here.

All changes

New Contributors

I especially want to thank @JacksonCakes for their excellent CachedGISTEmbedLoss PR and @kddubey for their wonderful PRs surrounding Matryoshka models and general repository housekeeping.

Full Changelog: v2.6.1...v2.7.0