# Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple
import pickle

# HF datasets
from datasets import load_dataset_builder, load_dataset, Dataset, Value, ClassLabel, Features, DatasetDict

# rubrix
import rubrix as rb
from rubrix.labeling.text_classification import Rule
from rubrix.labeling.text_classification import WeakLabels
from rubrix.labeling.text_classification import Snorkel
from rubrix.labeling.text_classification import load_rules

# sklearn
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pd.set_option('display.max_colwidth', 400)

# SET SEED
RANDOM_STATE = 42

# Data

## `ag_news` dataset
[ag_news](https://huggingface.co/datasets/ag_news)

In [3]:
dataset_info = load_dataset_builder("ag_news")



In [4]:
print(dataset_info.info.description)

AG is a collection of more than 1 million news articles. News articles have been
gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of
activity. ComeToMyHead is an academic news search engine which has been running
since July, 2004. The dataset is provided by the academic comunity for research
purposes in data mining (clustering, classification, etc), information retrieval
(ranking, search, etc), xml, data compression, data streaming, and any other
non-commercial activity. For more information, please refer to the link
http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html .

The AG's news topic classification dataset is constructed by Xiang Zhang
(xiang.zhang@nyu.edu) from the dataset above. It is used as a text
classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann
LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015).



In [5]:
print(dataset_info.info.features["label"])

ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)


There are 4 classes:
- World
- Sports
- Business
- Sci/Tech

## Load data

In [6]:
dataset = load_dataset("ag_news")
dataset

100%|██████████| 2/2 [00:00<00:00, 30.50it/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

## Splits

We are gonna simulate the following scenario in the real world:

- Let's imagine that we are interested in building a news classifier and have collected the `ag_news` dataset but without labels.
- Our budget for labelling is short, so we decided to label a small subset of data and use weak supervision techniques to label the rest of the data.
- We'll use some of the labeled data to understand the data better and create LFs.
- Also, after training the model on the weakly supervised dataset, we'll evaluate it on labeled data.

__Splits:__
1. We'll use the test set from the HF dataset to test the final classifier.
2. 20% of the labels in the training set will be used for creating LFs and validation.
2. We'll use weak supervision to create a weakly supervised dataset. The validation set will be use to measure the how good each LFs is. We'll check coverage, annotated coverage and precision. Rubrix allows us to do this easily.

In [7]:
def generate_splits(
    dataset: DatasetDict,
    unlabeled_frac: float
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:

  train_data = pd.DataFrame(dataset["train"])
  # drop x% of the labels
  unlabeled_train_data = train_data.sample(
      frac=unlabeled_frac,
      random_state=RANDOM_STATE
  ).drop("label", axis=1)
  # indexes
  unlabeled_train_idxs = unlabeled_train_data.index
  # labeled training data
  validation_data = train_data.loc[~train_data.index.isin(unlabeled_train_idxs)]
  assert np.intersect1d(
    unlabeled_train_data.index, validation_data.index
  ).shape[0] == 0, AssertionError(
      "Duplicate data points. Check labeled and unlabeled splits."
  )
  # test set
  test_data = pd.DataFrame(dataset["test"])

  return unlabeled_train_data, validation_data, test_data

In [8]:
# generate splits
unlabeled_data, validation_data, test_data = generate_splits(dataset, 0.8)

In [9]:
unlabeled_data.shape, validation_data.shape, test_data.shape

((96000, 1), (24000, 2), (7600, 2))

# Quick EDA

In [10]:
unlabeled_data.head()

Unnamed: 0,text
71787,"BBC set for major shake-up, claims newspaper London - The British Broadcasting Corporation, the world #39;s biggest public broadcaster, is to cut almost a quarter of its 28 000-strong workforce, in the biggest shake-up in its 82-year history, The Times newspaper in London said on Monday."
67218,Marsh averts cash crunch Embattled insurance broker #39;s banks agree to waive clause that may have prevented access to credit. NEW YORK (Reuters) - Marsh amp; McLennan Cos.
54066,"Jeter, Yankees Look to Take Control (AP) AP - Derek Jeter turned a season that started with a terrible slump into one of the best in his accomplished 10-year career."
7168,"Flying the Sun to Safety When the Genesis capsule comes back to Earth with its samples of the sun, helicopter pilots will be waiting for it, ready to snag it out of the sky."
29618,"Stocks Seen Flat as Nortel and Oil Weigh NEW YORK (Reuters) - U.S. stocks were set to open near unchanged on Thursday after a warning from technology bellwether Nortel Networks Corp. &lt;A HREF=""http://www.investor.reuters.com/FullQuote.aspx?ticker=NT.N target=/stocks/quickinfo/fullquote""&gt;NT.N&lt;/A&gt; dimmed hopes, while stubbornly high oil prices also weighed on sentiment."


In [11]:
print(f"There are {unlabeled_data.shape[0]} data points.")

There are 96000 data points.


# Log unlabeled data and validation set to Rubrix

We use Rubrix and Snorkel label's model to create the weak labels. The workflow is as follows:

1. Log unlabeled data and labeled validation set to Rubrix
2. Create Labelling Functions (LFs) by leveraging Rubric UI and comparing the performance of these LFs against the validation set. As we don't have a lot of time for understanding the data, the LFs will mainly rely on pattern matching.
3. Denoise the weak labels using Snorkel's label model. This model is already integrated into Rubrix workflow. The Snorkel label model uses a generative model to produce probabilistic labels that can be used to train a downstream model. It's also able to capture correlations between LFs, which can be important.
4. Train a baseline model and test it on the validation set to see if it outperforms Snorkel's model.

In [12]:
RUBRIX_DATASET_NAME = "ag_news_v2"

In [13]:
labels = dataset_info.info.features["label"].names
labels

['World', 'Sports', 'Business', 'Sci/Tech']

In [14]:
test_data.head()

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.,2
1,"The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.",3
2,"Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.",3
3,"Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.",3
4,"Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.",3


In [15]:
import os
os.environ.get("RUBRIX_API_URL")

'http://rubrix:80'

In [18]:
import os
in_docker = os.environ.get("RUBRIX_API_URL") is not None
RUBRIX_URL = "http://rubrix:80" if in_docker else "http://localhost:6900"
print(RUBRIX_URL)

http://rubrix:80


In [19]:

# build our test records
records = []

for _, row in unlabeled_data.iterrows():
    text = row.text
    record = rb.TextClassificationRecord(
        text=text
    )
    records.append(record)

for _, row in validation_data.iterrows():
    text = row.text
    label_idx = row.label
    record = rb.TextClassificationRecord(
        text=text,
        annotation=labels[label_idx]
    )
    records.extend([record])

# log the records to Rubrix
rb.log(records, name=RUBRIX_DATASET_NAME)

100%|██████████| 120000/120000 [05:05<00:00, 392.96it/s]

120000 records logged to http://rubrix:80/datasets/rubrix/ag_news_v2





BulkResponse(dataset='ag_news_v2', processed=120000, failed=0)

After this step, we created LFs using Rubrix UI. LFs are now stored in the formed of Rules in a pickle file so we can reproduce the process.

# Denoise weak labels

In [20]:
from rubrix.labeling.text_classification import Rule, WeakLabels

#  rules defined as Elasticsearch queries
check_out = Rule(query="check out", label="SPAM")
plz = Rule(query="plz OR please", label="SPAM")
subscribe = Rule(query="subscribe", label="SPAM")
my = Rule(query="my", label="SPAM")
song = Rule(query="song", label="HAM")
love = Rule(query="love", label="HAM")

In [18]:
# load rules
with open("./labeling_functions/lf_rules_rubrix.pkl", "rb") as f:
    rules = pickle.load(f)

In [19]:
for rule in rules:
    print(dict(map(lambda x: (x, getattr(rule, x)), ["name", "query", "author"])))
    

{'name': 'war', 'query': 'war', 'author': 'rubrix'}
{'name': 'war AND country', 'query': 'war AND country', 'author': 'rubrix'}
{'name': 'minister*', 'query': 'minister*', 'author': 'rubrix'}
{'name': 'minister* AND countr*', 'query': 'minister* AND countr*', 'author': 'rubrix'}
{'name': 'conflict', 'query': 'conflict', 'author': 'rubrix'}
{'name': 'countr* AND war', 'query': 'countr* AND war', 'author': 'rubrix'}
{'name': 'politic*', 'query': 'politic*', 'author': 'rubrix'}
{'name': '*ball', 'query': '*ball', 'author': 'rubrix'}
{'name': 'footbal*', 'query': 'footbal*', 'author': 'rubrix'}
{'name': 'footbal* AND game', 'query': 'footbal* AND game', 'author': 'rubrix'}
{'name': '*ball AND game', 'query': '*ball AND game', 'author': 'rubrix'}
{'name': 'sport*', 'query': 'sport*', 'author': 'rubrix'}
{'name': 'game', 'query': 'game', 'author': 'rubrix'}
{'name': 'play* AND *ball', 'query': 'play* AND *ball', 'author': 'rubrix'}
{'name': 'play* AND footbal*', 'query': 'play* AND footbal*'

Apply rules to records - Note rules don't get pushed to the UI

In [21]:
weak_labels = WeakLabels(dataset=RUBRIX_DATASET_NAME, rules=rules)

Preparing rules: 100%|██████████| 34/34 [00:50<00:00,  1.50s/it]
Applying rules: 100%|██████████| 120000/120000 [00:04<00:00, 24776.63it/s]


In [22]:
weak_labels.summary()

Unnamed: 0,label,coverage,annotated_coverage,overlaps,conflicts,correct,incorrect,precision
war,{World},0.015533,0.016,0.005492,0.002283,291,93,0.757812
war AND country,{World},0.000975,0.001042,0.000975,9.2e-05,23,2,0.92
minister*,{World},0.030142,0.029125,0.008683,0.003183,608,91,0.869814
minister* AND countr*,{World},0.002983,0.002625,0.002983,0.00035,56,7,0.888889
conflict,{World},0.003042,0.003417,0.000992,0.000167,70,12,0.853659
countr* AND war,{World},0.001433,0.001417,0.001433,0.000108,32,2,0.941176
politic*,{World},0.0125,0.012708,0.003767,0.001675,227,78,0.744262
*ball,{Sports},0.0296,0.029792,0.019808,0.001667,653,62,0.913287
footbal*,{Sports},0.013042,0.012625,0.01265,0.00055,272,31,0.89769
footbal* AND game,{Sports},0.001708,0.001542,0.001708,0.000108,32,5,0.864865


Coverage = 38.8%

Annotated coverage = 38.8%

No big difference -> low risk of overfitting

In [23]:
# create the label model
label_model = Snorkel(weak_labels, verbose=True, device="cpu")

# default
params = {
    "n_epochs": 100,
    "lr": 0.001,
    "optimizer": "sgd",
    
}

# fit the model
label_model.fit(include_annotated_records=False, **params)

100%|██████████| 100/100 [00:00<00:00, 394.61epoch/s]


In [24]:
print(label_model.score(output_str=True))

              precision    recall  f1-score   support

    Sci/Tech       0.81      0.72      0.76      2838
       World       0.85      0.70      0.77      1496
      Sports       0.80      0.92      0.86      2314
    Business       0.75      0.81      0.77      2677

    accuracy                           0.79      9325
   macro avg       0.80      0.79      0.79      9325
weighted avg       0.79      0.79      0.79      9325



Experimented with different hyperparams and didn't see a significant change in the metrics. Kept the default ones.

# Prepare training set with probabilistic labels

In [25]:
records = label_model.predict(include_annotated_records=False, include_abstentions=True)

In [26]:
label2int = {'World': 0, 'Sports': 1, 'Business': 2, 'Sci/Tech': 3}
# extract training data (covered and an uncovered)
X = [rec.text for rec in records]
y = [label2int[rec.prediction[0][0]] if rec.prediction else None for rec in records]

In [27]:
print(f"{len(X)} data points in the training set.")

96000 data points in the training set.


In [28]:
all_data = pd.DataFrame({"text": X, "label": y})
all_data.head()

Unnamed: 0,text,label
0,"UK takes Linux to the heart of government Open source software is a viable alternative to commercial proprietary software, with potential significant value-for-money benefits for government, the Office of Government Commerce (OGC) has concluded.",3.0
1,Skype launches Pocket PC software Peer-to-peer IP telephony startup Skype yesterday released a version of its software designed for mobile devices running Microsoft #39;s PocketPC operating system.,3.0
2,"Compromises urged amid deadlock in Darfur talks ABUJA, Nigeria -- Peace talks on Sudan's violence-torn Darfur region are deadlocked, a mediator said yesterday, as the chief of the African Union appealed to the Sudanese government and rebels to compromise.",
3,"#39;Warne #39;s glory short-lived #39; Colombo - Sri Lanka #39;s World Cup-winning skipper Arjuna Ranatunga on Friday congratulated his old foe Shane Warne on becoming the world #39;s highest Test wicket-taker, but predicted his glory will be short-lived.",
4,"Consortium Forms to Set Network Centric Communications Standards The Network Centric Operations Industry Consortium (NCOIC) formally introduced itself this week at a Tuesday (Sept. 28) press conference in Washington, DC The new group, consisting initially of 28 companies",


In [29]:
# separate weakly supervised dataset
weakly_labeled_data = all_data[all_data["label"].notnull()]
weakly_labeled_data["label"] = weakly_labeled_data["label"].astype(int)
remaining_unlabeled_data = all_data[all_data["label"].isnull()].drop("label", axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  weakly_labeled_data["label"] = weakly_labeled_data["label"].astype(int)


In [30]:
X_train = weakly_labeled_data["text"].values.tolist()
y_train = weakly_labeled_data["label"].values.tolist()

In [32]:
# retrieve records with annotations for the validation set
val_data = weak_labels.records(has_annotation=True)

# extract text and labels
X_val = [rec.text for rec in val_data]
y_val = [label2int[rec.annotation] for rec in val_data]

In [33]:
print(f"{len(X_val)} data points in the validation set.")

24000 data points in the validation set.


# Baseline

We are gonna train a very simple text classifier that will use as our baseline.

In [34]:
model = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', MultinomialNB())
])

In [35]:
model.fit(
    X=X_train,
    y=y_train,
)

In [36]:
predictions = model.predict(X_val)

In [37]:
print(classification_report(y_val, predictions, target_names=list(label2int.keys())))

              precision    recall  f1-score   support

       World       0.90      0.84      0.87      6043
      Sports       0.84      0.99      0.91      5927
    Business       0.86      0.76      0.81      6073
    Sci/Tech       0.80      0.83      0.81      5957

    accuracy                           0.85     24000
   macro avg       0.85      0.85      0.85     24000
weighted avg       0.85      0.85      0.85     24000



85% on the validation set isn't that bad for a very simple model.

In [38]:
pd.DataFrame([X_val, predictions]).T.rename(columns={0: "text", 1: "label"})

Unnamed: 0,text,label
0,"USC and Miami Atop First BCS Standings Southern California took the top spot Monday in the season's first Bowl Championship Series standings, but surprisingly Miami is ahead of Oklahoma in a close race for the second spot. Oklahoma is No...",1
1,"AMD, IBM extend chip development deal Chipmaker will pay Big Blue nearly \$300 million, but move should help in competition with Intel.",3
2,Will Graves: Laettner has learned to live with the labels Christian Laettner knows what #39;s coming. He braces for it every time he sees an article with his name in it. Somewhere in the middle of the story - usually past the part where it talks about Laettner #39;s remarkable,1
3,"Nicotine Addiction Gene Identified Researchers say they have identified brain cell receptors that appear to be responsible for nicotine addiction, a finding of clear importance for smokers who are desperately trying to kick the habit.",3
4,"Olympics: Greeks Dive for Joy, More Gold for Thorpe ATHENS (Reuters) - Greece was plunged in national self-satisfaction on Tuesday after its synchronized divers won the host nation's first gold medal of the Athens Olympics.",1
...,...,...
23995,"Google delves through dark corners of the hard drive The modern PC is a marvel, a machine that lets an ordinary person with little training create a document, check its spelling, dress it up with graphics, send it electronically to someone across the globe - and then save it accidentally into some dark",3
23996,"Candidates Play on Fears of Attacks, Wars WASHINGTON - Playing on the fear factor, Vice President Dick Cheney suggested in a campaign speech there might be another terrorist attack on the United States if John Kerry were in the White House. President Bush's opponents' are raising their own worst fears, including the potential for more wars during a second Bush term...",0
23997,"Safeway Sets Outlook Below Estimates (Reuters) Reuters - Grocery chain operator Safeway Inc.\ on Wednesday set a 2005 earnings per share goal that\was below analysts' expectations, with added marketing expenses\and lingering impact from a strike at its Vons stores weighing\on profits.",2
23998,"New Hybrid Disc Offers CD-DVD Combination (AP) AP - Recording companies looking to wring more profits out of music sales are hoping to sell retailers on a new hybrid disc. On one side is standard CD audio; on the other, the enhanced sound, video and other media capabilities of a DVD.",3


# Save training set

In [52]:
# weakly_labeled_data.to_csv("../data/ag_news_training_weak_data.csv", index=False)
#remaining_unlabeled_data.to_csv("../data/ag_news_unlabeled_data.csv", index=False)
#validation_data.to_csv("../data/ag_news_validation_data.csv", index=False)
#test_data.to_csv("../data/ag_news_test_data.csv", index=False)