Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: Adds tutorial about custom few-shot classification with SetFit (#…
- Loading branch information
1 parent
fc39ce4
commit f02682a
Showing
4 changed files
with
296 additions
and
1 deletion.
There are no files selected for viewing
Binary file added
BIN
+592 KB
docs/_static/tutorials/few-shot-classification-with-setfit/setfit-labelled.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
287 changes: 287 additions & 0 deletions
287
docs/tutorials/few-shot-classification-with-setfit.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b15d20dd-31fe-44f9-bba1-b44201dceb91", | ||
"metadata": {}, | ||
"source": [ | ||
"# 🤯 Few-shot classification with SetFit and a custom dataset\n", | ||
"\n", | ||
"**SetFit** is an exciting open-source package for few-shot classification developed by teams at Hugging Face and Intel Labs. You can read all about it on the [project repository](https://github.com/huggingface/setfit). \n", | ||
"\n", | ||
"To showcase how powerful is **the combination of SetFit and Rubrix**:\n", | ||
"\n", | ||
"* We manually **label 55 examples** from the unlabelled split of the imdb dataset, \n", | ||
"* we train a model in **5 min**, \n", | ||
"* and without using a single example from the original imdb training set, we achieve **0.9 accuracy on the full test set!**\n", | ||
"\n", | ||
"\n", | ||
"## Summary\n", | ||
"\n", | ||
"\n", | ||
"In this tutorial, you'll learn to:\n", | ||
"\n", | ||
"1. **Load a unlabelled dataset** in Rubrix. We'll be using the unlabelled split from the `imdb` movie reviews sentiment dataset. This same workflow can be applied to any custom dataset, problem, and language!\n", | ||
"\n", | ||
"2. Manually **label a FEW examples** using the UI.\n", | ||
"\n", | ||
"3. **Train a SetFit model** to get highly competitive results. For this example, with **only 55 examples**, we get **0.9 accuracy** on the test set which is comparable to models fine-tuned on 3K examples. That means similar performance with `50x` less examples 🤯. \n", | ||
"\n", | ||
"For reference see the [Hugging Face Hub](https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=imdb) and [PapersWithCode](https://paperswithcode.com/sota/text-classification-on-imdb) leaderboards.\n", | ||
"\n", | ||
"Let's get started!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7d44029c-c54e-474e-927b-a85749817f36", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup Rubrix\n", | ||
"\n", | ||
"Rubrix is a free and **open-source data labeling framework for NLP**. \n", | ||
"\n", | ||
"To get started on your local machine, you just need three steps:\n", | ||
"\n", | ||
"1. Install the library:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"id": "e6c26c6c-b844-4fa8-a225-28104d81d995", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install rubrix[server]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "92443b12-70a9-4acc-92f2-40c22e711a1f", | ||
"metadata": {}, | ||
"source": [ | ||
"2. Install and launch [Elasticsearch](https://rubrix.readthedocs.io/en/stable/getting_started/setup%26installation.html#launch-the-web-app).\n", | ||
"\n", | ||
"3. Launch the server and the UI from your terminal or notebook:\n", | ||
"\n", | ||
"```bash\n", | ||
"python -m rubrix\n", | ||
"```\n", | ||
"\n", | ||
"🎉 If everything went well, you can go to https://localhost:6900 and login using the default user/password: `rubrix/1234`.\n", | ||
"\n", | ||
"🆘 If you need help you can join our [Slack channel](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) to get inmediate support." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ae0a45ae-10c8-43bb-8bb8-45a3792031b0", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup `SetFit` and `datasets` libraries" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "dd97e294-7b35-431f-87af-92fd42e849f1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install setfit datasets -qqq" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "91a2d475-12d3-49a3-9cd7-a021f795f97d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"from sentence_transformers.losses import CosineSimilarityLoss\n", | ||
"\n", | ||
"from setfit import SetFitModel, SetFitTrainer\n", | ||
"\n", | ||
"import rubrix as rb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8eb457bf-f703-4242-9f0d-d6dea9817dbf", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load unlabelled dataset in Rubrix\n", | ||
"\n", | ||
"First, we load the `unsupervised` split from the `imdb` dataset and create a new Rubrix dataset with 100 random examples:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b7ccfc09-ba4c-4881-838f-7186897bc1e9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"unlabelled = load_dataset(\"imdb\", split=\"unsupervised\").shuffle(seed=42).select(range(100))\n", | ||
"\n", | ||
"unlabelled = rb.DatasetForTextClassification.from_datasets(unlabelled)\n", | ||
"\n", | ||
"rb.log(unlabelled, \"imdb_unlabelled\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "574c9c3c-c841-43da-b5f0-16947cb261a4", | ||
"metadata": {}, | ||
"source": [ | ||
"## Manual labelling\n", | ||
"\n", | ||
"In this step, we create the labels `pos` and `neg` using the same label scheme as the original dataset. Then we use the UI to sequentially label a few examples. For the example, we spent literally 15 minutes.\n", | ||
"\n", | ||
"Watch the video below to get a sense of the steps and time you need to replicate the results.\n", | ||
"\n", | ||
"<video width=\"100%\" controls><source src=\"../_static/tutorials/few-shot-classification-with-setfit/setfit.mp4\" type=\"video/mp4\"></video>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ccff9d86-55d4-42ca-ae32-1428113a5684", | ||
"metadata": {}, | ||
"source": [ | ||
"Before training, you can easily share the dataset using the `push_to_hub` method. This might be useful if you don't have a GPU on your machine and want to use a training service or Colab for example." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "dffdb095-b546-4cdf-999c-00c59583866f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rb.load(\"imdb_unlabelled\").prepare_for_training().push_to_hub(\"mini-imdb\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6960b9bd-49fb-4f79-a37c-ee386ad37b13", | ||
"metadata": {}, | ||
"source": [ | ||
"The dataset is available on the [HF hub](https://huggingface.co/datasets/dvilasuero/mini-imdb). You can see the summary in the UI below:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2ed84bd6-50af-4a61-a7bf-4f1a4719f773", | ||
"metadata": {}, | ||
"source": [ | ||
"![Labelled_dataset](../_static/tutorials/few-shot-classification-with-setfit/setfit-labelled.png)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2c8dbd6a-7dd1-438e-9081-d8152b4b8a00", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train and evaluate SetFit model\n", | ||
"\n", | ||
"Finally, we are ready to test SetFit! \n", | ||
"\n", | ||
"Thanks to Rubrix's integration with `datasets` and the Hub, if you don't have a local GPU you can use this [Google Colab](https://colab.research.google.com/drive/166TrSY0aJfKYi8U9qWilaXN2b2-nGlVD?usp=sharing) to reproduce the training process with the labelled dataset. If you use a GPU runtime, it literally takes 5 minutes to train.\n", | ||
"\n", | ||
"Below we load the dataset from Rubrix, format it for training with transformers, load the full imbd test dataset, load a pre-trained sentence transformers model, train the SetFit model, and evaluate it!\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6d558541-6ab1-4917-9e39-26bd75849567", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the handlabelled dataset from Rubrix\n", | ||
"train_ds = rb.load(\"imdb_unlabelled\").prepare_for_training()\n", | ||
"\n", | ||
"# Load the full imdb test dataset\n", | ||
"test_ds = load_dataset(\"imdb\", split=\"test\")\n", | ||
"\n", | ||
"\n", | ||
"# Load SetFit model from Hub\n", | ||
"model = SetFitModel.from_pretrained(\"sentence-transformers/paraphrase-mpnet-base-v2\")\n", | ||
"\n", | ||
"# Create trainer\n", | ||
"trainer = SetFitTrainer(\n", | ||
" model=model,\n", | ||
" train_dataset=train_ds,\n", | ||
" eval_dataset=test_ds,\n", | ||
" loss_class=CosineSimilarityLoss,\n", | ||
" batch_size=16,\n", | ||
" num_iterations=20, # The number of text pairs to generate\n", | ||
")\n", | ||
"\n", | ||
"# Train and evaluate\n", | ||
"trainer.train()\n", | ||
"metrics = trainer.evaluate()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c2175cb8-65c9-4c8c-900f-03b0d187627d", | ||
"metadata": {}, | ||
"source": [ | ||
"Optionally, you can share your amazing model with the world!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2aec58bf-5871-41e6-81df-292237d6150b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.push_to_hub(\"setfit-mini-imdb\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e4e9b078-7da9-4c92-a0d7-a6143d227400", | ||
"metadata": {}, | ||
"source": [ | ||
"## Conclusion\n", | ||
"\n", | ||
"The metrics object should give you around 0.9 accuracy on the full test set 🎉 \n", | ||
"\n", | ||
"And remember:\n", | ||
"\n", | ||
"- We have manually labelled 55 examples, \n", | ||
"- We haven't used a single example from the original training set, \n", | ||
"- and we've trained the model in 5 min!\n", | ||
"\n", | ||
"Now, I don't think you have any more excuses to not invest some time labeling a few good quality examples!" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters