diff --git a/assets/images/training_speed_vs_score.png b/assets/images/training_speed_vs_score.png new file mode 100644 index 00000000..3781220b Binary files /dev/null and b/assets/images/training_speed_vs_score.png differ diff --git a/model2vec/train/README.md b/model2vec/train/README.md index f7a7d61e..87365fc1 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -106,3 +106,7 @@ The core functionality of the `StaticModelForClassification` is contained in a c * `fit`: contains all the lightning-related fitting logic. The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. + +# Results + +We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation. diff --git a/results/README.md b/results/README.md index 6072fa92..85120694 100644 --- a/results/README.md +++ b/results/README.md @@ -70,43 +70,56 @@ As can be seen, [potion-retrieval-32M](https://huggingface.co/minishlab/potion-r The main results for Model2Vec training are outlined in this section. -We compare three different architectures: +We compare five different architectures for our main results: - `model2vec + logreg`: A model2vec model with a scikit-learn `LogisticRegressionCV` on top. - `model2vec full finetune`: A model2vec classifier with the full model finetuned. This uses our `StaticModelForClassification`. +- `tfidf`: A TF-IDF model with a scikit-learn `LogisticRegressionCV` on top. - `setfit`: A [SetFit](https://github.com/huggingface/setfit/tree/main) model trained using [all-minilm-l6-v2](sentence-transformers/all-MiniLM-L6-v2) as a base model. +- `bge-base + logreg`: A [BGE-base](https://huggingface.co/BAAI/bge-base-en-v1.5) encoder model with a scikit-learn `LogisticRegressionCV` on top. We use 14 classification datasets, using 1000 examples from the train set, and the full test set. No parameters were tuned on any validation set. All datasets were taken from the [Setfit organization on Hugging Face](https://huggingface.co/datasets/SetFit). -| dataset | model2vec + logreg | model2vec full finetune | setfit | -|:---------------------------|----------------------------------------------:|---------------------------------------:|-------------------------------------------------:| -| 20_newgroups | 56.24 | 57.94 | 61.29 | -| ade | 79.2 | 79.68 | 83.05 | -| ag_news | 86.7 | 87.2 | 88.01 | -| amazon_counterfactual | 90.96 | 91.93 | 95.51 | -| bbc | 95.8 | 97.21 | 96.6 | -| emotion | 65.57 | 67.11 | 72.86 | -| enron_spam | 96.4 | 96.85 | 97.45 | -| hatespeech_offensive | 83.54 | 85.61 | 87.69 | -| imdb | 85.34 | 85.59 | 86 | -| massive_scenario | 82.86 | 84.42 | 83.54 | -| senteval_cr | 77.03 | 79.47 | 86.15 | -| sst5 | 32.34 | 37.95 | 42.31 | -| student | 83.2 | 85.02 | 89.62 | -| subj | 89.2 | 89.85 | 93.8 | -| tweet_sentiment_extraction | 64.96 | 62.65 | 75.15 | - -| | logreg | full finetune | setfit -|:---------------------------|-----------:|---------------:|-------:| -| average | 77.9 | 79.2 | 82.6 | +| dataset | tfidf | model2vec + logreg | model2vec full finetune | setfit | bge-base + logreg | +|:---------------------------|--------:|---------------------:|--------------------------:|---------:|--------------------:| +| 20_newgroups | 50.71 | 56.24 | 57.94 | 61.29 | 67.39 | +| ade | 71.46 | 79.20 | 79.68 | 83.05 | 86.12 | +| ag_news | 81.68 | 86.70 | 87.20 | 88.01 | 88.95 | +| amazon_counterfactual | 85.18 | 90.96 | 91.93 | 95.51 | 92.74 | +| bbc | 95.09 | 95.80 | 97.21 | 96.60 | 97.50 | +| emotion | 59.28 | 65.57 | 67.11 | 72.86 | 65.63 | +| enron_spam | 96.00 | 96.40 | 96.85 | 97.45 | 97.30 | +| hatespeech_offensive | 66.45 | 83.54 | 85.61 | 87.69 | 84.92 | +| imdb | 80.44 | 85.34 | 85.59 | 86.00 | 92.25 | +| massive_scenario | 77.26 | 82.86 | 84.42 | 83.54 | 87.07 | +| senteval_cr | 65.61 | 77.03 | 79.47 | 86.15 | 90.53 | +| sst5 | 18.52 | 32.34 | 37.95 | 42.31 | 38.49 | +| student | 74.16 | 83.20 | 85.02 | 89.62 | 89.71 | +| subj | 86.39 | 89.20 | 89.85 | 93.80 | 94.55 | +| tweet_sentiment_extraction | 53.20 | 64.96 | 62.65 | 75.15 | 69.48 | + + +| | tfidf | model2vec + logreg | model2vec full finetune | setfit | bge-base + logreg | +|:--------|--------:|---------------------:|--------------------------:|---------:|--------------------:| +| average | 70.8 | 78.0 | 79.2 | 82.6 | 82.8 | + + + As can be seen see, full fine-tuning brings modest performance improvements in some cases, but very large ones in other cases, leading to a pretty large increase in average score. Our advice is to test both if you can use `potion-base-32m`, and to use full fine-tuning if you are starting from another base model. -The speed difference between model2vec and setfit is immense, with the full finetune being 35x faster than a setfit based on `all-minilm-l6-v2` on CPU. +The speed difference between model2vec and the other models is immeense, with the full finetune being 35x faster than a setfit based on `all-minilm-l6-v2` on CPU and 200x faster than the`bge-base` transformer model. -| | logreg | full finetune | setfit -|:---------------------------|-----------:|---------------:|-------:| -| samples / second | 17925 | 24744 | 716 | +| | tfidf | model2vec + logreg | model2vec full finetune | setfit | bge-base + logreg | +|:-----------------|--------:|---------------------:|--------------------------:|---------:|--------------------:| +| samples / second | 108434 | 17925 | 24744 | 716 | 118 | + + +The figure below shows the relationship between the number of sentences per second and the average training score, where we've included more transformer-based models for comparison. + +| ![Description](../assets/images/training_speed_vs_score.png) | +|:--:| +|*Figure: The average training score plotted against sentences per second (log scale).*| ## Ablations