# **1. Robustness**

In [None]:
from transformers import pipeline

nlp = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")

In [34]:
from datasets import load_dataset

dataset = load_dataset("ag_news", split="test[0:500]")
labels = dataset.features["label"].names

Using custom data configuration default
Reusing dataset ag_news (/Users/dani/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


In [22]:
labels

['World', 'Sports', 'Business', 'Sci/Tech']

In [36]:
def add_predictions(record):
    prediction = nlp(record["text"], candidate_labels=labels)
    return {"prediction": prediction["labels"][0], "score": prediction["scores"][0] }

In [37]:
dataset = dataset.map(add_predictions)

  0%|          | 0/500 [00:00<?, ?ex/s]

In [53]:
import rubrix as rb

records = [
    rb.TextClassificationRecord(
        inputs=record["text"],
        metadata={"len": len(record["text"])},
        annotation=labels[record["label"]],
        prediction=[(record["prediction"], record["score"])]
    )
    for record in dataset
]

In [54]:
rb.delete("zeroshot_agnews")
rb.log(records, name="zeroshot_agnews")

  0%|          | 0/500 [00:00<?, ?it/s]

500 records logged to http://localhost:6900/zeroshot_agnews


BulkResponse(dataset='zeroshot_agnews', processed=500, failed=0)

## Metrics and data slices

In [1]:
from rubrix.metrics.text_classification import f1

f1(name="zeroshot_agnews").data

{'micro': 0.598,
 'macro': 0.5857564446483439,
 'per_label': {'Sports': 0.7959183673469388,
  'Sci/Tech': 0.5176470588235293,
  'World': 0.4669603524229075,
  'Business': 0.5624999999999999}}

In [64]:
f1(name="zeroshot_agnews", query="metadata.len:[100 TO 200]")

MetricSummary(data={'micro': 0.5867768595041323, 'macro': 0.5701230451694848, 'per_label': {'Sports': 0.736842105263158, 'Sci/Tech': 0.5538461538461538, 'World': 0.5098039215686274, 'Business': 0.48000000000000004}})

In [77]:
f1(name="zeroshot_agnews", query="score:>=.5")

MetricSummary(data={'micro': 0.8051948051948051, 'macro': 0.7082466825113884, 'per_label': {'Sports': 0.9318181818181819, 'Sci/Tech': 0.7407407407407407, 'World': 0.45454545454545453, 'Business': 0.7058823529411765}})

In [74]:
f1(name="zeroshot_agnews", query="president")

MetricSummary(data={'micro': 0.43478260869565216, 'macro': 0.4113636363636364, 'per_label': {'Sports': 0.5, 'Sci/Tech': 0.2, 'World': 0.5454545454545454, 'Business': 0.4000000000000001}})

# **2. Weak supervision**

https://rubrix.readthedocs.io/en/master/guides/weak-supervision.html

In [152]:
dataset = load_dataset("ag_news", split="test[500:2500]")
labels = dataset.features["label"].names



In [153]:
records = [
    rb.TextClassificationRecord(
        inputs=record["text"],
        metadata={"len": len(record["text"])},
        #annotation=labels[record["label"]],
        #prediction=[(record["prediction"], record["score"])]
    )
    for record in dataset
]

In [72]:
rb.log(records, name="zeroshot_agnews")

  0%|          | 0/2000 [00:00<?, ?it/s]

2000 records logged to http://localhost:6900/zeroshot_agnews


BulkResponse(dataset='zeroshot_agnews', processed=2000, failed=0)

In [154]:
from rubrix.labeling.text_classification import Rule, WeakLabels

In [345]:
war = Rule(query="war", label="World")
gov = Rule(query="gov*", label="World")
football = Rule(query="footb*", label="Sports")
sport = Rule(query="sport*", label="Sports")
business = Rule(query="business", label="Business")
computer = Rule(query="comput*", label="Sci/Tech")
sci = Rule(query="sci*", label="Sci/Tech")

In [346]:
rules = [
    war,
    football,
    business,
    computer,
    sci,
    sport,
    gov
]

In [347]:
weak_labels = WeakLabels(rules=rules, dataset="zeroshot_agnews")

Preparing rules:   0%|          | 0/7 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/2500 [00:00<?, ?it/s]

In [348]:
weak_labels.summary()

Unnamed: 0,polarity,coverage,overlaps,conflicts,correct,incorrect,precision
war,{World},0.0152,0.0008,0.0,5,2,0.714286
footb*,{Sports},0.0104,0.0004,0.0,5,0,1.0
business,{Business},0.0208,0.002,0.002,5,2,0.714286
comput*,{Sci/Tech},0.0272,0.0024,0.002,6,2,0.75
sci*,{Sci/Tech},0.0192,0.002,0.0016,10,0,1.0
sport*,{Sports},0.0208,0.0012,0.0008,10,1,0.909091
gov*,{World},0.0404,0.0032,0.0024,14,10,0.583333
total,"{World, Sports, Sci/Tech, Business}",0.148,0.006,0.0044,55,17,0.763889


In [349]:
from rubrix.labeling.text_classification import Snorkel

# we pass our WeakLabels instance to our Snorkel label model
label_model = Snorkel(weak_labels)

# we train the model
label_model.fit()

# we check its performance
label_model.score()

{'accuracy': 0.29}

In [410]:
rules = [
    war,
    football,
    business,
    computer,
    sci,
    sport,
    gov
]

rules.append(Rule(query="goal", label="Sports"))
rules.append(Rule(query="play*", label="Sports"))
rules.append(Rule(query="champion*", label="Sports"))
rules.append(Rule(query="game*", label="Sports"))
rules.append(Rule(query="olympic*", label="Sports"))

rules.append(Rule(query="party", label="World"))
rules.append(Rule(query="market", label="World"))
rules.append(Rule(query="democ*", label="World"))
rules.append(Rule(query="deal", label="World"))
rules.append(Rule(query="court", label="World"))
rules.append(Rule(query="minist*", label="World"))
rules.append(Rule(query="country", label="World"))
rules.append(Rule(query="iraq", label="World"))
rules.append(Rule(query="offici*", label="World"))

rules.append(Rule(query="price*", label="Business"))
rules.append(Rule(query="financ*", label="Business"))
rules.append(Rule(query="econom*", label="Business"))
rules.append(Rule(query="company", label="Business"))
rules.append(Rule(query="million*", label="Business"))
rules.append(Rule(query="trade", label="Business"))
rules.append(Rule(query="oil", label="Business"))
rules.append(Rule(query="mone*", label="Business"))
rules.append(Rule(query="market*", label="Business"))
rules.append(Rule(query="stock", label="Business"))
rules.append(Rule(query="corp*", label="Business"))


rules.append(Rule(query="internet", label="Sci/Tech"))
rules.append(Rule(query="web", label="Sci/Tech"))
rules.append(Rule(query="users", label="Sci/Tech"))
rules.append(Rule(query="telep*", label="Sci/Tech"))
rules.append(Rule(query="tech*", label="Sci/Tech"))
rules.append(Rule(query="Microsoft", label="Sci/Tech"))
rules.append(Rule(query="software", label="Sci/Tech"))
#rules.append(Rule(query="motors", label="Sci/Tech"))
len(rules)

39

## Using a list of country names

In [394]:
import pandas as pd

country_list = pd.read_csv("https://pkgstore.datahub.io/core/country-list/data_csv/data/d7c9d7cfb42cb69f4422dec222dbbaa8/data_csv.csv")

In [395]:
rules_country = []
for country in country_list.Name.values:
    rules_country.append(Rule(query=f"{country}", label="World"))

In [396]:
rules.extend(rules_country) ; len(rules)

288

## Using a sport names list

In [397]:
sport_list = pd.read_csv("https://github.com/ali-ce/datasets/raw/master/Most-paid-athletes/Athletes.csv")

In [398]:
sport_list = sport_list.Sport.unique()

In [399]:
sport_rules = []
for name in sport_list:
    sport_rules.append(Rule(query=f"{name}", label="Sports"))

In [400]:
rules.extend(sport_rules) ; len(rules)

298

In [411]:
weak_labels = WeakLabels(rules=rules, dataset="zeroshot_agnews")

Preparing rules:   0%|          | 0/39 [00:00<?, ?it/s]

Applying rules:   0%|          | 0/2500 [00:00<?, ?it/s]

In [412]:
weak_labels.summary()

Unnamed: 0,polarity,coverage,overlaps,conflicts,correct,incorrect,precision
war,{World},0.0152,0.0104,0.002,5,2,0.714286
footb*,{Sports},0.0104,0.0064,0.0008,5,0,1.0
business,{Business},0.0208,0.0144,0.0104,5,2,0.714286
comput*,{Sci/Tech},0.0272,0.0196,0.0136,6,2,0.75
sci*,{Sci/Tech},0.0192,0.01,0.0076,10,0,1.0
sport*,{Sports},0.0208,0.0136,0.0048,10,1,0.909091
gov*,{World},0.0404,0.0296,0.0204,14,10,0.583333
goal,{Sports},0.0044,0.0032,0.0004,2,0,1.0
play*,{Sports},0.0488,0.0304,0.0164,11,4,0.733333
champion*,{Sports},0.0316,0.0176,0.004,11,2,0.846154


In [408]:
# we pass our WeakLabels instance to our Snorkel label model
label_model = Snorkel(weak_labels)

# we train the model
label_model.fit()

# we check its performance
label_model.score()

{'accuracy': 0.364}

In [413]:
from snorkel.labeling.model import LabelModel

# train our label model
label_model = LabelModel(cardinality=4)
label_model.fit(L_train=weak_labels.matrix(has_annotation=False))

# check its performance
label_model.score(
    L=weak_labels.matrix(has_annotation=True), 
    Y=weak_labels.annotation()
)



{'accuracy': 0.7275641025641025}

In [404]:
# get your training records with the predictions of the label model
records_for_training = label_model.predict()

# log the records to a new dataset in Rubrix
rb.log(records_for_training, name="snorkel_results")

  0%|          | 0/2000 [00:00<?, ?it/s]

2000 records logged to http://localhost:6900/snorkel_results


BulkResponse(dataset='snorkel_results', processed=2000, failed=0)

## Using a pre-trained model

In [None]:
nlp = pipeline(
    "zero-shot-classification",
    model="typeform/mobilebert-uncased-mnli",
    truncation=True,
    padding=True
)

def zeroshot(record: rb.TextClassificationRecord):
    prediction = nlp(record.inputs["text"], candidate_labels=labels)
    return prediction["labels"][0]
    #if prediction["scores"][0] >=0.7:
rules.append(zeroshot)