# Example 3 - Using a Huggingface Sequence Classifier

This notebook demonstrates how to use the TX2 dashboard with a sequence classification transformer using the premade huggingface sequence classification models.

In [1]:
%cd -q ..
%load_ext autoreload
%autoreload 2
%matplotlib inline

We enable logging to view the output from `wrapper.prepare()` further down in the notebook. (It's a long running function, and logs which step it's on.)

In [2]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [5]:
import numpy as np
import torch
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from torch import cuda
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from datasets import load_metric
import pandas as pd

In this example notebook, we use the 20 newsgroups dataset, which can be downloaded through huggingface via below:

In [6]:
from datasets import load_dataset

# getting newsgroups data from huggingface
train_data = pd.DataFrame(data=load_dataset("SetFit/20_newsgroups", split="train"))
test_data = pd.DataFrame(data=load_dataset("SetFit/20_newsgroups", split="test"))

# setting up pytorch device
if cuda.is_available():
    device = "cuda"
elif torch.has_mps:
    device = "mps"
else:
    device = "cpu"

device

Using custom data configuration SetFit--20_newsgroups-f9362e018b6adf67
Reusing dataset json (/home/81n/.cache/huggingface/datasets/SetFit___json/SetFit--20_newsgroups-f9362e018b6adf67/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253)
Using custom data configuration SetFit--20_newsgroups-f9362e018b6adf67
Reusing dataset json (/home/81n/.cache/huggingface/datasets/SetFit___json/SetFit--20_newsgroups-f9362e018b6adf67/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253)


'cuda'

Some simplistic data cleaning, and putting all data into dataframes for the wrapper

In [7]:
def clean_text(text):
    text = str(text)
    # text = text[text.index("\n\n") + 2 :]
    text = text.replace("\n", " ")
    text = text.replace("    ", " ")
    text = text.replace("   ", " ")
    text = text.replace("  ", " ")
    text = text.strip()
    return text

In [8]:
# clean long white space or extensive character returns
train_data.text = train_data.text.apply(lambda x: clean_text(x))
test_data.text = test_data.text.apply(lambda x: clean_text(x))

# remove empty entries or trivially short ones
train_cleaned = train_data[train_data["text"].str.len() > 1]
test_cleaned = test_data[test_data["text"].str.len() > 1]
train_cleaned

Unnamed: 0,text,label,label_text
0,I was wondering if anyone out there could enli...,7,rec.autos
1,A fair number of brave souls who upgraded thei...,4,comp.sys.mac.hardware
2,"well folks, my mac plus finally gave up the gh...",4,comp.sys.mac.hardware
3,Do you have Weitek's address/phone number? I'd...,1,comp.graphics
4,"From article <C5owCB.n3p@world.std.com>, by to...",14,sci.space
...,...,...,...
11309,DN> From: nyeda@cnsvax.uwec.edu (David Nye) DN...,13,sci.med
11310,"I have a (very old) Mac 512k and a Mac Plus, b...",4,comp.sys.mac.hardware
11311,I just installed a DX2-66 CPU in a clone mothe...,3,comp.sys.ibm.pc.hardware
11312,Wouldn't this require a hyper-sphere. In 3-spa...,1,comp.graphics


In [9]:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=20)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [20]:
class EncodedSet(Dataset):
    def __init__(self, dataframe: pd.DataFrame, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        print(self.len)

    def __getitem__(self, index):
        text = str(self.data.text[index])
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_token_type_ids=True,
        )
        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]

        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(mask, dtype=torch.long),
            "labels": torch.tensor(self.data.label[index], dtype=torch.long),
        }

    def __len__(self):
        return self.len


train_cleaned.reset_index(drop=True, inplace=True)
test_cleaned.reset_index(drop=True, inplace=True)

train_set = EncodedSet(train_cleaned, tokenizer, 256)
test_set = EncodedSet(test_cleaned[:1000], tokenizer, 256)

11014
1000


## Training

This section minimally trains the classification and language model - nothing fancy here, just to give the dashboard demo something to work with. Most of this is similar to the huggingface tutorial notebooks.

In [21]:
type(train_set)

__main__.EncodedSet

In [22]:
issubclass(EncodedSet, torch.utils.data.Dataset) 

True

In [23]:
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(output_dir="data", no_cuda=False, num_train_epochs=1)
trainer = Trainer(model=model, args=training_args, train_dataset=train_set, compute_metrics=compute_metrics)
trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running training *****
  Num examples = 11014
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 689


Step,Training Loss
500,1.4282


Saving model checkpoint to data/checkpoint-500
Configuration saved in data/checkpoint-500/config.json
Model weights saved in data/checkpoint-500/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=689, training_loss=1.3056424034351188, metrics={'train_runtime': 144.7068, 'train_samples_per_second': 76.113, 'train_steps_per_second': 4.761, 'total_flos': 1449186753957888.0, 'train_loss': 1.3056424034351188, 'epoch': 1.0})

The wrapper uses an `encodings` dictionary for various labels/visualizations, and can be set up with something similar to:

In [24]:
encodings = (
    train_cleaned[["label", "label_text"]]
    .groupby(["label_text"])
    .apply(lambda x: x["label"].tolist()[0])
    .to_dict()
)
encodings

{'alt.atheism': 0,
 'comp.graphics': 1,
 'comp.os.ms-windows.misc': 2,
 'comp.sys.ibm.pc.hardware': 3,
 'comp.sys.mac.hardware': 4,
 'comp.windows.x': 5,
 'misc.forsale': 6,
 'rec.autos': 7,
 'rec.motorcycles': 8,
 'rec.sport.baseball': 9,
 'rec.sport.hockey': 10,
 'sci.crypt': 11,
 'sci.electronics': 12,
 'sci.med': 13,
 'sci.space': 14,
 'soc.religion.christian': 15,
 'talk.politics.guns': 16,
 'talk.politics.mideast': 17,
 'talk.politics.misc': 18,
 'talk.religion.misc': 19}

## TX2

This section shows how to put everything into the TX2 wrapper to get the dashboard widget displayed.

In [25]:
from tx2.wrapper import Wrapper
from tx2.dashboard import Dashboard

In [26]:
wrapper = Wrapper(
    train_texts=train_cleaned.text,
    train_labels=train_cleaned.label,
    test_texts=test_cleaned.text[:2000],
    test_labels=test_cleaned.label[:2000],
    encodings=encodings, 
    classifier=model,  
    language_model=model.bert, 
    tokenizer=tokenizer,
    cache_path="data/example3",
    overwrite=True
)
wrapper.prepare()

INFO:root:Cache path found
INFO:root:Checking for cached predictions...
INFO:root:Running classifier...
INFO:root:Saving predictions...
INFO:root:Writing to data/example3/predictions.json
INFO:root:Done!
INFO:root:Checking for cached embeddings...
INFO:root:Embedding training and testing datasets
INFO:root:Saving embeddings...
INFO:root:Writing to data/example3/embedding_training.json
INFO:root:Writing to data/example3/embedding_testing.json
INFO:root:Done!
INFO:root:Checking for cached projections...
INFO:root:Training projector...
INFO:root:Applying projector to test dataset...
INFO:root:Saving projections...
INFO:root:Writing to data/example3/projections_training.json
INFO:root:Writing to data/example3/projections_testing.json
INFO:root:Writing to data/example3/projector.pkl.gz
INFO:root:Done!
INFO:root:Checking for cached salience maps...
INFO:root:Computing salience maps...


  0%|          | 0/2000 [00:00<?, ?it/s]

INFO:root:Saving salience maps...
INFO:root:Writing to data/example3/salience.pkl.gz
INFO:root:Done!
INFO:root:Checking for cached cluster profiles...
INFO:root:Clustering projections...
INFO:root:Saving cluster profiles...
INFO:root:Writing to data/example3/clusters.json
INFO:root:Writing to data/example3/cluster_profiles.pkl.gz
INFO:root:Done!
INFO:root:Saving cluster labels...
INFO:root:Writing to data/example3/cluster_labels.json
INFO:root:Done!
INFO:root:Saving cluster word counts...
INFO:root:Writing to data/example3/cluster_words.json
INFO:root:Writing to data/example3/cluster_class_words.json
INFO:root:Done!


In [27]:
%matplotlib agg
import matplotlib.pyplot as plt
dash = Dashboard(wrapper)
dash.render()

VBox(children=(VBox(children=(HTML(value='<h3>UMAP Embedding Graph</h3>'), HBox(children=(Output(), VBox(child…

To play with different UMAP and DBSCAN arguments without having to recompute the entire `prepare()` function, you can use `recompute_projections` (which will recompute both the projections and visual clusterings) or `recompute_visual_clusterings` (which will only redo the clustering)

In [28]:
# wrapper.recompute_visual_clusterings("KMeans", clustering_args=dict(n_clusters=18))
# wrapper.recompute_visual_clusterings("OPTICS", clustering_args=dict())
# wrapper.recompute_projections(umap_args=dict(min_dist=.2), dbscan_args=dict())

To test or debug the classification model/see what raw outputs the viusualizations are getting, or create your own visualization tools, you can manually call the `classify()`, `soft_classify()`, `embed()` functions, or get access to any of the cached data as seen in the bottom cell

In [29]:
wrapper.classify(["testing"])

[12]

In [30]:
wrapper.soft_classify(["testing"])

[[-0.7564830780029297,
  0.4089648723602295,
  -0.04325471073389053,
  0.3259239196777344,
  0.9641967415809631,
  -0.6353747844696045,
  1.0004136562347412,
  0.43865421414375305,
  1.0501797199249268,
  -0.06218921020627022,
  -0.27333322167396545,
  -0.1180000826716423,
  1.3278456926345825,
  -0.4630220830440521,
  0.5328652262687683,
  -0.7756190299987793,
  -0.10802455246448517,
  -1.0514167547225952,
  -0.780843198299408,
  -0.9997470378875732]]

In [31]:
wrapper.embed(["testing"])

[[0.152827650308609,
  0.5606905817985535,
  0.1829121857881546,
  -0.5550004243850708,
  0.022565964609384537,
  -0.2181190848350525,
  -0.263417512178421,
  -0.09464315325021744,
  0.5924189686775208,
  -0.19153384864330292,
  -0.1327141374349594,
  0.19457969069480896,
  -0.12094241380691528,
  0.44976839423179626,
  0.3706577718257904,
  -0.3284190595149994,
  1.4127256870269775,
  -0.23218385875225067,
  0.3732975721359253,
  -0.7016769647598267,
  -0.03631226718425751,
  0.3732241690158844,
  0.39766523241996765,
  -0.28562697768211365,
  0.3034123480319977,
  -0.38593775033950806,
  0.6229720115661621,
  -0.6008175015449524,
  0.8414539694786072,
  0.3867461085319519,
  -0.6011030673980713,
  -0.30456915497779846,
  0.053773071616888046,
  -0.671636700630188,
  -0.045153357088565826,
  0.7409505844116211,
  -0.13031044602394104,
  0.31774774193763733,
  0.04103156924247742,
  -0.04292754456400871,
  0.43576565384864807,
  -0.07049784064292908,
  0.3436148464679718,
  1.087537169

In [None]:
# cached data:
# wrapper.embeddings_training
# wrapper.embeddings_testing
# wrapper.projector
# wrapper.projections_training
# wrapper.projections_testing
# wrapper.salience_maps
# wrapper.clusters
# wrapper.cluster_profiles
# wrapper.cluster_words_freqs
# wrapper.cluster_class_word_sets