In [2]:
import rubrix as rb
from datasets import load_dataset

# this replaces the `records = label_model.predict()` line of section 4
records = rb.read_datasets(
    load_dataset("rubrix/news", split="train"),
    task="TextClassification",
)



In [5]:
from datasets import load_dataset

# load our data
dataset = load_dataset("ag_news")

# get the index to label mapping
labels = dataset["test"].features["label"].names



  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
import pandas as pd

# quick look at our data
with pd.option_context('display.max_colwidth', None):
    display(dataset["test"].to_pandas().head())

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.,2
1,"The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.",3
2,"Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.",3
3,"Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.",3
4,"Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.",3


In [8]:
import rubrix as rb

# build our test records
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "test"},
        annotation=labels[record["label"]]
    )
    for record in dataset["test"]
]

# log the records to Rubrix
rb.log(records, name="news")

  0%|          | 0/7600 [00:00<?, ?it/s]

7600 records logged to http://localhost:6900/datasets/rubrix/news


BulkResponse(dataset='news', processed=7600, failed=0)

In [41]:
# build our training records without labels
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
    )
    for record in dataset["train"]
]

# log the records to Rubrix
rb.log(records, name="news")

  0%|          | 0/120000 [00:00<?, ?it/s]

120000 records logged to http://localhost:6900/datasets/rubrix/news


BulkResponse(dataset='news', processed=120000, failed=0)

In [42]:
from rubrix.labeling.text_classification import WeakLabels

weak_labels = WeakLabels(dataset="news")
weak_labels.summary()

Preparing rules:   0%|          | 0/16 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/127600 [00:00<?, ?it/s]

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
money,{Business},0.008268,0.008816,0.002375,0.001857,30,37,0.447761
dollar*,{Business},0.016591,0.016316,0.003558,0.002915,87,37,0.701613
war,{World},0.015627,0.017105,0.004467,0.001693,101,29,0.776923
game,{Sports},0.038738,0.037632,0.011364,0.002406,216,70,0.755245
play,{Sports},0.013534,0.012763,0.013534,0.000972,73,24,0.752577
software,{Sci/Tech},0.030188,0.029474,0.009843,0.003393,183,41,0.816964
financ*,{Business},0.019655,0.017763,0.005737,0.005024,80,55,0.592593
gov*,{World},0.045086,0.045263,0.010987,0.006003,170,174,0.494186
minister*,{World},0.030031,0.028289,0.007829,0.002727,193,22,0.897674
conflict*,{World},0.003824,0.003684,0.001356,0.000235,22,6,0.785714


In [49]:
from rubrix.labeling.text_classification import Snorkel

# create the label model
label_model = Snorkel(weak_labels)

# fit the model
label_model.fit()
print(label_model.score(output_str=True))

100%|██████████| 100/100 [00:00<00:00, 769.24epoch/s]

              precision    recall  f1-score   support

    Sci/Tech       0.78      0.76      0.77       778
       World       0.69      0.81      0.75       523
    Business       0.66      0.36      0.46       447
      Sports       0.76      0.95      0.84       543

    accuracy                           0.74      2291
   macro avg       0.72      0.72      0.71      2291
weighted avg       0.73      0.74      0.72      2291






In [44]:
import pandas as pd

# get records with the predictions from the label model
records = label_model.predict()
# you can replace this line with
# records = rb.read_datasets(
#    load_dataset("rubrix/news", split="train"),
#    task="TextClassification",
# )

# we could also use the `weak_labels.label2int` dict
label2int = {'Sports': 0, 'Sci/Tech': 1, 'World': 2, 'Business': 3}

# extract training data
X_train = [rec.text for rec in records]
y_train = [label2int[rec.prediction[0][0]] for rec in records]

In [45]:
# quick look at our training data with the weak labels from our label model
with pd.option_context('display.max_colwidth', None):
    display(pd.DataFrame({"text": X_train, "label": y_train}))

Unnamed: 0,text,label
0,"Boston Scientific Outlook Trails Views Boston Scientific Corp. (BSX.N: Quote, Profile, Research) , whose Taxus heart device has been involved in a major recall, said on Wednesday preliminary August sales of the device slowed",1
1,"India ready to consider Musharraf's proposals if made formally (AFP) AFP - India is ready to consider Pakistani President Pervez Musharraf's new proposals on disputed Kashmir if they are made formally through diplomatic channels, Foreign Minister Natwar Singh said.",2
2,"Jobless Claims Drop More Than Expected WASHINGTON (Reuters) - The number of people filing an initial claim for U.S. jobless aid fell by a larger-than-expected 25,000 last week, the government said on Thursday in a report seen as a good sign for job growth.",2
3,"YUKOS Cuts China Supplies, Oil Above \$46 NEW YORK (Reuters) - Oil prices broke above \$46 on Monday after Russian oil giant YUKOS said it would cut some oil shipments to China, the first toll on exports from the company's financial turmoil.",3
4,"Britain, Ireland push for Northern Ireland deal At a fairytale English castle, the prime ministers of Britain and Ireland will this week attempt to broker an unlikely marriage of convenience between the two extremes of Northern Irish politics.",2
...,...,...
36922,Joke e-mail virus tricks users A new version of the Bagle computer virus is spreading rapidly around the internet.,1
36923,KPMG to Pay \$10M to Settle SEC Charges Accounting giant KMPG LLP will pay \$10 million to settle Securities and Exchange Commission charges of improper conduct in the firm #39;s audit of the financial statements of Gemstar-TV Guide International Inc.,3
36924,"Iraq Leader Says Violence Won't Stop Vote BAGHDAD, Iraq - The Iraqi prime minister insisted Sunday that the raging insurgency - which has claimed 300 lives in the last week alone and resulted in a wave of kidnappings - will not delay January elections, promising the vote will strike a ""major blow"" against the violent opposition. Meanwhile, a grisly videotape posted on a Web site showed the beheading of three hostages believed to be Iraqi Kurds accused by militants of cooperating with U.S...",2
36925,"Give handball a sporting chance ATHENS -- I saw a game that featured, with only the slightest expansion of our basic concepts, fast breaks, fouled in the act of shooting, sneakaways, pivotmen, turnovers, weaves, give-and-gos, lookaway passes, backward bounce passes, skip passes, a penalty shot, a backdoor play, and great shot blocking worthy of a Russell or a Roy.",0


In [46]:
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

# define our final classifier
classifier = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', MultinomialNB())
])

# fit the classifier
classifier.fit(
    X=X_train,
    y=y_train,
)

In [47]:
# retrieve records with annotations
test_ds = weak_labels.records(has_annotation=True)
# you can replace this line with
# test_ds = rb.read_datasets(
#    load_dataset("rubrix/news_test", split="train"),
#    task="TextClassification",
# )

# extract text and labels
X_test = [rec.text for rec in test_ds]
y_test = [label2int[rec.annotation] for rec in test_ds]
# compute the test accuracy
accuracy = classifier.score(
    X=X_test,
    y=y_test,
)

print(f"Test accuracy: {accuracy}")

Test accuracy: 0.8115789473684211


In [48]:
from sklearn import metrics

# get predictions for the test set
predicted = classifier.predict(X_test)

print(metrics.classification_report(y_test, predicted, target_names=label2int.keys()))

              precision    recall  f1-score   support

      Sports       0.86      0.96      0.91      1900
    Sci/Tech       0.76      0.83      0.79      1900
       World       0.78      0.89      0.83      1900
    Business       0.88      0.56      0.69      1900

    accuracy                           0.81      7600
   macro avg       0.82      0.81      0.80      7600
weighted avg       0.82      0.81      0.80      7600

