<a href="https://colab.research.google.com/github/MoritzLaurer/zeroshot-classifier/blob/main/2_data_cleaning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Automatic Data Cleaning
This notebook implements automatic data cleaning with cleanlab on some datasets.

### Install and setup

In [None]:
!pip install transformers[sentencepiece]~=4.33.0 -qqq
!pip install datasets~=2.14.0 -qqq
!pip install accelerate~=0.21.0 -qqq
# for automatic dataset cleaning
!pip install cleanlab~=2.4.0 -qqq
!pip install sentence-transformers~=2.2.2 -qqq

In [None]:
# load packages
import transformers
import torch

import pandas as pd
import numpy as np
import os
from datasets import load_dataset

from sklearn.model_selection import train_test_split, cross_val_predict
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer
from cleanlab.dataset import health_summary
from cleanlab.classification import CleanLearning
import warnings
from sklearn.exceptions import ConvergenceWarning
import time
import tqdm

from google.colab.data_table import DataTable
from google.colab import data_table
from IPython.display import display
data_table.enable_dataframe_formatter() # https://colab.research.google.com/notebooks/data_table.ipynb#scrollTo=JgBtx0xFFv_i

# set global seed for reproducibility and against seed hacking
SEED_GLOBAL = 42
np.random.seed(SEED_GLOBAL)

In [None]:
# connect to google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

print(os.getcwd())
os.chdir("/content/drive/My Drive/PhD/zero-shot-models")
print(os.getcwd())

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content
/content/drive/My Drive/PhD/zero-shot-models


### Overarching functions

In [None]:
# model for automatic cleaning
# choose model from MTEB leaderboard: https://huggingface.co/spaces/mteb/leaderboard
sbert_embedder = SentenceTransformer("BAAI/bge-base-en-v1.5")

In [None]:
# function for overall dataset health statistics https://docs.cleanlab.ai/stable/cleanlab/dataset.html
def dataset_health_analysis(texts_encoded=None, df=None, random_state=42):
    model = LogisticRegression(max_iter=500, random_state=random_state)
    pred_probs = cross_val_predict(model, texts_encoded, df["label_standard"], cv=5, method="predict_proba")

    summary = health_summary(labels=df["label_standard"], pred_probs=pred_probs, verbose=False)

    # add label text to columns
    label_text_map_alphabetical = {i: label_text for i, label_text in enumerate(df.label_text.factorize(sort=True)[1])}
    summary["classes_by_label_quality"]["Class Text"] = summary["classes_by_label_quality"]["Class Index"].map(label_text_map_alphabetical)
    summary["overlapping_classes"]["Class Label A"] = summary["overlapping_classes"]["Class Index A"].map(label_text_map_alphabetical)
    summary["overlapping_classes"]["Class Label B"] = summary["overlapping_classes"]["Class Index B"].map(label_text_map_alphabetical)

    return summary


# cleaning function
def cleanlab_cleaning(texts_encoded=None, df=None, random_state=42, label_quality_thresh=0.5, min_output_per_class=5, verbose=True):

    # find label issues
    model = LogisticRegression(max_iter=500, random_state=random_state)
    cl = CleanLearning(
        model, cv_n_folds=5,
        find_label_issues_kwargs={  # https://docs.cleanlab.ai/stable/cleanlab/filter.html#cleanlab.filter.find_label_issues
            "filter_by": "prune_by_noise_rate",
            "min_examples_per_class": min_output_per_class
        }
    )
    df_all_quality = cl.find_label_issues(
        X=texts_encoded, labels=df["label_standard"],  #train_labels,
    )

    # add text and label_text
    df_all_quality["text"] = df["text"].to_list()
    df_all_quality["given_label_text"] = df["label_text"].to_list()
    # add label_text for prediction
    # ! careful: this can lead to mismatches between df_all_quality and df in the numeric labels, if df's labels were not created with factorize
    label_text_map_alphabetical = {i: label_text for i, label_text in enumerate(df.label_text.factorize(sort=True)[1])}
    df_all_quality["predicted_label_text"] = df_all_quality.predicted_label.map(label_text_map_alphabetical)
    # remove unnecessary columns
    df_all_quality = df_all_quality.drop(["given_label", "predicted_label"], axis=1)

    ## cleaning decision
    # remove label issues identified with cleanlab find_label_issues method
    print("Total texts: ", len(df_all_quality))
    df_all_quality_to_remove_1 = df_all_quality[df_all_quality["is_label_issue"] == True]
    print("Issues identified via prune_by_noise_rate: ", len(df_all_quality_to_remove_1))
    # remove labels with low label quality, but keep at least N
    # Step 1: Sort the dataframe by quality and identify the top N for each class to keep at least N
    sorted_df = df_all_quality.sort_values(by=['given_label_text', 'label_quality'], ascending=[True, False])
    df_top_n_per_class = sorted_df.groupby('given_label_text').head(min_output_per_class)
    # Step 2: Filter out rows below the threshold
    df_low_quality = df_all_quality[df_all_quality.label_quality < label_quality_thresh]
    # Step 3: keep at least the top N for each class among low quality
    df_all_quality_to_remove_2 = df_low_quality[~df_low_quality.index.isin(df_top_n_per_class.index)]
    print("Issues identified via label quality: ", len(df_all_quality_to_remove_2))

    indices_to_remove = set(df_all_quality_to_remove_1.index.to_list() + df_all_quality_to_remove_2.index.to_list())
    print("Total texts to remove: ", len(indices_to_remove))
    print("Make sure to reset index of initial df so that indices_to_remove align with df.index")

    if verbose:
        # label imbalance analysis
        df_cl = df[~df.index.isin(indices_to_remove)]
        df_label_balance_change = pd.DataFrame({
            "label_distribution_original": df.label_text.value_counts(),
            "label_distribution_clean": df_cl.label_text.value_counts(),
        })
        display(df_label_balance_change)

        df_all_quality.label_quality.plot.hist(
            bins=10, xlabel="label quality", title="Label quality distribution"
        )

    return indices_to_remove, df_all_quality


# cleaning pipeline function
# wrapper around two functions above and additional methods (e.g. sampling)
def clean_dataset(dataset_path=None, max_input_per_class=10_000,
                  label_quality_thresh=0.5, min_output_per_class=10, max_output_per_class=1_000,
                  verbose=True, health_summary=False, embedding_batch_size=32):

    print("*** Dataset being cleaned ***")
    print(f"*** {dataset_path} ***\n")
    df = pd.read_parquet(dataset_path).reset_index(drop=True)

    # downsample for faster iterations
    df = df.groupby("label_text", as_index=False, group_keys=False).apply(
        lambda x: x.sample(min(max_input_per_class, len(x)), random_state=SEED_GLOBAL)
    ).reset_index(drop=True)

    # encode texts with sbert model
    # https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode
    texts_encoded = sbert_embedder.encode(
        df["text"].to_list(),
        batch_size=embedding_batch_size, show_progress_bar=True, normalize_embeddings=False
    )

    #if not verbose:
    warnings.simplefilter('ignore', ConvergenceWarning)
    #else:
    #    warnings.simplefilter('default', ConvergenceWarning)

    # overall dataset health
    if health_summary:
        summary = dataset_health_analysis(texts_encoded=texts_encoded, df=df)
        display(summary["classes_by_label_quality"])
        display(summary["overlapping_classes"])

    indices_to_remove, df_all_quality = cleanlab_cleaning(
        texts_encoded=texts_encoded,
        df=df, random_state=SEED_GLOBAL,
        label_quality_thresh=label_quality_thresh, min_output_per_class=10,
        verbose=verbose
    )

    # add quality score to final dataset
    # adding index column to enable merging on them and eliminate risk of merging issues due to text duplicates or similar
    df["index"] = df.index
    df_all_quality["index"] = df_all_quality.index
    df = df.merge(df_all_quality[["label_quality", "text", "given_label_text", "index"]], left_on=["text", "label_text", "index"], right_on=["text", "given_label_text", "index"], how="left")
    df = df.drop(columns=["index", "given_label_text"]).reset_index(drop=True)
    df_all_quality = df_all_quality.drop(columns=["index"]).reset_index(drop=True)

    # remove noisy rows
    df_cl = df[~df.index.isin(indices_to_remove)]

    # downsample to account for imbalance
    df_cl = df_cl.groupby("label_text", as_index=False, group_keys=False).apply(
        lambda x: x.sample(min(max_output_per_class, len(x)), random_state=SEED_GLOBAL)
    )

    return df_cl, df_all_quality



### Automatic data cleaning

In [None]:
import os

def find_train_files(directory):
    # List all files dataset directory
    all_files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    # Filter files that have the word "train" in the filename
    # only clean train datasets
    train_files = [f for f in all_files if "train" in os.path.basename(f)]
    return train_files

directory_path = './datasets_standardized'
train_files_lst = find_train_files(directory_path)
print(train_files_lst)

# remove specific datasets where automatic cleaning won't add value
# or cleaning was done with other method
datasets_no_automatic_clean = [
    "trueteacher", "anthropic_harmless", "anthropic_helpful",
    #"manifesto", "capsotu"
]
train_files_lst = [path for path in train_files_lst if not any(dataset in path for dataset in datasets_no_automatic_clean)]
print(train_files_lst)

['./datasets_standardized/ds_wellformedquery_train.gzip', './datasets_standardized/ds_rottentomatoes_train.gzip', './datasets_standardized/ds_amazonpolarity_train.gzip', './datasets_standardized/ds_imdb_train.gzip', './datasets_standardized/ds_yelpreviews_train.gzip', './datasets_standardized/ds_hatexplain_train.gzip', './datasets_standardized/ds_massive_train.gzip', './datasets_standardized/ds_banking77_train.gzip', './datasets_standardized/ds_emotiondair_train.gzip', './datasets_standardized/ds_emocontext_train.gzip', './datasets_standardized/ds_empathetic_train.gzip', './datasets_standardized/ds_agnews_train.gzip', './datasets_standardized/ds_yahootopics_train.gzip', './datasets_standardized/ds_biasframes_offensive_train.gzip', './datasets_standardized/ds_biasframes_sex_train.gzip', './datasets_standardized/ds_biasframes_intent_train.gzip', './datasets_standardized/ds_financialphrasebank_train.gzip', './datasets_standardized/ds_appreviews_train.gzip', './datasets_standardized/ds_hat

In [None]:
# cleaning loop

df_cl_dic = {}
df_all_quality_dic = {}
for dataset_path in tqdm.notebook.tqdm(train_files_lst, desc="Dataset iterations"):

    df_cl, df_all_quality = clean_dataset(
        dataset_path=dataset_path, max_input_per_class=10_000,
        label_quality_thresh=0.5, min_output_per_class=10, max_output_per_class=10_000,
        verbose=True, health_summary=False, embedding_batch_size=64*4
    )
    df_cl_dic.update({dataset_path: df_cl})
    df_all_quality_dic.update({dataset_path: df_all_quality})

    # display sample of cleaned data
    display(df_cl.sample(frac=1.0, random_state=SEED_GLOBAL).head(10))

    # sleep to allow tables to display for inspection before next iter
    time.sleep(2)


### Overall cleaned data statistics

In [None]:

# count text per class after cleaning
n_texts_per_class_cl = pd.Series(dtype=object)
for key_dataset_path, value_df_cl in df_cl_dic.items():
    n_texts_per_class_cl = pd.concat([n_texts_per_class_cl, value_df_cl.label_text.value_counts()])
n_texts_per_class_cl = pd.DataFrame(n_texts_per_class_cl, columns=["label_text"])

# text per class before cleaning
n_texts_per_class = pd.Series(dtype=object)
for key_dataset_path, value_df_all_quality in df_all_quality_dic.items():
    n_texts_per_class = pd.concat([n_texts_per_class, value_df_all_quality.given_label_text.value_counts()])
n_texts_per_class = pd.DataFrame(n_texts_per_class, columns=["label_text"])

df_n_texts_per_class = n_texts_per_class_cl.merge(
    n_texts_per_class, left_index=True, right_index=True, suffixes=["_cl", ""]
)

# statistics:
print("Total texts before cleaning: ", df_n_texts_per_class.label_text.sum())
print("Total texts after cleaning: ", df_n_texts_per_class.label_text_cl.sum())

print("\nClasses in total: ", len(df_n_texts_per_class.label_text_cl))
print("\nTasks in total: ", len(df_cl_dic.keys()))

print("\nLabel distribution after/before cleaning:\n", df_n_texts_per_class, "\n")


Unnamed: 0,label_text_cl,label_text
Welfare State Expansion,2366,3465
Education Expansion,1468,1865
Technology and Infrastructure: Positive,1178,2075
Environmental Protection,1117,1771
Law and Order: Positive,851,1271
...,...,...
Domestic Commerce,10,89
Public Lands,10,50
Technology,10,76
Transportation,10,62


### Saving data to disk

In [None]:
# save cleaned data to disk
save_to_disk = True

if save_to_disk:
    for key_dataset_path, value_df in df_cl_dic.items():
        key_dataset_path_cl = key_dataset_path.replace(".gzip", "").replace("datasets_standardized", "datasets_clean") + "_cl.gzip"
        value_df.to_parquet(key_dataset_path_cl, compression='gzip')
