# Upload data

In [None]:
from datasets import load_dataset

agnews = load_dataset("ag_news")

In [None]:
# We need the test set for a fair comparison, we basically take the valid split to optimize the thresholds
agnews_train, agnews_valid = agnews["train"].train_test_split(test_size=7600, seed=43).values()

In [None]:
import rubrix as rb

# build our validation records
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "labelled"},
        annotation=agnews_valid.features["label"].int2str(record["label"])
    )
    for record in agnews_valid
]

# build our training records without labels
records += [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
    )
    for record in agnews_train
]

# log the records to Rubrix
rb.log(records, name="news2")

# Create weak labels

In [2]:
from rubrix.labeling.text_classification import Rule

# define queries and patterns for each category (using ES DSL)
queries = [
  (["money", "financ*", "dollar*"], "Business"),
  (["war", "gov*", "minister*", "conflict"], "World"),
  (["footbal*", "sport*", "game", "play*"], "Sports"),
  (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech")
] 

# define rules
rules = [
    Rule(query=term, label=label)
    for terms,label in queries
    for term in terms
]

In [3]:
from rubrix.labeling.text_classification import WeakLabels

# generate the weak labels 
weak_labels = WeakLabels(
    rules=rules, 
    dataset="news2"
)



Preparing rules:   0%|          | 0/16 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/120000 [00:00<?, ?it/s]

In [4]:
weak_labels.summary()

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
money,{Business},0.008242,0.008816,0.00245,0.001925,31,36,0.462687
financ*,{Business},0.019775,0.021184,0.005892,0.005183,115,46,0.714286
dollar*,{Business},0.016608,0.016974,0.003492,0.00285,98,31,0.75969
war,{World},0.011683,0.008816,0.003242,0.001367,44,23,0.656716
gov*,{World},0.045067,0.043158,0.0108,0.006225,156,172,0.47561
minister*,{World},0.030142,0.030263,0.007508,0.002825,207,23,0.9
conflict,{World},0.00305,0.003684,0.001025,9.2e-05,20,8,0.714286
footbal*,{Sports},0.01305,0.015132,0.004875,0.000408,105,10,0.913043
sport*,{Sports},0.021183,0.021711,0.007033,0.001225,146,19,0.884848
game,{Sports},0.03895,0.043026,0.014067,0.002375,253,74,0.7737


# Create embeddings

In [5]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')

In [6]:
from tqdm.auto import tqdm

embeddings = []
for rec in tqdm(weak_labels.records()):
    embeddings.append(model.encode(rec.text))

  0%|          | 0/120000 [00:00<?, ?it/s]

In [7]:
import numpy as np
embeddings = np.array(embeddings)
np.save("news2_embeddings.npy", embeddings)

# Load embeddings, compute distances

In [8]:
import numpy as np
embeddings = np.load("news2_embeddings.npy")

In [None]:
weak_labels.extend_matrix([1.0]*len(rules), embeddings)

In [10]:
weak_labels.summary()

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
money,{Business},0.008242,0.008816,0.00245,0.001925,31,36,0.462687
financ*,{Business},0.019775,0.021184,0.005892,0.005183,115,46,0.714286
dollar*,{Business},0.016608,0.016974,0.003492,0.00285,98,31,0.75969
war,{World},0.011683,0.008816,0.003242,0.001367,44,23,0.656716
gov*,{World},0.045067,0.043158,0.0108,0.006225,156,172,0.47561
minister*,{World},0.030142,0.030263,0.007508,0.002825,207,23,0.9
conflict,{World},0.00305,0.003684,0.001025,9.2e-05,20,8,0.714286
footbal*,{Sports},0.01305,0.015132,0.004875,0.000408,105,10,0.913043
sport*,{Sports},0.021183,0.021711,0.007033,0.001225,146,19,0.884848
game,{Sports},0.03895,0.043026,0.014067,0.002375,253,74,0.7737


# Label model

In [11]:
from rubrix.labeling.text_classification import Snorkel

label_model = Snorkel(weak_labels)
label_model.fit()
print(label_model.score(output_str=True))

100%|██████████| 100/100 [00:00<00:00, 1539.67epoch/s]

              precision    recall  f1-score   support

    Sci/Tech       0.80      0.74      0.77       831
       World       0.68      0.83      0.75       461
      Sports       0.77      0.96      0.86       703
    Business       0.73      0.41      0.53       494

    accuracy                           0.76      2489
   macro avg       0.75      0.74      0.73      2489
weighted avg       0.76      0.76      0.74      2489






## Quick grid search for label model

In [None]:
scores = {}

for n in [10, 20, 40, 80]:
    for lr in [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]:
        lm = Snorkel(weak_labels)
        lm.fit(lr=lr, n_epochs=n)
        scores[(n, lr)] = lm.score()["accuracy"]

In [24]:
sorted(scores.items(), key=lambda x: x[1], reverse=True)[0]

((10, 0.002), 0.7568327974276527)

# Grid search with label model

In [29]:
def train_eval_labelmodel(ths):
    weak_labels.extend_matrix(ths)
    
    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
    
    metrics = label_model.score()
    acc, sup, n = metrics["accuracy"], metrics["macro avg"]["support"], len(weak_labels.annotation())
    coverage = sup / n
    return 2 * acc * coverage / ( acc + coverage )

In [30]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc, accs = [1.0] * len(weak_labels.rules), 0, []
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

print(ths_range, n_ths)

for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_labelmodel(thresholds)
        accs.append(acc)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[1.  0.9 0.8 0.7 0.6 0.5 0.4] 16


  0%|          | 0/16 [00:00<?, ?it/s]

In [31]:
best_thresholds

[0.40000000000000013,
 0.5000000000000001,
 0.7000000000000001,
 0.40000000000000013,
 0.40000000000000013,
 0.7000000000000001,
 0.6000000000000001,
 0.40000000000000013,
 0.40000000000000013,
 0.5000000000000001,
 1.0,
 0.5000000000000001,
 0.7000000000000001,
 0.6000000000000001,
 0.5000000000000001,
 0.5000000000000001]

## Apply best thresholds

In [54]:
weak_labels.extend_matrix(best_thresholds)

label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))

              precision    recall  f1-score   support

    Sci/Tech       0.80      0.61      0.69      1876
       World       0.64      0.82      0.71      1804
      Sports       0.80      0.93      0.86      1873
    Business       0.70      0.56      0.62      1859

    accuracy                           0.73      7412
   macro avg       0.73      0.73      0.72      7412
weighted avg       0.73      0.73      0.72      7412



# Train and evaluate downstream model

In [68]:
import pandas as pd
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn import metrics


def final_eval(label_model):

    # get records with the predictions from the label model
    records = label_model.predict()

    # # we could also use the `weak_labels.label2int` dict
    # label2int = {'Sports': 0, 'Sci/Tech': 1, 'World': 2, 'Business': 3}
    label2int = label_model.weak_labels.label2int

    # extract training data
    X_train = [rec.text for rec in records]
    y_train = [label2int[rec.prediction[0][0]] for rec in records]
   
    # define our final classifier
    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])

    # fit the classifier
    classifier.fit(
        X=X_train,
        y=y_train,
    )
    
    # extract text and labels
    X_test = [rec["text"] for rec in agnews["test"]]
    y_test = [label2int[agnews["test"].features["label"].int2str(rec["label"])] for rec in agnews["test"]]

    # compute the test accuracy
    accuracy = classifier.score(
        X=X_test,
        y=y_test,
    )

    # get predictions for the test set
    predicted = classifier.predict(X_test)

    return metrics.classification_report(y_test, predicted, target_names=[k for k in label2int.keys() if k])

In [58]:
print(final_eval(label_model))

              precision    recall  f1-score   support

    Sci/Tech       0.86      0.78      0.82      1900
       World       0.83      0.89      0.86      1900
      Sports       0.88      0.98      0.93      1900
    Business       0.84      0.77      0.80      1900

    accuracy                           0.85      7600
   macro avg       0.85      0.85      0.85      7600
weighted avg       0.85      0.85      0.85      7600



# Grid search with downstream model

In [60]:
# retrieve records with annotations
test_ds = weak_labels.records(has_annotation=True)

# extract text and labels
X_test_for_grid_search = [rec.text for rec in test_ds]
y_test_for_grid_search = [label2int[rec.annotation] for rec in test_ds]

def train_eval_downstream(ths):
    weak_labels.extend_matrix(ths)
    
    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    records = label_model.predict()

    X_train = [rec.text for rec in records]
    y_train = [label2int[rec.prediction[0][0]] for rec in records]
    
    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])

    classifier.fit(
        X=X_train,
        y=y_train,
    )
    
    accuracy = classifier.score(
        X=X_test_for_grid_search,
        y=y_test_for_grid_search,
    )
    
    return accuracy

In [46]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc, accs = [1.0] * len(weak_labels.rules), 0, []
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

print(ths_range, n_ths)

for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_downstream(thresholds)
        accs.append(acc)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[1.  0.9 0.8 0.7 0.6 0.5 0.4] 16


  0%|          | 0/16 [00:00<?, ?it/s]

In [47]:
best_acc, best_thresholds

(0.8544736842105263,
 [0.6000000000000001,
  0.7000000000000001,
  0.8,
  0.9,
  1.0,
  0.7000000000000001,
  1.0,
  0.9,
  0.6000000000000001,
  1.0,
  1.0,
  0.8,
  0.7000000000000001,
  0.8,
  1.0,
  1.0])

## Apply best thresholds

In [48]:
weak_labels.extend_matrix(best_thresholds)

label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))

              precision    recall  f1-score   support

    Sci/Tech       0.78      0.67      0.72      1240
       World       0.77      0.73      0.75       851
      Sports       0.83      0.95      0.88      1321
    Business       0.64      0.65      0.64      1038

    accuracy                           0.76      4450
   macro avg       0.75      0.75      0.75      4450
weighted avg       0.76      0.76      0.76      4450



# Train and evaluate downstream model

In [52]:
print(final_eval(label_model))

              precision    recall  f1-score   support

    Sci/Tech       0.82      0.81      0.81      1900
       World       0.89      0.85      0.87      1900
      Sports       0.88      0.98      0.93      1900
    Business       0.81      0.77      0.79      1900

    accuracy                           0.85      7600
   macro avg       0.85      0.85      0.85      7600
weighted avg       0.85      0.85      0.85      7600



# Uniform grid search

In [62]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc, accs = [1.0] * len(weak_labels.rules), 0, []
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

print(ths_range, n_ths)

for threshold in tqdm(ths_range):
    thresholds = [threshold] * n_ths
    acc = train_eval_downstream(thresholds)
    accs.append(acc)
    if acc > best_acc:
        best_acc = acc
        best_thresholds = thresholds.copy()

[1.  0.9 0.8 0.7 0.6 0.5 0.4] 16


  0%|          | 0/7 [00:00<?, ?it/s]

In [72]:
best_acc, best_thresholds

(0.8309210526315789,
 [0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001,
  0.6000000000000001])

## Apply best thresholds

In [69]:
weak_labels.extend_matrix(best_thresholds)

label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))

              precision    recall  f1-score   support

    Sci/Tech       0.76      0.74      0.75      1636
       World       0.67      0.86      0.76      1421
      Sports       0.79      0.96      0.87      1544
    Business       0.78      0.39      0.52      1430

    accuracy                           0.74      6031
   macro avg       0.75      0.74      0.72      6031
weighted avg       0.75      0.74      0.73      6031



# Train and evaluate downstream model

In [71]:
print(final_eval(label_model))

              precision    recall  f1-score   support

    Sci/Tech       0.79      0.84      0.81      1900
       World       0.79      0.90      0.84      1900
      Sports       0.87      0.98      0.93      1900
    Business       0.89      0.60      0.72      1900

    accuracy                           0.83      7600
   macro avg       0.84      0.83      0.82      7600
weighted avg       0.84      0.83      0.82      7600

