# 🧐 Using Rubrix to find label errors with cleanlab 

In [None]:
# !pip install -U transformers
# !pip install -U datasets
# !pip install -U cleanlab

In [142]:
import datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import pandas as pd
from cleanlab.pruning import get_noise_indices
import rubrix as rb

In [165]:
rb.init(
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli")
model = AutoModelForSequenceClassification.from_pretrained("typeform/distilbert-base-uncased-mnli")

In [218]:
dataset = datasets.load_dataset('multi_nli', split='validation_matched')

Using custom data configuration default
Reusing dataset multi_nli (/home/david/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


In [219]:
dataset[0]

{'promptID': 63735,
 'pairID': '63735n',
 'premise': 'The new rights are nice enough',
 'premise_binary_parse': '( ( The ( new rights ) ) ( are ( nice enough ) ) )',
 'premise_parse': '(ROOT (S (NP (DT The) (JJ new) (NNS rights)) (VP (VBP are) (ADJP (JJ nice) (RB enough)))))',
 'hypothesis': 'Everyone really likes the newest benefits ',
 'hypothesis_binary_parse': '( Everyone ( really ( likes ( the ( newest benefits ) ) ) ) )',
 'hypothesis_parse': '(ROOT (S (NP (NN Everyone)) (VP (ADVP (RB really)) (VBZ likes) (NP (DT the) (JJS newest) (NNS benefits)))))',
 'genre': 'slate',
 'label': 1}

In [220]:
# get model predictions
probs_ds = dataset.map(
    lambda x: {"prob": torch.softmax(model(**tokenizer(x["premise"], x["hypothesis"], return_tensors="pt")).logits, dim=1)[0]}, 
    remove_columns=dataset.column_names
)

  0%|          | 0/9815 [00:00<?, ?ex/s]

In [221]:
# combine data and model predictions
data_df = pd.DataFrame({"premise": dataset["premise"], "hypothesis": dataset["hypothesis"], "probs": probs_ds["prob"], "label": dataset["label"]})

In [222]:
probs_matrix = np.array(data_df["probs"].to_list())

In [223]:
label_errors = get_noise_indices(
    s=data_df["label"].to_numpy(),
    psx=probs_matrix,
    sorted_index_method=None#'normalized_margin', # Orders label errors
)

In [224]:
def make_rec(row):
    #prem_hypo = {"premise": row.premise, "hypothesis": row.hypothesis},
    preds = list(zip(["entailment", "neutral", "contradiction"], row.probs))
    annot = "neutral"
    if row.label == 0:
        annot = "entailment"
    elif row.label == 2:
        annot = "contradiction"
        
    return rb.TextClassificationRecord(inputs={"premise": row.premise, "hypothesis": row.hypothesis}, prediction=preds, prediction_agent="typeform/distilbert-base-uncased-mnli", annotation=annot, annotation_agent="mnli")
    # return rb.TextClassificationRecord(inputs=prem_hypo)  # super weird error

recs = data_df[label_errors].apply(make_rec, axis=1)

In [225]:
rb.log(records=recs.to_list(), name="mnli_label_error_m")

BulkResponse(dataset='mnli_label_error_m', processed=1560, failed=0)