In [25]:
import os
import torch
import numpy as np
import pandas as pd
from transformers import pipeline
from huggingface_hub import HfApi, ModelFilter
from cleanlab.filter import find_label_issues
from datasets import load_dataset, load_from_disk

# transform
import sibyl
import torch
import inspect
import random
from functools import partial

from augmenter import Augmenter
from transform import Transform

In [27]:
torch.use_deterministic_algorithms(False)

In [3]:
def vectorize(output):
    sorted_output = sorted(output, key=lambda d: d['label']) 
    probs = np.array([d['score'] for d in sorted_output])
    return probs

def initialize_transforms(transforms, task_name):
    return [Transform(t, task_name=task_name) for t in transforms]   

In [4]:
class CleanLabFilter:
    def __init__(self):
        self.api = HfApi()
        self.pipe = None
        self.device = 0 if torch.cuda.is_available() else -1

    def find_model_for_dataset(self, dataset_name):
        
        model_filter = ModelFilter(
            task="text-classification",
            library="pytorch",
            # model_name=dataset_name,
            trained_dataset=dataset_name)

        model_id = next(iter(self.api.list_models(filter=model_filter)))

        if model_id:
            model_id = getattr(model_id, 'modelId')
            print('Using ' + model_id + ' to support cleanlab datalabel issues.')
            self.pipe = pipeline("text-classification", 
                                 model=model_id, 
                                 device=self.device, 
                                 top_k=None)

    def extract_prediction_probabilities(self, dataset):
        output = self.pipe(dataset['text'])
        return np.stack([vectorize(o) for o in output])

    def find_num_to_remove_per_class(self, dataset, frac_to_remove=0.1):
        classes = dataset.features['label'].names
        num_classes = len(classes)
        print(num_classes)

        num_per_class = []
        for i in range(num_classes):
            class_partition = dataset.filter(lambda row: row["label"] == i)
            num_per_class.append(len(class_partition))
        num_to_remove_per_class = [int(frac_to_remove * num) for num in num_per_class]
        return num_to_remove_per_class

    def label_issue_rate(self, dataset):
        if self.pipe is None:
            return 

        pred_probs = self.extract_prediction_probabilities(dataset)
        print(f"pred_probs.shape ({pred_probs.shape})")

        suss_idx = find_label_issues(
            labels=dataset['label'],
            pred_probs=pred_probs,  
            return_indices_ranked_by='self_confidence'
        )
        return len(suss_idx) / len(dataset)

    def annotate_dataset(self, dataset):
        pred_probs = self.extract_prediction_probabilities(dataset)
        print(f"pred_probs.shape ({pred_probs.shape})")

        cleanlab_flagged = find_label_issues(
            labels=dataset['label'],
            pred_probs=pred_probs, 
        )

        dataset = dataset.add_column("cleanlab_flagged", [f for f in cleanlab_flagged])
        return dataset

    def clean_dataset(self, dataset):
        if self.pipe is None:
            return dataset
    
        dataset_len = len(dataset)
        num_classes = len(dataset.features['label'].names)

        pred_probs = self.extract_prediction_probabilities(dataset)
        print(f"pred_probs.shape ({pred_probs.shape})")

        num_to_remove_per_class = self.find_num_to_remove_per_class(dataset)
        print(f"num_to_remove_per_class ({num_to_remove_per_class})")

        num_to_add = pred_probs.shape[-1] - len(num_to_remove_per_class)
        print(f"num_to_add: {num_to_add}")
        for i in range(num_to_add):
            num_to_remove_per_class.append(0)

        suss_idx = find_label_issues(
            labels=dataset['label'],
            pred_probs=pred_probs,  
            return_indices_ranked_by='self_confidence',
            num_to_remove_per_class = num_to_remove_per_class
        )
        print(f"suss_idx.len ({len(suss_idx)})")
        idx_to_keep = [i for i in range(len(dataset)) if i not in suss_idx]
        print(f"idx_to_keep.len ({len(idx_to_keep)})")
        return dataset.select(idx_to_keep)

In [5]:
# dataset_name = "snips_built_in_intents"
# dataset = load_dataset(dataset_name)['train']

# print("Using cleanlab to cleanup dataset...")
# print(f"Original dataset length: {len(dataset)}")
# cl_filter = CleanLabFilter()
# cl_filter.find_model_for_dataset(dataset_name)
# dataset = cl_filter.annotate_dataset(dataset)
# print(f"Filtered dataset length: {len(dataset)}")

# Running cleanlab on Augmented Data

In [None]:
dataset_config = ("glue", "sst2")
dataset = load_dataset(*dataset_config)['train']
dataset = dataset.rename_column("sentence", "text")

fada_dataset = load_from_disk("./fada/datasets/glue.sst2.sibyl.fada.test").shuffle()
uniform_dataset = load_from_disk("./fada/datasets/glue.sst2.sibyl.uniform").shuffle()

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading and preparing dataset glue/sst2 to /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
cl_filter = CleanLabFilter()
cl_filter.find_model_for_dataset(dataset_config[1])

Using distilbert-base-uncased-finetuned-sst-2-english to support cleanlab datalabel issues.


Downloading (…)lve/main/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [None]:
dataset = cl_filter.annotate_dataset(dataset)

pred_probs.shape ((67349, 2))


In [None]:
dataset.to_pandas().to_csv("glue.sst2.annotated.csv")

In [None]:
uniform_dataset = cl_filter.annotate_dataset(uniform_dataset)
fada_dataset = cl_filter.annotate_dataset(fada_dataset)

In [None]:
original_label_issue_rate = cl_filter.label_issue_rate(dataset)
print(f"original_label_issue_rate: {original_label_issue_rate}")

pred_probs.shape ((67349, 2))
original_label_issue_rate: 0.0031329344162496844


In [None]:
uniform_label_issue_rate = cl_filter.label_issue_rate(uniform_dataset)
print(f"uniform_label_issue_rate: {uniform_label_issue_rate}")

pred_probs.shape ((404094, 2))
uniform_label_issue_rate: 0.23063445633936658


In [None]:
fada_label_issue_rate = cl_filter.label_issue_rate(fada_dataset)
print(f"fada_label_issue_rate: {fada_label_issue_rate}")

pred_probs.shape ((404094, 2))
fada_label_issue_rate: 0.24663320910481223


# Running Cleanlab on Transformed Datasets

In [6]:
blacklist = [
    sibyl.Emojify,
    sibyl.AddPositiveEmoji,
    sibyl.AddNegativeEmoji,
    sibyl.Demojify,
    sibyl.RemovePositiveEmoji,
    sibyl.RemoveNegativeEmoji,
    sibyl.AddPositiveEmoji,
    sibyl.AddNegativeEmoji,
    sibyl.InsertPositivePhrase,
    sibyl.InsertNegativePhrase,
    sibyl.AddPositiveLink,
    sibyl.AddNegativeLink,
    sibyl.ImportLinkText,
    sibyl.AddNegation,
    sibyl.RemoveNegation,
    sibyl.ChangeAntonym,
    sibyl.ConceptMix,
    sibyl.TextMix,
    sibyl.SentMix,
    sibyl.WordMix,
    sibyl.Concept2Sentence
]
sibyl_transforms = [t for t in sibyl.TRANSFORMATIONS if t not in blacklist]
sibyl_transforms = sorted(sibyl_transforms, key=lambda t: t.__name__)
sibyl_transforms = initialize_transforms(sibyl_transforms, "sentiment")

initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', return_metadata=True
initializing class with task_name='sentiment', r

In [7]:
dataset_config = ("glue", "sst2")
dataset = load_dataset(*dataset_config)['train']
dataset = dataset.rename_column("sentence", "text")

Found cached dataset glue (C:/Users/Fabrice/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
cl = CleanLabFilter()
cl.find_model_for_dataset(dataset_config[1])

Using distilbert-base-uncased-finetuned-sst-2-english to support cleanlab datalabel issues.


In [30]:
save_dir = "./datasets/"

num_examples = len(dataset)
uniform_policy = np.full((num_examples, 1), fill_value=1)

aug_datasets = []
for t in sibyl_transforms:
    
    # augment dataset
    augmenter = Augmenter(dataset=dataset, 
                          transforms=[t], 
                          transform_probabilities=uniform_policy, 
                          num_augmentations_per_record=1,
                          num_transforms_to_apply=1,
                          keep_originals=False)
    aug_dataset = augmenter.augment()
    
    # annotate with cleanlab
    aug_dataset = cl.annotate_dataset(aug_dataset)
    
    # save to disk
    t_name = t.transform_class.__name__
    save_file = ".".join(dataset_config) + ".sibyl." + t_name
    save_path = os.path.join(save_dir, save_file)
    aug_dataset.save_to_disk(save_path)
    
    aug_datasets.append(aug_dataset)
    
dataset_dfs = []
for d, t in zip(aug_datasets, sibyl_transforms):
    t_name = t.transform_class.__name__
    d = d.rename_column("text", t_name)
    d = d.remove_columns("idx")
    dataset_dfs.append(d.to_pandas())
    
df = pd.concat(dataset_dfs, axis=1)
df.to_csv("glue.sst2.sibyl.all.cleanlab.csv")

  0%|          | 0/6735 [00:00<?, ?ba/s]

pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

  0%|          | 0/6735 [00:00<?, ?ba/s]

pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

  0%|          | 0/6735 [00:00<?, ?ba/s]

pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-75b2471874ac4c1d.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-6c397029904a5279.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-1c77930be76bc719.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-40f8eba90d6ca743.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-4fe2557bafede195.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-91b00be14dfcf908.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-1e459eb654705959.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-5711ec8148c075a4.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-b46db288bc005aa1.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-5d301ffbb1bf8041.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-8a42b4a2921e19ba.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-c76b91329be122d2.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-642e3f3b9a924b77.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-5455e0e2737a3d84.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-6c9fb33fad4cc2aa.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-d40fd69062ba1606.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\cache-839a00b2fd09fffc.arrow


pred_probs.shape ((67349, 2))


Saving the dataset (0/1 shards):   0%|          | 0/67349 [00:00<?, ? examples/s]