diff --git a/docs/_source/_static/datasets/weak_supervision_tutorial/labeling_rules.csv b/docs/_source/_static/datasets/weak_supervision_tutorial/labeling_rules.csv new file mode 100644 index 0000000000..d535a4f55a --- /dev/null +++ b/docs/_source/_static/datasets/weak_supervision_tutorial/labeling_rules.csv @@ -0,0 +1,6 @@ +,query,label +0,your,SPAM +1,rich,SPAM +2,film,HAM +3,meeting,HAM +4,help,HAM diff --git a/docs/_source/guides/techniques/weak_supervision.ipynb b/docs/_source/guides/techniques/weak_supervision.ipynb index fda9e7c44f..7437c3df85 100644 --- a/docs/_source/guides/techniques/weak_supervision.ipynb +++ b/docs/_source/guides/techniques/weak_supervision.ipynb @@ -10,7 +10,7 @@ "\n", "This guide gives you a brief introduction to weak supervision with Argilla.\n", "\n", - "Argilla currently supports weak supervision for multi-class text classification use cases, but we'll be adding support for multilabel text classification and token classification (e.g., Named Entity Recognition) soon.\n", + "Argilla currently supports weak supervision for **multi-class** and **multi-label** text classification use cases. Support for token classification (e.g., Named Entity Recognition) will be added soon.\n", "\n", "![Labeling workflow](../../_static/images/guides/weak_supervision/weak_supervision.png \"Labeling workflow\")" ] @@ -25,8 +25,8 @@ "The recommended workflow for weak supervision is:\n", "\n", "- Log an unlabelled dataset into Argilla\n", - "- Use the `Annotate` mode for hand- and/or bulk-labelling a test set. This test is key to measure the quality and performance of your rules.\n", - "- Use the `Define rules` mode for testing and defining rules. Rules are defined with search queries (using ES query string DSL).\n", + "- Use the `Annotate` mode for hand- and/or bulk-labelling a validation set. This validation is key to measure the quality and performance of your rules. Additionally, you need to build a test set which is not used for defining rules. This test set will be used to measure the performance of your end model, as with any other supervised model.\n", + "- Use the `Define rules` mode for evaluating and defining rules. Rules are defined with search queries (using ES query string DSL). Additionally, you can use the Python client methods to add, delete, or modify rules programmatically, making them available for refinement in the UI.\n", "- Use the Python client for reading rules, defining additional rules if needed, and train a label (for building a training set) or a downstream model (for building an end classifier).\n", "\n", "The next sections cover the main components of this workflow. \n", @@ -34,12 +34,11 @@ "### Weak labeling using the UI\n", "\n", "Since version 0.8.0 you can find and define rules directly in the UI. \n", - "The [Define rules mode](../../reference/webapp/pages.html#metrics) is found in the right side bar of the [Dataset page](../../reference/webapp/pages.html#dataset).\n", - "The video below shows how you can interactively find and save rules with the UI. \n", + "The [Define rules mode](../../reference/webapp/pages.html#modes) is found in the right side bar of the [Dataset page](../../reference/webapp/pages.html#dataset).\n", "\n", "### Weak supervision from Python\n", "\n", - "Doing weak supervision with Argilla should be straightforward. Keeping the same spirit as other parts of the library, you can virtually use any weak supervision library or method, such as Snorkel or Flyingsquid. \n", + "Doing weak supervision with Argilla is straightforward. Keeping the same spirit as other parts of the library, you can use any weak supervision library or method, such as Snorkel or Flyingsquid. \n", "\n", "Argilla weak supervision support is built around two basic abstractions:\n", "\n", @@ -72,6 +71,13 @@ "\n", "A rule should either return a string value, that is a weak label, or a `None` type in case of abstention.\n", "\n", + "These rules can be:\n", + "\n", + "1. Defined using the no-code feature of the UI (see the [Define rules mode](../../reference/webapp/pages.html#modes) reference).\n", + "2. `Rule` objects can be created using Python as shown above. These objects can be either applied locally by developers (which might be interested for testing without overloading the server) or added to the dataset in the Argilla server, making these rules available from the UI. \n", + "3. Python functions cannot be defined with the no-code feature and can only be applied locally but not added to the dataset in the Argilla server. Data teams can use these Python labelling functions to add extra heuristics before building a weakly labelled dataset. This functions should be used for heuristics which are not possible to define using ES queries.\n", + "\n", + "\n", "\n", "### `Weak Labels`\n", "\n", @@ -104,8 +110,8 @@ "A typical workflow to use weak supervision is:\n", "\n", "1. Create a Argilla dataset with your raw dataset. If you actually have some labelled data you can log it into the the same dataset.\n", - "2. Define a set of weak labeling rules with the Rules definition mode in the UI.\n", - "3. Create a `WeakLabels` object and apply the rules. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, you'll iterate between this step and step 2.\n", + "2. Define a set of weak labeling rules with the Rules definition mode in the UI or using the Python client `add_rules` method.\n", + "3. Create a `WeakLabels` object and apply the rules using the Python client. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, you'll iterate between this step and step 2.\n", "4. Once you are satisfied with your weak labels, use the matrix of the `WeakLabels` instance with your library/method of choice to build a training set or even train a downstream text classification model.\n", "\n", "\n", @@ -232,8 +238,8 @@ "import pandas as pd\n", "\n", "# load data\n", - "train_df = pd.read_csv(\"../tutorials/data/yt_comments_train.csv\")\n", - "test_df = pd.read_csv(\"../tutorials/data/yt_comments_test.csv\")\n", + "train_df = pd.read_csv(\"../../tutorials/notebooks/data/yt_comments_train.csv\")\n", + "test_df = pd.read_csv(\"../../tutorials/notebooks/data/yt_comments_test.csv\")\n", "\n", "# preview data\n", "train_df.head()\n" @@ -278,7 +284,7 @@ "]\n", "\n", "# log records to Argilla\n", - "rg.log(records, name=\"weak_supervision_yt\")\n" + "rg.log(records, name=\"weak_supervision_yt\")" ] }, { @@ -294,13 +300,21 @@ "id": "dde95ce0-6e1e-4c9e-aff2-ac12643b9a48", "metadata": {}, "source": [ - "## 2. Defining rules\n", + "## 2. Define and manage rules\n", "\n", "Let's now define some of the rules proposed in the tutorial [Snorkel Intro Tutorial: Data Labeling](https://www.snorkel.org/use-cases/01-spam-tutorial). \n", - "Most of these rules can be defined directly with our web app in the [Define rules mode](../../reference/webapp/define_rules.md) and [Elasticsearch's query strings](../../reference/webapp/features.html#search-records). \n", + "\n", + "Rules in Argilla can be defined and used in several ways, In particular: (1) using the UI, (2) using the Python client to add rules to the server, and (3) using the Python client to add additional rules locally, either using Python functions or Rule objects.\n", + "\n", + "### Define rules using the UI\n", + "\n", + "Rules can be defined directly with our web app in the [Define rules mode](../../reference/webapp/define_rules.md) and [Elasticsearch's query strings](../../reference/webapp/features.html#search-records). \n", + "\n", "Afterward, you can conveniently load them into your notebook with the [load_rules function](../../reference/python/python_labeling.rst).\n", "\n", - "Rules can also be defined programmatically as shown below. Depending on your use case and team structure you can mix and match both interfaces (UI or Python).\n", + "### Define rules using the Python client\n", + "\n", + "Rules can also be defined programmatically as shown below. Depending on your use case and team structure you can mix and match both interfaces (UI or Python). Depending on your workflow, you can decide wether to use the `add_rules` method to add them to the dataset, or just apply them locally (without adding them to the Argilla dataset).\n", "\n", "Let's see here some programmatic rules:" ] @@ -356,6 +370,124 @@ " )\n" ] }, + { + "cell_type": "markdown", + "id": "9bf6eba2", + "metadata": {}, + "source": [ + "You can load your predefined rules and convert them to Rule instances, and add them to dataset " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fa36602", + "metadata": {}, + "outputs": [], + "source": [ + "labeling_rules_df = pd.read_csv(\"../../_static/datasets/weak_supervision_tutorial/labeling_rules.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3de8bf31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Unnamed: 0querylabel
00yourSPAM
11richSPAM
22filmHAM
33meetingHAM
44helpHAM
\n", + "
" + ], + "text/plain": [ + " Unnamed: 0 query label\n", + "0 0 your SPAM\n", + "1 1 rich SPAM\n", + "2 2 film HAM\n", + "3 3 meeting HAM\n", + "4 4 help HAM" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# preview labeling rules\n", + "labeling_rules_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0712bb49", + "metadata": {}, + "outputs": [], + "source": [ + "predefined_labeling_rules = []\n", + "for index, row in labeling_rules_df.iterrows():\n", + " predefined_labeling_rules.append(\n", + " Rule(row[\"query\"], row[\"label\"])\n", + " )" + ] + }, { "cell_type": "markdown", "id": "405c93ec-b136-43cf-af50-96c956b65f12", @@ -371,7 +503,7 @@ "metadata": {}, "outputs": [], "source": [ - "from argilla.labeling.text_classification import load_rules\n", + "from argilla.labeling.text_classification import load_rules, add_rules, delete_rules\n", "\n", "# bundle our rules in a list\n", "rules = [\n", @@ -380,22 +512,51 @@ " subscribe,\n", " my,\n", " song,\n", - " love,\n", + " love\n", + "]\n", + "\n", + "labeling_functions = [ \n", " contains_http,\n", " short_comment,\n", - " regex_check_out,\n", + " regex_check_out\n", "]\n", "\n", - "# optionally add the rules defined in the web app UI\n", - "rules += load_rules(dataset=\"weak_supervision_yt\")\n", + "# add rules to dataset\n", + "add_rules(dataset=\"weak_supervision_yt\", rules=rules)\n", + "\n", "\n", - "# apply the rules to a dataset to obtain the weak labels\n", - "weak_labels = WeakLabels(rules=rules, dataset=\"weak_supervision_yt\")\n" + "# add the predefined rules loaded from external file\n", + "add_rules(dataset=\"weak_supervision_yt\", rules=predefined_labeling_rules)\n" + ] + }, + { + "cell_type": "markdown", + "id": "40c8a439-5389-4329-86f8-e16d28eafb16", + "metadata": {}, + "source": [ + "After the above step, the rules will be accesible in the `weak_supervision_yt` dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d2cfec3", + "metadata": {}, + "outputs": [], + "source": [ + "# load all the rules available in the dataset including interactively defined in the UI \n", + "dataset_labeling_rules = load_rules(dataset=\"weak_supervision_yt\")\n", + "\n", + "# extend the labeling rules with labeling functions\n", + "dataset_labeling_rules.extend(labeling_functions)\n", + "\n", + "# apply the final rules to the dataset\n", + "weak_labels = WeakLabels(dataset=\"weak_supervision_yt\", rules=dataset_labeling_rules)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 27, "id": "f0584b8b-4b0d-4857-9e8b-9823d9172634", "metadata": {}, "outputs": [ @@ -434,32 +595,32 @@ " \n", " check out\n", " {SPAM}\n", - " 0.242919\n", - " 0.180\n", - " 0.235839\n", - " 0.029956\n", - " 45\n", + " 0.224401\n", + " 0.176\n", + " 0.224401\n", + " 0.031590\n", + " 44\n", " 0\n", " 1.000000\n", " \n", " \n", " plz OR please\n", " {SPAM}\n", - " 0.090414\n", - " 0.080\n", - " 0.081155\n", - " 0.019608\n", - " 20\n", + " 0.104575\n", + " 0.088\n", + " 0.098039\n", + " 0.036492\n", + " 22\n", " 0\n", " 1.000000\n", " \n", " \n", " subscribe\n", " {SPAM}\n", - " 0.106754\n", + " 0.101852\n", " 0.120\n", - " 0.083878\n", - " 0.028867\n", + " 0.082244\n", + " 0.031590\n", " 30\n", " 0\n", " 1.000000\n", @@ -467,43 +628,98 @@ " \n", " my\n", " {SPAM}\n", - " 0.190632\n", - " 0.188\n", - " 0.166667\n", - " 0.049564\n", - " 41\n", + " 0.192810\n", + " 0.192\n", + " 0.168845\n", + " 0.062636\n", + " 42\n", " 6\n", - " 0.872340\n", + " 0.875000\n", " \n", " \n", " song\n", " {HAM}\n", - " 0.132898\n", - " 0.192\n", - " 0.079521\n", - " 0.033769\n", - " 39\n", + " 0.118192\n", + " 0.172\n", + " 0.070806\n", + " 0.037037\n", + " 34\n", " 9\n", - " 0.812500\n", + " 0.790698\n", " \n", " \n", " love\n", " {HAM}\n", - " 0.092048\n", + " 0.090959\n", " 0.140\n", - " 0.070261\n", - " 0.031590\n", + " 0.071351\n", + " 0.034858\n", " 28\n", " 7\n", " 0.800000\n", " \n", " \n", + " your\n", + " {SPAM}\n", + " 0.052832\n", + " 0.088\n", + " 0.041939\n", + " 0.019608\n", + " 19\n", + " 3\n", + " 0.863636\n", + " \n", + " \n", + " rich\n", + " {SPAM}\n", + " 0.000545\n", + " 0.000\n", + " 0.000000\n", + " 0.000000\n", + " 0\n", + " 0\n", + " NaN\n", + " \n", + " \n", + " film\n", + " {}\n", + " 0.000000\n", + " 0.000\n", + " 0.000000\n", + " 0.000000\n", + " 0\n", + " 0\n", + " NaN\n", + " \n", + " \n", + " meeting\n", + " {}\n", + " 0.000000\n", + " 0.000\n", + " 0.000000\n", + " 0.000000\n", + " 0\n", + " 0\n", + " NaN\n", + " \n", + " \n", + " help\n", + " {HAM}\n", + " 0.027778\n", + " 0.036\n", + " 0.023965\n", + " 0.023965\n", + " 0\n", + " 9\n", + " 0.000000\n", + " \n", + " \n", " contains_http\n", " {SPAM}\n", " 0.106209\n", " 0.024\n", - " 0.073529\n", - " 0.049564\n", + " 0.078431\n", + " 0.055556\n", " 6\n", " 0\n", " 1.000000\n", @@ -513,7 +729,7 @@ " {HAM}\n", " 0.245098\n", " 0.368\n", - " 0.110566\n", + " 0.101307\n", " 0.064270\n", " 84\n", " 8\n", @@ -525,7 +741,7 @@ " 0.226580\n", " 0.180\n", " 0.226035\n", - " 0.027778\n", + " 0.032135\n", " 45\n", " 0\n", " 1.000000\n", @@ -533,13 +749,315 @@ " \n", " total\n", " {SPAM, HAM}\n", - " 0.754902\n", - " 0.836\n", - " 0.448802\n", - " 0.120915\n", - " 338\n", + " 0.762527\n", + " 0.880\n", + " 0.458061\n", + " 0.147059\n", + " 354\n", + " 42\n", + " 0.893939\n", + " \n", + " \n", + "\n", + "" + ], + "text/plain": [ + " label coverage annotated_coverage overlaps \\\n", + "check out {SPAM} 0.224401 0.176 0.224401 \n", + "plz OR please {SPAM} 0.104575 0.088 0.098039 \n", + "subscribe {SPAM} 0.101852 0.120 0.082244 \n", + "my {SPAM} 0.192810 0.192 0.168845 \n", + "song {HAM} 0.118192 0.172 0.070806 \n", + "love {HAM} 0.090959 0.140 0.071351 \n", + "your {SPAM} 0.052832 0.088 0.041939 \n", + "rich {SPAM} 0.000545 0.000 0.000000 \n", + "film {} 0.000000 0.000 0.000000 \n", + "meeting {} 0.000000 0.000 0.000000 \n", + "help {HAM} 0.027778 0.036 0.023965 \n", + "contains_http {SPAM} 0.106209 0.024 0.078431 \n", + "short_comment {HAM} 0.245098 0.368 0.101307 \n", + "regex_check_out {SPAM} 0.226580 0.180 0.226035 \n", + "total {SPAM, HAM} 0.762527 0.880 0.458061 \n", + "\n", + " conflicts correct incorrect precision \n", + "check out 0.031590 44 0 1.000000 \n", + "plz OR please 0.036492 22 0 1.000000 \n", + "subscribe 0.031590 30 0 1.000000 \n", + "my 0.062636 42 6 0.875000 \n", + "song 0.037037 34 9 0.790698 \n", + "love 0.034858 28 7 0.800000 \n", + "your 0.019608 19 3 0.863636 \n", + "rich 0.000000 0 0 NaN \n", + "film 0.000000 0 0 NaN \n", + "meeting 0.000000 0 0 NaN \n", + "help 0.023965 0 9 0.000000 \n", + "contains_http 0.055556 6 0 1.000000 \n", + "short_comment 0.064270 84 8 0.913043 \n", + "regex_check_out 0.032135 45 0 1.000000 \n", + "total 0.147059 354 42 0.893939 " + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# show some stats about the rules, see the `summary()` docstring for details\n", + "weak_labels.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "dcf263ad", + "metadata": {}, + "source": [ + "You can remove the rules which are wrong from the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49a6806f", + "metadata": {}, + "outputs": [], + "source": [ + "not_informative_rules = [\n", + " Rule(\"rich\", \"SPAM\"),\n", + " Rule(\"film\", \"HAM\"),\n", + " Rule(\"meeting\", \"HAM\")\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e33e8b7", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import delete_rules\n", + "delete_rules(dataset=\"weak_supervision_yt\", rules=not_informative_rules)" + ] + }, + { + "cell_type": "markdown", + "id": "cd384863", + "metadata": {}, + "source": [ + "You can update the rule:\n", + " \n", + " help\t{HAM}\t0.027778\t0.036\t0.023965\t0.023965\t0\t9\t0.000000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e664fb92", + "metadata": {}, + "outputs": [], + "source": [ + "help_rule = Rule(\"help\", label=\"SPAM\")\n", + "help_rule.update_at_dataset(dataset=\"weak_supervision_yt\")" + ] + }, + { + "cell_type": "markdown", + "id": "d8d9d3f9", + "metadata": {}, + "source": [ + "Lets load the rules again and apply weak labelling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07c52c5a", + "metadata": {}, + "outputs": [], + "source": [ + "final_rules = labeling_functions + load_rules(dataset=\"weak_supervision_yt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b9bc0df", + "metadata": {}, + "outputs": [], + "source": [ + "weak_labels = WeakLabels(dataset=\"weak_supervision_yt\", rules=final_rules)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "dbb7978a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
labelcoverageannotated_coverageoverlapsconflictscorrectincorrectprecision
contains_http{SPAM}0.1062090.0240.0784310.049020601.000000
short_comment{HAM}0.2450980.3680.1013070.0642708480.913043
regex_check_out{SPAM}0.2265800.1800.2260350.0277784501.000000
check out{SPAM}0.2244010.1760.2244010.0277784401.000000
plz OR please{SPAM}0.1045750.0880.0980390.0234202201.000000
subscribe{SPAM}0.1018520.1200.0822440.025054300.91847801.000000
my{SPAM}0.1928100.1920.1688450.0506544260.875000
song{HAM}0.1181920.1720.0708060.0370373490.790698
love{HAM}0.0909590.1400.0713510.0348582870.800000
your{SPAM}0.0528320.0880.0419390.0157951930.863636
help{SPAM}0.0277780.0360.0239650.003813901.000000
total{SPAM, HAM}0.7619830.8800.4580610.126906363330.916667
\n", @@ -547,38 +1065,41 @@ ], "text/plain": [ " label coverage annotated_coverage overlaps \\\n", - "check out {SPAM} 0.242919 0.180 0.235839 \n", - "plz OR please {SPAM} 0.090414 0.080 0.081155 \n", - "subscribe {SPAM} 0.106754 0.120 0.083878 \n", - "my {SPAM} 0.190632 0.188 0.166667 \n", - "song {HAM} 0.132898 0.192 0.079521 \n", - "love {HAM} 0.092048 0.140 0.070261 \n", - "contains_http {SPAM} 0.106209 0.024 0.073529 \n", - "short_comment {HAM} 0.245098 0.368 0.110566 \n", + "contains_http {SPAM} 0.106209 0.024 0.078431 \n", + "short_comment {HAM} 0.245098 0.368 0.101307 \n", "regex_check_out {SPAM} 0.226580 0.180 0.226035 \n", - "total {SPAM, HAM} 0.754902 0.836 0.448802 \n", + "check out {SPAM} 0.224401 0.176 0.224401 \n", + "plz OR please {SPAM} 0.104575 0.088 0.098039 \n", + "subscribe {SPAM} 0.101852 0.120 0.082244 \n", + "my {SPAM} 0.192810 0.192 0.168845 \n", + "song {HAM} 0.118192 0.172 0.070806 \n", + "love {HAM} 0.090959 0.140 0.071351 \n", + "your {SPAM} 0.052832 0.088 0.041939 \n", + "help {SPAM} 0.027778 0.036 0.023965 \n", + "total {SPAM, HAM} 0.761983 0.880 0.458061 \n", "\n", " conflicts correct incorrect precision \n", - "check out 0.029956 45 0 1.000000 \n", - "plz OR please 0.019608 20 0 1.000000 \n", - "subscribe 0.028867 30 0 1.000000 \n", - "my 0.049564 41 6 0.872340 \n", - "song 0.033769 39 9 0.812500 \n", - "love 0.031590 28 7 0.800000 \n", - "contains_http 0.049564 6 0 1.000000 \n", + "contains_http 0.049020 6 0 1.000000 \n", "short_comment 0.064270 84 8 0.913043 \n", "regex_check_out 0.027778 45 0 1.000000 \n", - "total 0.120915 338 30 0.918478 " + "check out 0.027778 44 0 1.000000 \n", + "plz OR please 0.023420 22 0 1.000000 \n", + "subscribe 0.025054 30 0 1.000000 \n", + "my 0.050654 42 6 0.875000 \n", + "song 0.037037 34 9 0.790698 \n", + "love 0.034858 28 7 0.800000 \n", + "your 0.015795 19 3 0.863636 \n", + "help 0.003813 9 0 1.000000 \n", + "total 0.126906 363 33 0.916667 " ] }, - "execution_count": 6, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# show some stats about the rules, see the `summary()` docstring for details\n", - "weak_labels.summary()\n" + "weak_labels.summary()" ] }, { @@ -612,7 +1133,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "7eb42d90-b2ec-4b26-ae0c-165b26a87458", "metadata": {}, "outputs": [], @@ -634,7 +1155,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 35, "id": "5acfce87-56e7-4f2f-8572-de797a22938d", "metadata": {}, "outputs": [ @@ -644,12 +1165,12 @@ "text": [ " precision recall f1-score support\n", "\n", - " SPAM 0.99 0.92 0.95 89\n", - " HAM 0.94 0.99 0.96 111\n", + " SPAM 0.99 0.93 0.96 102\n", + " HAM 0.94 0.99 0.96 108\n", "\n", - " accuracy 0.96 200\n", - " macro avg 0.96 0.96 0.96 200\n", - "weighted avg 0.96 0.96 0.96 200\n", + " accuracy 0.96 210\n", + " macro avg 0.96 0.96 0.96 210\n", + "weighted avg 0.96 0.96 0.96 210\n", "\n" ] } @@ -732,7 +1253,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 38, "id": "250d800e-85ba-4cc5-9fcf-a7105f94fbf7", "metadata": {}, "outputs": [ @@ -764,27 +1285,27 @@ " \n", " \n", " 0\n", - " Hi I'm lil m !!! Check out love the way yo...\n", + " http://www.rtbf.be/tv/emission/detail_the-voic...\n", " SPAM\n", " \n", " \n", " 1\n", - " LADIES!!! ----->> If you have a broken h...\n", + " http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht...\n", " SPAM\n", " \n", " \n", " 2\n", - " Love these guys, love the song!\n", + " Perfect! <3\n", " HAM\n", " \n", " \n", " 3\n", - " She's awesome XD\n", - " HAM\n", + " Check out Melbourne shuffle, everybody!\n", + " SPAM\n", " \n", " \n", " 4\n", - " go check out our video\n", + " Check out my videos guy! :) Hope you guys had ...\n", " SPAM\n", " \n", " \n", @@ -793,53 +1314,53 @@ " ...\n", " \n", " \n", - " 1050\n", - " Nice\n", + " 1048\n", + " Great song\n", " HAM\n", " \n", " \n", - " 1051\n", - " all u should go check out j rants vi about eminem\n", - " SPAM\n", + " 1049\n", + " subscribe\n", + " HAM\n", " \n", " \n", - " 1052\n", - " Check out this playlist on YouTube:\n", - " SPAM\n", + " 1050\n", + " LoL\n", + " HAM\n", " \n", " \n", - " 1053\n", - " just came to check the view count\n", - " SPAM\n", + " 1051\n", + " Love this song\n", + " HAM\n", " \n", " \n", - " 1054\n", - " Fantastic!!!\n", + " 1052\n", + " LOVE THE WAY YOU LIE .."\n", " HAM\n", " \n", " \n", "\n", - "

1055 rows × 2 columns

\n", + "

1053 rows × 2 columns

\n", "
" ], "text/plain": [ " text label\n", - "0 Hi I'm lil m !!! Check out love the way yo... SPAM\n", - "1 LADIES!!! ----->> If you have a broken h... SPAM\n", - "2 Love these guys, love the song! HAM\n", - "3 She's awesome XD HAM\n", - "4 go check out our video SPAM\n", + "0 http://www.rtbf.be/tv/emission/detail_the-voic... SPAM\n", + "1 http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht... SPAM\n", + "2 Perfect! <3 HAM\n", + "3 Check out Melbourne shuffle, everybody! SPAM\n", + "4 Check out my videos guy! :) Hope you guys had ... SPAM\n", "... ... ...\n", - "1050 Nice HAM\n", - "1051 all u should go check out j rants vi about eminem SPAM\n", - "1052 Check out this playlist on YouTube: SPAM\n", - "1053 just came to check the view count SPAM\n", - "1054 Fantastic!!! HAM\n", + "1048 Great song HAM\n", + "1049 subscribe HAM\n", + "1050 LoL HAM\n", + "1051 Love this song HAM\n", + "1052 LOVE THE WAY YOU LIE .." HAM\n", "\n", - "[1055 rows x 2 columns]" + "[1053 rows x 2 columns]" ] }, - "execution_count": 36, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -910,7 +1431,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 41, "id": "9c82339f-561d-4c59-89ee-c49ddbe8da83", "metadata": {}, "outputs": [ @@ -920,12 +1441,12 @@ "text": [ " precision recall f1-score support\n", "\n", - " SPAM 0.96 0.93 0.94 95\n", - " HAM 0.94 0.96 0.95 114\n", + " SPAM 0.93 0.93 0.93 106\n", + " HAM 0.94 0.94 0.94 114\n", "\n", - " accuracy 0.95 209\n", - " macro avg 0.95 0.95 0.95 209\n", - "weighted avg 0.95 0.95 0.95 209\n", + " accuracy 0.94 220\n", + " macro avg 0.94 0.94 0.94 220\n", + "weighted avg 0.94 0.94 0.94 220\n", "\n" ] } @@ -992,7 +1513,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 44, "id": "4e18b365-5412-47c7-b0e6-e140bd0c2dc6", "metadata": {}, "outputs": [ @@ -1024,28 +1545,28 @@ " \n", " \n", " 0\n", - " Check out Melbourne shuffle, everybody!\n", + " http://www.rtbf.be/tv/emission/detail_the-voic...\n", " SPAM\n", " \n", " \n", " 1\n", - " I love this song\n", - " HAM\n", + " http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht...\n", + " SPAM\n", " \n", " \n", " 2\n", - " I fuckin love this song!<br /><br /><br />Afte...\n", + " Perfect! &lt;3\n", " HAM\n", " \n", " \n", " 3\n", - " Check out this video on YouTube:\n", + " Check out Melbourne shuffle, everybody!\n", " SPAM\n", " \n", " \n", " 4\n", - " Who&#39;s watching in 2015 Subscribe for me !\n", - " SPAM\n", + " Facebook account HACK!! http://hackfbaccountl...\n", + " HAM\n", " \n", " \n", " ...\n", @@ -1053,60 +1574,60 @@ " ...\n", " \n", " \n", - " 1172\n", - " Hey guys! Im a 12 yr old music producer. I mak...\n", - " SPAM\n", + " 1174\n", + " Great song\n", + " HAM\n", " \n", " \n", - " 1173\n", - " Hey, check out my new website!! This site is a...\n", - " SPAM\n", + " 1175\n", + " subscribe\n", + " HAM\n", " \n", " \n", - " 1174\n", - " :3\n", + " 1176\n", + " LoL\n", " HAM\n", " \n", " \n", - " 1175\n", - " Hey! I'm NERDY PEACH and I'm a new youtuber an...\n", - " SPAM\n", + " 1177\n", + " Love this song\n", + " HAM\n", " \n", " \n", - " 1176\n", - " Are those real animals\n", + " 1178\n", + " LOVE THE WAY YOU LIE ..&quot;\n", " HAM\n", " \n", " \n", "\n", - "

1177 rows × 2 columns

\n", + "

1179 rows × 2 columns

\n", "" ], "text/plain": [ " text label\n", - "0 Check out Melbourne shuffle, everybody! SPAM\n", - "1 I love this song HAM\n", - "2 I fuckin love this song!


Afte... HAM\n", - "3 Check out this video on YouTube: SPAM\n", - "4 Who's watching in 2015 Subscribe for me ! SPAM\n", + "0 http://www.rtbf.be/tv/emission/detail_the-voic... SPAM\n", + "1 http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht... SPAM\n", + "2 Perfect! <3 HAM\n", + "3 Check out Melbourne shuffle, everybody! SPAM\n", + "4 Facebook account HACK!! http://hackfbaccountl... HAM\n", "... ... ...\n", - "1172 Hey guys! Im a 12 yr old music producer. I mak... SPAM\n", - "1173 Hey, check out my new website!! This site is a... SPAM\n", - "1174 :3 HAM\n", - "1175 Hey! I'm NERDY PEACH and I'm a new youtuber an... SPAM\n", - "1176 Are those real animals HAM\n", + "1174 Great song HAM\n", + "1175 subscribe HAM\n", + "1176 LoL HAM\n", + "1177 Love this song HAM\n", + "1178 LOVE THE WAY YOU LIE .." HAM\n", "\n", - "[1177 rows x 2 columns]" + "[1179 rows x 2 columns]" ] }, - "execution_count": 49, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# preview training data\n", - "training_data\n" + "training_data" ] }, { @@ -1158,7 +1679,7 @@ "flyingsquid_model = FlyingSquid(weak_labels)\n", "\n", "# we fit the model\n", - "flyingsquid_model.fit()\n" + "flyingsquid_model.fit()" ] }, { @@ -1177,7 +1698,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 47, "id": "2d81dbf5-edf9-4c8c-a876-5fcb16d73090", "metadata": {}, "outputs": [ @@ -1187,12 +1708,12 @@ "text": [ " precision recall f1-score support\n", "\n", - " SPAM 0.93 0.91 0.92 95\n", - " HAM 0.92 0.95 0.94 114\n", + " SPAM 0.92 0.93 0.93 106\n", + " HAM 0.94 0.92 0.93 114\n", "\n", - " accuracy 0.93 209\n", - " macro avg 0.93 0.93 0.93 209\n", - "weighted avg 0.93 0.93 0.93 209\n", + " accuracy 0.93 220\n", + " macro avg 0.93 0.93 0.93 220\n", + "weighted avg 0.93 0.93 0.93 220\n", "\n" ] } @@ -1223,7 +1744,7 @@ "\n", "# accuracy without abstentions: 0.93; accuracy of random classifier: 0.5\n", "print(\"accuracy_c:\", frac_non * 0.93 + frac_abs * 0.5)\n", - "# accuracy_c: 0.85948\n" + "# accuracy_c: 0.85948" ] }, { @@ -1253,12 +1774,12 @@ "# extract training data\n", "training_data = pd.DataFrame(\n", " [{\"text\": rec.text, \"label\": rec.prediction[0][0]} for rec in records_for_training]\n", - ")\n" + ")" ] }, { "cell_type": "code", - "execution_count": 231, + "execution_count": 50, "id": "0d641340-f82b-4af8-b86b-7c06eaf59f61", "metadata": {}, "outputs": [ @@ -1290,27 +1811,27 @@ " \n", " \n", " 0\n", - " Hey I&#39;m a British youtuber!!<br />I upload...\n", + " http://www.rtbf.be/tv/emission/detail_the-voic...\n", " SPAM\n", " \n", " \n", " 1\n", - " NOKIA spotted\n", - " HAM\n", + " http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht...\n", + " SPAM\n", " \n", " \n", " 2\n", - " Dance :)\n", + " Perfect! &lt;3\n", " HAM\n", " \n", " \n", " 3\n", - " You guys should check out this EXTRAORDINARY w...\n", + " Check out Melbourne shuffle, everybody!\n", " SPAM\n", " \n", " \n", " 4\n", - " Need money ? check my channel and subscribe,so...\n", + " Facebook account HACK!! http://hackfbaccountl...\n", " SPAM\n", " \n", " \n", @@ -1319,60 +1840,60 @@ " ...\n", " \n", " \n", - " 1172\n", - " Please check out my acoustic cover channel :) ...\n", - " SPAM\n", + " 1174\n", + " Great song\n", + " HAM\n", " \n", " \n", - " 1173\n", - " PLEASE SUBSCRIBE ME!!!!!!!!!!!!!!!!!!!!!!!!!!!...\n", - " SPAM\n", + " 1175\n", + " subscribe\n", + " HAM\n", " \n", " \n", - " 1174\n", - " <a href=\"http://www.gofundme.com/Helpmypitbull...\n", - " SPAM\n", + " 1176\n", + " LoL\n", + " HAM\n", " \n", " \n", - " 1175\n", - " I love this song so much!:-D I've heard it so ...\n", + " 1177\n", + " Love this song\n", " HAM\n", " \n", " \n", - " 1176\n", - " Check out this video on YouTube:\n", - " SPAM\n", + " 1178\n", + " LOVE THE WAY YOU LIE ..&quot;\n", + " HAM\n", " \n", " \n", "\n", - "

1177 rows × 2 columns

\n", + "

1179 rows × 2 columns

\n", "" ], "text/plain": [ " text label\n", - "0 Hey I'm a British youtuber!!
I upload... SPAM\n", - "1 NOKIA spotted HAM\n", - "2 Dance :) HAM\n", - "3 You guys should check out this EXTRAORDINARY w... SPAM\n", - "4 Need money ? check my channel and subscribe,so... SPAM\n", + "0 http://www.rtbf.be/tv/emission/detail_the-voic... SPAM\n", + "1 http://www.ermail.pl/dolacz/V3VeYGIN CLICK ht... SPAM\n", + "2 Perfect! <3 HAM\n", + "3 Check out Melbourne shuffle, everybody! SPAM\n", + "4 Facebook account HACK!! http://hackfbaccountl... SPAM\n", "... ... ...\n", - "1172 Please check out my acoustic cover channel :) ... SPAM\n", - "1173 PLEASE SUBSCRIBE ME!!!!!!!!!!!!!!!!!!!!!!!!!!!... SPAM\n", - "1174 \n", " money\n", " {Business}\n", - " 0.008242\n", + " 0.008233\n", " 0.008816\n", - " 0.002450\n", - " 0.001925\n", + " 0.002500\n", + " 0.001975\n", " 31\n", " 36\n", " 0.462687\n", @@ -284,8 +304,8 @@ " {Business}\n", " 0.019775\n", " 0.021184\n", - " 0.005892\n", - " 0.005183\n", + " 0.005933\n", + " 0.005225\n", " 115\n", " 46\n", " 0.714286\n", @@ -295,8 +315,8 @@ " {Business}\n", " 0.016608\n", " 0.016974\n", - " 0.003492\n", - " 0.002850\n", + " 0.003533\n", + " 0.002892\n", " 98\n", " 31\n", " 0.759690\n", @@ -304,21 +324,21 @@ " \n", " war\n", " {World}\n", - " 0.011683\n", - " 0.008816\n", - " 0.003242\n", - " 0.001367\n", - " 44\n", - " 23\n", - " 0.656716\n", + " 0.015533\n", + " 0.013553\n", + " 0.004467\n", + " 0.001750\n", + " 76\n", + " 27\n", + " 0.737864\n", " \n", " \n", " gov*\n", " {World}\n", - " 0.045067\n", + " 0.045075\n", " 0.043158\n", - " 0.010800\n", - " 0.006225\n", + " 0.011092\n", + " 0.006233\n", " 156\n", " 172\n", " 0.475610\n", @@ -328,7 +348,7 @@ " {World}\n", " 0.030142\n", " 0.030263\n", - " 0.007508\n", + " 0.007883\n", " 0.002825\n", " 207\n", " 23\n", @@ -337,9 +357,9 @@ " \n", " conflict\n", " {World}\n", - " 0.003050\n", + " 0.003042\n", " 0.003684\n", - " 0.001025\n", + " 0.001125\n", " 0.000092\n", " 20\n", " 8\n", @@ -348,10 +368,10 @@ " \n", " footbal*\n", " {Sports}\n", - " 0.013050\n", + " 0.013042\n", " 0.015132\n", - " 0.004875\n", - " 0.000408\n", + " 0.004883\n", + " 0.000417\n", " 105\n", " 10\n", " 0.913043\n", @@ -361,7 +381,7 @@ " {Sports}\n", " 0.021183\n", " 0.021711\n", - " 0.007033\n", + " 0.007025\n", " 0.001225\n", " 146\n", " 19\n", @@ -370,21 +390,21 @@ " \n", " game\n", " {Sports}\n", - " 0.038950\n", - " 0.043026\n", - " 0.014067\n", - " 0.002375\n", - " 253\n", - " 74\n", - " 0.773700\n", + " 0.038808\n", + " 0.042763\n", + " 0.014042\n", + " 0.002392\n", + " 252\n", + " 73\n", + " 0.775385\n", " \n", " \n", " play*\n", " {Sports}\n", " 0.052608\n", " 0.057632\n", - " 0.016767\n", - " 0.004992\n", + " 0.016875\n", + " 0.005133\n", " 312\n", " 126\n", " 0.712329\n", @@ -394,8 +414,8 @@ " {Sci/Tech}\n", " 0.016433\n", " 0.015658\n", - " 0.002742\n", - " 0.001275\n", + " 0.002775\n", + " 0.001292\n", " 101\n", " 18\n", " 0.848739\n", @@ -403,10 +423,10 @@ " \n", " techno*\n", " {Sci/Tech}\n", - " 0.027150\n", + " 0.027142\n", " 0.028816\n", - " 0.008325\n", - " 0.003108\n", + " 0.008433\n", + " 0.003142\n", " 153\n", " 66\n", " 0.698630\n", @@ -414,21 +434,21 @@ " \n", " computer*\n", " {Sci/Tech}\n", - " 0.027275\n", - " 0.026447\n", - " 0.011100\n", - " 0.004483\n", + " 0.027550\n", + " 0.026842\n", + " 0.011333\n", + " 0.004542\n", " 167\n", - " 34\n", - " 0.830846\n", + " 37\n", + " 0.818627\n", " \n", " \n", " software\n", " {Sci/Tech}\n", - " 0.030283\n", + " 0.030233\n", " 0.032763\n", - " 0.009625\n", - " 0.003308\n", + " 0.009808\n", + " 0.003342\n", " 202\n", " 47\n", " 0.811245\n", @@ -436,24 +456,24 @@ " \n", " web\n", " {Sci/Tech}\n", - " 0.015508\n", - " 0.016316\n", - " 0.004100\n", - " 0.001608\n", - " 111\n", - " 13\n", - " 0.895161\n", + " 0.017283\n", + " 0.018421\n", + " 0.004625\n", + " 0.001792\n", + " 124\n", + " 16\n", + " 0.885714\n", " \n", " \n", " total\n", - " {Sci/Tech, World, Sports, Business}\n", - " 0.317375\n", - " 0.327895\n", - " 0.053408\n", - " 0.019425\n", - " 2221\n", - " 746\n", - " 0.748568\n", + " {Sci/Tech, World, Business, Sports}\n", + " 0.321342\n", + " 0.333684\n", + " 0.054983\n", + " 0.019908\n", + " 2265\n", + " 755\n", + " 0.750000\n", " \n", " \n", "\n", @@ -461,45 +481,45 @@ ], "text/plain": [ " label coverage annotated_coverage \\\n", - "money {Business} 0.008242 0.008816 \n", + "money {Business} 0.008233 0.008816 \n", "financ* {Business} 0.019775 0.021184 \n", "dollar* {Business} 0.016608 0.016974 \n", - "war {World} 0.011683 0.008816 \n", - "gov* {World} 0.045067 0.043158 \n", + "war {World} 0.015533 0.013553 \n", + "gov* {World} 0.045075 0.043158 \n", "minister* {World} 0.030142 0.030263 \n", - "conflict {World} 0.003050 0.003684 \n", - "footbal* {Sports} 0.013050 0.015132 \n", + "conflict {World} 0.003042 0.003684 \n", + "footbal* {Sports} 0.013042 0.015132 \n", "sport* {Sports} 0.021183 0.021711 \n", - "game {Sports} 0.038950 0.043026 \n", + "game {Sports} 0.038808 0.042763 \n", "play* {Sports} 0.052608 0.057632 \n", "sci* {Sci/Tech} 0.016433 0.015658 \n", - "techno* {Sci/Tech} 0.027150 0.028816 \n", - "computer* {Sci/Tech} 0.027275 0.026447 \n", - "software {Sci/Tech} 0.030283 0.032763 \n", - "web {Sci/Tech} 0.015508 0.016316 \n", - "total {Sci/Tech, World, Sports, Business} 0.317375 0.327895 \n", + "techno* {Sci/Tech} 0.027142 0.028816 \n", + "computer* {Sci/Tech} 0.027550 0.026842 \n", + "software {Sci/Tech} 0.030233 0.032763 \n", + "web {Sci/Tech} 0.017283 0.018421 \n", + "total {Sci/Tech, World, Business, Sports} 0.321342 0.333684 \n", "\n", " overlaps conflicts correct incorrect precision \n", - "money 0.002450 0.001925 31 36 0.462687 \n", - "financ* 0.005892 0.005183 115 46 0.714286 \n", - "dollar* 0.003492 0.002850 98 31 0.759690 \n", - "war 0.003242 0.001367 44 23 0.656716 \n", - "gov* 0.010800 0.006225 156 172 0.475610 \n", - "minister* 0.007508 0.002825 207 23 0.900000 \n", - "conflict 0.001025 0.000092 20 8 0.714286 \n", - "footbal* 0.004875 0.000408 105 10 0.913043 \n", - "sport* 0.007033 0.001225 146 19 0.884848 \n", - "game 0.014067 0.002375 253 74 0.773700 \n", - "play* 0.016767 0.004992 312 126 0.712329 \n", - "sci* 0.002742 0.001275 101 18 0.848739 \n", - "techno* 0.008325 0.003108 153 66 0.698630 \n", - "computer* 0.011100 0.004483 167 34 0.830846 \n", - "software 0.009625 0.003308 202 47 0.811245 \n", - "web 0.004100 0.001608 111 13 0.895161 \n", - "total 0.053408 0.019425 2221 746 0.748568 " + "money 0.002500 0.001975 31 36 0.462687 \n", + "financ* 0.005933 0.005225 115 46 0.714286 \n", + "dollar* 0.003533 0.002892 98 31 0.759690 \n", + "war 0.004467 0.001750 76 27 0.737864 \n", + "gov* 0.011092 0.006233 156 172 0.475610 \n", + "minister* 0.007883 0.002825 207 23 0.900000 \n", + "conflict 0.001125 0.000092 20 8 0.714286 \n", + "footbal* 0.004883 0.000417 105 10 0.913043 \n", + "sport* 0.007025 0.001225 146 19 0.884848 \n", + "game 0.014042 0.002392 252 73 0.775385 \n", + "play* 0.016875 0.005133 312 126 0.712329 \n", + "sci* 0.002775 0.001292 101 18 0.848739 \n", + "techno* 0.008433 0.003142 153 66 0.698630 \n", + "computer* 0.011333 0.004542 167 37 0.818627 \n", + "software 0.009808 0.003342 202 47 0.811245 \n", + "web 0.004625 0.001792 124 16 0.885714 \n", + "total 0.054983 0.019908 2265 755 0.750000 " ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -537,7 +557,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "id": "751c8c8c-6632-4fe6-97db-fb616fc77358", "metadata": {}, "outputs": [ @@ -547,14 +567,14 @@ "text": [ " precision recall f1-score support\n", "\n", - " Business 0.73 0.41 0.53 493\n", - " Sports 0.77 0.97 0.86 703\n", - " World 0.69 0.83 0.75 462\n", - " Sci/Tech 0.80 0.74 0.77 833\n", + " Business 0.74 0.41 0.52 499\n", + " Sports 0.77 0.95 0.85 704\n", + " World 0.68 0.84 0.75 487\n", + " Sci/Tech 0.80 0.75 0.77 846\n", "\n", - " accuracy 0.76 2491\n", - " macro avg 0.75 0.74 0.73 2491\n", - "weighted avg 0.76 0.76 0.74 2491\n", + " accuracy 0.76 2536\n", + " macro avg 0.75 0.74 0.73 2536\n", + "weighted avg 0.76 0.76 0.74 2536\n", "\n" ] } @@ -598,10 +618,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "9c76fa80", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b94719e1b0fb47f1a4b281ebf53d79a3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/120000 [00:00\n", " thank*\n", " {gratitude}\n", - " 0.196768\n", - " 0.196237\n", - " 0.037785\n", - " 73\n", + " 0.199382\n", + " 0.198925\n", + " 0.048004\n", + " 74\n", " 0\n", " 1.000000\n", " \n", " \n", " appreciate\n", " {gratitude}\n", - " 0.016160\n", + " 0.016397\n", " 0.021505\n", - " 0.009506\n", + " 0.009981\n", " 7\n", " 1\n", " 0.875000\n", @@ -253,7 +268,7 @@ " {admiration, gratitude}\n", " 0.007842\n", " 0.010753\n", - " 0.007605\n", + " 0.007842\n", " 8\n", " 0\n", " 1.000000\n", @@ -263,7 +278,7 @@ " {admiration}\n", " 0.008317\n", " 0.008065\n", - " 0.006654\n", + " 0.007605\n", " 3\n", " 0\n", " 1.000000\n", @@ -273,7 +288,7 @@ " {admiration}\n", " 0.025428\n", " 0.021505\n", - " 0.003565\n", + " 0.004990\n", " 8\n", " 0\n", " 1.000000\n", @@ -283,7 +298,7 @@ " {admiration}\n", " 0.025190\n", " 0.034946\n", - " 0.006179\n", + " 0.007605\n", " 12\n", " 1\n", " 0.923077\n", @@ -303,7 +318,7 @@ " {admiration}\n", " 0.008555\n", " 0.018817\n", - " 0.002376\n", + " 0.003089\n", " 7\n", " 0\n", " 1.000000\n", @@ -321,32 +336,32 @@ " \n", " exactly\n", " {approval}\n", - " 0.004278\n", - " 0.002688\n", - " 0.001188\n", + " 0.007842\n", + " 0.010753\n", + " 0.002376\n", + " 3\n", " 1\n", - " 0\n", - " 1.000000\n", + " 0.750000\n", " \n", " \n", " agree\n", " {approval}\n", " 0.016873\n", " 0.021505\n", - " 0.003089\n", + " 0.003327\n", " 6\n", " 2\n", " 0.750000\n", " \n", " \n", " yeah\n", - " {approval}\n", + " {optimism}\n", " 0.024952\n", " 0.021505\n", - " 0.004990\n", - " 5\n", - " 3\n", - " 0.625000\n", + " 0.006179\n", + " 2\n", + " 6\n", + " 0.250000\n", " \n", " \n", " suck\n", @@ -363,7 +378,7 @@ " {annoyance}\n", " 0.002139\n", " 0.008065\n", - " 0.000475\n", + " 0.000713\n", " 2\n", " 1\n", " 0.666667\n", @@ -373,7 +388,7 @@ " {annoyance}\n", " 0.003327\n", " 0.018817\n", - " 0.000951\n", + " 0.001188\n", " 7\n", " 0\n", " 1.000000\n", @@ -399,11 +414,21 @@ " 1.000000\n", " \n", " \n", + " joking\n", + " {admiration, optimism}\n", + " 0.000238\n", + " 0.000000\n", + " 0.000000\n", + " 0\n", + " 0\n", + " NaN\n", + " \n", + " \n", " text:(\"good luck\")\n", " {optimism}\n", " 0.015209\n", " 0.018817\n", - " 0.002139\n", + " 0.002614\n", " 4\n", " 3\n", " 0.571429\n", @@ -423,7 +448,7 @@ " {curiosity}\n", " 0.004040\n", " 0.005376\n", - " 0.000951\n", + " 0.001188\n", " 2\n", " 0\n", " 1.000000\n", @@ -443,20 +468,40 @@ " {curiosity}\n", " 0.000951\n", " 0.005376\n", - " 0.000000\n", + " 0.000238\n", " 2\n", " 0\n", " 1.000000\n", " \n", " \n", + " \"do you\"\n", + " {admiration, curiosity}\n", + " 0.010932\n", + " 0.018817\n", + " 0.002376\n", + " 7\n", + " 7\n", + " 0.500000\n", + " \n", + " \n", + " \"great\"\n", + " {annoyance}\n", + " 0.055133\n", + " 0.061828\n", + " 0.016873\n", + " 1\n", + " 22\n", + " 0.043478\n", + " \n", + " \n", " total\n", - " {curiosity, annoyance, admiration, approval, o...\n", - " 0.327234\n", - " 0.384409\n", - " 0.041825\n", - " 161\n", - " 11\n", - " 0.936047\n", + " {approval, gratitude, admiration, optimism, cu...\n", + " 0.379753\n", + " 0.448925\n", + " 0.060361\n", + " 169\n", + " 44\n", + " 0.793427\n", " \n", " \n", "\n", @@ -475,22 +520,25 @@ "legend {admiration} \n", "exactly {approval} \n", "agree {approval} \n", - "yeah {approval} \n", + "yeah {optimism} \n", "suck {annoyance} \n", "pissed {annoyance} \n", "annoying {annoyance} \n", "ruined {annoyance} \n", "hoping {optimism} \n", + "joking {admiration, optimism} \n", "text:(\"good luck\") {optimism} \n", "\"nice day\" {optimism} \n", "\"what is\" {curiosity} \n", "\"can you\" {curiosity} \n", "\"would you\" {curiosity} \n", - "total {curiosity, annoyance, admiration, approval, o... \n", + "\"do you\" {admiration, curiosity} \n", + "\"great\" {annoyance} \n", + "total {approval, gratitude, admiration, optimism, cu... \n", "\n", " coverage annotated_coverage \\\n", - "thank* 0.196768 0.196237 \n", - "appreciate 0.016160 0.021505 \n", + "thank* 0.199382 0.198925 \n", + "appreciate 0.016397 0.021505 \n", "text:(thanks AND good) 0.007842 0.010753 \n", "advice 0.008317 0.008065 \n", "amazing 0.025428 0.021505 \n", @@ -498,7 +546,7 @@ "impressed 0.002139 0.005376 \n", "text:(good AND (point OR call OR idea OR job)) 0.008555 0.018817 \n", "legend 0.001901 0.002688 \n", - "exactly 0.004278 0.002688 \n", + "exactly 0.007842 0.010753 \n", "agree 0.016873 0.021505 \n", "yeah 0.024952 0.021505 \n", "suck 0.002139 0.008065 \n", @@ -506,37 +554,43 @@ "annoying 0.003327 0.018817 \n", "ruined 0.000713 0.002688 \n", "hoping 0.003565 0.005376 \n", + "joking 0.000238 0.000000 \n", "text:(\"good luck\") 0.015209 0.018817 \n", "\"nice day\" 0.000713 0.005376 \n", "\"what is\" 0.004040 0.005376 \n", "\"can you\" 0.004278 0.008065 \n", "\"would you\" 0.000951 0.005376 \n", - "total 0.327234 0.384409 \n", + "\"do you\" 0.010932 0.018817 \n", + "\"great\" 0.055133 0.061828 \n", + "total 0.379753 0.448925 \n", "\n", " overlaps correct incorrect \\\n", - "thank* 0.037785 73 0 \n", - "appreciate 0.009506 7 1 \n", - "text:(thanks AND good) 0.007605 8 0 \n", - "advice 0.006654 3 0 \n", - "amazing 0.003565 8 0 \n", - "awesome 0.006179 12 1 \n", + "thank* 0.048004 74 0 \n", + "appreciate 0.009981 7 1 \n", + "text:(thanks AND good) 0.007842 8 0 \n", + "advice 0.007605 3 0 \n", + "amazing 0.004990 8 0 \n", + "awesome 0.007605 12 1 \n", "impressed 0.000000 2 0 \n", - "text:(good AND (point OR call OR idea OR job)) 0.002376 7 0 \n", + "text:(good AND (point OR call OR idea OR job)) 0.003089 7 0 \n", "legend 0.000475 1 0 \n", - "exactly 0.001188 1 0 \n", - "agree 0.003089 6 2 \n", - "yeah 0.004990 5 3 \n", + "exactly 0.002376 3 1 \n", + "agree 0.003327 6 2 \n", + "yeah 0.006179 2 6 \n", "suck 0.000475 3 0 \n", - "pissed 0.000475 2 1 \n", - "annoying 0.000951 7 0 \n", + "pissed 0.000713 2 1 \n", + "annoying 0.001188 7 0 \n", "ruined 0.000238 1 0 \n", "hoping 0.000713 2 0 \n", - "text:(\"good luck\") 0.002139 4 3 \n", + "joking 0.000000 0 0 \n", + "text:(\"good luck\") 0.002614 4 3 \n", "\"nice day\" 0.000000 2 0 \n", - "\"what is\" 0.000951 2 0 \n", + "\"what is\" 0.001188 2 0 \n", "\"can you\" 0.000713 3 0 \n", - "\"would you\" 0.000000 2 0 \n", - "total 0.041825 161 11 \n", + "\"would you\" 0.000238 2 0 \n", + "\"do you\" 0.002376 7 7 \n", + "\"great\" 0.016873 1 22 \n", + "total 0.060361 169 44 \n", "\n", " precision \n", "thank* 1.000000 \n", @@ -548,23 +602,26 @@ "impressed 1.000000 \n", "text:(good AND (point OR call OR idea OR job)) 1.000000 \n", "legend 1.000000 \n", - "exactly 1.000000 \n", + "exactly 0.750000 \n", "agree 0.750000 \n", - "yeah 0.625000 \n", + "yeah 0.250000 \n", "suck 1.000000 \n", "pissed 0.666667 \n", "annoying 1.000000 \n", "ruined 1.000000 \n", "hoping 1.000000 \n", + "joking NaN \n", "text:(\"good luck\") 0.571429 \n", "\"nice day\" 1.000000 \n", "\"what is\" 1.000000 \n", "\"can you\" 1.000000 \n", "\"would you\" 1.000000 \n", - "total 0.936047 " + "\"do you\" 0.500000 \n", + "\"great\" 0.043478 \n", + "total 0.793427 " ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -576,214 +633,1664 @@ }, { "cell_type": "markdown", - "id": "ba19f6b8-520a-4caa-87a4-307f52749b92", - "metadata": { - "tags": [] - }, - "source": [ - "### Create training set\n", - "\n", - "When we are happy with our heuristics, it is time to combine them and compute weak labels for the training of our downstream model.\n", - "For this we will use the `MajorityVoter`.\n", - "In the multi-label case, it sets the probability of a label to 0 or 1 depending on whether at least one non-abstaining rule voted for the respective label or not." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6985e6b7-7c28-4efc-84d2-08b7fff02ef8", - "metadata": {}, - "outputs": [], - "source": [ - "from argilla.labeling.text_classification import MajorityVoter\n", - "\n", - "# Use the majority voter as the label model\n", - "label_model = MajorityVoter(weak_labels)\n" - ] - }, - { - "cell_type": "markdown", - "id": "f82fe715-e703-457a-a158-78709f442bbd", - "metadata": {}, - "source": [ - "From our label model we get the training records together with its weak labels and probabilities.\n", - "We will use the weak labels with a probability greater than 0.5 as labels for our training, and hence copy them to the `annotation` property of our records." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "152abc10-1538-4660-aa3a-0e210887c0f1", - "metadata": {}, - "outputs": [], - "source": [ - "# Get records with the predictions from the label model to train a down-stream model\n", - "train_rb = label_model.predict()\n", - "\n", - "# Copy label model predictions to annotation with a threshold of 0.5\n", - "for rec in train_rb:\n", - " rec.annotation = [pred[0] for pred in rec.prediction if pred[1] > 0.5]\n" - ] - }, - { - "cell_type": "markdown", - "id": "24414843-70c2-4dc2-902d-e6e8bcc2a449", - "metadata": {}, - "source": [ - "We extract the test set with manual annotations from our `WeakMultiLabels` object:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fce0c5a0-bfa5-4649-a59a-f0114f7891a5", - "metadata": {}, - "outputs": [], - "source": [ - "# Get records with manual annotations to use as test set for the down-stream model\n", - "test_rg = rg.DatasetForTextClassification(weak_labels.records(has_annotation=True))\n" - ] - }, - { - "cell_type": "markdown", - "id": "d7738210-4273-49c2-a70c-634b18fa008e", - "metadata": {}, - "source": [ - "We will use the convenient `DatasetForTextClassification.prepare_for_training()` method to create datasets optimized for training with the Hugging Face transformers library:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "232a2d46-1516-4c7d-8d39-d55fe0d8130d", - "metadata": {}, - "outputs": [], - "source": [ - "from datasets import DatasetDict\n", - "\n", - "# Create dataset dictionary and shuffle training set\n", - "ds = DatasetDict(\n", - " train=train_rg.prepare_for_training().shuffle(seed=42),\n", - " test=test_rg.prepare_for_training(),\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "b0be5d9c-1532-4265-8f7a-18e346e9caaf", - "metadata": {}, - "source": [ - "Let us push the dataset to the Hub to share it with our colleagues.\n", - "It is also an easy way to outsource the training of the model to an environment with an accelerator, like Google Colab for example." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "072218ae-2167-48d8-8b93-243daff497b3", - "metadata": {}, - "outputs": [], - "source": [ - "# Push dataset for training our down-stream model to the HF hub\n", - "ds.push_to_hub(\"argilla/go_emotions_training\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "c0957ad2-f05d-4ea4-b303-d77495a449e9", + "id": "88b3fffe", "metadata": {}, "source": [ - "### Train a transformer downstream model" - ] - }, - { - "cell_type": "markdown", - "id": "ec8cea7b-cd6f-408d-9667-249e3d66eac0", - "metadata": {}, - "source": [ - "The following steps are basically a copy&paste from the amazing documentation of the [Hugging Face transformers](https://huggingface.co/docs/transformers) library.\n", - "\n", - "First, we will load the tokenizer corresponding to our model, which we choose to be the [distilled version](https://huggingface.co/distilbert-base-uncased) of the infamous BERT.\n", - "\n", - "
\n", - "\n", - "Note\n", - "\n", - "Since we will use a full-blown transformer as a downstream model (albeit a distilled one), we recommend executing the following code on a machine with a GPU, or in a Google Colab with a GPU backend enabled.\n", - " \n", - "
" + "We can observe that \"joking\" does not have any support and also \"do you\" is not informative, because its correct/incorrect ratio equals to 1. We can delete these two rules from the dataset using \"delete_rules\" method " ] }, { "cell_type": "code", - "execution_count": null, - "id": "88365498-15df-4aff-ae3a-48c22e4e1a66", + "execution_count": 11, + "id": "f047f88e", "metadata": {}, "outputs": [], "source": [ - "from transformers import AutoTokenizer\n", + "rules_to_delete = [\n", + " Rule(\"joking\", [\"optimism\", \"admiration\"]),\n", + " Rule('\"do you\"', [\"curiosity\", \"admiration\"])]\n", "\n", - "# Initialize tokenizer\n", - "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n" + "delete_rules(dataset=\"go_emotions\", rules=rules_to_delete)" ] }, { "cell_type": "markdown", - "id": "fd0cb66a-c31e-4811-abc3-453179ee567f", + "id": "277c8167", "metadata": {}, "source": [ - "Afterward, we tokenize our data:" + "# lets apply Weak Labeling again " ] }, { "cell_type": "code", "execution_count": null, - "id": "6635d069-e242-4d89-92be-ecce502b92db", + "id": "d053adfe", "metadata": {}, "outputs": [], "source": [ - "def tokenize_func(examples):\n", - " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", - "\n", - "\n", - "# Tokenize the data\n", - "tokenized_ds = ds.map(tokenize_func, batched=True)\n" - ] - }, - { - "cell_type": "markdown", - "id": "28d70c03-9602-472f-aa79-28777fe3b0e2", - "metadata": {}, - "source": [ - "The transformer model expects our labels to follow a common multi-label format of binaries, so let us use [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html) for this transformation." + "weak_labels = WeakMultiLabels(\"go_emotions\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "a7efcbc6-938c-4341-9e83-e8761974deda", - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.preprocessing import MultiLabelBinarizer\n", - "\n", - "# Turn labels into multi-label format\n", - "mb = MultiLabelBinarizer()\n", - "mb.fit(ds[\"test\"][\"label\"])\n", - "\n", - "\n", - "def binarize_labels(examples):\n", - " return {\"label\": mb.transform(examples[\"label\"])}\n", - "\n", - "\n", - "binarized_tokenized_ds = tokenized_ds.map(binarize_labels, batched=True)\n" - ] - }, - { - "cell_type": "markdown", - "id": "e71e2cd7-50e4-4dd1-883f-f2e31d1ec051", + "execution_count": 13, + "id": "9ad4fbd4", "metadata": {}, - "source": [ + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcoverageannotated_coverageoverlapscorrectincorrectprecision
thank*{gratitude}0.1993820.1989250.0477667401.000000
appreciate{gratitude}0.0163970.0215050.009743710.875000
text:(thanks AND good){admiration, gratitude}0.0078420.0107530.007842801.000000
advice{admiration}0.0083170.0080650.007367301.000000
amazing{admiration}0.0254280.0215050.004990801.000000
awesome{admiration}0.0251900.0349460.0071291210.923077
impressed{admiration}0.0021390.0053760.000000201.000000
text:(good AND (point OR call OR idea OR job)){admiration}0.0085550.0188170.003089701.000000
legend{admiration}0.0019010.0026880.000475101.000000
exactly{approval}0.0078420.0107530.002139310.750000
agree{approval}0.0168730.0215050.003327620.750000
yeah{optimism}0.0249520.0215050.006179260.250000
suck{annoyance}0.0021390.0080650.000475301.000000
pissed{annoyance}0.0021390.0080650.000475210.666667
annoying{annoyance}0.0033270.0188170.001188701.000000
ruined{annoyance}0.0007130.0026880.000238101.000000
hoping{optimism}0.0035650.0053760.000713201.000000
text:(\"good luck\"){optimism}0.0152090.0188170.002614430.571429
\"nice day\"{optimism}0.0007130.0053760.000000201.000000
\"what is\"{curiosity}0.0040400.0053760.001188201.000000
\"can you\"{curiosity}0.0042780.0080650.000713301.000000
\"would you\"{curiosity}0.0009510.0053760.000238201.000000
\"great\"{annoyance}0.0551330.0618280.0163971220.043478
total{approval, gratitude, admiration, optimism, cu...0.3709600.4354840.058222162370.814070
\n", + "
" + ], + "text/plain": [ + " label \\\n", + "thank* {gratitude} \n", + "appreciate {gratitude} \n", + "text:(thanks AND good) {admiration, gratitude} \n", + "advice {admiration} \n", + "amazing {admiration} \n", + "awesome {admiration} \n", + "impressed {admiration} \n", + "text:(good AND (point OR call OR idea OR job)) {admiration} \n", + "legend {admiration} \n", + "exactly {approval} \n", + "agree {approval} \n", + "yeah {optimism} \n", + "suck {annoyance} \n", + "pissed {annoyance} \n", + "annoying {annoyance} \n", + "ruined {annoyance} \n", + "hoping {optimism} \n", + "text:(\"good luck\") {optimism} \n", + "\"nice day\" {optimism} \n", + "\"what is\" {curiosity} \n", + "\"can you\" {curiosity} \n", + "\"would you\" {curiosity} \n", + "\"great\" {annoyance} \n", + "total {approval, gratitude, admiration, optimism, cu... \n", + "\n", + " coverage annotated_coverage \\\n", + "thank* 0.199382 0.198925 \n", + "appreciate 0.016397 0.021505 \n", + "text:(thanks AND good) 0.007842 0.010753 \n", + "advice 0.008317 0.008065 \n", + "amazing 0.025428 0.021505 \n", + "awesome 0.025190 0.034946 \n", + "impressed 0.002139 0.005376 \n", + "text:(good AND (point OR call OR idea OR job)) 0.008555 0.018817 \n", + "legend 0.001901 0.002688 \n", + "exactly 0.007842 0.010753 \n", + "agree 0.016873 0.021505 \n", + "yeah 0.024952 0.021505 \n", + "suck 0.002139 0.008065 \n", + "pissed 0.002139 0.008065 \n", + "annoying 0.003327 0.018817 \n", + "ruined 0.000713 0.002688 \n", + "hoping 0.003565 0.005376 \n", + "text:(\"good luck\") 0.015209 0.018817 \n", + "\"nice day\" 0.000713 0.005376 \n", + "\"what is\" 0.004040 0.005376 \n", + "\"can you\" 0.004278 0.008065 \n", + "\"would you\" 0.000951 0.005376 \n", + "\"great\" 0.055133 0.061828 \n", + "total 0.370960 0.435484 \n", + "\n", + " overlaps correct incorrect \\\n", + "thank* 0.047766 74 0 \n", + "appreciate 0.009743 7 1 \n", + "text:(thanks AND good) 0.007842 8 0 \n", + "advice 0.007367 3 0 \n", + "amazing 0.004990 8 0 \n", + "awesome 0.007129 12 1 \n", + "impressed 0.000000 2 0 \n", + "text:(good AND (point OR call OR idea OR job)) 0.003089 7 0 \n", + "legend 0.000475 1 0 \n", + "exactly 0.002139 3 1 \n", + "agree 0.003327 6 2 \n", + "yeah 0.006179 2 6 \n", + "suck 0.000475 3 0 \n", + "pissed 0.000475 2 1 \n", + "annoying 0.001188 7 0 \n", + "ruined 0.000238 1 0 \n", + "hoping 0.000713 2 0 \n", + "text:(\"good luck\") 0.002614 4 3 \n", + "\"nice day\" 0.000000 2 0 \n", + "\"what is\" 0.001188 2 0 \n", + "\"can you\" 0.000713 3 0 \n", + "\"would you\" 0.000238 2 0 \n", + "\"great\" 0.016397 1 22 \n", + "total 0.058222 162 37 \n", + "\n", + " precision \n", + "thank* 1.000000 \n", + "appreciate 0.875000 \n", + "text:(thanks AND good) 1.000000 \n", + "advice 1.000000 \n", + "amazing 1.000000 \n", + "awesome 0.923077 \n", + "impressed 1.000000 \n", + "text:(good AND (point OR call OR idea OR job)) 1.000000 \n", + "legend 1.000000 \n", + "exactly 0.750000 \n", + "agree 0.750000 \n", + "yeah 0.250000 \n", + "suck 1.000000 \n", + "pissed 0.666667 \n", + "annoying 1.000000 \n", + "ruined 1.000000 \n", + "hoping 1.000000 \n", + "text:(\"good luck\") 0.571429 \n", + "\"nice day\" 1.000000 \n", + "\"what is\" 1.000000 \n", + "\"can you\" 1.000000 \n", + "\"would you\" 1.000000 \n", + "\"great\" 0.043478 \n", + "total 0.814070 " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weak_labels.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "f8019280", + "metadata": {}, + "source": [ + "We can observe that following rules are not working well; \n", + "\n", + " Rule('\"great\"', [\"annoyance\"])\n", + "\n", + " Rule(\"yeah\", \"optimism\"),\n", + "\n", + "Let's update this two rules such that:\n", + "\n", + " Rule('\"great\"', [\"admiration\"])\n", + "\n", + " Rule(\"yeah\", \"approval\")," + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0c299231", + "metadata": {}, + "outputs": [], + "source": [ + "rules_to_update = [\n", + " Rule('\"great\"', [\"admiration\"]),\n", + " Rule(\"yeah\", \"approval\")]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e8d85e06", + "metadata": {}, + "outputs": [], + "source": [ + "update_rules(dataset=\"go_emotions\", rules=rules_to_update)" + ] + }, + { + "cell_type": "markdown", + "id": "296ba48a", + "metadata": {}, + "source": [ + "Lets' run weak labeling with final rules of the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb5add74", + "metadata": {}, + "outputs": [], + "source": [ + "weak_labels = WeakMultiLabels(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "449d7a17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcoverageannotated_coverageoverlapscorrectincorrectprecision
thank*{gratitude}0.1993820.1989250.0477667401.000000
appreciate{gratitude}0.0163970.0215050.009743710.875000
text:(thanks AND good){admiration, gratitude}0.0078420.0107530.007842801.000000
advice{admiration}0.0083170.0080650.007367301.000000
amazing{admiration}0.0254280.0215050.004990801.000000
awesome{admiration}0.0251900.0349460.0071291210.923077
impressed{admiration}0.0021390.0053760.000000201.000000
text:(good AND (point OR call OR idea OR job)){admiration}0.0085550.0188170.003089701.000000
legend{admiration}0.0019010.0026880.000475101.000000
exactly{approval}0.0078420.0107530.002139310.750000
agree{approval}0.0168730.0215050.003327620.750000
yeah{approval}0.0249520.0215050.006179530.625000
suck{annoyance}0.0021390.0080650.000475301.000000
pissed{annoyance}0.0021390.0080650.000475210.666667
annoying{annoyance}0.0033270.0188170.001188701.000000
ruined{annoyance}0.0007130.0026880.000238101.000000
hoping{optimism}0.0035650.0053760.000713201.000000
text:(\"good luck\"){optimism}0.0152090.0188170.002614430.571429
\"nice day\"{optimism}0.0007130.0053760.000000201.000000
\"what is\"{curiosity}0.0040400.0053760.001188201.000000
\"can you\"{curiosity}0.0042780.0080650.000713301.000000
\"would you\"{curiosity}0.0009510.0053760.000238201.000000
\"great\"{admiration}0.0551330.0618280.0163971940.826087
total{approval, gratitude, admiration, optimism, cu...0.3709600.4354840.058222183160.919598
\n", + "
" + ], + "text/plain": [ + " label \\\n", + "thank* {gratitude} \n", + "appreciate {gratitude} \n", + "text:(thanks AND good) {admiration, gratitude} \n", + "advice {admiration} \n", + "amazing {admiration} \n", + "awesome {admiration} \n", + "impressed {admiration} \n", + "text:(good AND (point OR call OR idea OR job)) {admiration} \n", + "legend {admiration} \n", + "exactly {approval} \n", + "agree {approval} \n", + "yeah {approval} \n", + "suck {annoyance} \n", + "pissed {annoyance} \n", + "annoying {annoyance} \n", + "ruined {annoyance} \n", + "hoping {optimism} \n", + "text:(\"good luck\") {optimism} \n", + "\"nice day\" {optimism} \n", + "\"what is\" {curiosity} \n", + "\"can you\" {curiosity} \n", + "\"would you\" {curiosity} \n", + "\"great\" {admiration} \n", + "total {approval, gratitude, admiration, optimism, cu... \n", + "\n", + " coverage annotated_coverage \\\n", + "thank* 0.199382 0.198925 \n", + "appreciate 0.016397 0.021505 \n", + "text:(thanks AND good) 0.007842 0.010753 \n", + "advice 0.008317 0.008065 \n", + "amazing 0.025428 0.021505 \n", + "awesome 0.025190 0.034946 \n", + "impressed 0.002139 0.005376 \n", + "text:(good AND (point OR call OR idea OR job)) 0.008555 0.018817 \n", + "legend 0.001901 0.002688 \n", + "exactly 0.007842 0.010753 \n", + "agree 0.016873 0.021505 \n", + "yeah 0.024952 0.021505 \n", + "suck 0.002139 0.008065 \n", + "pissed 0.002139 0.008065 \n", + "annoying 0.003327 0.018817 \n", + "ruined 0.000713 0.002688 \n", + "hoping 0.003565 0.005376 \n", + "text:(\"good luck\") 0.015209 0.018817 \n", + "\"nice day\" 0.000713 0.005376 \n", + "\"what is\" 0.004040 0.005376 \n", + "\"can you\" 0.004278 0.008065 \n", + "\"would you\" 0.000951 0.005376 \n", + "\"great\" 0.055133 0.061828 \n", + "total 0.370960 0.435484 \n", + "\n", + " overlaps correct incorrect \\\n", + "thank* 0.047766 74 0 \n", + "appreciate 0.009743 7 1 \n", + "text:(thanks AND good) 0.007842 8 0 \n", + "advice 0.007367 3 0 \n", + "amazing 0.004990 8 0 \n", + "awesome 0.007129 12 1 \n", + "impressed 0.000000 2 0 \n", + "text:(good AND (point OR call OR idea OR job)) 0.003089 7 0 \n", + "legend 0.000475 1 0 \n", + "exactly 0.002139 3 1 \n", + "agree 0.003327 6 2 \n", + "yeah 0.006179 5 3 \n", + "suck 0.000475 3 0 \n", + "pissed 0.000475 2 1 \n", + "annoying 0.001188 7 0 \n", + "ruined 0.000238 1 0 \n", + "hoping 0.000713 2 0 \n", + "text:(\"good luck\") 0.002614 4 3 \n", + "\"nice day\" 0.000000 2 0 \n", + "\"what is\" 0.001188 2 0 \n", + "\"can you\" 0.000713 3 0 \n", + "\"would you\" 0.000238 2 0 \n", + "\"great\" 0.016397 19 4 \n", + "total 0.058222 183 16 \n", + "\n", + " precision \n", + "thank* 1.000000 \n", + "appreciate 0.875000 \n", + "text:(thanks AND good) 1.000000 \n", + "advice 1.000000 \n", + "amazing 1.000000 \n", + "awesome 0.923077 \n", + "impressed 1.000000 \n", + "text:(good AND (point OR call OR idea OR job)) 1.000000 \n", + "legend 1.000000 \n", + "exactly 0.750000 \n", + "agree 0.750000 \n", + "yeah 0.625000 \n", + "suck 1.000000 \n", + "pissed 0.666667 \n", + "annoying 1.000000 \n", + "ruined 1.000000 \n", + "hoping 1.000000 \n", + "text:(\"good luck\") 0.571429 \n", + "\"nice day\" 1.000000 \n", + "\"what is\" 1.000000 \n", + "\"can you\" 1.000000 \n", + "\"would you\" 1.000000 \n", + "\"great\" 0.826087 \n", + "total 0.919598 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weak_labels.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "61ffe13e", + "metadata": {}, + "source": [ + "Lets consider we want to try a rule" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "39669aee", + "metadata": {}, + "outputs": [], + "source": [ + "optimism_rule = Rule(\"wish*\", \"optimism\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "8aa2e341", + "metadata": {}, + "outputs": [], + "source": [ + "optimism_rule.apply(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f4a1a051", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.006178707224334601,\n", + " 'annotated_coverage': 0.0,\n", + " 'correct': 0,\n", + " 'incorrect': 0,\n", + " 'precision': None}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimism_rule.metrics(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "markdown", + "id": "50bb4a12", + "metadata": {}, + "source": [ + "__optimism_rule__ is not informative so we don't add it to dataset" + ] + }, + { + "cell_type": "markdown", + "id": "e55c1c0c", + "metadata": {}, + "source": [ + "Let's try a rule for __curiosity__ class" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a16426a5", + "metadata": {}, + "outputs": [], + "source": [ + "curiosity_rule = Rule(\"could you\", \"curiosity\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "707d7d2a", + "metadata": {}, + "outputs": [], + "source": [ + "curiosity_rule.apply(\"go_emotions\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "c6d9dfeb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.005465779467680608,\n", + " 'annotated_coverage': 0.002688172043010753,\n", + " 'correct': 1,\n", + " 'incorrect': 0,\n", + " 'precision': 1.0}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "curiosity_rule.metrics(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "markdown", + "id": "812deb5e", + "metadata": {}, + "source": [ + "__curiosity_rule__ have a positive support, we can add it to dataset as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "267a7887", + "metadata": {}, + "outputs": [], + "source": [ + "curiosity_rule.add_to_dataset(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "markdown", + "id": "10f6ea0c", + "metadata": {}, + "source": [ + "Let's apply Weak Labeling again with final rule set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87f81588", + "metadata": {}, + "outputs": [], + "source": [ + "weak_labels = WeakMultiLabels(dataset=\"go_emotions\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "eebedefb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcoverageannotated_coverageoverlapscorrectincorrectprecision
thank*{gratitude}0.1993820.1989250.0480047401.000000
appreciate{gratitude}0.0163970.0215050.009743710.875000
text:(thanks AND good){admiration, gratitude}0.0078420.0107530.007842801.000000
advice{admiration}0.0083170.0080650.007367301.000000
amazing{admiration}0.0254280.0215050.004990801.000000
awesome{admiration}0.0251900.0349460.0073671210.923077
impressed{admiration}0.0021390.0053760.000000201.000000
text:(good AND (point OR call OR idea OR job)){admiration}0.0085550.0188170.003089701.000000
legend{admiration}0.0019010.0026880.000475101.000000
exactly{approval}0.0078420.0107530.002139310.750000
agree{approval}0.0168730.0215050.003565620.750000
yeah{approval}0.0249520.0215050.006179530.625000
suck{annoyance}0.0021390.0080650.000475301.000000
pissed{annoyance}0.0021390.0080650.000475210.666667
annoying{annoyance}0.0033270.0188170.001188701.000000
ruined{annoyance}0.0007130.0026880.000238101.000000
hoping{optimism}0.0035650.0053760.000713201.000000
text:(\"good luck\"){optimism}0.0152090.0188170.002614430.571429
\"nice day\"{optimism}0.0007130.0053760.000000201.000000
\"what is\"{curiosity}0.0040400.0053760.001188201.000000
\"can you\"{curiosity}0.0042780.0080650.000713301.000000
\"would you\"{curiosity}0.0009510.0053760.000475201.000000
\"great\"{admiration}0.0551330.0618280.0163971940.826087
could you{curiosity}0.0054660.0026880.001188101.000000
total{approval, gratitude, admiration, optimism, cu...0.3752380.4354840.059173184160.920000
\n", + "
" + ], + "text/plain": [ + " label \\\n", + "thank* {gratitude} \n", + "appreciate {gratitude} \n", + "text:(thanks AND good) {admiration, gratitude} \n", + "advice {admiration} \n", + "amazing {admiration} \n", + "awesome {admiration} \n", + "impressed {admiration} \n", + "text:(good AND (point OR call OR idea OR job)) {admiration} \n", + "legend {admiration} \n", + "exactly {approval} \n", + "agree {approval} \n", + "yeah {approval} \n", + "suck {annoyance} \n", + "pissed {annoyance} \n", + "annoying {annoyance} \n", + "ruined {annoyance} \n", + "hoping {optimism} \n", + "text:(\"good luck\") {optimism} \n", + "\"nice day\" {optimism} \n", + "\"what is\" {curiosity} \n", + "\"can you\" {curiosity} \n", + "\"would you\" {curiosity} \n", + "\"great\" {admiration} \n", + "could you {curiosity} \n", + "total {approval, gratitude, admiration, optimism, cu... \n", + "\n", + " coverage annotated_coverage \\\n", + "thank* 0.199382 0.198925 \n", + "appreciate 0.016397 0.021505 \n", + "text:(thanks AND good) 0.007842 0.010753 \n", + "advice 0.008317 0.008065 \n", + "amazing 0.025428 0.021505 \n", + "awesome 0.025190 0.034946 \n", + "impressed 0.002139 0.005376 \n", + "text:(good AND (point OR call OR idea OR job)) 0.008555 0.018817 \n", + "legend 0.001901 0.002688 \n", + "exactly 0.007842 0.010753 \n", + "agree 0.016873 0.021505 \n", + "yeah 0.024952 0.021505 \n", + "suck 0.002139 0.008065 \n", + "pissed 0.002139 0.008065 \n", + "annoying 0.003327 0.018817 \n", + "ruined 0.000713 0.002688 \n", + "hoping 0.003565 0.005376 \n", + "text:(\"good luck\") 0.015209 0.018817 \n", + "\"nice day\" 0.000713 0.005376 \n", + "\"what is\" 0.004040 0.005376 \n", + "\"can you\" 0.004278 0.008065 \n", + "\"would you\" 0.000951 0.005376 \n", + "\"great\" 0.055133 0.061828 \n", + "could you 0.005466 0.002688 \n", + "total 0.375238 0.435484 \n", + "\n", + " overlaps correct incorrect \\\n", + "thank* 0.048004 74 0 \n", + "appreciate 0.009743 7 1 \n", + "text:(thanks AND good) 0.007842 8 0 \n", + "advice 0.007367 3 0 \n", + "amazing 0.004990 8 0 \n", + "awesome 0.007367 12 1 \n", + "impressed 0.000000 2 0 \n", + "text:(good AND (point OR call OR idea OR job)) 0.003089 7 0 \n", + "legend 0.000475 1 0 \n", + "exactly 0.002139 3 1 \n", + "agree 0.003565 6 2 \n", + "yeah 0.006179 5 3 \n", + "suck 0.000475 3 0 \n", + "pissed 0.000475 2 1 \n", + "annoying 0.001188 7 0 \n", + "ruined 0.000238 1 0 \n", + "hoping 0.000713 2 0 \n", + "text:(\"good luck\") 0.002614 4 3 \n", + "\"nice day\" 0.000000 2 0 \n", + "\"what is\" 0.001188 2 0 \n", + "\"can you\" 0.000713 3 0 \n", + "\"would you\" 0.000475 2 0 \n", + "\"great\" 0.016397 19 4 \n", + "could you 0.001188 1 0 \n", + "total 0.059173 184 16 \n", + "\n", + " precision \n", + "thank* 1.000000 \n", + "appreciate 0.875000 \n", + "text:(thanks AND good) 1.000000 \n", + "advice 1.000000 \n", + "amazing 1.000000 \n", + "awesome 0.923077 \n", + "impressed 1.000000 \n", + "text:(good AND (point OR call OR idea OR job)) 1.000000 \n", + "legend 1.000000 \n", + "exactly 0.750000 \n", + "agree 0.750000 \n", + "yeah 0.625000 \n", + "suck 1.000000 \n", + "pissed 0.666667 \n", + "annoying 1.000000 \n", + "ruined 1.000000 \n", + "hoping 1.000000 \n", + "text:(\"good luck\") 0.571429 \n", + "\"nice day\" 1.000000 \n", + "\"what is\" 1.000000 \n", + "\"can you\" 1.000000 \n", + "\"would you\" 1.000000 \n", + "\"great\" 0.826087 \n", + "could you 1.000000 \n", + "total 0.920000 " + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weak_labels.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "ba19f6b8-520a-4caa-87a4-307f52749b92", + "metadata": { + "tags": [] + }, + "source": [ + "### Create training set\n", + "\n", + "When we are happy with our heuristics, it is time to combine them and compute weak labels for the training of our downstream model.\n", + "For this we will use the `MajorityVoter`.\n", + "In the multi-label case, it sets the probability of a label to 0 or 1 depending on whether at least one non-abstaining rule voted for the respective label or not." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6985e6b7-7c28-4efc-84d2-08b7fff02ef8", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import MajorityVoter\n", + "\n", + "# Use the majority voter as the label model\n", + "label_model = MajorityVoter(weak_labels)\n" + ] + }, + { + "cell_type": "markdown", + "id": "f82fe715-e703-457a-a158-78709f442bbd", + "metadata": {}, + "source": [ + "From our label model we get the training records together with its weak labels and probabilities.\n", + "We will use the weak labels with a probability greater than 0.5 as labels for our training, and hence copy them to the `annotation` property of our records." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "152abc10-1538-4660-aa3a-0e210887c0f1", + "metadata": {}, + "outputs": [], + "source": [ + "# Get records with the predictions from the label model to train a down-stream model\n", + "train_rg = label_model.predict()\n", + "\n", + "# Copy label model predictions to annotation with a threshold of 0.5\n", + "for rec in train_rg:\n", + " rec.annotation = [pred[0] for pred in rec.prediction if pred[1] > 0.5]\n" + ] + }, + { + "cell_type": "markdown", + "id": "24414843-70c2-4dc2-902d-e6e8bcc2a449", + "metadata": {}, + "source": [ + "We extract the test set with manual annotations from our `WeakMultiLabels` object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fce0c5a0-bfa5-4649-a59a-f0114f7891a5", + "metadata": {}, + "outputs": [], + "source": [ + "# Get records with manual annotations to use as test set for the down-stream model\n", + "test_rg = rg.DatasetForTextClassification(weak_labels.records(has_annotation=True))\n" + ] + }, + { + "cell_type": "markdown", + "id": "d7738210-4273-49c2-a70c-634b18fa008e", + "metadata": {}, + "source": [ + "We will use the convenient `DatasetForTextClassification.prepare_for_training()` method to create datasets optimized for training with the Hugging Face transformers library:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "232a2d46-1516-4c7d-8d39-d55fe0d8130d", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import DatasetDict\n", + "\n", + "# Create dataset dictionary and shuffle training set\n", + "ds = DatasetDict(\n", + " train=train_rg.prepare_for_training().shuffle(seed=42),\n", + " test=test_rg.prepare_for_training(),\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "b0be5d9c-1532-4265-8f7a-18e346e9caaf", + "metadata": {}, + "source": [ + "Let us push the dataset to the Hub to share it with our colleagues.\n", + "It is also an easy way to outsource the training of the model to an environment with an accelerator, like Google Colab for example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "072218ae-2167-48d8-8b93-243daff497b3", + "metadata": {}, + "outputs": [], + "source": [ + "# Push dataset for training our down-stream model to the HF hub\n", + "ds.push_to_hub(\"argilla/go_emotions_training\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "c0957ad2-f05d-4ea4-b303-d77495a449e9", + "metadata": {}, + "source": [ + "### Train a transformer downstream model" + ] + }, + { + "cell_type": "markdown", + "id": "ec8cea7b-cd6f-408d-9667-249e3d66eac0", + "metadata": {}, + "source": [ + "The following steps are basically a copy&paste from the amazing documentation of the [Hugging Face transformers](https://huggingface.co/docs/transformers) library.\n", + "\n", + "First, we will load the tokenizer corresponding to our model, which we choose to be the [distilled version](https://huggingface.co/distilbert-base-uncased) of the infamous BERT.\n", + "\n", + "
\n", + "\n", + "Note\n", + "\n", + "Since we will use a full-blown transformer as a downstream model (albeit a distilled one), we recommend executing the following code on a machine with a GPU, or in a Google Colab with a GPU backend enabled.\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88365498-15df-4aff-ae3a-48c22e4e1a66", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "# Initialize tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "fd0cb66a-c31e-4811-abc3-453179ee567f", + "metadata": {}, + "source": [ + "Afterward, we tokenize our data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6635d069-e242-4d89-92be-ecce502b92db", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_func(examples):\n", + " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", + "\n", + "\n", + "# Tokenize the data\n", + "tokenized_ds = ds.map(tokenize_func, batched=True)\n" + ] + }, + { + "cell_type": "markdown", + "id": "28d70c03-9602-472f-aa79-28777fe3b0e2", + "metadata": {}, + "source": [ + "The transformer model expects our labels to follow a common multi-label format of binaries, so let us use [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html) for this transformation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7efcbc6-938c-4341-9e83-e8761974deda", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import MultiLabelBinarizer\n", + "\n", + "# Turn labels into multi-label format\n", + "mb = MultiLabelBinarizer()\n", + "mb.fit(ds[\"test\"][\"label\"])\n", + "\n", + "\n", + "def binarize_labels(examples):\n", + " return {\"label\": mb.transform(examples[\"label\"])}\n", + "\n", + "\n", + "binarized_tokenized_ds = tokenized_ds.map(binarize_labels, batched=True)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e71e2cd7-50e4-4dd1-883f-f2e31d1ec051", + "metadata": {}, + "source": [ "Before we start the training, it is important to define our metric for the evaluation.\n", "Here we settle on the commonly used micro averaged *F1* metric, but we will also keep track of the *F1 per label*, for a more in-depth error analysis afterward." ] @@ -791,279 +2298,1354 @@ { "cell_type": "code", "execution_count": null, - "id": "d50f9ae8-2b1a-4905-b19f-62bcfa8c64d9", + "id": "d50f9ae8-2b1a-4905-b19f-62bcfa8c64d9", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_metric\n", + "import numpy as np\n", + "\n", + "# Define our metrics\n", + "metric = load_metric(\"f1\", config_name=\"multilabel\")\n", + "\n", + "\n", + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " # apply sigmoid\n", + " predictions = (1.0 / (1 + np.exp(-logits))) > 0.5\n", + "\n", + " # f1 micro averaged\n", + " metrics = metric.compute(\n", + " predictions=predictions, references=labels, average=\"micro\"\n", + " )\n", + " # f1 per label\n", + " per_label_metric = metric.compute(\n", + " predictions=predictions, references=labels, average=None\n", + " )\n", + " for label, f1 in zip(\n", + " ds[\"train\"].features[\"label\"][0].names, per_label_metric[\"f1\"]\n", + " ):\n", + " metrics[f\"f1_{label}\"] = f1\n", + "\n", + " return metrics\n" + ] + }, + { + "cell_type": "markdown", + "id": "d5cf20aa-0117-438d-a960-59a85adb4d37", + "metadata": {}, + "source": [ + "Now we are ready to load our pretrained transformer model and prepare it for our task: multi-label text classification with 6 labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fa9ee91-82f0-4cb1-a1b6-10369804e22a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "# Init our down-stream model\n", + "model = AutoModelForSequenceClassification.from_pretrained(\n", + " \"distilbert-base-uncased\", problem_type=\"multi_label_classification\", num_labels=6\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "29498e65-8edc-4086-8ab7-f2db90f8a113", + "metadata": {}, + "source": [ + "The only thing missing for the training is the `Trainer` and its `TrainingArguments`.\n", + "To keep it simple, we mostly rely on the default arguments, that often work out of the box, but tweak a bit the batch size to train faster. \n", + "We also checked that 2 epochs are enough for our rather small dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dd0c0bd-b7d9-4e1c-a734-8c2a8e9d57d6", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments\n", + "\n", + "# Set our training arguments\n", + "training_args = TrainingArguments(\n", + " output_dir=\"test_trainer\",\n", + " evaluation_strategy=\"epoch\",\n", + " num_train_epochs=2,\n", + " per_device_train_batch_size=16,\n", + " per_device_eval_batch_size=16,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61e0896c-dc7b-4844-95b6-a258aa8f8d1c", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import Trainer\n", + "\n", + "# Init the trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=binarized_tokenized_ds[\"train\"],\n", + " eval_dataset=binarized_tokenized_ds[\"test\"],\n", + " compute_metrics=compute_metrics,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ddd25cb-1cf8-4721-a8e8-c16eb0c7aaf7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Train the down-stream model\n", + "trainer.train()\n" + ] + }, + { + "cell_type": "markdown", + "id": "b77e8d7e-1463-42c8-96e2-13dfe98cc318", + "metadata": {}, + "source": [ + "We achieved an micro averaged *F1* of abut 0.54, which is not perfect, but a good baseline for this challenging dataset.\n", + "When inspecting the *F1s per label*, we clearly see that the worst performing labels are the ones with the poorest heuristics in terms of accuracy and coverage, which comes to no surprise." + ] + }, + { + "cell_type": "markdown", + "id": "ee0dfbaa-90f3-4803-8e93-b86e1440bd51", + "metadata": { + "tags": [] + }, + "source": [ + "## Research topic dataset\n", + "\n", + "After covering a multi-label emotion classification task, we will try to do the same for a multi-label classification task related to topic modeling.\n", + "In this dataset, research papers were classified with 6 non-exclusive labels based on their title and abstract.\n", + "\n", + "We will try to classify the papers only based on the title, which is considerably harder, but allows us to quickly scan through the data and come up with heuristics.\n", + "See Appendix B for all the details of the minimal data preprocessing." + ] + }, + { + "cell_type": "markdown", + "id": "8ba65805-bfb2-4027-9231-28538a7873e8", + "metadata": {}, + "source": [ + "### Define rules\n", + "\n", + "Let us start by downloading our preprocessed dataset from the Hugging Face Hub, and log it to Argilla:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd4cdd38-604c-4fe6-8104-fc6b136d331a", + "metadata": {}, + "outputs": [], + "source": [ + "import argilla as rg\n", + "from datasets import load_dataset\n", + "\n", + "# Download preprocessed dataset\n", + "ds_rb = rg.read_datasets(\n", + " load_dataset(\"argilla/research_titles_multi-label\", split=\"train\"),\n", + " task=\"TextClassification\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34899478-62ae-4864-8196-36348aa568c3", + "metadata": {}, + "outputs": [], + "source": [ + "# Log dataset to Argilla to find good heuristics\n", + "rg.log(ds_rb, \"research_titles\")" + ] + }, + { + "cell_type": "markdown", + "id": "1a95ac66-480b-4bf2-b35c-e548df16a084", + "metadata": {}, + "source": [ + "After uploading the dataset, we can explore and inspect it to find good heuristic rules.\n", + "For this we highly recommend the dedicated [*Define rules* mode](../../reference/webapp/features.html#weak-labelling) of the Argilla web app, that allows you to quickly iterate over heuristic rules, compute their metrics and save them.\n", + "\n", + "Here we copy our rules found via the web app to the notebook for you to easily follow along the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8f85fd20-7086-4581-9078-2a28c9155997", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import Rule\n", + "\n", + "# Define our heuristic rules (can probably be improved)\n", + "\n", + "rules = [\n", + " Rule(\"stock*\", \"Quantitative Finance\"),\n", + " Rule(\"*asset*\", \"Quantitative Finance\"),\n", + " Rule(\"pric*\", \"Quantitative Finance\"),\n", + " Rule(\"economy\", \"Quantitative Finance\"),\n", + " Rule(\"deep AND neural AND network*\", \"Computer Science\"),\n", + " Rule(\"convolutional\", \"Computer Science\"),\n", + " Rule(\"allocat* AND *net*\", \"Computer Science\"),\n", + " Rule(\"program\", \"Computer Science\"),\n", + " Rule(\"classification* AND (label* OR deep)\", \"Computer Science\"),\n", + " Rule(\"scattering\", \"Physics\"),\n", + " Rule(\"astro*\", \"Physics\"),\n", + " Rule(\"optical\", \"Physics\"),\n", + " Rule(\"ray\", \"Physics\"),\n", + " Rule(\"entangle*\", \"Physics\"),\n", + " Rule(\"*algebra*\", \"Mathematics\"),\n", + " Rule(\"spaces\", \"Mathematics\"),\n", + " Rule(\"operators\", \"Mathematics\"),\n", + " Rule(\"estimation\", \"Statistics\"),\n", + " Rule(\"mixture\", \"Statistics\"),\n", + " Rule(\"gaussian\", \"Statistics\"),\n", + " Rule(\"gene\", \"Quantitative Biology\"),\n", + "]\n" + ] + }, + { + "cell_type": "markdown", + "id": "93311e7a-4522-4036-9151-10edc1101d3d", + "metadata": {}, + "source": [ + "We go on and apply these heuristic rules to our dataset creating our weak label matrix.\n", + "As mentioned in the [GoEmotions](#goemotions) section, the weak label matrix will have 3 dimensions and values of -1, 0 and 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c41d2c90-e550-4d44-9935-b5eeea415c67", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import WeakMultiLabels\n", + "\n", + "# Compute the weak labels for our dataset given the rules\n", + "# If your dataset already contains rules you can omit the rules argument.\n", + "\n", + "\n", + "add_rules(dataset=\"research_titles\", rules=rules)\n", + "weak_labels = WeakMultiLabels(\"research_titles\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "fa6a2585-58cc-4eae-b5da-e31e29fd188d", + "metadata": {}, + "source": [ + "Let us get an overview of the our heuristics and how they perform:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "a5bfc002-0845-4d36-b2c5-8133995561ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcoverageannotated_coverageoverlapscorrectincorrectprecision
stock*{Quantitative Finance}0.0009540.0007150.000191301.000000
*asset*{Quantitative Finance}0.0004770.0007150.000238301.000000
pric*{Quantitative Finance}0.0034330.0033370.000668950.642857
economy{Quantitative Finance}0.0002380.0002380.000000101.000000
deep AND neural AND network*{Computer Science}0.0091550.0102500.00257532110.744186
convolutional{Computer Science}0.0101090.0092970.0020033270.820513
allocat* AND *net*{Computer Science}0.0007630.0007150.000000301.000000
program{Computer Science}0.0026230.0030990.0000951120.846154
classification* AND (label* OR deep){Computer Science}0.0033380.0040520.0012871430.823529
scattering{Physics}0.0040530.0028610.0005721020.833333
astro*{Physics}0.0030990.0040520.0004771701.000000
optical{Physics}0.0071050.0069130.0008112720.931034
ray{Physics}0.0058650.0073900.0006682740.870968
entangle*{Physics}0.0026230.0028610.0000481110.916667
*algebra*{Mathematics}0.0148290.0183550.0004297070.909091
spaces{Mathematics}0.0105860.0097740.0012873830.926829
operators{Mathematics}0.0061510.0059590.0011922230.880000
estimation{Statistics}0.0212660.0212160.00162165240.730337
mixture{Statistics}0.0032900.0030990.0009061030.769231
gaussian{Statistics}0.0092500.0112040.00152636110.765957
gene{Quantitative Biology}0.0012870.0016690.000143610.857143
total{Mathematics, Quantitative Biology, Physics, Q...0.1119110.1189510.008154447890.833955
\n", + "
" + ], + "text/plain": [ + " label \\\n", + "stock* {Quantitative Finance} \n", + "*asset* {Quantitative Finance} \n", + "pric* {Quantitative Finance} \n", + "economy {Quantitative Finance} \n", + "deep AND neural AND network* {Computer Science} \n", + "convolutional {Computer Science} \n", + "allocat* AND *net* {Computer Science} \n", + "program {Computer Science} \n", + "classification* AND (label* OR deep) {Computer Science} \n", + "scattering {Physics} \n", + "astro* {Physics} \n", + "optical {Physics} \n", + "ray {Physics} \n", + "entangle* {Physics} \n", + "*algebra* {Mathematics} \n", + "spaces {Mathematics} \n", + "operators {Mathematics} \n", + "estimation {Statistics} \n", + "mixture {Statistics} \n", + "gaussian {Statistics} \n", + "gene {Quantitative Biology} \n", + "total {Mathematics, Quantitative Biology, Physics, Q... \n", + "\n", + " coverage annotated_coverage overlaps \\\n", + "stock* 0.000954 0.000715 0.000191 \n", + "*asset* 0.000477 0.000715 0.000238 \n", + "pric* 0.003433 0.003337 0.000668 \n", + "economy 0.000238 0.000238 0.000000 \n", + "deep AND neural AND network* 0.009155 0.010250 0.002575 \n", + "convolutional 0.010109 0.009297 0.002003 \n", + "allocat* AND *net* 0.000763 0.000715 0.000000 \n", + "program 0.002623 0.003099 0.000095 \n", + "classification* AND (label* OR deep) 0.003338 0.004052 0.001287 \n", + "scattering 0.004053 0.002861 0.000572 \n", + "astro* 0.003099 0.004052 0.000477 \n", + "optical 0.007105 0.006913 0.000811 \n", + "ray 0.005865 0.007390 0.000668 \n", + "entangle* 0.002623 0.002861 0.000048 \n", + "*algebra* 0.014829 0.018355 0.000429 \n", + "spaces 0.010586 0.009774 0.001287 \n", + "operators 0.006151 0.005959 0.001192 \n", + "estimation 0.021266 0.021216 0.001621 \n", + "mixture 0.003290 0.003099 0.000906 \n", + "gaussian 0.009250 0.011204 0.001526 \n", + "gene 0.001287 0.001669 0.000143 \n", + "total 0.111911 0.118951 0.008154 \n", + "\n", + " correct incorrect precision \n", + "stock* 3 0 1.000000 \n", + "*asset* 3 0 1.000000 \n", + "pric* 9 5 0.642857 \n", + "economy 1 0 1.000000 \n", + "deep AND neural AND network* 32 11 0.744186 \n", + "convolutional 32 7 0.820513 \n", + "allocat* AND *net* 3 0 1.000000 \n", + "program 11 2 0.846154 \n", + "classification* AND (label* OR deep) 14 3 0.823529 \n", + "scattering 10 2 0.833333 \n", + "astro* 17 0 1.000000 \n", + "optical 27 2 0.931034 \n", + "ray 27 4 0.870968 \n", + "entangle* 11 1 0.916667 \n", + "*algebra* 70 7 0.909091 \n", + "spaces 38 3 0.926829 \n", + "operators 22 3 0.880000 \n", + "estimation 65 24 0.730337 \n", + "mixture 10 3 0.769231 \n", + "gaussian 36 11 0.765957 \n", + "gene 6 1 0.857143 \n", + "total 447 89 0.833955 " + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check coverage/precision of our rules\n", + "weak_labels.summary()\n" + ] + }, + { + "cell_type": "markdown", + "id": "cd8ed61a", + "metadata": {}, + "source": [ + "Consider the case we have come up with new rules and want to add them to dataset " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "948d2058", + "metadata": {}, + "outputs": [], + "source": [ + "additional_rules = [\n", + " Rule(\"trading\", \"Quantitative Finance\"),\n", + " Rule(\"finance\", \"Quantitative Finance\"),\n", + " Rule(\"memor* AND (design* OR network*)\", \"Computer Science\"),\n", + " Rule(\"system* AND design*\", \"Computer Science\"),\n", + " Rule(\"material*\", \"Physics\"),\n", + " Rule(\"spin\", \"Physics\"),\n", + " Rule(\"magnetic\", \"Physics\"),\n", + " Rule(\"manifold* AND (NOT learn*)\", \"Mathematics\"),\n", + " Rule(\"equation\", \"Mathematics\"),\n", + " Rule(\"regression\", \"Statistics\"),\n", + " Rule(\"bayes*\", \"Statistics\"),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "ae80efed", + "metadata": {}, + "outputs": [], + "source": [ + "add_rules(dataset=\"research_titles\", rules=additional_rules)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0cc85ea", "metadata": {}, "outputs": [], "source": [ - "from datasets import load_metric\n", - "import numpy as np\n", - "\n", - "# Define our metrics\n", - "metric = load_metric(\"f1\", config_name=\"multilabel\")\n", - "\n", - "\n", - "def compute_metrics(eval_pred):\n", - " logits, labels = eval_pred\n", - " # apply sigmoid\n", - " predictions = (1.0 / (1 + np.exp(-logits))) > 0.5\n", - "\n", - " # f1 micro averaged\n", - " metrics = metric.compute(\n", - " predictions=predictions, references=labels, average=\"micro\"\n", - " )\n", - " # f1 per label\n", - " per_label_metric = metric.compute(\n", - " predictions=predictions, references=labels, average=None\n", - " )\n", - " for label, f1 in zip(\n", - " ds[\"train\"].features[\"label\"][0].names, per_label_metric[\"f1\"]\n", - " ):\n", - " metrics[f\"f1_{label}\"] = f1\n", - "\n", - " return metrics\n" + "weak_labels = WeakMultiLabels(\"research_titles\")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "09d35bb2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcoverageannotated_coverageoverlapscorrectincorrectprecision
stock*{Quantitative Finance}0.0009540.0007150.000334301.000000
*asset*{Quantitative Finance}0.0004770.0007150.000286301.000000
pric*{Quantitative Finance}0.0034330.0033370.000715950.642857
economy{Quantitative Finance}0.0002380.0002380.000000101.000000
deep AND neural AND network*{Computer Science}0.0091550.0102500.00290932110.744186
convolutional{Computer Science}0.0101090.0092970.0022413270.820513
allocat* AND *net*{Computer Science}0.0007630.0007150.000000301.000000
program{Computer Science}0.0026230.0030990.0001431120.846154
classification* AND (label* OR deep){Computer Science}0.0033380.0040520.0013351430.823529
scattering{Physics}0.0040530.0028610.0010011020.833333
astro*{Physics}0.0030990.0040520.0006201701.000000
optical{Physics}0.0071050.0069130.0010972720.931034
ray{Physics}0.0058650.0073900.0011922740.870968
entangle*{Physics}0.0026230.0028610.0000951110.916667
*algebra*{Mathematics}0.0148290.0183550.0006207070.909091
spaces{Mathematics}0.0105860.0097740.0018603830.926829
operators{Mathematics}0.0061510.0059590.0015262230.880000
estimation{Statistics}0.0212660.0212160.00338565240.730337
mixture{Statistics}0.0032900.0030990.0012871030.769231
gaussian{Statistics}0.0092500.0112040.00276636110.765957
gene{Quantitative Biology}0.0012870.0016690.000191610.857143
trading{Quantitative Finance}0.0009540.0002380.000191101.000000
finance{Quantitative Finance}0.0000480.0002380.000000101.000000
memor* AND (design* OR network*){Computer Science}0.0013830.0021450.000286901.000000
system* AND design*{Computer Science}0.0011440.0023840.000238910.900000
material*{Physics}0.0041480.0030990.0002381030.769231
spin{Physics}0.0135420.0150180.0021466030.952381
magnetic{Physics}0.0113010.0128720.0024324950.907407
manifold* AND (NOT learn*){Mathematics}0.0070570.0083430.0008582870.800000
equation{Mathematics}0.0106810.0078670.0009542490.727273
regression{Statistics}0.0093930.0090580.0025753350.868421
bayes*{Statistics}0.0153060.0147790.00314749130.790323
total{Mathematics, Quantitative Biology, Physics, Q...0.1766160.1859360.0178337201350.842105
\n", + "
" + ], + "text/plain": [ + " label \\\n", + "stock* {Quantitative Finance} \n", + "*asset* {Quantitative Finance} \n", + "pric* {Quantitative Finance} \n", + "economy {Quantitative Finance} \n", + "deep AND neural AND network* {Computer Science} \n", + "convolutional {Computer Science} \n", + "allocat* AND *net* {Computer Science} \n", + "program {Computer Science} \n", + "classification* AND (label* OR deep) {Computer Science} \n", + "scattering {Physics} \n", + "astro* {Physics} \n", + "optical {Physics} \n", + "ray {Physics} \n", + "entangle* {Physics} \n", + "*algebra* {Mathematics} \n", + "spaces {Mathematics} \n", + "operators {Mathematics} \n", + "estimation {Statistics} \n", + "mixture {Statistics} \n", + "gaussian {Statistics} \n", + "gene {Quantitative Biology} \n", + "trading {Quantitative Finance} \n", + "finance {Quantitative Finance} \n", + "memor* AND (design* OR network*) {Computer Science} \n", + "system* AND design* {Computer Science} \n", + "material* {Physics} \n", + "spin {Physics} \n", + "magnetic {Physics} \n", + "manifold* AND (NOT learn*) {Mathematics} \n", + "equation {Mathematics} \n", + "regression {Statistics} \n", + "bayes* {Statistics} \n", + "total {Mathematics, Quantitative Biology, Physics, Q... \n", + "\n", + " coverage annotated_coverage overlaps \\\n", + "stock* 0.000954 0.000715 0.000334 \n", + "*asset* 0.000477 0.000715 0.000286 \n", + "pric* 0.003433 0.003337 0.000715 \n", + "economy 0.000238 0.000238 0.000000 \n", + "deep AND neural AND network* 0.009155 0.010250 0.002909 \n", + "convolutional 0.010109 0.009297 0.002241 \n", + "allocat* AND *net* 0.000763 0.000715 0.000000 \n", + "program 0.002623 0.003099 0.000143 \n", + "classification* AND (label* OR deep) 0.003338 0.004052 0.001335 \n", + "scattering 0.004053 0.002861 0.001001 \n", + "astro* 0.003099 0.004052 0.000620 \n", + "optical 0.007105 0.006913 0.001097 \n", + "ray 0.005865 0.007390 0.001192 \n", + "entangle* 0.002623 0.002861 0.000095 \n", + "*algebra* 0.014829 0.018355 0.000620 \n", + "spaces 0.010586 0.009774 0.001860 \n", + "operators 0.006151 0.005959 0.001526 \n", + "estimation 0.021266 0.021216 0.003385 \n", + "mixture 0.003290 0.003099 0.001287 \n", + "gaussian 0.009250 0.011204 0.002766 \n", + "gene 0.001287 0.001669 0.000191 \n", + "trading 0.000954 0.000238 0.000191 \n", + "finance 0.000048 0.000238 0.000000 \n", + "memor* AND (design* OR network*) 0.001383 0.002145 0.000286 \n", + "system* AND design* 0.001144 0.002384 0.000238 \n", + "material* 0.004148 0.003099 0.000238 \n", + "spin 0.013542 0.015018 0.002146 \n", + "magnetic 0.011301 0.012872 0.002432 \n", + "manifold* AND (NOT learn*) 0.007057 0.008343 0.000858 \n", + "equation 0.010681 0.007867 0.000954 \n", + "regression 0.009393 0.009058 0.002575 \n", + "bayes* 0.015306 0.014779 0.003147 \n", + "total 0.176616 0.185936 0.017833 \n", + "\n", + " correct incorrect precision \n", + "stock* 3 0 1.000000 \n", + "*asset* 3 0 1.000000 \n", + "pric* 9 5 0.642857 \n", + "economy 1 0 1.000000 \n", + "deep AND neural AND network* 32 11 0.744186 \n", + "convolutional 32 7 0.820513 \n", + "allocat* AND *net* 3 0 1.000000 \n", + "program 11 2 0.846154 \n", + "classification* AND (label* OR deep) 14 3 0.823529 \n", + "scattering 10 2 0.833333 \n", + "astro* 17 0 1.000000 \n", + "optical 27 2 0.931034 \n", + "ray 27 4 0.870968 \n", + "entangle* 11 1 0.916667 \n", + "*algebra* 70 7 0.909091 \n", + "spaces 38 3 0.926829 \n", + "operators 22 3 0.880000 \n", + "estimation 65 24 0.730337 \n", + "mixture 10 3 0.769231 \n", + "gaussian 36 11 0.765957 \n", + "gene 6 1 0.857143 \n", + "trading 1 0 1.000000 \n", + "finance 1 0 1.000000 \n", + "memor* AND (design* OR network*) 9 0 1.000000 \n", + "system* AND design* 9 1 0.900000 \n", + "material* 10 3 0.769231 \n", + "spin 60 3 0.952381 \n", + "magnetic 49 5 0.907407 \n", + "manifold* AND (NOT learn*) 28 7 0.800000 \n", + "equation 24 9 0.727273 \n", + "regression 33 5 0.868421 \n", + "bayes* 49 13 0.790323 \n", + "total 720 135 0.842105 " + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weak_labels.summary()" ] }, { "cell_type": "markdown", - "id": "d5cf20aa-0117-438d-a960-59a85adb4d37", + "id": "08afaae4", "metadata": {}, "source": [ - "Now we are ready to load our pretrained transformer model and prepare it for our task: multi-label text classification with 6 labels." + "Let's create new rules and see their affects, if they are informative enough we can proceed by adding them to dataset" ] }, { "cell_type": "code", - "execution_count": null, - "id": "1fa9ee91-82f0-4cb1-a1b6-10369804e22a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from transformers import AutoModelForSequenceClassification\n", - "\n", - "# Init our down-stream model\n", - "model = AutoModelForSequenceClassification.from_pretrained(\n", - " \"distilbert-base-uncased\", problem_type=\"multi_label_classification\", num_labels=6\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "29498e65-8edc-4086-8ab7-f2db90f8a113", + "execution_count": 36, + "id": "9b3f483c", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.004672897196261682,\n", + " 'annotated_coverage': 0.004529201430274136,\n", + " 'correct': 17,\n", + " 'incorrect': 2,\n", + " 'precision': 0.8947368421052632}" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "The only thing missing for the training is the `Trainer` and its `TrainingArguments`.\n", - "To keep it simple, we mostly rely on the default arguments, that often work out of the box, but tweak a bit the batch size to train faster. \n", - "We also checked that 2 epochs are enough for our rather small dataset." + "# create a statistics rule and get its metrics\n", + "statistics_rule = Rule(\"sample\", \"Statistics\")\n", + "statistics_rule.apply(\"research_titles\")\n", + "statistics_rule.metrics(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "2dd0c0bd-b7d9-4e1c-a734-8c2a8e9d57d6", + "execution_count": 37, + "id": "6e562f31", "metadata": {}, "outputs": [], "source": [ - "from transformers import TrainingArguments\n", - "\n", - "# Set our training arguments\n", - "training_args = TrainingArguments(\n", - " output_dir=\"test_trainer\",\n", - " evaluation_strategy=\"epoch\",\n", - " num_train_epochs=2,\n", - " per_device_train_batch_size=16,\n", - " per_device_eval_batch_size=16,\n", - ")\n" + "# add the statistics_rule to the research_titles dataset\n", + "statistics_rule.add_to_dataset(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "61e0896c-dc7b-4844-95b6-a258aa8f8d1c", + "execution_count": 38, + "id": "b23accdf", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.004815945069616631,\n", + " 'annotated_coverage': 0.004290822407628129,\n", + " 'correct': 1,\n", + " 'incorrect': 17,\n", + " 'precision': 0.05555555555555555}" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from transformers import Trainer\n", - "\n", - "# Init the trainer\n", - "trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=binarized_tokenized_ds[\"train\"],\n", - " eval_dataset=binarized_tokenized_ds[\"test\"],\n", - " compute_metrics=compute_metrics,\n", - ")\n" + "finance_rule = Rule(\"risk\", \"Quantitative Finance\")\n", + "finance_rule.apply(\"research_titles\")\n", + "finance_rule.metrics(\"research_titles\")\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "1ddd25cb-1cf8-4721-a8e8-c16eb0c7aaf7", - "metadata": { - "tags": [] - }, + "execution_count": 39, + "id": "5533291e", + "metadata": {}, "outputs": [], "source": [ - "# Train the down-stream model\n", - "trainer.train()\n" + "finance_rule.add_to_dataset(\"research_titles\")" ] }, { "cell_type": "markdown", - "id": "b77e8d7e-1463-42c8-96e2-13dfe98cc318", + "id": "245d572a", "metadata": {}, "source": [ - "We achieved an micro averaged *F1* of abut 0.54, which is not perfect, but a good baseline for this challenging dataset.\n", - "When inspecting the *F1s per label*, we clearly see that the worst performing labels are the ones with the poorest heuristics in terms of accuracy and coverage, which comes to no surprise." + "Our assertion does not seem correct lets update this rule" ] }, { - "cell_type": "markdown", - "id": "ee0dfbaa-90f3-4803-8e93-b86e1440bd51", - "metadata": { - "tags": [] - }, + "cell_type": "code", + "execution_count": 40, + "id": "ff59d9b3", + "metadata": {}, + "outputs": [], "source": [ - "## Research topic dataset\n", - "\n", - "After covering a multi-label emotion classification task, we will try to do the same for a multi-label classification task related to topic modeling.\n", - "In this dataset, research papers were classified with 6 non-exclusive labels based on their title and abstract.\n", - "\n", - "We will try to classify the papers only based on the title, which is considerably harder, but allows us to quickly scan through the data and come up with heuristics.\n", - "See Appendix B for all the details of the minimal data preprocessing." + "rule = Rule(\"risk\", \"Statistics\")" ] }, { - "cell_type": "markdown", - "id": "8ba65805-bfb2-4027-9231-28538a7873e8", + "cell_type": "code", + "execution_count": 41, + "id": "83f689f1", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.004815945069616631,\n", + " 'annotated_coverage': 0.004290822407628129,\n", + " 'correct': 11,\n", + " 'incorrect': 7,\n", + " 'precision': 0.6111111111111112}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "### Define rules\n", - "\n", - "Let us start by downloading our preprocessed dataset from the Hugging Face Hub, and log it to Argilla:" + "rule.metrics(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "dd4cdd38-604c-4fe6-8104-fc6b136d331a", + "execution_count": 42, + "id": "4364f261", "metadata": {}, "outputs": [], "source": [ - "import argilla as rg\n", - "from datasets import load_dataset\n", - "\n", - "# Download preprocessed dataset\n", - "ds_rb = rg.read_datasets(\n", - " load_dataset(\"argilla/research_titles_multi-label\", split=\"train\"),\n", - " task=\"TextClassification\",\n", - ")\n" + "rule.update_at_dataset(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "34899478-62ae-4864-8196-36348aa568c3", + "execution_count": 43, + "id": "0df8ab0e", "metadata": {}, "outputs": [], "source": [ - "# Log dataset to Argilla to find good heuristics\n", - "rg.log(ds_rb, \"research_titles\")\n" + "quantitative_biology_rule = Rule(\"dna\", \"Quantitative Biology\")" ] }, { - "cell_type": "markdown", - "id": "1a95ac66-480b-4bf2-b35c-e548df16a084", + "cell_type": "code", + "execution_count": 44, + "id": "0949eb8d", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'coverage': 0.0013351134846461949,\n", + " 'annotated_coverage': 0.0011918951132300357,\n", + " 'correct': 4,\n", + " 'incorrect': 1,\n", + " 'precision': 0.8}" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "After uploading the dataset, we can explore and inspect it to find good heuristic rules.\n", - "For this we highly recommend the dedicated [*Define rules* mode](../../reference/webapp/features.html#weak-labelling) of the Argilla web app, that allows you to quickly iterate over heuristic rules, compute their metrics and save them.\n", - "\n", - "Here we copy our rules found via the web app to the notebook for you to easily follow along the tutorial." + "quantitative_biology_rule.metrics(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "8f85fd20-7086-4581-9078-2a28c9155997", + "execution_count": 45, + "id": "170b7704", "metadata": {}, "outputs": [], "source": [ - "from argilla.labeling.text_classification import Rule\n", - "\n", - "# Define our heuristic rules (can probably be improved)\n", - "rules = [\n", - " Rule(\"stock*\", \"Quantitative Finance\"),\n", - " Rule(\"*asset*\", \"Quantitative Finance\"),\n", - " Rule(\"trading\", \"Quantitative Finance\"),\n", - " Rule(\"finance\", \"Quantitative Finance\"),\n", - " Rule(\"pric*\", \"Quantitative Finance\"),\n", - " Rule(\"economy\", \"Quantitative Finance\"),\n", - " Rule(\"deep AND neural AND network*\", \"Computer Science\"),\n", - " Rule(\"convolutional\", \"Computer Science\"),\n", - " Rule(\"memor* AND (design* OR network*)\", \"Computer Science\"),\n", - " Rule(\"system* AND design*\", \"Computer Science\"),\n", - " Rule(\"allocat* AND *net*\", \"Computer Science\"),\n", - " Rule(\"program\", \"Computer Science\"),\n", - " Rule(\"classification* AND (label* OR deep)\", \"Computer Science\"),\n", - " Rule(\"scattering\", \"Physics\"),\n", - " Rule(\"astro*\", \"Physics\"),\n", - " Rule(\"material*\", \"Physics\"),\n", - " Rule(\"spin\", \"Physics\"),\n", - " Rule(\"magnetic\", \"Physics\"),\n", - " Rule(\"optical\", \"Physics\"),\n", - " Rule(\"ray\", \"Physics\"),\n", - " Rule(\"entangle*\", \"Physics\"),\n", - " Rule(\"*algebra*\", \"Mathematics\"),\n", - " Rule(\"manifold* AND (NOT learn*)\", \"Mathematics\"),\n", - " Rule(\"equation\", \"Mathematics\"),\n", - " Rule(\"spaces\", \"Mathematics\"),\n", - " Rule(\"operators\", \"Mathematics\"),\n", - " Rule(\"regression\", \"Statistics\"),\n", - " Rule(\"bayes*\", \"Statistics\"),\n", - " Rule(\"estimation\", \"Statistics\"),\n", - " Rule(\"mixture\", \"Statistics\"),\n", - " Rule(\"gaussian\", \"Statistics\"),\n", - " Rule(\"gene\", \"Quantitative Biology\"),\n", - "]\n" + "quantitative_biology_rule.add_to_dataset(\"research_titles\")" ] }, { "cell_type": "markdown", - "id": "93311e7a-4522-4036-9151-10edc1101d3d", + "id": "722fa8f9", "metadata": {}, "source": [ - "We go on and apply these heuristic rules to our dataset creating our weak label matrix.\n", - "As mentioned in the [GoEmotions](#goemotions) section, the weak label matrix will have 3 dimensions and values of -1, 0 and 1." + "Lets see the final matrix with new added rules" ] }, { "cell_type": "code", "execution_count": null, - "id": "c41d2c90-e550-4d44-9935-b5eeea415c67", + "id": "204ddcf1", "metadata": {}, "outputs": [], "source": [ - "from argilla.labeling.text_classification import WeakMultiLabels\n", - "\n", - "# Compute the weak labels for our dataset given the rules\n", - "# If your dataset already contains rules you can omit the rules argument.\n", - "weak_labels = WeakMultiLabels(\"research_titles\", rules=rules)\n" - ] - }, - { - "cell_type": "markdown", - "id": "fa6a2585-58cc-4eae-b5da-e31e29fd188d", - "metadata": {}, - "source": [ - "Let us get an overview of the our heuristics and how they perform:" + "weak_labels = WeakMultiLabels(\"research_titles\")" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "a5bfc002-0845-4d36-b2c5-8133995561ce", + "execution_count": 47, + "id": "eece91f2", "metadata": {}, "outputs": [ { @@ -1112,37 +3694,17 @@ " {Quantitative Finance}\n", " 0.000477\n", " 0.000715\n", - " 0.000286\n", + " 0.000334\n", " 3\n", " 0\n", " 1.000000\n", " \n", " \n", - " trading\n", - " {Quantitative Finance}\n", - " 0.000954\n", - " 0.000238\n", - " 0.000191\n", - " 1\n", - " 0\n", - " 1.000000\n", - " \n", - " \n", - " finance\n", - " {Quantitative Finance}\n", - " 0.000048\n", - " 0.000238\n", - " 0.000000\n", - " 1\n", - " 0\n", - " 1.000000\n", - " \n", - " \n", " pric*\n", " {Quantitative Finance}\n", " 0.003433\n", " 0.003337\n", - " 0.000715\n", + " 0.000811\n", " 9\n", " 5\n", " 0.642857\n", @@ -1152,7 +3714,7 @@ " {Quantitative Finance}\n", " 0.000238\n", " 0.000238\n", - " 0.000000\n", + " 0.000048\n", " 1\n", " 0\n", " 1.000000\n", @@ -1162,7 +3724,7 @@ " {Computer Science}\n", " 0.009155\n", " 0.010250\n", - " 0.002909\n", + " 0.002956\n", " 32\n", " 11\n", " 0.744186\n", @@ -1172,12 +3734,182 @@ " {Computer Science}\n", " 0.010109\n", " 0.009297\n", - " 0.002241\n", + " 0.002336\n", " 32\n", " 7\n", " 0.820513\n", " \n", " \n", + " allocat* AND *net*\n", + " {Computer Science}\n", + " 0.000763\n", + " 0.000715\n", + " 0.000048\n", + " 3\n", + " 0\n", + " 1.000000\n", + " \n", + " \n", + " program\n", + " {Computer Science}\n", + " 0.002623\n", + " 0.003099\n", + " 0.000191\n", + " 11\n", + " 2\n", + " 0.846154\n", + " \n", + " \n", + " classification* AND (label* OR deep)\n", + " {Computer Science}\n", + " 0.003338\n", + " 0.004052\n", + " 0.001335\n", + " 14\n", + " 3\n", + " 0.823529\n", + " \n", + " \n", + " scattering\n", + " {Physics}\n", + " 0.004053\n", + " 0.002861\n", + " 0.001049\n", + " 10\n", + " 2\n", + " 0.833333\n", + " \n", + " \n", + " astro*\n", + " {Physics}\n", + " 0.003099\n", + " 0.004052\n", + " 0.000668\n", + " 17\n", + " 0\n", + " 1.000000\n", + " \n", + " \n", + " optical\n", + " {Physics}\n", + " 0.007105\n", + " 0.006913\n", + " 0.001097\n", + " 27\n", + " 2\n", + " 0.931034\n", + " \n", + " \n", + " ray\n", + " {Physics}\n", + " 0.005865\n", + " 0.007390\n", + " 0.001240\n", + " 27\n", + " 4\n", + " 0.870968\n", + " \n", + " \n", + " entangle*\n", + " {Physics}\n", + " 0.002623\n", + " 0.002861\n", + " 0.000095\n", + " 11\n", + " 1\n", + " 0.916667\n", + " \n", + " \n", + " *algebra*\n", + " {Mathematics}\n", + " 0.014829\n", + " 0.018355\n", + " 0.000620\n", + " 70\n", + " 7\n", + " 0.909091\n", + " \n", + " \n", + " spaces\n", + " {Mathematics}\n", + " 0.010586\n", + " 0.009774\n", + " 0.001860\n", + " 38\n", + " 3\n", + " 0.926829\n", + " \n", + " \n", + " operators\n", + " {Mathematics}\n", + " 0.006151\n", + " 0.005959\n", + " 0.001574\n", + " 22\n", + " 3\n", + " 0.880000\n", + " \n", + " \n", + " estimation\n", + " {Statistics}\n", + " 0.021266\n", + " 0.021216\n", + " 0.003862\n", + " 65\n", + " 24\n", + " 0.730337\n", + " \n", + " \n", + " mixture\n", + " {Statistics}\n", + " 0.003290\n", + " 0.003099\n", + " 0.001335\n", + " 10\n", + " 3\n", + " 0.769231\n", + " \n", + " \n", + " gaussian\n", + " {Statistics}\n", + " 0.009250\n", + " 0.011204\n", + " 0.003052\n", + " 36\n", + " 11\n", + " 0.765957\n", + " \n", + " \n", + " gene\n", + " {Quantitative Biology}\n", + " 0.001287\n", + " 0.001669\n", + " 0.000191\n", + " 6\n", + " 1\n", + " 0.857143\n", + " \n", + " \n", + " trading\n", + " {Quantitative Finance}\n", + " 0.000954\n", + " 0.000238\n", + " 0.000191\n", + " 1\n", + " 0\n", + " 1.000000\n", + " \n", + " \n", + " finance\n", + " {Quantitative Finance}\n", + " 0.000048\n", + " 0.000238\n", + " 0.000000\n", + " 1\n", + " 0\n", + " 1.000000\n", + " \n", + " \n", " memor* AND (design* OR network*)\n", " {Computer Science}\n", " 0.001383\n", @@ -1198,56 +3930,6 @@ " 0.900000\n", " \n", " \n", - " allocat* AND *net*\n", - " {Computer Science}\n", - " 0.000763\n", - " 0.000715\n", - " 0.000000\n", - " 3\n", - " 0\n", - " 1.000000\n", - " \n", - " \n", - " program\n", - " {Computer Science}\n", - " 0.002623\n", - " 0.003099\n", - " 0.000143\n", - " 11\n", - " 2\n", - " 0.846154\n", - " \n", - " \n", - " classification* AND (label* OR deep)\n", - " {Computer Science}\n", - " 0.003338\n", - " 0.004052\n", - " 0.001335\n", - " 14\n", - " 3\n", - " 0.823529\n", - " \n", - " \n", - " scattering\n", - " {Physics}\n", - " 0.004053\n", - " 0.002861\n", - " 0.001001\n", - " 10\n", - " 2\n", - " 0.833333\n", - " \n", - " \n", - " astro*\n", - " {Physics}\n", - " 0.003099\n", - " 0.004052\n", - " 0.000620\n", - " 17\n", - " 0\n", - " 1.000000\n", - " \n", - " \n", " material*\n", " {Physics}\n", " 0.004148\n", @@ -1278,46 +3960,6 @@ " 0.907407\n", " \n", " \n", - " optical\n", - " {Physics}\n", - " 0.007105\n", - " 0.006913\n", - " 0.001097\n", - " 27\n", - " 2\n", - " 0.931034\n", - " \n", - " \n", - " ray\n", - " {Physics}\n", - " 0.005865\n", - " 0.007390\n", - " 0.001192\n", - " 27\n", - " 4\n", - " 0.870968\n", - " \n", - " \n", - " entangle*\n", - " {Physics}\n", - " 0.002623\n", - " 0.002861\n", - " 0.000095\n", - " 11\n", - " 1\n", - " 0.916667\n", - " \n", - " \n", - " *algebra*\n", - " {Mathematics}\n", - " 0.014829\n", - " 0.018355\n", - " 0.000620\n", - " 70\n", - " 7\n", - " 0.909091\n", - " \n", - " \n", " manifold* AND (NOT learn*)\n", " {Mathematics}\n", " 0.007057\n", @@ -1332,37 +3974,17 @@ " {Mathematics}\n", " 0.010681\n", " 0.007867\n", - " 0.000954\n", + " 0.001001\n", " 24\n", " 9\n", " 0.727273\n", " \n", " \n", - " spaces\n", - " {Mathematics}\n", - " 0.010586\n", - " 0.009774\n", - " 0.001860\n", - " 38\n", - " 3\n", - " 0.926829\n", - " \n", - " \n", - " operators\n", - " {Mathematics}\n", - " 0.006151\n", - " 0.005959\n", - " 0.001526\n", - " 22\n", - " 3\n", - " 0.880000\n", - " \n", - " \n", " regression\n", " {Statistics}\n", " 0.009393\n", " 0.009058\n", - " 0.002575\n", + " 0.002718\n", " 33\n", " 5\n", " 0.868421\n", @@ -1372,60 +3994,50 @@ " {Statistics}\n", " 0.015306\n", " 0.014779\n", - " 0.003147\n", + " 0.003481\n", " 49\n", " 13\n", " 0.790323\n", " \n", " \n", - " estimation\n", - " {Statistics}\n", - " 0.021266\n", - " 0.021216\n", - " 0.003385\n", - " 65\n", - " 24\n", - " 0.730337\n", - " \n", - " \n", - " mixture\n", + " sample\n", " {Statistics}\n", - " 0.003290\n", - " 0.003099\n", - " 0.001287\n", - " 10\n", - " 3\n", - " 0.769231\n", + " 0.004673\n", + " 0.004529\n", + " 0.000811\n", + " 17\n", + " 2\n", + " 0.894737\n", " \n", " \n", - " gaussian\n", + " risk\n", " {Statistics}\n", - " 0.009250\n", - " 0.011204\n", - " 0.002766\n", - " 36\n", + " 0.004816\n", + " 0.004291\n", + " 0.001097\n", " 11\n", - " 0.765957\n", + " 7\n", + " 0.611111\n", " \n", " \n", - " gene\n", + " dna\n", " {Quantitative Biology}\n", - " 0.001287\n", - " 0.001669\n", - " 0.000191\n", - " 6\n", + " 0.001335\n", + " 0.001192\n", + " 0.000143\n", + " 4\n", " 1\n", - " 0.857143\n", + " 0.800000\n", " \n", " \n", " total\n", - " {Physics, Quantitative Biology, Mathematics, C...\n", - " 0.176616\n", - " 0.185936\n", - " 0.017833\n", - " 720\n", - " 135\n", - " 0.842105\n", + " {Mathematics, Quantitative Biology, Physics, Q...\n", + " 0.185390\n", + " 0.194041\n", + " 0.019788\n", + " 752\n", + " 145\n", + " 0.838350\n", " \n", " \n", "\n", @@ -1435,117 +4047,125 @@ " label \\\n", "stock* {Quantitative Finance} \n", "*asset* {Quantitative Finance} \n", - "trading {Quantitative Finance} \n", - "finance {Quantitative Finance} \n", "pric* {Quantitative Finance} \n", "economy {Quantitative Finance} \n", "deep AND neural AND network* {Computer Science} \n", "convolutional {Computer Science} \n", - "memor* AND (design* OR network*) {Computer Science} \n", - "system* AND design* {Computer Science} \n", "allocat* AND *net* {Computer Science} \n", "program {Computer Science} \n", "classification* AND (label* OR deep) {Computer Science} \n", "scattering {Physics} \n", "astro* {Physics} \n", - "material* {Physics} \n", - "spin {Physics} \n", - "magnetic {Physics} \n", "optical {Physics} \n", "ray {Physics} \n", "entangle* {Physics} \n", "*algebra* {Mathematics} \n", - "manifold* AND (NOT learn*) {Mathematics} \n", - "equation {Mathematics} \n", "spaces {Mathematics} \n", "operators {Mathematics} \n", - "regression {Statistics} \n", - "bayes* {Statistics} \n", "estimation {Statistics} \n", "mixture {Statistics} \n", "gaussian {Statistics} \n", "gene {Quantitative Biology} \n", - "total {Physics, Quantitative Biology, Mathematics, C... \n", + "trading {Quantitative Finance} \n", + "finance {Quantitative Finance} \n", + "memor* AND (design* OR network*) {Computer Science} \n", + "system* AND design* {Computer Science} \n", + "material* {Physics} \n", + "spin {Physics} \n", + "magnetic {Physics} \n", + "manifold* AND (NOT learn*) {Mathematics} \n", + "equation {Mathematics} \n", + "regression {Statistics} \n", + "bayes* {Statistics} \n", + "sample {Statistics} \n", + "risk {Statistics} \n", + "dna {Quantitative Biology} \n", + "total {Mathematics, Quantitative Biology, Physics, Q... \n", "\n", " coverage annotated_coverage overlaps \\\n", "stock* 0.000954 0.000715 0.000334 \n", - "*asset* 0.000477 0.000715 0.000286 \n", + "*asset* 0.000477 0.000715 0.000334 \n", + "pric* 0.003433 0.003337 0.000811 \n", + "economy 0.000238 0.000238 0.000048 \n", + "deep AND neural AND network* 0.009155 0.010250 0.002956 \n", + "convolutional 0.010109 0.009297 0.002336 \n", + "allocat* AND *net* 0.000763 0.000715 0.000048 \n", + "program 0.002623 0.003099 0.000191 \n", + "classification* AND (label* OR deep) 0.003338 0.004052 0.001335 \n", + "scattering 0.004053 0.002861 0.001049 \n", + "astro* 0.003099 0.004052 0.000668 \n", + "optical 0.007105 0.006913 0.001097 \n", + "ray 0.005865 0.007390 0.001240 \n", + "entangle* 0.002623 0.002861 0.000095 \n", + "*algebra* 0.014829 0.018355 0.000620 \n", + "spaces 0.010586 0.009774 0.001860 \n", + "operators 0.006151 0.005959 0.001574 \n", + "estimation 0.021266 0.021216 0.003862 \n", + "mixture 0.003290 0.003099 0.001335 \n", + "gaussian 0.009250 0.011204 0.003052 \n", + "gene 0.001287 0.001669 0.000191 \n", "trading 0.000954 0.000238 0.000191 \n", "finance 0.000048 0.000238 0.000000 \n", - "pric* 0.003433 0.003337 0.000715 \n", - "economy 0.000238 0.000238 0.000000 \n", - "deep AND neural AND network* 0.009155 0.010250 0.002909 \n", - "convolutional 0.010109 0.009297 0.002241 \n", "memor* AND (design* OR network*) 0.001383 0.002145 0.000286 \n", "system* AND design* 0.001144 0.002384 0.000238 \n", - "allocat* AND *net* 0.000763 0.000715 0.000000 \n", - "program 0.002623 0.003099 0.000143 \n", - "classification* AND (label* OR deep) 0.003338 0.004052 0.001335 \n", - "scattering 0.004053 0.002861 0.001001 \n", - "astro* 0.003099 0.004052 0.000620 \n", "material* 0.004148 0.003099 0.000238 \n", "spin 0.013542 0.015018 0.002146 \n", "magnetic 0.011301 0.012872 0.002432 \n", - "optical 0.007105 0.006913 0.001097 \n", - "ray 0.005865 0.007390 0.001192 \n", - "entangle* 0.002623 0.002861 0.000095 \n", - "*algebra* 0.014829 0.018355 0.000620 \n", "manifold* AND (NOT learn*) 0.007057 0.008343 0.000858 \n", - "equation 0.010681 0.007867 0.000954 \n", - "spaces 0.010586 0.009774 0.001860 \n", - "operators 0.006151 0.005959 0.001526 \n", - "regression 0.009393 0.009058 0.002575 \n", - "bayes* 0.015306 0.014779 0.003147 \n", - "estimation 0.021266 0.021216 0.003385 \n", - "mixture 0.003290 0.003099 0.001287 \n", - "gaussian 0.009250 0.011204 0.002766 \n", - "gene 0.001287 0.001669 0.000191 \n", - "total 0.176616 0.185936 0.017833 \n", + "equation 0.010681 0.007867 0.001001 \n", + "regression 0.009393 0.009058 0.002718 \n", + "bayes* 0.015306 0.014779 0.003481 \n", + "sample 0.004673 0.004529 0.000811 \n", + "risk 0.004816 0.004291 0.001097 \n", + "dna 0.001335 0.001192 0.000143 \n", + "total 0.185390 0.194041 0.019788 \n", "\n", " correct incorrect precision \n", "stock* 3 0 1.000000 \n", "*asset* 3 0 1.000000 \n", - "trading 1 0 1.000000 \n", - "finance 1 0 1.000000 \n", "pric* 9 5 0.642857 \n", "economy 1 0 1.000000 \n", "deep AND neural AND network* 32 11 0.744186 \n", "convolutional 32 7 0.820513 \n", - "memor* AND (design* OR network*) 9 0 1.000000 \n", - "system* AND design* 9 1 0.900000 \n", "allocat* AND *net* 3 0 1.000000 \n", "program 11 2 0.846154 \n", "classification* AND (label* OR deep) 14 3 0.823529 \n", "scattering 10 2 0.833333 \n", "astro* 17 0 1.000000 \n", - "material* 10 3 0.769231 \n", - "spin 60 3 0.952381 \n", - "magnetic 49 5 0.907407 \n", "optical 27 2 0.931034 \n", "ray 27 4 0.870968 \n", "entangle* 11 1 0.916667 \n", "*algebra* 70 7 0.909091 \n", - "manifold* AND (NOT learn*) 28 7 0.800000 \n", - "equation 24 9 0.727273 \n", "spaces 38 3 0.926829 \n", "operators 22 3 0.880000 \n", - "regression 33 5 0.868421 \n", - "bayes* 49 13 0.790323 \n", "estimation 65 24 0.730337 \n", "mixture 10 3 0.769231 \n", "gaussian 36 11 0.765957 \n", "gene 6 1 0.857143 \n", - "total 720 135 0.842105 " + "trading 1 0 1.000000 \n", + "finance 1 0 1.000000 \n", + "memor* AND (design* OR network*) 9 0 1.000000 \n", + "system* AND design* 9 1 0.900000 \n", + "material* 10 3 0.769231 \n", + "spin 60 3 0.952381 \n", + "magnetic 49 5 0.907407 \n", + "manifold* AND (NOT learn*) 28 7 0.800000 \n", + "equation 24 9 0.727273 \n", + "regression 33 5 0.868421 \n", + "bayes* 49 13 0.790323 \n", + "sample 17 2 0.894737 \n", + "risk 11 7 0.611111 \n", + "dna 4 1 0.800000 \n", + "total 752 145 0.838350 " ] }, - "execution_count": 5, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Check coverage/precision of our rules\n", - "weak_labels.summary()\n" + "weak_labels.summary()" ] }, { @@ -1563,7 +4183,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "id": "8152fabd-969d-40b1-a4fc-956f8783ce29", "metadata": {}, "outputs": [], @@ -1585,7 +4205,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "id": "f7edc676-5c9a-4b21-82d6-1a820b466483", "metadata": {}, "outputs": [], @@ -1603,7 +4223,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "id": "f45f92a3-b9a8-482a-9f37-52e4802790b4", "metadata": {}, "outputs": [], @@ -1638,7 +4258,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "id": "7f2a5d4b-4d6a-4dc0-8533-287f92c745f4", "metadata": {}, "outputs": [], @@ -1664,10 +4284,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, "id": "019aae12-aab3-4d9c-9695-80018541c3ca", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Pipeline(steps=[('vect', CountVectorizer()),\n",
+       "                ('clf',\n",
+       "                 BinaryRelevance(classifier=MultinomialNB(),\n",
+       "                                 require_dense=[True, True]))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "Pipeline(steps=[('vect', CountVectorizer()),\n", + " ('clf',\n", + " BinaryRelevance(classifier=MultinomialNB(),\n", + " require_dense=[True, True]))])" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "\n", @@ -1688,7 +4331,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 53, "id": "858f1c6e-99df-4918-803c-18647e2edd68", "metadata": {}, "outputs": [], @@ -1701,7 +4344,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 54, "id": "8b12a128-a184-494d-a926-6100a9d252da", "metadata": {}, "outputs": [ @@ -1711,17 +4354,17 @@ "text": [ " precision recall f1-score support\n", "\n", - " Computer Science 0.82 0.26 0.40 1740\n", - " Mathematics 0.71 0.64 0.67 1141\n", - " Physics 0.81 0.70 0.75 1186\n", - "Quantitative Biology 1.00 0.01 0.02 109\n", - "Quantitative Finance 0.44 0.09 0.15 45\n", - " Statistics 0.48 0.71 0.57 1069\n", + " Computer Science 0.81 0.24 0.38 1740\n", + " Mathematics 0.79 0.58 0.67 1141\n", + " Physics 0.88 0.65 0.74 1186\n", + "Quantitative Biology 0.67 0.02 0.04 109\n", + "Quantitative Finance 0.46 0.13 0.21 45\n", + " Statistics 0.52 0.69 0.60 1069\n", "\n", - " micro avg 0.66 0.53 0.59 5290\n", - " macro avg 0.71 0.40 0.43 5290\n", - " weighted avg 0.73 0.53 0.56 5290\n", - " samples avg 0.66 0.56 0.59 5290\n", + " micro avg 0.71 0.49 0.58 5290\n", + " macro avg 0.69 0.39 0.44 5290\n", + " weighted avg 0.76 0.49 0.56 5290\n", + " samples avg 0.58 0.52 0.53 5290\n", "\n" ] } @@ -1877,7 +4520,7 @@ "source": [ "# publish dataset in the Hub\n", "\n", - "ds_rb = rg.DatasetForTextClassification(records).to_datasets()\n", + "ds_rg = rg.DatasetForTextClassification(records).to_datasets()\n", "\n", "ds_rg.push_to_hub(\"argilla/go_emotions_multi-label\", private=True)\n" ] @@ -1953,7 +4596,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.13 ('.venv': venv)", + "display_name": "Python 3.9.13 ('argilla')", "language": "python", "name": "python3" }, @@ -1971,7 +4614,7 @@ }, "vscode": { "interpreter": { - "hash": "39f4e3bd8ecb53b4a2ef9bccb982583dac0632e40e094b10b94294b76eaa26cb" + "hash": "83e13ff0de9ea08cace169d1016bf08ce368842307fd88824f08736a0a9ca04b" } } }, diff --git a/docs/_source/tutorials/notebooks/labelling-textclassification-snorkel-weaksupervision.ipynb b/docs/_source/tutorials/notebooks/labelling-textclassification-snorkel-weaksupervision.ipynb index 65fd46c7ec..62364aab5d 100644 --- a/docs/_source/tutorials/notebooks/labelling-textclassification-snorkel-weaksupervision.ipynb +++ b/docs/_source/tutorials/notebooks/labelling-textclassification-snorkel-weaksupervision.ipynb @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "id": "b63e3399-6fad-483f-9f9c-7011a4793a9f", "metadata": {}, "outputs": [ @@ -271,12 +271,52 @@ }, { "cell_type": "markdown", - "id": "26a7255c-9f67-483c-b253-f186df8286cf", + "id": "6ddfe433", "metadata": {}, "source": [ - "## 2. Interactive weak labeling: Finding and defining rules\n", + "## 2. Define Rules" + ] + }, + { + "cell_type": "markdown", + "id": "62724abf", + "metadata": {}, + "source": [ + "Rules can be defined and managed (1) using the UI, and (2) using the Python client. We will add some rules with the Python Client that will be available in the UI where we can start our interactive weak labelling." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "7e064960", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import Rule\n", + "\n", + "# define queries and patterns for each category (using ES DSL)\n", + "queries = [\n", + " ([\"money\", \"financ*\", \"dollar*\"], \"Business\"),\n", + " ([\"war\", \"gov*\", \"minister*\", \"conflict\"], \"World\"),\n", + " ([\"footbal*\", \"sport*\", \"game\", \"play*\"], \"Sports\"),\n", + " ([\"sci*\", \"techno*\", \"computer*\", \"software\", \"web\"], \"Sci/Tech\"),\n", + "]\n", + "\n", + "# define rules\n", + "rules = [Rule(query=term, label=label) for terms, label in queries for term in terms]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "fce45ae1", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import add_rules\n", "\n", - "After logging the dataset, you can find and save rules directly with the UI. Then, you can read the rules with Python to train a label or downstream model, as we'll see in the next step. " + "# add rules to the dataset\n", + "add_rules(dataset=\"news\", rules=rules)" ] }, { @@ -303,12 +343,12 @@ "source": [ "from argilla.labeling.text_classification import WeakLabels\n", "\n", - "weak_labels = WeakLabels(dataset=\"news\")\n" + "weak_labels = WeakLabels(dataset=\"news\")" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 15, "id": "09bd9bee-9207-4e5e-8d79-41c04d9d0b19", "metadata": {}, "outputs": [ @@ -347,10 +387,10 @@ " \n", " money\n", " {Business}\n", - " 0.008276\n", + " 0.008268\n", " 0.008816\n", - " 0.002437\n", - " 0.001936\n", + " 0.002484\n", + " 0.001983\n", " 30\n", " 37\n", " 0.447761\n", @@ -360,8 +400,8 @@ " {Business}\n", " 0.019655\n", " 0.017763\n", - " 0.005893\n", - " 0.005188\n", + " 0.005933\n", + " 0.005227\n", " 80\n", " 55\n", " 0.592593\n", @@ -371,8 +411,8 @@ " {Business}\n", " 0.016591\n", " 0.016316\n", - " 0.003542\n", - " 0.002908\n", + " 0.003582\n", + " 0.002947\n", " 87\n", " 37\n", " 0.701613\n", @@ -380,21 +420,21 @@ " \n", " war\n", " {World}\n", - " 0.011779\n", - " 0.013289\n", - " 0.003213\n", - " 0.001348\n", - " 75\n", - " 26\n", - " 0.742574\n", + " 0.015627\n", + " 0.017105\n", + " 0.004459\n", + " 0.001732\n", + " 101\n", + " 29\n", + " 0.776923\n", " \n", " \n", " gov*\n", " {World}\n", - " 0.045078\n", + " 0.045086\n", " 0.045263\n", - " 0.010878\n", - " 0.006270\n", + " 0.011191\n", + " 0.006277\n", " 170\n", " 174\n", " 0.494186\n", @@ -404,7 +444,7 @@ " {World}\n", " 0.030031\n", " 0.028289\n", - " 0.007531\n", + " 0.007908\n", " 0.002821\n", " 193\n", " 22\n", @@ -413,21 +453,21 @@ " \n", " conflict\n", " {World}\n", - " 0.003041\n", - " 0.002895\n", - " 0.001003\n", + " 0.003025\n", + " 0.002763\n", + " 0.001097\n", " 0.000102\n", - " 18\n", + " 17\n", " 4\n", - " 0.818182\n", + " 0.809524\n", " \n", " \n", " footbal*\n", " {Sports}\n", - " 0.013166\n", + " 0.013158\n", " 0.015000\n", - " 0.004945\n", - " 0.000439\n", + " 0.004953\n", + " 0.000447\n", " 107\n", " 7\n", " 0.938596\n", @@ -437,7 +477,7 @@ " {Sports}\n", " 0.021191\n", " 0.021316\n", - " 0.007045\n", + " 0.007038\n", " 0.001223\n", " 139\n", " 23\n", @@ -446,21 +486,21 @@ " \n", " game\n", " {Sports}\n", - " 0.038879\n", - " 0.037763\n", - " 0.014083\n", - " 0.002375\n", + " 0.038738\n", + " 0.037632\n", + " 0.014060\n", + " 0.002390\n", " 216\n", - " 71\n", - " 0.752613\n", + " 70\n", + " 0.755245\n", " \n", " \n", " play*\n", " {Sports}\n", " 0.052453\n", " 0.050000\n", - " 0.016889\n", - " 0.005063\n", + " 0.016991\n", + " 0.005196\n", " 268\n", " 112\n", " 0.705263\n", @@ -470,8 +510,8 @@ " {Sci/Tech}\n", " 0.016552\n", " 0.018421\n", - " 0.002735\n", - " 0.001309\n", + " 0.002782\n", + " 0.001340\n", " 114\n", " 26\n", " 0.814286\n", @@ -479,10 +519,10 @@ " \n", " techno*\n", " {Sci/Tech}\n", - " 0.027218\n", + " 0.027210\n", " 0.028289\n", - " 0.008433\n", - " 0.003174\n", + " 0.008534\n", + " 0.003205\n", " 155\n", " 60\n", " 0.720930\n", @@ -490,46 +530,46 @@ " \n", " computer*\n", " {Sci/Tech}\n", - " 0.027320\n", - " 0.028026\n", - " 0.011058\n", - " 0.004459\n", + " 0.027586\n", + " 0.028158\n", + " 0.011277\n", + " 0.004514\n", " 159\n", - " 54\n", - " 0.746479\n", + " 55\n", + " 0.742991\n", " \n", " \n", " software\n", " {Sci/Tech}\n", - " 0.030243\n", - " 0.029605\n", - " 0.009655\n", - " 0.003346\n", - " 184\n", + " 0.030188\n", + " 0.029474\n", + " 0.009828\n", + " 0.003378\n", + " 183\n", " 41\n", - " 0.817778\n", + " 0.816964\n", " \n", " \n", " web\n", " {Sci/Tech}\n", - " 0.015376\n", - " 0.013289\n", - " 0.004067\n", - " 0.001607\n", - " 76\n", + " 0.017132\n", + " 0.014737\n", + " 0.004561\n", + " 0.001779\n", + " 87\n", " 25\n", - " 0.752475\n", + " 0.776786\n", " \n", " \n", " total\n", - " {World, Sports, Business, Sci/Tech}\n", - " 0.317022\n", - " 0.311447\n", - " 0.053582\n", - " 0.019561\n", - " 2071\n", - " 774\n", - " 0.727944\n", + " {World, Sci/Tech, Business, Sports}\n", + " 0.320964\n", + " 0.315000\n", + " 0.055149\n", + " 0.020039\n", + " 2106\n", + " 777\n", + " 0.730489\n", " \n", " \n", "\n", @@ -537,59 +577,67 @@ ], "text/plain": [ " label coverage annotated_coverage \\\n", - "money {Business} 0.008276 0.008816 \n", + "money {Business} 0.008268 0.008816 \n", "financ* {Business} 0.019655 0.017763 \n", "dollar* {Business} 0.016591 0.016316 \n", - "war {World} 0.011779 0.013289 \n", - "gov* {World} 0.045078 0.045263 \n", + "war {World} 0.015627 0.017105 \n", + "gov* {World} 0.045086 0.045263 \n", "minister* {World} 0.030031 0.028289 \n", - "conflict {World} 0.003041 0.002895 \n", - "footbal* {Sports} 0.013166 0.015000 \n", + "conflict {World} 0.003025 0.002763 \n", + "footbal* {Sports} 0.013158 0.015000 \n", "sport* {Sports} 0.021191 0.021316 \n", - "game {Sports} 0.038879 0.037763 \n", + "game {Sports} 0.038738 0.037632 \n", "play* {Sports} 0.052453 0.050000 \n", "sci* {Sci/Tech} 0.016552 0.018421 \n", - "techno* {Sci/Tech} 0.027218 0.028289 \n", - "computer* {Sci/Tech} 0.027320 0.028026 \n", - "software {Sci/Tech} 0.030243 0.029605 \n", - "web {Sci/Tech} 0.015376 0.013289 \n", - "total {World, Sports, Business, Sci/Tech} 0.317022 0.311447 \n", + "techno* {Sci/Tech} 0.027210 0.028289 \n", + "computer* {Sci/Tech} 0.027586 0.028158 \n", + "software {Sci/Tech} 0.030188 0.029474 \n", + "web {Sci/Tech} 0.017132 0.014737 \n", + "total {World, Sci/Tech, Business, Sports} 0.320964 0.315000 \n", "\n", " overlaps conflicts correct incorrect precision \n", - "money 0.002437 0.001936 30 37 0.447761 \n", - "financ* 0.005893 0.005188 80 55 0.592593 \n", - "dollar* 0.003542 0.002908 87 37 0.701613 \n", - "war 0.003213 0.001348 75 26 0.742574 \n", - "gov* 0.010878 0.006270 170 174 0.494186 \n", - "minister* 0.007531 0.002821 193 22 0.897674 \n", - "conflict 0.001003 0.000102 18 4 0.818182 \n", - "footbal* 0.004945 0.000439 107 7 0.938596 \n", - "sport* 0.007045 0.001223 139 23 0.858025 \n", - "game 0.014083 0.002375 216 71 0.752613 \n", - "play* 0.016889 0.005063 268 112 0.705263 \n", - "sci* 0.002735 0.001309 114 26 0.814286 \n", - "techno* 0.008433 0.003174 155 60 0.720930 \n", - "computer* 0.011058 0.004459 159 54 0.746479 \n", - "software 0.009655 0.003346 184 41 0.817778 \n", - "web 0.004067 0.001607 76 25 0.752475 \n", - "total 0.053582 0.019561 2071 774 0.727944 " + "money 0.002484 0.001983 30 37 0.447761 \n", + "financ* 0.005933 0.005227 80 55 0.592593 \n", + "dollar* 0.003582 0.002947 87 37 0.701613 \n", + "war 0.004459 0.001732 101 29 0.776923 \n", + "gov* 0.011191 0.006277 170 174 0.494186 \n", + "minister* 0.007908 0.002821 193 22 0.897674 \n", + "conflict 0.001097 0.000102 17 4 0.809524 \n", + "footbal* 0.004953 0.000447 107 7 0.938596 \n", + "sport* 0.007038 0.001223 139 23 0.858025 \n", + "game 0.014060 0.002390 216 70 0.755245 \n", + "play* 0.016991 0.005196 268 112 0.705263 \n", + "sci* 0.002782 0.001340 114 26 0.814286 \n", + "techno* 0.008534 0.003205 155 60 0.720930 \n", + "computer* 0.011277 0.004514 159 55 0.742991 \n", + "software 0.009828 0.003378 183 41 0.816964 \n", + "web 0.004561 0.001779 87 25 0.776786 \n", + "total 0.055149 0.020039 2106 777 0.730489 " ] }, - "execution_count": 22, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "weak_labels.summary()\n" + "weak_labels.summary()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "c27f7820-0144-4827-86b6-08f1fd190334", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100/100 [00:00<00:00, 1228.48epoch/s]\n" + ] + } + ], "source": [ "from argilla.labeling.text_classification import Snorkel\n", "\n", @@ -602,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 17, "id": "6572368b-9b1c-4d68-ab36-ef86ca3f3f0f", "metadata": {}, "outputs": [ @@ -612,14 +660,14 @@ "text": [ " precision recall f1-score support\n", "\n", - " Sports 0.79 0.96 0.87 632\n", - " Sci/Tech 0.77 0.77 0.77 773\n", - " World 0.70 0.80 0.74 509\n", - " Business 0.65 0.36 0.46 453\n", + " Business 0.66 0.35 0.46 455\n", + " World 0.70 0.81 0.75 522\n", + " Sci/Tech 0.78 0.77 0.77 784\n", + " Sports 0.78 0.96 0.86 633\n", "\n", - " accuracy 0.75 2367\n", - " macro avg 0.73 0.72 0.71 2367\n", - "weighted avg 0.74 0.75 0.73 2367\n", + " accuracy 0.75 2394\n", + " macro avg 0.73 0.72 0.71 2394\n", + "weighted avg 0.74 0.75 0.73 2394\n", "\n" ] } @@ -648,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "7c321871-3bfc-45b9-902b-4beb45f4ca13", "metadata": {}, "outputs": [], @@ -673,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 19, "id": "9c81c6c3-75fa-4534-9e27-e888d3f5257b", "metadata": {}, "outputs": [ @@ -705,28 +753,28 @@ " \n", " \n", " 0\n", - " FA Cup: Third round draw - joy for Yeading and Exeter The great big fat bullies from across the playground enter the FA Cup arena at the third round stage, for which the draw was made yesterday (December 5).\n", - " 0\n", + " Tennis: Defending champion Myskina sees off world number one &lt;b&gt;...&lt;/b&gt; MOSCOW : Defending champion and French Open winner Anastasia Myskina advanced into the final of the 2.3 million dollar Kremlin Cup beating new world number one Lindsay Davenport of the United States here.\n", + " 3\n", " \n", " \n", " 1\n", - " Rats May Help Unravel Human Drug Addiction Mysteries By LAURAN NEERGAARD WASHINGTON (AP) -- Rats can become drug addicts. That's important to know, scientists say, and has taken a long time to prove...\n", - " 1\n", + " Britain Pays Final Respects to Beheaded Hostage British Prime Minister Tony Blair was among the hundreds of people that attended an emotional service for a man kidnapped and killed in Iraq.\n", + " 2\n", " \n", " \n", " 2\n", - " Palmer Passes Test Bengals quarterback Carson Palmer enjoyed his breakthrough game at the expense of the Super Bowl champion Patriots, racking up 179 yards on 12-of-19 passing in a 31-3 triumph on Saturday night.\n", - " 0\n", + " Skulls trojan targets Symbian smartphones A new trojan on the internet attacks the Nokia 7610 smartphone and possibly other phones running Symbian Series 60 software. quot;We have located several freeware and shareware sites offering a program, called\n", + " 1\n", " \n", " \n", " 3\n", - " Compromises urged amid deadlock in Darfur talks ABUJA, Nigeria -- Peace talks on Sudan's violence-torn Darfur region are deadlocked, a mediator said yesterday, as the chief of the African Union appealed to the Sudanese government and rebels to compromise.\n", + " Sudan Security Foils New Sabotage Plot -- Agency Sudanese authorities said Friday they foiled another plot by an opposition Islamist party to kidnap and kill senior government officials and blow up sites in the capital\n", " 2\n", " \n", " \n", " 4\n", - " CAPELLO FED UP WITH FEIGNING Juventus coach Fabio Capello has ordered his players not to kick the ball out of play when an opponent falls to the ground apparently hurt because he believes some players fake injury to stop the match.\n", - " 0\n", + " Sony and Partners Agree To Acquire MGM Sony Corp. and several financial partners have agreed in principle to acquire movie studio Metro-Goldwyn-Mayer for about \\$2.94 billion in cash, sources familiar with the talks said Monday.\n", + " 3\n", " \n", " \n", " ...\n", @@ -734,50 +782,63 @@ " ...\n", " \n", " \n", - " 38080\n", - " Apple ships Mac OS X update The 26MB upgrade to version 10.3.6 is available now via the OS #39; Software Update control panel and from Apple #39;s support web site.\n", + " 38556\n", + " Titan hangs on to its secrets Cassini #39;s close fly-by of Titan, Saturn #39;s largest moon, has left scientists with no clear idea of what to expect when the Huygens probe lands on the alien world, despite the amazingly detailed images they now have of the surface.\n", " 1\n", " \n", " \n", - " 38081\n", - " Bob Evans, mainframe pioneer, dies at 77 Bob Evans, an IBM computer scientist who helped to develop the modern mainframe computer, died Thursday. He was 77. Evans died of heart failure at his at his home in the San Francisco suburb of Hillsborough, his son Evan told the Associated Press.\n", - " 1\n", + " 38557\n", + " Ministers deny interest in raising inheritance tax Downing Street distanced itself last night from reports that inheritance tax will rise to 50 per cent for the wealthiest families.\n", + " 2\n", " \n", " \n", - " 38082\n", - " For Brazil's Economy, the Doctor Is In Antonio Palocci, a doctor from Brazil's farm belt, has found himself presiding as the country's finance minister during the most robust economic expansion in a decade.\n", - " 2\n", + " 38558\n", + " No Frills, but Everything Else Is on Craigslist (washingtonpost.com) washingtonpost.com - Ernie Miller, a 38-year-old software developer in Silver Spring, offers a telling clue as to how www.craigslist.org became the Internet's go-to place to solve life's vexing problems.\n", + " 1\n", " \n", " \n", - " 38083\n", - " UbiSoft Get Ready With The Next Rainbox Six Installment Ubisoft today announced its plans to launch the next installment in the Tom Clancys Rainbow Six franchise in Spring 2005. The next Rainbow Six game follows Team Rainbow, the worlds most\n", + " 38559\n", + " Familiar refrain as Singh leads Just when Vijay Singh thinks he can't play better, he does. Just when it seems he can't do much more during his Tiger Woods-like season, he does that, too.\n", " 0\n", " \n", " \n", - " 38084\n", - " PM #39;s visit to focus on reconstruction of Kashmir New Delhi: The two-day visit of Prime Minister Manmohan Singh to Jammu and Kashmir, starting on Wednesday will focus more on reconstruction and development of the state, Parliamentary Affairs Minister Ghulam Nabi Azad has said.\n", - " 2\n", + " 38560\n", + " Cisco to acquire P-Cube for \\$200m Cisco Systems has agreed to buy software developer P-Cube in a cash-and-options deal Cisco valued at \\$200m (110m). P-Cube makes software to help service providers analyse and control network traffic.\n", + " 1\n", " \n", " \n", "\n", - "

38085 rows × 2 columns

\n", + "

38561 rows × 2 columns

\n", "" ], "text/plain": [ - " text label\n", - "0 FA Cup: Third round draw - joy for Yeading and Exeter The great big fat bullies from across the playground enter the FA Cup arena at the third round stage, for which the draw was made yesterday (December 5). 0\n", - "1 Rats May Help Unravel Human Drug Addiction Mysteries By LAURAN NEERGAARD WASHINGTON (AP) -- Rats can become drug addicts. That's important to know, scientists say, and has taken a long time to prove... 1\n", - "2 Palmer Passes Test Bengals quarterback Carson Palmer enjoyed his breakthrough game at the expense of the Super Bowl champion Patriots, racking up 179 yards on 12-of-19 passing in a 31-3 triumph on Saturday night. 0\n", - "3 Compromises urged amid deadlock in Darfur talks ABUJA, Nigeria -- Peace talks on Sudan's violence-torn Darfur region are deadlocked, a mediator said yesterday, as the chief of the African Union appealed to the Sudanese government and rebels to compromise. 2\n", - "4 CAPELLO FED UP WITH FEIGNING Juventus coach Fabio Capello has ordered his players not to kick the ball out of play when an opponent falls to the ground apparently hurt because he believes some players fake injury to stop the match. 0\n", - "... ... ...\n", - "38080 Apple ships Mac OS X update The 26MB upgrade to version 10.3.6 is available now via the OS #39; Software Update control panel and from Apple #39;s support web site. 1\n", - "38081 Bob Evans, mainframe pioneer, dies at 77 Bob Evans, an IBM computer scientist who helped to develop the modern mainframe computer, died Thursday. He was 77. Evans died of heart failure at his at his home in the San Francisco suburb of Hillsborough, his son Evan told the Associated Press. 1\n", - "38082 For Brazil's Economy, the Doctor Is In Antonio Palocci, a doctor from Brazil's farm belt, has found himself presiding as the country's finance minister during the most robust economic expansion in a decade. 2\n", - "38083 UbiSoft Get Ready With The Next Rainbox Six Installment Ubisoft today announced its plans to launch the next installment in the Tom Clancys Rainbow Six franchise in Spring 2005. The next Rainbow Six game follows Team Rainbow, the worlds most 0\n", - "38084 PM #39;s visit to focus on reconstruction of Kashmir New Delhi: The two-day visit of Prime Minister Manmohan Singh to Jammu and Kashmir, starting on Wednesday will focus more on reconstruction and development of the state, Parliamentary Affairs Minister Ghulam Nabi Azad has said. 2\n", + " text \\\n", + "0 Tennis: Defending champion Myskina sees off world number one <b>...</b> MOSCOW : Defending champion and French Open winner Anastasia Myskina advanced into the final of the 2.3 million dollar Kremlin Cup beating new world number one Lindsay Davenport of the United States here. \n", + "1 Britain Pays Final Respects to Beheaded Hostage British Prime Minister Tony Blair was among the hundreds of people that attended an emotional service for a man kidnapped and killed in Iraq. \n", + "2 Skulls trojan targets Symbian smartphones A new trojan on the internet attacks the Nokia 7610 smartphone and possibly other phones running Symbian Series 60 software. quot;We have located several freeware and shareware sites offering a program, called \n", + "3 Sudan Security Foils New Sabotage Plot -- Agency Sudanese authorities said Friday they foiled another plot by an opposition Islamist party to kidnap and kill senior government officials and blow up sites in the capital \n", + "4 Sony and Partners Agree To Acquire MGM Sony Corp. and several financial partners have agreed in principle to acquire movie studio Metro-Goldwyn-Mayer for about \\$2.94 billion in cash, sources familiar with the talks said Monday. \n", + "... ... \n", + "38556 Titan hangs on to its secrets Cassini #39;s close fly-by of Titan, Saturn #39;s largest moon, has left scientists with no clear idea of what to expect when the Huygens probe lands on the alien world, despite the amazingly detailed images they now have of the surface. \n", + "38557 Ministers deny interest in raising inheritance tax Downing Street distanced itself last night from reports that inheritance tax will rise to 50 per cent for the wealthiest families. \n", + "38558 No Frills, but Everything Else Is on Craigslist (washingtonpost.com) washingtonpost.com - Ernie Miller, a 38-year-old software developer in Silver Spring, offers a telling clue as to how www.craigslist.org became the Internet's go-to place to solve life's vexing problems. \n", + "38559 Familiar refrain as Singh leads Just when Vijay Singh thinks he can't play better, he does. Just when it seems he can't do much more during his Tiger Woods-like season, he does that, too. \n", + "38560 Cisco to acquire P-Cube for \\$200m Cisco Systems has agreed to buy software developer P-Cube in a cash-and-options deal Cisco valued at \\$200m (110m). P-Cube makes software to help service providers analyse and control network traffic. \n", + "\n", + " label \n", + "0 3 \n", + "1 2 \n", + "2 1 \n", + "3 2 \n", + "4 3 \n", + "... ... \n", + "38556 1 \n", + "38557 2 \n", + "38558 1 \n", + "38559 0 \n", + "38560 1 \n", "\n", - "[38085 rows x 2 columns]" + "[38561 rows x 2 columns]" ] }, "metadata": {}, @@ -802,10 +863,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "1a2e1aa9-68c9-4a65-bf52-ff5254cdd9db", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer\n", "from sklearn.naive_bayes import MultinomialNB\n", @@ -831,7 +906,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "db0c41f9-bde8-4bc8-8406-a048e2e73aef", "metadata": {}, "outputs": [], @@ -851,7 +926,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 22, "id": "aac5b5ae-7c0d-48d6-800c-d95934b3d03f", "metadata": {}, "outputs": [ @@ -859,7 +934,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test accuracy: 0.8182894736842106\n" + "Test accuracy: 0.8176315789473684\n" ] } ], @@ -888,7 +963,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 23, "id": "609f169a-30ed-480e-8ab8-ae5af2c0a84f", "metadata": {}, "outputs": [ @@ -900,12 +975,12 @@ "\n", " Sports 0.86 0.98 0.91 1900\n", " Sci/Tech 0.76 0.84 0.80 1900\n", - " World 0.80 0.89 0.84 1900\n", - " Business 0.88 0.57 0.69 1900\n", + " World 0.79 0.89 0.84 1900\n", + " Business 0.89 0.56 0.69 1900\n", "\n", " accuracy 0.82 7600\n", - " macro avg 0.82 0.82 0.81 7600\n", - "weighted avg 0.82 0.82 0.81 7600\n", + " macro avg 0.83 0.82 0.81 7600\n", + "weighted avg 0.83 0.82 0.81 7600\n", "\n" ] } @@ -968,14 +1043,12 @@ "id": "f0fc6a66-5e4b-4c84-b8ea-9f986339fd55", "metadata": {}, "source": [ - "For some use cases, you might want to use Python for defining labeling rules and generating weak labels. Argilla provides you with the ability to define and test rules and labeling functions directly using Python. This might be useful for combining it with rules defined in the UI, and for leveraging structured resources such as lexicons and gazeteers which are easier to use directly a programmatic environment.\n", - "\n", - "In this section, we define the rules we've defined in the UI, this time directly using Python:" + "For some use cases, you might want to use Python for defining labeling rules and generating weak labels. Argilla provides you with the ability to define and test rules and labeling functions directly using Python. This might be useful for combining it with rules defined in the UI, and for leveraging structured resources such as lexicons and gazeteers which are easier to use directly from a programmatic environment. For ES query rules you can decide wether to add them to the Argilla dataset so they are available from the UI, or just apply them locally to build the final weakly labelled dataset. Not adding the rules to the server is recommended for testing purposes but if you trust your heuristics we recommend using the `add_rules` method to make them accessible from the UI." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "1d264315-83bf-40e4-895f-44b4b82cae3b", "metadata": {}, "outputs": [], @@ -994,6 +1067,19 @@ "rules = [Rule(query=term, label=label) for terms, label in queries for term in terms]\n" ] }, + { + "cell_type": "code", + "execution_count": 26, + "id": "0f8c6d97", + "metadata": {}, + "outputs": [], + "source": [ + "from argilla.labeling.text_classification import add_rules\n", + "\n", + "# add rules to the dataset\n", + "add_rules(dataset=\"news\", rules=rules)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1001,10 +1087,57 @@ "metadata": {}, "outputs": [], "source": [ - "from argilla.labeling.text_classification import WeakLabels\n", + "from argilla.labeling.text_classification import WeakLabels, add_rules\n", "\n", "# generate the weak labels\n", - "weak_labels = WeakLabels(rules=rules, dataset=\"news\")\n" + "weak_labels = WeakLabels(dataset=\"news\")" + ] + }, + { + "cell_type": "markdown", + "id": "bbd4b2fa", + "metadata": {}, + "source": [ + "If you want to apply the rules without adding to dataset it's also possible as follows" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "3077cc68", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "953448ef8fb44f25ac972901cc76c17f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Preparing rules: 0%| | 0/16 [00:00\n", " money\n", " {Business}\n", - " 0.008276\n", + " 0.008268\n", " 0.008816\n", - " 0.002437\n", - " 0.001936\n", + " 0.002484\n", + " 0.001983\n", " 30\n", " 37\n", " 0.447761\n", @@ -1072,8 +1205,8 @@ " {Business}\n", " 0.019655\n", " 0.017763\n", - " 0.005893\n", - " 0.005188\n", + " 0.005933\n", + " 0.005227\n", " 80\n", " 55\n", " 0.592593\n", @@ -1083,8 +1216,8 @@ " {Business}\n", " 0.016591\n", " 0.016316\n", - " 0.003542\n", - " 0.002908\n", + " 0.003582\n", + " 0.002947\n", " 87\n", " 37\n", " 0.701613\n", @@ -1092,21 +1225,21 @@ " \n", " war\n", " {World}\n", - " 0.011779\n", - " 0.013289\n", - " 0.003213\n", - " 0.001348\n", - " 75\n", - " 26\n", - " 0.742574\n", + " 0.015627\n", + " 0.017105\n", + " 0.004459\n", + " 0.001732\n", + " 101\n", + " 29\n", + " 0.776923\n", " \n", " \n", " gov*\n", " {World}\n", - " 0.045078\n", + " 0.045086\n", " 0.045263\n", - " 0.010878\n", - " 0.006270\n", + " 0.011191\n", + " 0.006277\n", " 170\n", " 174\n", " 0.494186\n", @@ -1116,7 +1249,7 @@ " {World}\n", " 0.030031\n", " 0.028289\n", - " 0.007531\n", + " 0.007908\n", " 0.002821\n", " 193\n", " 22\n", @@ -1125,21 +1258,21 @@ " \n", " conflict\n", " {World}\n", - " 0.003041\n", - " 0.002895\n", - " 0.001003\n", + " 0.003025\n", + " 0.002763\n", + " 0.001097\n", " 0.000102\n", - " 18\n", + " 17\n", " 4\n", - " 0.818182\n", + " 0.809524\n", " \n", " \n", " footbal*\n", " {Sports}\n", - " 0.013166\n", + " 0.013158\n", " 0.015000\n", - " 0.004945\n", - " 0.000439\n", + " 0.004953\n", + " 0.000447\n", " 107\n", " 7\n", " 0.938596\n", @@ -1149,7 +1282,7 @@ " {Sports}\n", " 0.021191\n", " 0.021316\n", - " 0.007045\n", + " 0.007038\n", " 0.001223\n", " 139\n", " 23\n", @@ -1158,21 +1291,21 @@ " \n", " game\n", " {Sports}\n", - " 0.038879\n", - " 0.037763\n", - " 0.014083\n", - " 0.002375\n", + " 0.038738\n", + " 0.037632\n", + " 0.014060\n", + " 0.002390\n", " 216\n", - " 71\n", - " 0.752613\n", + " 70\n", + " 0.755245\n", " \n", " \n", " play*\n", " {Sports}\n", " 0.052453\n", " 0.050000\n", - " 0.016889\n", - " 0.005063\n", + " 0.016991\n", + " 0.005196\n", " 268\n", " 112\n", " 0.705263\n", @@ -1182,8 +1315,8 @@ " {Sci/Tech}\n", " 0.016552\n", " 0.018421\n", - " 0.002735\n", - " 0.001309\n", + " 0.002782\n", + " 0.001340\n", " 114\n", " 26\n", " 0.814286\n", @@ -1191,10 +1324,10 @@ " \n", " techno*\n", " {Sci/Tech}\n", - " 0.027218\n", + " 0.027210\n", " 0.028289\n", - " 0.008433\n", - " 0.003174\n", + " 0.008534\n", + " 0.003205\n", " 155\n", " 60\n", " 0.720930\n", @@ -1202,46 +1335,46 @@ " \n", " computer*\n", " {Sci/Tech}\n", - " 0.027320\n", - " 0.028026\n", - " 0.011058\n", - " 0.004459\n", + " 0.027586\n", + " 0.028158\n", + " 0.011277\n", + " 0.004514\n", " 159\n", - " 54\n", - " 0.746479\n", + " 55\n", + " 0.742991\n", " \n", " \n", " software\n", " {Sci/Tech}\n", - " 0.030243\n", - " 0.029605\n", - " 0.009655\n", - " 0.003346\n", - " 184\n", + " 0.030188\n", + " 0.029474\n", + " 0.009828\n", + " 0.003378\n", + " 183\n", " 41\n", - " 0.817778\n", + " 0.816964\n", " \n", " \n", " web\n", " {Sci/Tech}\n", - " 0.015376\n", - " 0.013289\n", - " 0.004067\n", - " 0.001607\n", - " 76\n", + " 0.017132\n", + " 0.014737\n", + " 0.004561\n", + " 0.001779\n", + " 87\n", " 25\n", - " 0.752475\n", + " 0.776786\n", " \n", " \n", " total\n", - " {World, Sports, Business, Sci/Tech}\n", - " 0.317022\n", - " 0.311447\n", - " 0.053582\n", - " 0.019561\n", - " 2071\n", - " 774\n", - " 0.727944\n", + " {World, Sci/Tech, Business, Sports}\n", + " 0.320964\n", + " 0.315000\n", + " 0.055149\n", + " 0.020039\n", + " 2106\n", + " 777\n", + " 0.730489\n", " \n", " \n", "\n", @@ -1249,45 +1382,45 @@ ], "text/plain": [ " label coverage annotated_coverage \\\n", - "money {Business} 0.008276 0.008816 \n", + "money {Business} 0.008268 0.008816 \n", "financ* {Business} 0.019655 0.017763 \n", "dollar* {Business} 0.016591 0.016316 \n", - "war {World} 0.011779 0.013289 \n", - "gov* {World} 0.045078 0.045263 \n", + "war {World} 0.015627 0.017105 \n", + "gov* {World} 0.045086 0.045263 \n", "minister* {World} 0.030031 0.028289 \n", - "conflict {World} 0.003041 0.002895 \n", - "footbal* {Sports} 0.013166 0.015000 \n", + "conflict {World} 0.003025 0.002763 \n", + "footbal* {Sports} 0.013158 0.015000 \n", "sport* {Sports} 0.021191 0.021316 \n", - "game {Sports} 0.038879 0.037763 \n", + "game {Sports} 0.038738 0.037632 \n", "play* {Sports} 0.052453 0.050000 \n", "sci* {Sci/Tech} 0.016552 0.018421 \n", - "techno* {Sci/Tech} 0.027218 0.028289 \n", - "computer* {Sci/Tech} 0.027320 0.028026 \n", - "software {Sci/Tech} 0.030243 0.029605 \n", - "web {Sci/Tech} 0.015376 0.013289 \n", - "total {World, Sports, Business, Sci/Tech} 0.317022 0.311447 \n", + "techno* {Sci/Tech} 0.027210 0.028289 \n", + "computer* {Sci/Tech} 0.027586 0.028158 \n", + "software {Sci/Tech} 0.030188 0.029474 \n", + "web {Sci/Tech} 0.017132 0.014737 \n", + "total {World, Sci/Tech, Business, Sports} 0.320964 0.315000 \n", "\n", " overlaps conflicts correct incorrect precision \n", - "money 0.002437 0.001936 30 37 0.447761 \n", - "financ* 0.005893 0.005188 80 55 0.592593 \n", - "dollar* 0.003542 0.002908 87 37 0.701613 \n", - "war 0.003213 0.001348 75 26 0.742574 \n", - "gov* 0.010878 0.006270 170 174 0.494186 \n", - "minister* 0.007531 0.002821 193 22 0.897674 \n", - "conflict 0.001003 0.000102 18 4 0.818182 \n", - "footbal* 0.004945 0.000439 107 7 0.938596 \n", - "sport* 0.007045 0.001223 139 23 0.858025 \n", - "game 0.014083 0.002375 216 71 0.752613 \n", - "play* 0.016889 0.005063 268 112 0.705263 \n", - "sci* 0.002735 0.001309 114 26 0.814286 \n", - "techno* 0.008433 0.003174 155 60 0.720930 \n", - "computer* 0.011058 0.004459 159 54 0.746479 \n", - "software 0.009655 0.003346 184 41 0.817778 \n", - "web 0.004067 0.001607 76 25 0.752475 \n", - "total 0.053582 0.019561 2071 774 0.727944 " + "money 0.002484 0.001983 30 37 0.447761 \n", + "financ* 0.005933 0.005227 80 55 0.592593 \n", + "dollar* 0.003582 0.002947 87 37 0.701613 \n", + "war 0.004459 0.001732 101 29 0.776923 \n", + "gov* 0.011191 0.006277 170 174 0.494186 \n", + "minister* 0.007908 0.002821 193 22 0.897674 \n", + "conflict 0.001097 0.000102 17 4 0.809524 \n", + "footbal* 0.004953 0.000447 107 7 0.938596 \n", + "sport* 0.007038 0.001223 139 23 0.858025 \n", + "game 0.014060 0.002390 216 70 0.755245 \n", + "play* 0.016991 0.005196 268 112 0.705263 \n", + "sci* 0.002782 0.001340 114 26 0.814286 \n", + "techno* 0.008534 0.003205 155 60 0.720930 \n", + "computer* 0.011277 0.004514 159 55 0.742991 \n", + "software 0.009828 0.003378 183 41 0.816964 \n", + "web 0.004561 0.001779 87 25 0.776786 \n", + "total 0.055149 0.020039 2106 777 0.730489 " ] }, - "execution_count": 78, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1340,7 +1473,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.6 ('argillaghost-Q0BOJQE3-py3.10')", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1354,11 +1487,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.5" }, "vscode": { "interpreter": { - "hash": "03f665a1a0825a92bc240e01069a3e92c7de5b4cc92abab3eda948bbc50d6f03" + "hash": "83e13ff0de9ea08cace169d1016bf08ce368842307fd88824f08736a0a9ca04b" } } }, diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 514b9cafcf..73c5d5eeb7 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -608,6 +608,29 @@ def compute_metric( return MetricResults(**metric_.dict(), results=response.parsed) + def add_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]): + """Adds the dataset labeling rules""" + for rule in rules: + text_classification_api.add_dataset_labeling_rule( + self._client, + name=dataset, + rule=rule, + ) + + def update_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]): + """Updates the dataset labeling rules""" + for rule in rules: + text_classification_api.update_dataset_labeling_rule( + self._client, name=dataset, rule=rule + ) + + def delete_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]): + """Deletes the dataset labeling rules""" + for rule in rules: + text_classification_api.delete_dataset_labeling_rule( + self._client, name=dataset, rule=rule + ) + def fetch_dataset_labeling_rules(self, dataset: str) -> List[LabelingRule]: response = text_classification_api.fetch_dataset_labeling_rules( self._client, name=dataset diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index fb821bbe26..024997cf69 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -58,6 +58,63 @@ def data( ) +def add_dataset_labeling_rule( + client: AuthenticatedClient, + name: str, + rule: LabelingRule, +) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]: + url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format( + client.base_url, name=name + ) + + response = httpx.post( + url=url, + json={"query": rule.query, "labels": rule.labels}, + headers=client.get_headers(), + cookies=client.get_cookies(), + timeout=client.get_timeout(), + ) + + return build_typed_response(response, LabelingRule) + + +def update_dataset_labeling_rule( + client: AuthenticatedClient, + name: str, + rule: LabelingRule, +) -> Response[Union[HTTPValidationError, ErrorMessage]]: + url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}".format( + client.base_url, name=name, query=rule.query + ) + + response = httpx.patch( + url, + json={"labels": rule.labels}, + headers=client.get_headers(), + cookies=client.get_cookies(), + timeout=client.get_timeout(), + ) + + return build_typed_response(response, LabelingRule) + + +def delete_dataset_labeling_rule( + client: AuthenticatedClient, + name: str, + rule: LabelingRule, +) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]: + url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}".format( + client.base_url, name=name, query=rule.query + ) + + httpx.delete( + url, + headers=client.get_headers(), + cookies=client.get_cookies(), + timeout=client.get_timeout(), + ) + + def fetch_dataset_labeling_rules( client: AuthenticatedClient, name: str, diff --git a/src/argilla/client/sdk/text_classification/models.py b/src/argilla/client/sdk/text_classification/models.py index 9aa231bf58..e1e4e77e99 100644 --- a/src/argilla/client/sdk/text_classification/models.py +++ b/src/argilla/client/sdk/text_classification/models.py @@ -186,7 +186,7 @@ class LabelingRule(BaseModel): labels: List[str] = Field(default_factory=list) query: str description: Optional[str] = None - author: str + author: Optional[str] = None created_at: datetime = None diff --git a/src/argilla/labeling/text_classification/__init__.py b/src/argilla/labeling/text_classification/__init__.py index 55f0554b9c..5a6e9969a4 100644 --- a/src/argilla/labeling/text_classification/__init__.py +++ b/src/argilla/labeling/text_classification/__init__.py @@ -15,5 +15,5 @@ from .label_errors import find_label_errors from .label_models import FlyingSquid, MajorityVoter, Snorkel -from .rule import Rule, load_rules +from .rule import Rule, add_rules, delete_rules, load_rules, update_rules from .weak_labels import WeakLabels, WeakMultiLabels diff --git a/src/argilla/labeling/text_classification/rule.py b/src/argilla/labeling/text_classification/rule.py index da4b1e3e8e..b9d6160313 100644 --- a/src/argilla/labeling/text_classification/rule.py +++ b/src/argilla/labeling/text_classification/rule.py @@ -62,6 +62,10 @@ def label(self) -> Union[str, List[str]]: """The rule label""" return self._label + @label.setter + def label(self, value): + self._label = value + @property def name(self): """The name of the rule.""" @@ -74,6 +78,34 @@ def author(self): """Who authored the rule.""" return self._author + def _convert_to_labeling_rule(self): + """Converts the rule to a LabelingRule""" + if isinstance(self._label, str): + labels = [self._label] + else: + labels = self._label + + return LabelingRule(query=self.query, labels=labels) + + def add_to_dataset(self, dataset: str): + """Add to rule to the given dataset""" + api.active_api().add_dataset_labeling_rules( + dataset, rules=[self._convert_to_labeling_rule()] + ) + + def remove_from_dataset(self, dataset: str): + """Removes the rule from the given dataset""" + + api.active_api().delete_dataset_labeling_rules( + dataset, rules=[self._convert_to_labeling_rule()] + ) + + def update_at_dataset(self, dataset: str): + """Updates the rule at the given dataset""" + api.active_api().update_dataset_labeling_rules( + dataset, rules=[self._convert_to_labeling_rule()] + ) + def apply(self, dataset: str): """Apply the rule to a dataset and save matching ids of the records. @@ -101,9 +133,7 @@ def metrics(self, dataset: str) -> Dict[str, Union[int, float]]: """ metrics = api.active_api().rule_metrics_for_dataset( dataset=dataset, - rule=LabelingRule( - query=self.query, label=self.label, author=self.author or "None" - ), + rule=LabelingRule(query=self.query, label=self.label), ) return { @@ -143,6 +173,45 @@ def __call__( return self._label +def add_rules(dataset: str, rules: List[Rule]): + """Adds the rules to a given dataset + + Args: + dataset: Name of the dataset. + rules: Rules to add to the dataset + + Returns: + """ + rules = [rule._convert_to_labeling_rule() for rule in rules] + return api.active_api().add_dataset_labeling_rules(dataset, rules) + + +def delete_rules(dataset: str, rules: List[Rule]): + """Deletes the rules from the given dataset + + Args: + dataset: Name of the dataset + rules: Rules to delete from the dataset + + Returns: + """ + rules = [rule._convert_to_labeling_rule() for rule in rules] + api.active_api().delete_dataset_labeling_rules(dataset, rules) + + +def update_rules(dataset: str, rules: List[Rule]): + """Updates the rules of the given dataset + + Args: + dataset: Name of the dataset + rules: Rules to update at the dataset + + Returns: + """ + rules = [rule._convert_to_labeling_rule() for rule in rules] + api.active_api().update_dataset_labeling_rules(dataset, rules) + + def load_rules(dataset: str) -> List[Rule]: """load the rules defined in a given dataset. diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index 857281a1fb..e169ff630d 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -71,14 +71,17 @@ def check_schema_props(client_props, server_props): return len(different_props) < len(client_props) / 2 client_props = self._expands_schema( - client_schema["properties"], client_schema["definitions"] + client_schema["properties"], + client_schema.get("definitions", {}), ) server_props = self._expands_schema( - server_schema["properties"], server_schema["definitions"] + server_schema["properties"], + server_schema.get("definitions", {}), ) if client_props == server_props: return True + return check_schema_props(client_props, server_props) def _expands_schema( diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py index 4b5bfbea8f..24bee1eed1 100644 --- a/tests/client/sdk/text_classification/test_models.py +++ b/tests/client/sdk/text_classification/test_models.py @@ -62,9 +62,7 @@ def test_labeling_rule_schema(helpers): client_schema = LabelingRule.schema() server_schema = ServerLabelingRule.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_labeling_rule_metrics_schema(helpers): diff --git a/tests/conftest.py b/tests/conftest.py index 06afc575f7..862d482b3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ def whoami_mocked(client): monkeypatch.setattr(users_api, "whoami", whoami_mocked) monkeypatch.setattr(httpx, "post", client_.post) + monkeypatch.setattr(httpx, "patch", client_.patch) monkeypatch.setattr(httpx.AsyncClient, "post", client_.post_async) monkeypatch.setattr(httpx, "get", client_.get) monkeypatch.setattr(httpx, "delete", client_.delete) diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index e9f72aa22b..2b2b3659c5 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -21,7 +21,13 @@ CreationTextClassificationRecord, TextClassificationBulkData, ) -from argilla.labeling.text_classification import Rule, load_rules +from argilla.labeling.text_classification import ( + Rule, + add_rules, + delete_rules, + load_rules, + update_rules, +) from argilla.labeling.text_classification.rule import RuleNotAppliedError from argilla.server.errors import EntityNotFoundError @@ -86,6 +92,40 @@ def test_name(name, expected): assert rule.name == expected +def test_atomic_crud_operations(monkeypatch, mocked_client, log_dataset): + rule = Rule(query="inputs.text:(NOT positive)", label="negative") + with pytest.raises(RuleNotAppliedError): + rule(TextClassificationRecord(text="test")) + + monkeypatch.setattr(httpx, "get", mocked_client.get) + monkeypatch.setattr(httpx, "patch", mocked_client.patch) + monkeypatch.setattr(httpx, "delete", mocked_client.delete) + monkeypatch.setattr(httpx, "post", mocked_client.post) + monkeypatch.setattr(httpx, "stream", mocked_client.stream) + + rule.add_to_dataset(log_dataset) + + rules = load_rules(log_dataset) + assert len(rules) == 1 + assert rules[0].query == "inputs.text:(NOT positive)" + assert rules[0].label == "negative" + + rule.remove_from_dataset(log_dataset) + + rules = load_rules(log_dataset) + assert len(rules) == 0 + + rule = Rule(query="inputs.text:(NOT positive)", label="negative") + rule.add_to_dataset(log_dataset) + rule.label = "positive" + rule.update_at_dataset(log_dataset) + + rules = load_rules(log_dataset) + assert len(rules) == 1 + assert rules[0].query == "inputs.text:(NOT positive)" + assert rules[0].label == "positive" + + def test_apply(monkeypatch, mocked_client, log_dataset): rule = Rule(query="inputs.text:(NOT positive)", label="negative") with pytest.raises(RuleNotAppliedError): @@ -123,6 +163,76 @@ def test_load_rules(mocked_client, log_dataset): assert rules[0].label == "LALA" +def test_add_rules(mocked_client, log_dataset): + + expected_rules = [ + Rule(query="a query", label="La La"), + Rule(query="another query", label="La La"), + Rule(query="the other query", label="La La La"), + ] + + add_rules(log_dataset, expected_rules) + + actual_rules = load_rules(log_dataset) + + assert len(actual_rules) == 3 + for actual_rule, expected_rule in zip(actual_rules, expected_rules): + assert actual_rule.query == expected_rule.query + assert actual_rule.label == expected_rule.label + + +def test_delete_rules(mocked_client, log_dataset): + + rules = [ + Rule(query="a query", label="La La"), + Rule(query="another query", label="La La"), + Rule(query="the other query", label="La La La"), + ] + + add_rules(log_dataset, rules) + + delete_rules( + log_dataset, + [ + Rule(query="a query", label="La La"), + ], + ) + + actual_rules = load_rules(log_dataset) + + assert len(actual_rules) == 2 + + for actual_rule, expected_rule in zip(actual_rules, rules[1:]): + assert actual_rule.label == expected_rule.label + assert actual_rule.query == expected_rule.query + + +def test_update_rules(mocked_client, log_dataset): + + rules = [ + Rule(query="a query", label="La La"), + Rule(query="another query", label="La La"), + Rule(query="the other query", label="La La La"), + ] + + add_rules(log_dataset, rules) + rules_to_update = [ + Rule(query="a query", label="La La La"), + ] + update_rules(log_dataset, rules=rules_to_update) + + actual_rules = load_rules(log_dataset) + + assert len(rules) == 3 + + assert actual_rules[0].query == "a query" + assert actual_rules[0].label == "La La La" + + for actual_rule, expected_rule in zip(actual_rules[1:], rules[1:]): + assert actual_rule.label == expected_rule.label + assert actual_rule.query == expected_rule.query + + def test_copy_dataset_with_rules(mocked_client, log_dataset): import argilla as ar