# Text Classification with TensorFlow, Keras, and Cleanlab


In [168]:
# Package installation (hidden on docs.cleanlab.ai).
# If running on Colab, may want to use GPU (select: Runtime > Change runtime type > Hardware accelerator > GPU)
# Package versions we used: tensorflow==2.9.1 scikeras==0.6.1 tensorflow_datasets==4.5.2
dependencies = ["cleanlab", "sklearn", "tensorflow", "tensorflow_datasets", "scikeras"]

# Supress outputs that may appear if tensorflow happens to be improperly installed: 
import os 
import logging 
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # suppress tensorflow log output 
logging.getLogger('tensorflow').setLevel(logging.FATAL) 

if "google.colab" in str(get_ipython()):  # Check if it's running in Google Colab
    %pip install cleanlab==v2.2.0
    cmd = ' '.join([dep for dep in dependencies if dep != "cleanlab"])
    %pip install $cmd
else:
    missing_dependencies = []
    for dependency in dependencies:
        try:
            __import__(dependency)
        except ImportError:
            missing_dependencies.append(dependency)

    if len(missing_dependencies) > 0:
        print("Missing required dependencies:")
        print(*missing_dependencies, sep=", ")
        print("\nPlease install them before running the rest of this notebook.")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [169]:
import re 
import string 
import collections
import pathlib
import tensorflow as tf 
import pandas as pd 
import tensorflow_datasets as tfds 
from sklearn.metrics import accuracy_score, log_loss 
from sklearn.model_selection import cross_val_predict 
from tensorflow.keras import layers 
from scikeras.wrappers import KerasClassifier 
from tensorflow.keras import losses
from tensorflow.keras import utils
from tensorflow.keras.layers import TextVectorization


SEED = 123456  # for reproducibility 

In [170]:
# This cell is hidden from docs.cleanlab.ai 

import random 
import numpy as np 

pd.set_option("display.max_colwidth", None) 

tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [171]:
data_url = 'https://github.com/traaaariad/AutomaticTestHomework/raw/main/txt10.tar.gz'

dataset_dir = utils.get_file(
    origin=data_url,
    untar=True,
    cache_dir='txt',
    cache_subdir='')

dataset_dir = pathlib.Path(dataset_dir).parent


In [172]:
list(dataset_dir.iterdir())

[PosixPath('/tmp/.keras/train'),
 PosixPath('/tmp/.keras/test'),
 PosixPath('/tmp/.keras/txt10.tar.gz')]

In [173]:
train_dir = dataset_dir/'train'
list(train_dir.iterdir())

[PosixPath('/tmp/.keras/train/open'),
 PosixPath('/tmp/.keras/train/close'),
 PosixPath('/tmp/.keras/train/unknown')]

In [174]:
batch_size = 1024
raw_full_ds = utils.text_dataset_from_directory(
    train_dir,
    batch_size=batch_size,
    seed=SEED)

Found 39434 files belonging to 3 classes.


In [175]:
for text_batch, label_batch in raw_full_ds.take(1):
  
  for i in range(10):
    print("Error:", text_batch.numpy()[i])
    print("Label:", label_batch.numpy()[i])

raw_full_texts=text_batch.numpy()
full_labels=label_batch.numpy()

Error: b'Parameter docType should be final'
Label: 2
Error: b'Method is missing a javadoc comment'
Label: 1
Error: b'Parameter docType should be final'
Label: 2
Error: b'Since class "GeneratedCriteria" is designed to be inheritable, add javadoc documentation for overridable non-null method "addCriterion"'
Label: 1
Error: b'Parameter docType should be final'
Label: 2
Error: b'Parameter docType should be final'
Label: 2
Error: b'Parameter docType should be final'
Label: 2
Error: b'Parameter docType should be final'
Label: 2
Error: b'Parameter docType should be final'
Label: 2
Error: b'Line longer than 80 characters'
Label: 1


In [176]:
i = 0
print(f"Example Label: {full_labels[i]}")
print(f"Example Text: {raw_full_texts[i]}")

Example Label: 2
Example Text: b'Parameter docType should be final'


In [177]:
for i, label in enumerate(raw_full_ds.class_names):
  print("Label", i, "corresponds to", label)

Label 0 corresponds to close
Label 1 corresponds to open
Label 2 corresponds to unknown


In [178]:
def preprocess_text(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(stripped_html, f"[{re.escape(string.punctuation)}]", "")

In [179]:
max_features = 10000
sequence_length = 250

vectorize_layer = layers.TextVectorization(
    standardize=preprocess_text,
    max_tokens=max_features,
    output_mode="int",
    output_sequence_length=sequence_length,
)

In [180]:
%%capture

vectorize_layer.adapt(raw_full_texts)
full_texts = vectorize_layer(raw_full_texts)
full_texts = full_texts.numpy()

In [181]:
def get_net():
    net = tf.keras.Sequential(
        [
            tf.keras.Input(shape=(None,), dtype="int64"),
            layers.Embedding(max_features + 1, 16),
            layers.Dropout(0.2),
            layers.GlobalAveragePooling1D(),
            layers.Dense(3),#classnum?
            layers.Softmax()
        ]
    )  # outputs probability that text belongs to class 1

    net.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=tf.keras.metrics.CategoricalAccuracy(),
    )
    return net

In [182]:
model = KerasClassifier(get_net(), epochs=20)

In [183]:
num_crossval_folds = 5  # for efficiency; values like 5 or 10 will generally work better
pred_probs = cross_val_predict(
    model,
    full_texts,
    full_labels,
    cv=num_crossval_folds,
    method="predict_proba",
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20

In [184]:
loss = log_loss(full_labels, pred_probs)  # score to evaluate probabilistic predictions, lower values are better
print(f"Cross-validated estimate of log loss: {loss:.3f}")

Cross-validated estimate of log loss: 0.665


In [185]:
from cleanlab.filter import find_label_issues

ranked_label_issues = find_label_issues(
    labels=full_labels, pred_probs=pred_probs, return_indices_ranked_by="self_confidence"
)

In [186]:
print(
    f"cleanlab found {len(ranked_label_issues)} potential label errors.\n"
    f"Here are indices of the top 10 most likely errors: \n {ranked_label_issues[:10]}"
)

cleanlab found 179 potential label errors.
Here are indices of the top 10 most likely errors: 
 [309 359 315  38 991 425 562 135  41 157]


In [187]:
def print_as_df(index):
    return pd.DataFrame(
        {"texts": raw_full_texts[index], "labels": full_labels[index]},
        [index]
    )

In [188]:
for i, label in enumerate(raw_full_ds.class_names):
  print("Label", i, "corresponds to", label)

Label 0 corresponds to close
Label 1 corresponds to open
Label 2 corresponds to unknown


In [189]:
print_as_df(ranked_label_issues[0])

Unnamed: 0,texts,labels
309,"b'Redundant ""Public"" modifier.'",0


In [190]:
print_as_df(ranked_label_issues[1])

Unnamed: 0,texts,labels
359,b'Parameter docType should be final',0


In [191]:
print_as_df(ranked_label_issues[2])

Unnamed: 0,texts,labels
315,"b""'X' is a magic number (immediate constant)""",0


In [192]:
print_as_df(ranked_label_issues[3])

Unnamed: 0,texts,labels
38,"b'Variable ""ABC"" must match pattern ""[a-zA-Z0-9]*$""'",0
