Skip to content

Commit

Permalink
docs: fixing the active learning tutorial with small-text (#1726)
Browse files Browse the repository at this point in the history
* docs: fixing the active learning tutorial with `small-text`

* docs: using a tiny model

* docs: Change tutorial title

* docs: Change active learning title in card

(cherry picked from commit f4f2289)

Closes #1693
  • Loading branch information
frascuchon committed Oct 5, 2022
1 parent 58e1a5b commit 909efdf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/active-learning.md
Expand Up @@ -11,7 +11,7 @@ These tutorials show you how to create Active Learning workflows with Rubrix.
Build an active learning prototype with Rubrix, ModAL and scikit-learn to filter spam from the YouTube Spam Collection dataset.
```
```{grid-item-card} 👂 Learn actively, and listen carefully to small-text
```{grid-item-card} 👂 Active learning for text classification with small-text
:img-top: ../_static/tutorials/active_learning_with_small_text/screenshot.png
:link: active_learning_with_small_text.ipynb
Expand Down
26 changes: 13 additions & 13 deletions docs/tutorials/active_learning_with_small_text.ipynb
Expand Up @@ -5,7 +5,7 @@
"id": "e7b5a1fa-9265-4fb8-9c0e-a9156c22f4ff",
"metadata": {},
"source": [
"# 👂 Learn actively, and listen carefully to small-text"
"# 👂 Active learning for text classification with small-text"
]
},
{
Expand Down Expand Up @@ -111,7 +111,7 @@
"source": [
"import datasets\n",
"\n",
"trec = datasets.load_dataset('trec')"
"trec = datasets.load_dataset('trec', revision=\"bc790b9ce61d4c2b1ea9622cd65da40182725a61\")"
]
},
{
Expand Down Expand Up @@ -143,7 +143,7 @@
"from transformers import AutoTokenizer\n",
"\n",
"# Choose transformer model\n",
"TRANSFORMER_MODEL = \"google/electra-small-discriminator\"\n",
"TRANSFORMER_MODEL = \"prajjwal1/bert-tiny\"\n",
"\n",
"# Init tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL)\n",
Expand Down Expand Up @@ -173,15 +173,14 @@
"outputs": [],
"source": [
"from small_text.integrations.transformers import TransformersDataset\n",
"from small_text.base import LABEL_UNLABELED\n",
"\n",
"# Set convenient output format \n",
"trec_tokenized.set_format(\"torch\")\n",
"\n",
"# Create tuples from the tokenized training data\n",
"data = [\n",
" # Need to add an extra dimension to indicate a batch size of 1 -> [None]\n",
" (row[\"input_ids\"][None], row[\"attention_mask\"][None], LABEL_UNLABELED) \n",
" (row[\"input_ids\"][None], row[\"attention_mask\"][None], int(row[\"coarse_label\"])) \n",
" for row in trec_tokenized[\"train\"]\n",
"]\n",
"\n",
Expand All @@ -206,7 +205,7 @@
"source": [
"# Create test dataset\n",
"data_test = [\n",
" (row[\"input_ids\"][None], row[\"attention_mask\"][None], int(row[\"label-coarse\"])) \n",
" (row[\"input_ids\"][None], row[\"attention_mask\"][None], int(row[\"coarse_label\"])) \n",
" for row in trec_tokenized[\"test\"]\n",
"]\n",
"dataset_test = TransformersDataset(data_test)"
Expand Down Expand Up @@ -254,7 +253,7 @@
" num_classes=6,\n",
" # If you have a cuda device, specify it here.\n",
" # Otherwise, just remove the following line.\n",
" kwargs={\"device\": \"cuda\"}\n",
" # kwargs={\"device\": \"cuda\"}\n",
")\n",
"\n",
"# Define our query strategy\n",
Expand Down Expand Up @@ -319,12 +318,11 @@
"source": [
"import rubrix as rb\n",
"\n",
"\n",
"# Choose a name for the dataset\n",
"DATASET_NAME = \"trec_with_active_learning\"\n",
"\n",
"# Define labeling schema\n",
"labels = trec[\"train\"].features[\"label-coarse\"].names\n",
"labels = trec[\"train\"].features[\"coarse_label\"].names\n",
"settings = rb.TextClassificationSettings(label_schema=labels)\n",
"\n",
"# Create dataset with a label schema\n",
Expand Down Expand Up @@ -367,7 +365,7 @@
"from sklearn.metrics import accuracy_score\n",
"\n",
"# Define some helper variables\n",
"LABEL2INT = trec[\"train\"].features[\"label-coarse\"].str2int\n",
"LABEL2INT = trec[\"train\"].features[\"coarse_label\"].str2int\n",
"ACCURACIES = []\n",
"\n",
"# Set up the active learning loop with the listener decorator\n",
Expand Down Expand Up @@ -396,11 +394,11 @@
" # 2. Query active learner\n",
" print(\"Querying new data points ...\")\n",
" queried_indices = active_learner.query(num_samples=NUM_SAMPLES)\n",
" ctx.query_params[\"batch_id\"] += 1\n",
" new_batch = ctx.query_params[\"batch_id\"] + 1\n",
" new_records = [\n",
" rb.TextClassificationRecord(\n",
" text=trec[\"train\"][\"text\"][idx], \n",
" metadata={\"batch_id\": ctx.query_params[\"batch_id\"]},\n",
" metadata={\"batch_id\": new_batch},\n",
" id=idx,\n",
" ) \n",
" for idx in queried_indices\n",
Expand All @@ -415,9 +413,11 @@
" dataset_test.y, \n",
" active_learner.classifier.predict(dataset_test),\n",
" )\n",
" \n",
" ACCURACIES.append(accuracy)\n",
" ctx.query_params[\"batch_id\"] = new_batch\n",
" print(\"Done!\")\n",
" \n",
"\n",
" print(\"Waiting for annotations ...\")"
]
},
Expand Down

0 comments on commit 909efdf

Please sign in to comment.