# 👮 Weak Supervision


This guide gives you a brief introduction to weak supervision with Argilla.

Argilla currently supports weak supervision for multi-class text classification use cases, but we'll be adding support for multilabel text classification and token classification (e.g., Named Entity Recognition) soon.

![Labeling workflow](../../_static/images/guides/weak_supervision/weak_supervision.png "Labeling workflow")

## Argilla weak supervision in a nutshell

The recommended workflow for weak supervision is:

- Log an unlabelled dataset into Argilla
- Use the `Annotate` mode for hand- and/or bulk-labelling a test set. This test is key to measure the quality and performance of your rules.
- Use the `Define rules` mode for testing and defining rules. Rules are defined with search queries (using ES query string DSL).
- Use the Python client for reading rules, defining additional rules if needed, and train a label (for building a training set) or a downstream model (for building an end classifier).

The next sections cover the main components of this workflow. 

### Weak labeling using the UI

Since version 0.8.0 you can find and define rules directly in the UI. 
The [Define rules mode](../../reference/webapp/pages.html#metrics) is found in the right side bar of the [Dataset page](../../reference/webapp/pages.html#dataset).
The video below shows how you can interactively find and save rules with the UI. 

### Weak supervision from Python

Doing weak supervision with Argilla should be straightforward. Keeping the same spirit as other parts of the library, you can virtually use any weak supervision library or method, such as Snorkel or Flyingsquid. 

Argilla weak supervision support is built around two basic abstractions:


### `Rule`
A rule encodes an heuristic for labeling a record.

Heuristics can be defined using [Elasticsearch's queries](../../reference/webapp/features.html#search-records):

```python
plz = Rule(query="plz OR please", label="SPAM")
```

or with Python functions (similar to Snorkel's labeling functions, which you can use as well):

```python
def contains_http(record: rg.TextClassificationRecord) -> Optional[str]:
    if "http" in record.inputs["text"]:
        return "SPAM"
```

Besides textual features, Python labeling functions can exploit metadata features:

```python
def author_channel(record: rg.TextClassificationRecord) -> Optional[str]:
    # the word channel appears in the comment author name
    if "channel" in record.metadata["author"]:
        return "SPAM"
```

A rule should either return a string value, that is a weak label, or a `None` type in case of abstention.


### `Weak Labels`

Weak Labels objects bundle and apply a set of rules to the records of a Argilla dataset. Applying a rule to a record means assigning a weak label or abstaining.

This abstraction provides you with the building blocks for training and testing weak supervision "denoising", "label" or even "end" models:

```python
rules = [contains_http, author_channel]
weak_labels = WeakLabels(
    rules=rules, 
    dataset="weak_supervision_yt"
)

# returns a summary of the applied rules
weak_labels.summary()
```

More information about these abstractions can be found in [the Python Labeling module docs](../../reference/python/python_labeling.rst).

## Built-in label models

To make things even easier for you, we provide wrapper classes around the most common label models, that directly consume a `WeakLabels` object.
This makes working with those models a breeze.
Take a look at the list of built-in models in the [labeling module docs](../../reference/python/python_labeling.rst).


## Detailed Workflow

A typical workflow to use weak supervision is:

1. Create a Argilla dataset with your raw dataset. If you actually have some labelled data you can log it into the the same dataset.
2. Define a set of weak labeling rules with the Rules definition mode in the UI.
3. Create a `WeakLabels` object and apply the rules. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, you'll iterate between this step and step 2.
4. Once you are satisfied with your weak labels, use the matrix of the `WeakLabels` instance with your library/method of choice to build a training set or even train a downstream text classification model.


This guide shows you an end-to-end example using Snorkel, Flyingsquid and Weasel. Let's get started!

## Example dataset

We'll be using a well-known dataset for weak supervision examples, the [YouTube Spam Collection](http://www.dt.fee.unicamp.br/~tiago//youtubespamcollection/) dataset, which is a binary classification task for detecting spam comments in Youtube videos. 

In [1]:
import pandas as pd

# load data
train_df = pd.read_csv("../../tutorials/notebooks/data/yt_comments_train.csv")
test_df = pd.read_csv("../../tutorials/notebooks/data/yt_comments_test.csv")

# preview data
train_df.head()


Unnamed: 0.1,Unnamed: 0,author,date,text,label,video
0,0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx...,-1.0,1
1,1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Ta...",-1.0,1
2,2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,-1.0,1
3,3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",-1.0,1
4,4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",-1.0,1


## 1. Create a Argilla dataset with unlabelled data and test data

Let's load the train (non-labelled) and the test (containing labels) dataset.

In [4]:
import argilla as rg

# build records from the train dataset
records = [
    rg.TextClassificationRecord(
        text=row.text, metadata={"video": row.video, "author": row.author}
    )
    for i, row in train_df.iterrows()
]

# build records from the test dataset with annotation
labels = ["HAM", "SPAM"]
records += [
    rg.TextClassificationRecord(
        text=row.text,
        annotation=labels[row.label],
        metadata={"video": row.video, "author": row.author},
    )
    for i, row in test_df.iterrows()
]

# log records to Argilla
rg.log(records, name="weak_supervision_yt")

  0%|          | 0/1836 [00:00<?, ?it/s]

1836 records logged to http://localhost:6900/datasets/argilla/weak_supervision_yt


BulkResponse(dataset='weak_supervision_yt', processed=1836, failed=0)

After this step, you have a fully browsable dataset available that you can access via the [Argilla web app](../../reference/webapp/index.md).

## 2. Defining rules

Let's now define some of the rules proposed in the tutorial [Snorkel Intro Tutorial: Data Labeling](https://www.snorkel.org/use-cases/01-spam-tutorial). 
Most of these rules can be defined directly with our web app in the [Define rules mode](../../reference/webapp/define_rules.md) and [Elasticsearch's query strings](../../reference/webapp/features.html#search-records). 
Afterward, you can conveniently load them into your notebook with the [load_rules function](../../reference/python/python_labeling.rst).

Rules can also be defined programmatically as shown below. Depending on your use case and team structure you can mix and match both interfaces (UI or Python).

Let's see here some programmatic rules:

In [5]:
from argilla.labeling.text_classification import Rule, WeakLabels

#  rules defined as Elasticsearch queries
check_out = Rule(query="check out", label="SPAM")
plz = Rule(query="plz OR please", label="SPAM")
subscribe = Rule(query="subscribe", label="SPAM")
my = Rule(query="my", label="SPAM")
song = Rule(query="song", label="HAM")
love = Rule(query="love", label="HAM")


You can also define plain Python labeling functions:

In [6]:
import re

# rules defined as Python labeling functions
def contains_http(record: rg.TextClassificationRecord):
    if "http" in record.inputs["text"]:
        return "SPAM"


def short_comment(record: rg.TextClassificationRecord):
    return "HAM" if len(record.inputs["text"].split()) < 5 else None


def regex_check_out(record: rg.TextClassificationRecord):
    return (
        "SPAM" if re.search(r"check.*out", record.inputs["text"], flags=re.I) else None
    )


You can load your predefined rules and convert them to Rule instances, and add them to dataset 

In [7]:
labeling_rules_df = pd.read_csv("../../tutorials/notebooks/labeling_rules.csv")

In [8]:
labeling_rules_df.head()

Unnamed: 0.1,Unnamed: 0,query,label
0,0,your,SPAM
1,1,rich,SPAM
2,2,film,HAM
3,3,meeting,HAM
4,4,help,HAM


In [12]:
predefined_labeling_rules = []
for index, row in labeling_rules_df.iterrows():
    predefined_labeling_rules.append(
        Rule(row["query"], row["label"])
    )

## 3. Building and analyzing weak labels

In [10]:
from argilla.labeling.text_classification import load_rules, add_rules

# bundle our rules in a list
rules = [
    check_out,
    plz,
    subscribe,
    my,
    song,
    love
]

labeling_functions = [  
    contains_http,
    short_comment,
    regex_check_out
]

# add rules to dataset
add_rules(dataset="weak_supervision_yt", rules=rules)


# add the predefined rules loaded from external file
add_rules(dataset="weak_supervision_yt", rules=predefined_labeling_rules)

# load all the rules available in the dataset including interactively defined in the UI 
dataset_labeling_rules = load_rules(dataset="weak_supervision_yt")

# extend the labeling rules with labeling functions
dataset_labeling_rules.extend(labeling_functions)

# apply the final rules to the dataset
weak_labels = WeakLabels(dataset="weak_supervision_yt", rules=dataset_labeling_rules)


Preparing rules:   0%|          | 0/14 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/3672 [00:00<?, ?it/s]

In [11]:
# show some stats about the rules, see the `summary()` docstring for details
weak_labels.summary()


Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
check out,{SPAM},0.224401,0.176,0.224401,0.03159,88,0,1.0
plz OR please,{SPAM},0.104575,0.088,0.098039,0.036492,44,0,1.0
subscribe,{SPAM},0.101852,0.12,0.082244,0.03159,60,0,1.0
my,{SPAM},0.19281,0.192,0.168845,0.062636,84,12,0.875
song,{HAM},0.118192,0.172,0.070806,0.037037,68,18,0.790698
love,{HAM},0.090959,0.14,0.071351,0.034858,56,14,0.8
your,{SPAM},0.052832,0.088,0.041939,0.019608,38,6,0.863636
rich,{SPAM},0.000545,0.0,0.0,0.0,0,0,
film,{},0.0,0.0,0.0,0.0,0,0,
meeting,{},0.0,0.0,0.0,0.0,0,0,


You can remove the rules which are wrong from the dataset

In [20]:
not_informative_rules = [
    Rule("rich", "SPAM"),
    Rule("film", "HAM"),
    Rule("meeting", "HAM")
]

In [21]:
from argilla.labeling.text_classification import delete_rules
delete_rules(dataset="weak_supervision_yt", rules=not_informative_rules)

You can update the rule

help	{HAM}	0.027778	0.036	0.023965	0.023965	0	18	0.000000

In [24]:
help_rule = Rule("help", label="SPAM")
help_rule.update_at_dataset(dataset="weak_supervision_yt")

Lets load the rules again and apply weak labelling

In [25]:
final_rules = labeling_functions + load_rules(dataset="weak_supervision_yt")

In [26]:
weak_labels = WeakLabels(dataset="weak_supervision_yt", rules=final_rules)

Preparing rules:   0%|          | 0/11 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/3672 [00:00<?, ?it/s]

In [28]:
weak_labels.summary()

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
contains_http,{SPAM},0.106209,0.024,0.078431,0.04902,12,0,1.0
short_comment,{HAM},0.245098,0.368,0.101307,0.06427,168,16,0.913043
regex_check_out,{SPAM},0.22658,0.18,0.226035,0.027778,90,0,1.0
check out,{SPAM},0.224401,0.176,0.224401,0.027778,88,0,1.0
plz OR please,{SPAM},0.104575,0.088,0.098039,0.02342,44,0,1.0
subscribe,{SPAM},0.101852,0.12,0.082244,0.025054,60,0,1.0
my,{SPAM},0.19281,0.192,0.168845,0.050654,84,12,0.875
song,{HAM},0.118192,0.172,0.070806,0.037037,68,18,0.790698
love,{HAM},0.090959,0.14,0.071351,0.034858,56,14,0.8
your,{SPAM},0.052832,0.088,0.041939,0.015795,38,6,0.863636


## 4. Using the weak labels

At this step you have at least two options:

1. Use the weak labels for training a "denoising" or label model to build a less noisy training set. Highly popular options for this are [Snorkel](https://snorkel.org/) or [Flyingsquid](https://github.com/HazyResearch/flyingsquid). After this step, you can train a downstream model with the "clean" labels.

2. Use the weak labels directly with recent "end-to-end" (e.g., [Weasel](https://github.com/autonlab/weasel)) or joint models (e.g., [COSINE](https://github.com/yueyu1030/COSINE)).


Let's see some examples:

### A simple majority vote

As a first example we will show you, how to use the `WeakLabels` object together with a simple majority vote model, which is arguably the most straightforward label model.
On a per-record basis, it simply counts the votes for each label returned by the rules, and takes the majority vote.
Argilla provides a neat implementation of this logic in its `MajorityVoter` class.

In [29]:
from argilla.labeling.text_classification import MajorityVoter

# instantiate the majority vote label model by simply providing the weak labels object
majority_model = MajorityVoter(weak_labels)


In contrast to the other label models we will discuss further down, the majority voter does not need to be fitted.
You can directly check its performance by simply calling its `score()` method.

In [30]:
# check its performance
print(majority_model.score(output_str=True))


              precision    recall  f1-score   support

         HAM       0.94      0.99      0.96       216
        SPAM       0.99      0.93      0.96       204

    accuracy                           0.96       420
   macro avg       0.96      0.96      0.96       420
weighted avg       0.96      0.96      0.96       420



An accuracy of 0.96 seems surprisingly high, but you need to keep in mind that we simply excluded the records from the evaluation, for which the model abstained (that is a tie in the votes or no votes at all).
So let's account for this and correct the accuracy by assuming the model performs like a random classifier for these abstained records:

> $accuracy_c = frac_{non} \times accuracy + frac_{abs} \times accuracy_{random}$

where $frac_{non}$ is the fraction of non-abstained records and $frac_{abs}$ the fraction of abstained records.

In [31]:
# calculate fractions using the support metric (see above)
frac_non = 200 / len(weak_labels.annotation())
frac_abs = 1 - (200 / len(weak_labels.annotation()))

# accuracy without abstentions: 0.96; accuracy of random classifier: 0.5
print("accuracy_c:", frac_non * 0.96 + frac_abs * 0.5)
# accuracy_c: 0.868


accuracy_c: 0.6839999999999999


As we will see further down, **an accuracy of 0.868** is still a very decent baseline.

<div class="alert alert-info">

Note

To get a noisy estimate of the corrected accuracy, you can also set the _"tie_break_policy"_ argument: `majority_model.score(..., tie_break_policy="random")`.
    
</div>

When predicting weak labels to train a down-stream model, however, you probably want to discard the abstentions.
Calling the `predict()` method on the majority voter, excludes the abstentions by default and only returns records without annotations.
These are normally used to build a training set for a downstream model.

You can quickly explore the predicted records with Argilla, before building a training set for training a downstream text classifier. 
This step is useful for validation, manual revision, or defining score thresholds for accepting labels from your label model (for example, only considering labels with a score greater then 0.8.)

In [32]:
# get your training records with the predictions of the label model
records_for_training = majority_model.predict()

# optional: log the records to a new dataset in Argilla
rg.log(records_for_training, name="majority_voter_results")

# extract training data
training_data = pd.DataFrame(
    [{"text": rec.text, "label": rec.prediction[0][0]} for rec in records_for_training]
)


  0%|          | 0/2106 [00:00<?, ?it/s]

2106 records logged to http://localhost:6900/datasets/argilla/majority_voter_results


In [33]:
# preview training data
training_data


Unnamed: 0,text,label
0,this song is better then monster by eminem﻿,HAM
1,Hey guys! My mom said if i got 100 subs before...,SPAM
2,hi everyone this is cool check out sexy and i ...,SPAM
3,I love this shit but I disliked it because it'...,HAM
4,For all you ladies out there...... Check out ...,SPAM
...,...,...
2101,awesome﻿,HAM
2102,Good﻿,HAM
2103,Check out this video on YouTube:﻿,SPAM
2104,hey guys look im aware im spamming and it piss...,SPAM


### Label model with Snorkel

Snorkel's label model is by far the most popular option for using weak supervision, and Argilla provides built-in support for it. 
Using Snorkel with Argilla's `WeakLabels` is as simple as:

In [34]:
%pip install snorkel -qqq

Note: you may need to restart the kernel to use updated packages.


In [35]:
from argilla.labeling.text_classification import Snorkel

# we pass our WeakLabels instance to our Snorkel label model
snorkel_model = Snorkel(weak_labels)

# we fit the model
snorkel_model.fit(lr=0.001, n_epochs=50)


100%|██████████| 50/50 [00:00<00:00, 2798.11epoch/s]


<div class="alert alert-info">

Note

The `Snorkel` label model is not suited for multi-label classification tasks and does not support them.
    
</div>

When fitting the snorkel model, we recommend performing a quick grid search for the learning rate `lr` and the number of epochs `n_epochs`.

In [36]:
# we check its performance
print(snorkel_model.score(output_str=True))


              precision    recall  f1-score   support

         HAM       0.94      0.94      0.94       228
        SPAM       0.93      0.93      0.93       212

    accuracy                           0.94       440
   macro avg       0.94      0.94      0.94       440
weighted avg       0.94      0.94      0.94       440



At first sight, the model seems to perform worse than the majority vote baseline.
However, let's again correct the accuracy for the abstentions.

In [37]:
# calculate fractions using the support metric (see above)
frac_non = 209 / len(weak_labels.annotation())
frac_abs = 1 - (209 / len(weak_labels.annotation()))

# accuracy without abstentions: 0.95; accuracy of random classifier: 0.5
print("accuracy_c:", frac_non * 0.95 + frac_abs * 0.5)
# accuracy_c: 0.8761999999999999


accuracy_c: 0.6880999999999999


Now we can see that with **an accuracy of 0.876**, its performance over the whole test set is actually slightly better.

After fitting your label model, you can quickly explore its predictions, before building a training set for training a downstream text classifier. 
This step is useful for validation, manual revision, or defining score thresholds for accepting labels from your label model (for example, only considering labels with a score greater then 0.8.)

In [38]:
# get your training records with the predictions of the label model
records_for_training = snorkel_model.predict()

# optional: log the records to a new dataset in Argilla
rg.log(records_for_training, name="snorkel_results")

# extract training data
training_data = pd.DataFrame(
    [{"text": rec.text, "label": rec.prediction[0][0]} for rec in records_for_training]
)


  0%|          | 0/2358 [00:00<?, ?it/s]

2358 records logged to http://localhost:6900/datasets/argilla/snorkel_results


In [39]:
# preview training data
training_data


Unnamed: 0,text,label
0,this song is better then monster by eminem﻿,HAM
1,Hey guys! My mom said if i got 100 subs before...,SPAM
2,hi everyone this is cool check out sexy and i ...,SPAM
3,my favorite song﻿,SPAM
4,I love this shit but I disliked it because it'...,HAM
...,...,...
2353,Check out this video on YouTube:﻿,SPAM
2354,http://woobox.com/33gxrf/brt0u5 FREE CS GO!!!!﻿,HAM
2355,hey guys look im aware im spamming and it piss...,SPAM
2356,Subscribe to My CHANNEL﻿,SPAM


<div class="alert alert-info">

Note

For an example of how to use the `WeakLabels` object with Snorkel's raw `LabelModel` class, you can check out the [WeakLabels reference](../../reference/python/python_labeling.rst).
    
</div>

### Label model with FlyingSquid

FlyingSquid is a powerful method developed by [Hazy Research](https://hazyresearch.stanford.edu/), a research group from Stanford behind ground-breaking work on programmatic data labeling, including Snorkel.
FlyingSquid uses a closed-form solution for fitting the label model with great speed gains and similar performance.
Just like for Snorkel, Argilla provides built-in support for FlyingSquid, too.

In [40]:
%pip install flyingsquid pgmpy -qqq

Note: you may need to restart the kernel to use updated packages.


In [41]:
from argilla.labeling.text_classification import FlyingSquid

# we pass our WeakLabels instance to our FlyingSquid label model
flyingsquid_model = FlyingSquid(weak_labels)

# we fit the model
flyingsquid_model.fit()




<div class="alert alert-info">

Note

The `FlyingSquid` label model is not suited for multi-label classification tasks and does not support them.
    
</div>

In [42]:
# we check its performance
print(flyingsquid_model.score(output_str=True))


              precision    recall  f1-score   support

         HAM       0.94      0.92      0.93       228
        SPAM       0.92      0.93      0.93       212

    accuracy                           0.93       440
   macro avg       0.93      0.93      0.93       440
weighted avg       0.93      0.93      0.93       440



Again, let's correct the accuracy for the abstentions.

In [43]:
# calculate fractions using the support metric (see above)
frac_non = 209 / len(weak_labels.annotation())
frac_abs = 1 - (209 / len(weak_labels.annotation()))

# accuracy without abstentions: 0.93; accuracy of random classifier: 0.5
print("accuracy_c:", frac_non * 0.93 + frac_abs * 0.5)
# accuracy_c: 0.85948


accuracy_c: 0.67974


Here, it really seems that with **an accuracy of 0.859**, the performance over the whole test set is actually slightly worse than the baseline of the majority vote.

After fitting your label model, you can quickly explore its predictions, before building a training set for training a downstream text classifier. 
This step is useful for validation, manual revision, or defining score thresholds for accepting labels from your label model (for example, only considering labels with a score greater then 0.8.)

In [44]:
# get your training records with the predictions of the label model
records_for_training = flyingsquid_model.predict()

# log the records to a new dataset in Argilla
rg.log(records_for_training, name="flyingsquid_results")

# extract training data
training_data = pd.DataFrame(
    [{"text": rec.text, "label": rec.prediction[0][0]} for rec in records_for_training]
)


  0%|          | 0/2358 [00:00<?, ?it/s]

2358 records logged to http://localhost:6900/datasets/argilla/flyingsquid_results


In [45]:
# preview training data
training_data


Unnamed: 0,text,label
0,this song is better then monster by eminem﻿,HAM
1,Hey guys! My mom said if i got 100 subs before...,SPAM
2,hi everyone this is cool check out sexy and i ...,SPAM
3,my favorite song﻿,SPAM
4,I love this shit but I disliked it because it'...,HAM
...,...,...
2353,Check out this video on YouTube:﻿,SPAM
2354,http://woobox.com/33gxrf/brt0u5 FREE CS GO!!!!﻿,SPAM
2355,hey guys look im aware im spamming and it piss...,SPAM
2356,Subscribe to My CHANNEL﻿,SPAM


### Joint Model with Weasel

[Weasel](https://github.com/autonlab/weasel) lets you train downstream models end-to-end using directly weak labels.
In contrast to Snorkel or FlyingSquid, which are two-stage approaches, Weasel is a one-stage method that jointly trains the label and the end model at the same time.
For more details check out the [End-to-End Weak Supervision paper](https://arxiv.org/abs/2107.02233) presented at NeurIPS 2021.

In this guide we will show you, how you can **train a Hugging Face transformers** model directly **with weak labels using Weasel**.
Since Weasel uses [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) for the training, some basic knowledge of PyTorch is helpful, but not strictly necessary.

Let's start with installing the Weasel python package:

In [46]:
!python -m pip install git+https://github.com/autonlab/weasel#egg=weasel[all]

Collecting weasel[all]
  Cloning https://github.com/autonlab/weasel to c:\users\ufukh\appdata\local\temp\pip-install-4zxnfbo0\weasel_39696956984e47bfbd9af0f94b6cab25
  Resolved https://github.com/autonlab/weasel to commit c6bccc5cb919b9cc6f4296a0ac34fcd128f628d4
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/autonlab/weasel 'C:\Users\ufukh\AppData\Local\Temp\pip-install-4zxnfbo0\weasel_39696956984e47bfbd9af0f94b6cab25'


The first step is to obtain our weak labels.
For this we use the same rules and data set as in the examples above (Snorkel and FlyingSquid).

In [47]:
# obtain our weak labels
weak_labels = WeakLabels(rules=rules, dataset="weak_supervision_yt")


Preparing rules:   0%|          | 0/6 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/3672 [00:00<?, ?it/s]

In a second step we instantiate our end model, which in our case will be a pre-trained transformer from the Hugging Face Hub.
Here we choose the small ELECTRA model by Google that shows excellent performance given its moderate number of parameters.
Due to its size, you can fine-tune it on your CPU within a reasonable amount of time.

In [48]:
from weasel.models.downstream_models.transformers import Transformers

# instantiate our transformers end model
end_model = Transformers("google/electra-small-discriminator", num_labels=2)


Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

With our end-model at hand, we can now instantiate the Weasel model.
Apart from the end-model, it also includes a neural encoder that tries to estimate latent labels.

In [49]:
from weasel.models import Weasel

# instantiate our weasel end-to-end model
weasel = Weasel(
    end_model=end_model,
    num_LFs=len(weak_labels.rules),
    n_classes=2,
    encoder={"hidden_dims": [32, 10]},
    optim_encoder={"name": "adam", "lr": 1e-4},
    optim_end_model={"name": "adam", "lr": 5e-5},
)


  rank_zero_warn(


Afterwards, we wrap our data in the `TransformersDataModule`, so that Weasel and PyTorch Lightning can work with it.
In this step we also tokenize the data. 
Here we need to be careful to use the corresponding tokenizer to our end model.

In [50]:
from transformers import AutoTokenizer
from weasel.datamodules.transformers_datamodule import (
    TransformersDataModule,
    TransformersCollator,
)

# tokenizer for our transformers end model
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")

# tokenize train and test data
X_train = [
    tokenizer(rec.text, truncation=True)
    for rec in weak_labels.records(has_annotation=False)
]
X_test = [
    tokenizer(rec.text, truncation=True)
    for rec in weak_labels.records(has_annotation=True)
]

# instantiate data module
datamodule = TransformersDataModule(
    label_matrix=weak_labels.matrix(has_annotation=False),
    X_train=X_train,
    collator=TransformersCollator(tokenizer),
    X_test=X_test,
    Y_test=weak_labels.annotation(),
    batch_size=8,
)


Now we have everything ready to start the training of our Weasel model.
For the training process, Weasel relies on the excellent [PyTorch Lightning Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html).
It provides tons of options and features to optimize the training process, but the defaults below should give you reasonable results.
Keep in mind that you are fine-tuning a full-blown transformer model, albeit a small one.

In [51]:
import pytorch_lightning as pl

# instantiate the pytorch-lightning trainer
trainer = pl.Trainer(
    gpus=0,  # >= 1 to use GPU(s)
    max_epochs=2,
    logger=None,
    callbacks=[pl.callbacks.ModelCheckpoint(monitor="Val/accuracy", mode="max")],
)

# fit the model end-to-end
trainer.fit(
    model=weasel,
    datamodule=datamodule,
)


  rank_zero_deprecation(
2022-11-18 19:04:40.953 | INFO     | lightning_utilities.core.rank_zero:_info:45 - GPU available: False, used: False
2022-11-18 19:04:41.019 | INFO     | lightning_utilities.core.rank_zero:_info:45 - TPU available: False, using: 0 TPU cores
2022-11-18 19:04:41.020 | INFO     | lightning_utilities.core.rank_zero:_info:45 - IPU available: False, using: 0 IPUs
2022-11-18 19:04:41.020 | INFO     | lightning_utilities.core.rank_zero:_info:45 - HPU available: False, using: 0 HPUs
2022-11-18 19:04:41.239 | INFO     | pytorch_lightning.callbacks.model_summary:summarize:83 - 
  | Name          | Type         | Params
-----------------------------------------------
0 | end_model     | Transformers | 13.5 M
1 | encoder       | MLPEncoder   | 770   
2 | accuracy_func | Softmax      | 0     
-----------------------------------------------
13.6 M    Trainable params
0         Non-trainable params
13.6 M    Total params
54.200    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


MisconfigurationException: You are trying to `self.log()` but it is not managed by the `Trainer` control flow

After the training we can call the `Trainer.test` method to check the final performance. 
The model should achieve a test accuracy of around 0.94.

In [None]:
trainer.test()
# {'accuracy': 0.94, ...}


To use the model for inference, you can either use its *predict* method:

In [None]:
# Example text for the inference
text = "In my head this is like 2 years ago.. Time FLIES"

# Get predictions for the example text
predicted_probs, predicted_label = weasel.predict(tokenizer(text, return_tensors="pt"))

# Map predicted int to label
weak_labels.int2label[int(predicted_label)]  # HAM


Or you can instantiate one of the popular transformers pipelines, providing directly the end-model and the tokenizer:

In [None]:
from transformers import pipeline

# modify the id2label mapping of the model
weasel.end_model.model.config.id2label = weak_labels.int2label

# create transformers pipeline
classifier = pipeline(
    "text-classification", model=weasel.end_model.model, tokenizer=tokenizer
)

# use pipeline for predictions
classifier(text)  # [{'label': 'HAM', 'score': 0.6110987663269043}]
