In [1]:
from datasets import load_dataset

In [None]:
# download data
dataset = load_dataset('ag_news')

In [None]:
# get train set and shuffle
ds_train = dataset["train"].shuffle(seed=43)

# get test set
ds_test = dataset["test"]

# get labels
labels = ds_train.features["label"].names

In [115]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

# define our classifier
classifier = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', MultinomialNB())
])

In [None]:
# fit the classifier
classifier.fit(
    X=ds_train["text"], 
    y=ds_train["label"]
)

In [117]:
# compute test accuracy
acc = classifier.score(
    X=ds_test["text"], 
    y=ds_test["label"],
)  

f"Test accuracy: {acc}"

'Test accuracy: 0.900921052631579'

In [118]:
# get predicted probabilities for all labels
probabilities = classifier.predict_proba(ds_test["text"])

In [119]:
# create records for the test set
records = [
    rb.TextClassificationRecord(
        inputs=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "test"}
    )
    for data, prediction in zip(ds_test, probabilities)
]

In [120]:
from rubrix.labeling.text_classification import find_label_errors

# get records with potential label errors
records_with_label_error = find_label_errors(records)

In [121]:
# uncover label errors in the Rubrix web app
rb.log(records_with_label_error, "label_errors_in_ag_news")

  0%|          | 0/616 [00:00<?, ?it/s]

616 records logged to http://localhost:6900/label_errors_in_ag_news


BulkResponse(dataset='label_errors_in_ag_news', processed=616, failed=0)

In [110]:
from sklearn.model_selection import cross_val_predict

# get predicted probabilities for the whole dataset via cross validation
cv_probs = cross_val_predict(
    classifier,
    X=ds_train["text"] + ds_test["text"], 
    y=ds_train["label"] + ds_test["label"], 
    cv=int(len(ds_train) / len(ds_test)), 
    method="predict_proba", 
    n_jobs=-1
)

In [112]:
# create records for the training set
records = [
    rb.TextClassificationRecord(
        inputs=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "train"}
    )
    for data, prediction in zip(ds_train, cv_probs)
]

In [114]:
# uncover label errors for the train set in the Rubrix web app
rb.log(find_label_errors(records), "label_errors_in_ag_news")

  0%|          | 0/9378 [00:00<?, ?it/s]

9378 records logged to http://localhost:6900/label_errors_in_ag_news


BulkResponse(dataset='label_errors_in_ag_news', processed=9378, failed=0)