# 📰 Building a news classifier with weak supervision

In this tutorial, we will build a news classifier using rules and weak supervision: 

- 📰 For this example, we use the AG News dataset but you can follow this process to programmatically label any dataset.
- 🤿 The train split without labels is used to build a training set with rules, Rubrix and Snorkel's Label model.
- 🔧 The test set is used for evaluating our weak labels, label model and downstream news classifier.
- 🤯 We achieve 0.84 macro avg. f1-score without using a single example from the original dataset and using a pretty lightweight model (scikit-learn's `MultinomialNB`).


<video width="100%" controls><source src="../_static/tutorials/weak-supervision-with-rubrix/ws_news.mp4" type="video/mp4"></video>

The following diagram shows the overall process for using Weak supervision with Rubrix:

![Labeling workflow](../_static/tutorials/weak-supervision-with-rubrix/weak_supervision1.svg "Labeling workflow")

## Introduction

> *Weak supervision is a branch of machine learning where noisy, limited, or imprecise sources are used to provide supervision signal for labeling large amounts of training data in a supervised learning setting. This approach alleviates the burden of obtaining hand-labeled data sets, which can be costly or impractical. Instead, inexpensive weak labels are employed with the understanding that they are imperfect, but can nonetheless be used to create a strong predictive model.* [[Wikipedia]](https://en.wikipedia.org/wiki/Weak_supervision)

For a broader introduction to weak supervision, as well as further references, we recommend the excellent [overview by Alex Ratner et al.](https://www.snorkel.org/blog/weak-supervision).

This tutorial aims to be a practical introduction to weak supervision and will walk you through its entire process.
First we will generate weak labels with *Rubrix*, combine these labels with *Snorkel*, and finally train a classifier with *Scikit Learn*.

## Setup

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ [Github repository](https://github.com/recognai/rubrix).

If you have not installed and launched Rubrix yet, check the [Setup and Installation guide](../getting_started/setup&installation.rst).

For this tutorial we also need some third party libraries that can be installed via pip:

In [None]:
%pip install snorkel datasets sklearn sentence-transformers -qqq

<div class="alert alert-info">

Note

If you want to skip the first three sections of this tutorial, and only prepare the training set and train a downstream model, you can load the records directly from the [Hugging Face Hub](https://huggingface.co/datasets):

```python
import rubrix as rb
from datasets import load_dataset

records = rb.read_datasets(
    load_dataset("rubrix/news", split="train"),
    task="TextClassification",
)
```

</div>

## 1. Load test and unlabelled datasets into Rubrix

First, let's download the `ag_news` data set and have a quick look at it.

In [74]:
from datasets import load_dataset

# load our data
dataset = load_dataset("ag_news")

# get the index to label mapping 
labels = dataset["test"].features["label"].names



  0%|          | 0/2 [00:00<?, ?it/s]

In [73]:
import pandas as pd

# quick look at our data
with pd.option_context('display.max_colwidth', None):
    display(dataset["test"].to_pandas().head())

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.,2
1,"The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.",3
2,"Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.",3
3,"Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.",3
4,"Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.",3


Now we will log the test split of our data set to *Rubrix*, which we will be using for testing our label and downstream models.

In [75]:
import rubrix as rb

# build our test records
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "test"},
        annotation=labels[record["label"]]
    )
    for record in dataset["test"]
]

# log the records to Rubrix
rb.log(records, name="news")

  0%|          | 0/7600 [00:00<?, ?it/s]

7600 records logged to http://localhost:6900/ws/rubrix/news


BulkResponse(dataset='news', processed=7600, failed=0)

In a second step we log the train split without labels.
Remember, our goal is to programmatically build a training set using rules and weak supervision.

In [76]:
# build our training records without labels
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
    )
    for record in dataset["train"]
]

# log the records to Rubrix
rb.log(records, name="news")

  0%|          | 0/120000 [00:00<?, ?it/s]

120000 records logged to http://localhost:6900/ws/rubrix/news


BulkResponse(dataset='news', processed=120000, failed=0)

The result of the above is the following dataset in Rubrix, with **127,600 records** (120,000 unlabelled and 7,600 for testing). 

You can use the web app to find good rules for programmatic labeling!

## 2. Interactive weak labeling: Finding and defining rules

After logging the dataset, you can find and save rules directly with the UI. Then, you can read the rules with Python to train a label or downstream model, as we'll see in the next step. 

<video width="100%" controls><source src="../_static/tutorials/weak-supervision-with-rubrix/ws_news.mp4" type="video/mp4"></video>

## 3. Denoise weak labels with Snorkel's Label Model

The goal at this step is to **denoise** the weak labels we've just created using rules. There are several approaches to this problem using different statistical methods.

In this tutorial, we're going to use Snorkel but you can actually use any other Label model or weak supervision method, such as FlyingSquid for example (see the [Weak supervision guide](../guides/weak-supervision.ipynb) for more details).
For convenience, Rubrix defines a simple wrapper over Snorkel's Label Model so it's easier to use with Rubrix weak labels and datasets

Let's first read the rules defined in our dataset and create our weak labels:

In [27]:
from rubrix.labeling.text_classification import load_rules, WeakLabels

rules = load_rules(dataset="news")

weak_labels = WeakLabels(
    rules=rules, 
    dataset="news"
)
weak_labels.summary()

Preparing rules:   0%|          | 0/18 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/129100 [00:00<?, ?it/s]

Unnamed: 0,polarity,coverage,overlaps,conflicts,correct,incorrect,precision
sci*,{Sci/Tech},0.0166,0.003176,0.001588,138,33,0.807018
dollar*,{Business},0.016592,0.006723,0.00299,108,41,0.724832
*ball,{Sports},0.030132,0.010015,0.001425,257,31,0.892361
conflict,{World},0.003052,0.000999,0.000287,23,5,0.821429
financ*,{Business},0.01962,0.007622,0.005298,90,70,0.5625
match,{Sports},0.008629,0.002138,0.000287,78,7,0.917647
goal,{Sports},0.005585,0.001774,0.000395,41,9,0.82
election,{World},0.017235,0.011789,0.002192,128,27,0.825806
president*,{World},0.053346,0.01859,0.007188,353,130,0.730849
techn*,{Sci/Tech},0.03031,0.012277,0.005143,193,75,0.720149


In [59]:
from rubrix.labeling.text_classification import Snorkel

# create the label model
label_model = Snorkel(weak_labels)

# fit the model
label_model.fit()

# test it with labeled test set
label_model.score()

100%|██████████| 100/100 [00:00<00:00, 518.39epoch/s]


{'Sci/Tech': {'precision': 0.7726692209450831,
  'recall': 0.7826649417852523,
  'f1-score': 0.7776349614395888,
  'support': 773},
 'World': {'precision': 0.6925675675675675,
  'recall': 0.8055009823182712,
  'f1-score': 0.7447774750227066,
  'support': 509},
 'Sports': {'precision': 0.8102288021534321,
  'recall': 0.9525316455696202,
  'f1-score': 0.8756363636363637,
  'support': 632},
 'Business': {'precision': 0.6506024096385542,
  'recall': 0.3576158940397351,
  'f1-score': 0.46153846153846156,
  'support': 453},
 'accuracy': 0.7515842839036755,
 'macro avg': {'precision': 0.7315170000761593,
  'recall': 0.7245783659282198,
  'f1-score': 0.7148968154092802,
  'support': 2367},
 'weighted avg': {'precision': 0.7421114043978349,
  'recall': 0.7515842839036755,
  'f1-score': 0.7362410920466687,
  'support': 2367}}

## 4. Prepare our training set

Now, we already have a "denoised" training set, which we can prepare for training a downstream model. 
The label model predict returns `TextClassificationRecord` objects with the `predictions` from the label model. 

We can either refine and review these records using the Rubrix web app, use them as is, or filter them by score, for example.

In this case, we assume the predictions are precise enough and use them without any revision. 
Our training set has ~38,000 records, which corresponds to all records where the label model has not abstained.

In [60]:
import pandas as pd

# get records with the predictions from the label model
records = label_model.predict()

# build a simple dataframe with text and the prediction with the highest score
df_train = pd.DataFrame([
    {"text": record.text, "label": label_model.weak_labels.label2int[record.prediction[0][0]]}
    for record in records
])

# quick look at our training data with the weak labels from our label model 
with pd.option_context('display.max_colwidth', None):
    display(df_train)

Unnamed: 0,text,label
0,"Saturn's Moon Titan: Planet Wannabe by Henry Bortman Jonathan Lunine, professor of planetary science and physics at the University of Arizona's Lunar and Planetary Laboratory in Tucson, Arizona, has long been fascinated by Saturn's largest moon, Titan. In this first part of the interview, Lunine explains what scientists hope to learn from Huygens...",0
1,Reds with a Spanish spine Rafael Benitez would not have imagined a more troubled introduction to his debut Premiership season following Michael Owen #39;s shock transfer. NANTHA KUMAR examines the tough transition period at Liverpool.LIVERPOOL Football Club were not the only party ...,2
2,"Gateway spreads out at retail Gateway computers will be more widely available at Office Depot, in the PC maker #39;s latest move to broaden distribution at retail stores since acquiring rival eMachines this year.",0
3,"Apple recalls 15-in. PowerBook batteries Apple Computer Inc. has issued a recall for about 28,000 PowerBook batteries sold between January and August for use with its 15-in. PowerBook G4 computers",0
4,"Microsoft delays security update for Windows XP Professional REDMOND, Wash. -- Microsoft has delayed distribution of a security update for users of its Windows XP Professional operating system to give some companies more time to test it, the software company said Tuesday.",0
...,...,...
1468,"New Clot Preventer Saves Lives and Money By Ed Edelson, HealthDay Reporter HealthDayNews -- A new anti-clotting drug for people having artery-opening procedures lowers the rate of complications, gets patients out of the hospital faster, and probably saves lives, a study finds. And it saves money to boot, says Dr...",3
1469,"Intuit 4Q Loss Widens on Charge MOUNTAIN VIEW, Calif. (AP)--Intuit Inc. #39;s loss widened for the fourth quarter ended July 31, hurt by slower seasonal sales of the company #39;s tax and finance software and an impairment charge from its decision to sell one of its businesses, the company said ...",0
1470,"Real v Apple music war: iPod freedom petition backfires Hostilities started in late July, when Real cracked Apple #39;s FairPlay code, meaning songs bought from the RealPlayer Music Store could be played on the iPod - a move that went down very badly over at Apple. Real then decided to ratchet up the pressure by ...",2
1471,"Rwanda Troops Start AU Mission in Darfur EL FASHER, Sudan (Reuters) - Rwandan troops arrived in Darfur Sunday as the first foreign force there, mandated to protect observers monitoring a shaky cease-fire between the Sudanese government and rebels in the remote western region.",1


Record can be loaded again now:

In [61]:
# for the test set, we can retrieve the records with validated annotations (the original ag_news test set)
df_test = rb.load("news", query="status:Validated")

# transform data to match our training set format
df_test['annotation'] = df_test['annotation'].apply(
    lambda r: label_model.weak_labels.label2int[r]
)

## 5. Train a downstream model with scikit-learn

Now, let's train our final model using `scikit-learn`:

In [62]:
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

# define our final classifier
classifier = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', MultinomialNB())
])

# fit the classifier
classifier.fit(
    X=df_train.text.tolist(), 
    y=df_train.label.values
)

Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])

In [63]:
# compute the test accuracy
accuracy = classifier.score(
    X=df_test.text.tolist(), 
    y=label_model.weak_labels.annotation()
)

print(f"Test accuracy: {accuracy}")

Test accuracy: 0.7227631578947369


Not too bad! 🥳

We have achieved around **0.84 accuracy** without even using a single example from the original `ag_news` train set and with a small set of rules (less than 30). Also, we've improved over the 0.81 accuracy of our Label Model.

Finally, let's take a look at more detailed metrics:

In [64]:
from sklearn import metrics

# get predictions for the test set
predicted = classifier.predict(df_test.text.tolist())

print(metrics.classification_report(label_model.weak_labels.annotation(), predicted, target_names=labels))

              precision    recall  f1-score   support

       World       0.60      0.87      0.71      1900
      Sports       0.73      0.84      0.78      1900
    Business       0.83      0.94      0.88      1900
    Sci/Tech       0.86      0.25      0.39      1900

    accuracy                           0.72      7600
   macro avg       0.76      0.72      0.69      7600
weighted avg       0.76      0.72      0.69      7600



At this point, we could go back to the UI to define more rules for those labels with less performance. Looking at the above table, we might want to add some more rules for increasing the recall of the `Business` label.

## 6. Weak supervision

In [80]:
# temporary cell ( remove before merge )
import pickle
sentences = [x.text for x in weak_labels.records()]
with open('/mnt/d/tmp/epoxy/sentences.pkl','wb') as f:
    pickle.dump(sentences, f)

In [81]:
# temporary cell ( remove before merge )
import pickle
with open('/mnt/d/tmp/epoxy/embeddings.pkl','rb') as f:
    embeddings = pickle.load(f)

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
embeddings = model.encode([x.text for x in weak_labels.records()])

In [82]:
thresholds = [0.9] * len(rules)
weak_labels.extend_matrix(thresholds, embeddings)

In [83]:
def avoid_label_overwrite(weak_labels):
    null_label = weak_labels.label2int[None]
    for idx, row in enumerate(weak_labels._matrix):
        if not all([x == null_label for x in row]):
            weak_labels._extended_matrix[idx] = weak_labels._matrix[idx]
    return weak_labels

weak_labels = avoid_label_overwrite(weak_labels)

In [88]:
from rubrix.labeling.text_classification import Snorkel
import pandas as pd
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

def get_accuracy(weak_labels, thresholds):
    # define our final classifier
    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])
    label_model = Snorkel(weak_labels)
    # fit the model
    label_model.fit()
    # test it with labeled test set
    label_model.score()
    records = label_model.predict()
    # build a simple dataframe with text and the prediction with the highest score
    df_train = pd.DataFrame([
        {"text": record.text, "label": label_model.weak_labels.label2int[record.prediction[0][0]]}
        for record in records
    ])
    classifier.fit(
        X=df_train.text.tolist(), 
        y=df_train.label.values
    )
    # compute the test accuracy
    accuracy = classifier.score(
        X=df_test.text.tolist(), 
        y=label_model.weak_labels.annotation()
    )
    return accuracy

thresholds = [0.2] * len(rules)
weak_labels.extend_matrix(thresholds)
weak_labels = avoid_label_overwrite(weak_labels)
acc = get_accuracy(weak_labels, thresholds)
acc

100%|██████████| 100/100 [00:00<00:00, 574.25epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.25

In [42]:
from rubrix.labeling.text_classification import Snorkel
import pandas as pd
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

def get_accuracy(weak_labels, thresholds):
    # define our final classifier
    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])
    label_model = Snorkel(weak_labels)
    # fit the model
    label_model.fit()
    # test it with labeled test set
    label_model.score()
    records = label_model.predict()
    # build a simple dataframe with text and the prediction with the highest score
    df_train = pd.DataFrame([
        {"text": record.text, "label": label_model.weak_labels.label2int[record.prediction[0][0]]}
        for record in records
    ])
    classifier.fit(
        X=df_train.text.tolist(), 
        y=df_train.label.values
    )
    # compute the test accuracy
    accuracy = classifier.score(
        X=df_test.text.tolist(), 
        y=label_model.weak_labels.annotation()
    )
    return accuracy

import numpy as np
search_space = np.linspace(0.1, 1.0, num=10)
threshold_dict = {}
for i in search_space:
    thresholds = [i] * len(rules)
    weak_labels = avoid_label_overwrite(weak_labels)
    weak_labels.extend_matrix(thresholds, embeddings)
    acc = get_accuracy(weak_labels, thresholds)
    threshold_dict[i] = acc

max_threshold = max(threshold_dict, key=threshold_dict.get)
best_thresholds = [max_threshold] * len(rules)

for idx, item in enumerate(best_thresholds):
    search_space = np.linspace(0.1, 1.0, num=10)
    threshold_dict = {}
    for i in search_space: 
        best_thresholds[idx] = i
        weak_labels.extend_matrix(best_thresholds, embeddings)
        acc = get_accuracy(weak_labels, best_thresholds)
        threshold_dict[i] = acc
    max_val = max(threshold_dict, key=threshold_dict.get)
    best_thresholds[idx] = max_val
    print(threshold_dict[max_val], idx, max_val)

100%|██████████| 100/100 [00:00<00:00, 734.09epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:00<00:00, 598.32epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:00<00:00, 659.73epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:00<00:00, 773.11epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 100/100 [00:00<00:00, 908.44epoch/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(ave

0.7563157894736842 0 0.7000000000000001


100%|██████████| 100/100 [00:00<00:00, 712.00epoch/s]
100%|██████████| 100/100 [00:00<00:00, 618.52epoch/s]
100%|██████████| 100/100 [00:00<00:00, 707.45epoch/s]
100%|██████████| 100/100 [00:00<00:00, 771.91epoch/s]
100%|██████████| 100/100 [00:00<00:00, 395.01epoch/s]
100%|██████████| 100/100 [00:00<00:00, 950.41epoch/s]
100%|██████████| 100/100 [00:00<00:00, 593.24epoch/s]
100%|██████████| 100/100 [00:00<00:00, 514.80epoch/s]
100%|██████████| 100/100 [00:00<00:00, 849.63epoch/s]
100%|██████████| 100/100 [00:00<00:00, 766.22epoch/s]


0.7610526315789473 1 0.9


100%|██████████| 100/100 [00:00<00:00, 784.14epoch/s]
100%|██████████| 100/100 [00:00<00:00, 611.00epoch/s]
100%|██████████| 100/100 [00:00<00:00, 1007.11epoch/s]
100%|██████████| 100/100 [00:00<00:00, 802.15epoch/s]
100%|██████████| 100/100 [00:00<00:00, 794.73epoch/s]
100%|██████████| 100/100 [00:00<00:00, 430.27epoch/s]
100%|██████████| 100/100 [00:00<00:00, 480.70epoch/s]
100%|██████████| 100/100 [00:00<00:00, 865.17epoch/s]
100%|██████████| 100/100 [00:00<00:00, 534.52epoch/s]
100%|██████████| 100/100 [00:00<00:00, 735.56epoch/s]


0.7610526315789473 2 1.0


100%|██████████| 100/100 [00:00<00:00, 697.72epoch/s]
100%|██████████| 100/100 [00:00<00:00, 854.22epoch/s]
100%|██████████| 100/100 [00:00<00:00, 685.67epoch/s]
100%|██████████| 100/100 [00:00<00:00, 639.51epoch/s]
100%|██████████| 100/100 [00:00<00:00, 493.97epoch/s]
100%|██████████| 100/100 [00:00<00:00, 636.76epoch/s]
100%|██████████| 100/100 [00:00<00:00, 870.72epoch/s]
100%|██████████| 100/100 [00:00<00:00, 630.93epoch/s]
100%|██████████| 100/100 [00:00<00:00, 689.41epoch/s]
100%|██████████| 100/100 [00:00<00:00, 488.49epoch/s]


0.7610526315789473 3 1.0


100%|██████████| 100/100 [00:00<00:00, 642.54epoch/s]
100%|██████████| 100/100 [00:00<00:00, 871.14epoch/s]
100%|██████████| 100/100 [00:00<00:00, 591.46epoch/s]
100%|██████████| 100/100 [00:00<00:00, 622.97epoch/s]
100%|██████████| 100/100 [00:00<00:00, 612.22epoch/s]
100%|██████████| 100/100 [00:00<00:00, 663.31epoch/s]
100%|██████████| 100/100 [00:00<00:00, 410.72epoch/s]
100%|██████████| 100/100 [00:00<00:00, 655.71epoch/s]
100%|██████████| 100/100 [00:00<00:00, 725.48epoch/s]
100%|██████████| 100/100 [00:00<00:00, 1125.35epoch/s]


0.7610526315789473 4 1.0


100%|██████████| 100/100 [00:00<00:00, 680.76epoch/s]
100%|██████████| 100/100 [00:00<00:00, 625.46epoch/s]
100%|██████████| 100/100 [00:00<00:00, 768.46epoch/s]
100%|██████████| 100/100 [00:00<00:00, 651.19epoch/s]
100%|██████████| 100/100 [00:00<00:00, 599.84epoch/s]
100%|██████████| 100/100 [00:00<00:00, 466.97epoch/s]
100%|██████████| 100/100 [00:00<00:00, 677.43epoch/s]
100%|██████████| 100/100 [00:00<00:00, 664.15epoch/s]
100%|██████████| 100/100 [00:00<00:00, 650.00epoch/s]
100%|██████████| 100/100 [00:00<00:00, 654.13epoch/s]


0.7610526315789473 5 1.0


100%|██████████| 100/100 [00:00<00:00, 1031.95epoch/s]
100%|██████████| 100/100 [00:00<00:00, 851.23epoch/s]
100%|██████████| 100/100 [00:00<00:00, 624.66epoch/s]
100%|██████████| 100/100 [00:00<00:00, 665.24epoch/s]
100%|██████████| 100/100 [00:00<00:00, 755.99epoch/s]
100%|██████████| 100/100 [00:00<00:00, 522.02epoch/s]
100%|██████████| 100/100 [00:00<00:00, 338.06epoch/s]
100%|██████████| 100/100 [00:00<00:00, 491.63epoch/s]
100%|██████████| 100/100 [00:00<00:00, 691.99epoch/s]
100%|██████████| 100/100 [00:00<00:00, 708.52epoch/s]


0.7613157894736842 6 0.7000000000000001


100%|██████████| 100/100 [00:00<00:00, 742.03epoch/s]
100%|██████████| 100/100 [00:00<00:00, 553.38epoch/s]
100%|██████████| 100/100 [00:00<00:00, 495.87epoch/s]
100%|██████████| 100/100 [00:00<00:00, 605.08epoch/s]
100%|██████████| 100/100 [00:00<00:00, 676.48epoch/s]
100%|██████████| 100/100 [00:00<00:00, 725.37epoch/s]
100%|██████████| 100/100 [00:00<00:00, 932.16epoch/s]
100%|██████████| 100/100 [00:00<00:00, 482.04epoch/s]
100%|██████████| 100/100 [00:00<00:00, 460.27epoch/s]
100%|██████████| 100/100 [00:00<00:00, 936.55epoch/s]


0.7613157894736842 7 1.0


100%|██████████| 100/100 [00:00<00:00, 586.67epoch/s]
100%|██████████| 100/100 [00:00<00:00, 751.32epoch/s]
100%|██████████| 100/100 [00:00<00:00, 525.56epoch/s]
100%|██████████| 100/100 [00:00<00:00, 1029.33epoch/s]
100%|██████████| 100/100 [00:00<00:00, 439.67epoch/s]
100%|██████████| 100/100 [00:00<00:00, 644.19epoch/s]
100%|██████████| 100/100 [00:00<00:00, 458.43epoch/s]
100%|██████████| 100/100 [00:00<00:00, 705.42epoch/s]
100%|██████████| 100/100 [00:00<00:00, 484.11epoch/s]
100%|██████████| 100/100 [00:00<00:00, 750.04epoch/s]


0.7613157894736842 8 1.0


100%|██████████| 100/100 [00:00<00:00, 705.64epoch/s]
100%|██████████| 100/100 [00:00<00:00, 545.44epoch/s]
100%|██████████| 100/100 [00:00<00:00, 726.79epoch/s]
100%|██████████| 100/100 [00:00<00:00, 683.91epoch/s]
100%|██████████| 100/100 [00:00<00:00, 444.53epoch/s]
100%|██████████| 100/100 [00:00<00:00, 737.92epoch/s]
100%|██████████| 100/100 [00:00<00:00, 770.31epoch/s]
100%|██████████| 100/100 [00:00<00:00, 741.51epoch/s]
100%|██████████| 100/100 [00:00<00:00, 623.70epoch/s]
100%|██████████| 100/100 [00:00<00:00, 839.04epoch/s]


0.7613157894736842 9 1.0


100%|██████████| 100/100 [00:00<00:00, 714.45epoch/s]
100%|██████████| 100/100 [00:00<00:00, 646.67epoch/s]
100%|██████████| 100/100 [00:00<00:00, 654.97epoch/s]
100%|██████████| 100/100 [00:00<00:00, 411.73epoch/s]
100%|██████████| 100/100 [00:00<00:00, 756.13epoch/s]
100%|██████████| 100/100 [00:00<00:00, 460.79epoch/s]
100%|██████████| 100/100 [00:00<00:00, 754.18epoch/s]
100%|██████████| 100/100 [00:00<00:00, 436.99epoch/s]
100%|██████████| 100/100 [00:00<00:00, 889.82epoch/s]


KeyboardInterrupt: 

## Summary

In this tutorial, we saw how you can leverage weak supervision to quickly build up a large training data set, and use it for the training of a first lightweight model.

*Rubrix* is a very handy tool to start the weak supervision process by making it easy to find a good set of starting rules, and to reiterate on them dynamically.
Since *Rubrix* also provides built-in support for the most common label models, you can get from rules to weak labels in a few straight forward steps.
For more suggestions on how to leverage weak labels, you can checkout our [weak supervision guide](../guides/weak-supervision.ipynb) where we describe an [interesting approach](../guides/weak-supervision.ipynb#Joint-Model-with-Weasel) to jointly train the label and a transformers downstream model.

## Next steps

If you are interested in the topic of weak supervision check our [weak supervision guide](../guides/weak-supervision.ipynb).

### ⭐ Rubrix [Github repo](https://github.com/recognai/rubrix) to stay updated.

### 📚 [Rubrix documentation](https://docs.rubrix.ml) for more guides and tutorials.

### 🙋‍♀️ Join the Rubrix community on [Slack](https://bit.ly/3o0Pfyk)

## Appendix I: Create rules and weak labels from Python

For some use cases, you might want to use Python for defining labeling rules and generating weak labels. Rubrix provides you with the ability to define and test rules and labeling functions directly using Python. This might be useful for combining it with rules defined in the UI, and for leveraging structured resources such as lexicons and gazeteers which are easier to use directly a programmatic environment.

In this section, we define the rules we've defined in the UI, this time directly using Python:

In [77]:
from rubrix.labeling.text_classification import Rule

# define queries and patterns for each category (using ES DSL)
queries = [
  (["money", "financ*", "dollar*"], "Business"),
  (["war", "gov*", "minister*", "conflict"], "World"),
  (["footbal*", "sport*", "game", "play*"], "Sports"),
  (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech")
] 

# define rules
rules = [
    Rule(query=term, label=label)
    for terms,label in queries
    for term in terms
]

In [78]:
from rubrix.labeling.text_classification import WeakLabels

# generate the weak labels 
weak_labels = WeakLabels(
    rules=rules, 
    dataset="news"
)

Preparing rules:   0%|          | 0/16 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/127600 [00:00<?, ?it/s]

On our machine it took around 24 seconds to apply the rules and to generate weak labels for the 127,600 examples.

Typically, you want to iterate on the rules and check their statistics. 
For this, you can use `weak_labels.summary` method:

In [79]:
weak_labels.summary()

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
money,{Business},0.008276,0.008816,0.002437,0.001936,30,37,0.447761
financ*,{Business},0.019655,0.017763,0.005893,0.005188,80,55,0.592593
dollar*,{Business},0.016591,0.016316,0.003542,0.002908,87,37,0.701613
war,{World},0.011779,0.013289,0.003213,0.001348,75,26,0.742574
gov*,{World},0.045078,0.045263,0.010878,0.00627,170,174,0.494186
minister*,{World},0.030031,0.028289,0.007531,0.002821,193,22,0.897674
conflict,{World},0.003041,0.002895,0.001003,0.000102,18,4,0.818182
footbal*,{Sports},0.013166,0.015,0.004945,0.000439,107,7,0.938596
sport*,{Sports},0.021191,0.021316,0.007045,0.001223,139,23,0.858025
game,{Sports},0.038879,0.037763,0.014083,0.002375,216,71,0.752613


From the above, we see that our rules cover around **30% of the original training set** with an **average precision of 0.72**. Our hope is that the label and downstream models will improve both the recall and the precision of the final classifier.

## Appendix II: Log datasets to the Hugging Face Hub

Here we will show you an example of how you can push a Rubrix dataset (records) to the [Hugging Face Hub](https://huggingface.co/datasets).
In this way you can effectively version any of your Rubrix datasets.

In [None]:
records = rb.load("news", as_pandas=False)
records.to_datasets().push_to_hub("<name of the dataset on the HF Hub>")