# Example 1 - Default Approach

This notebook demonstrates how to use the TX2 dashboard with a sequence classification transformer using the premade huggingface sequence classification models.

In [1]:
%cd -q ..
%load_ext autoreload
%autoreload 2
%matplotlib inline

We enable logging to view the output from `wrapper.prepare()` further down in the notebook. (It's a long running function, and logs which step it's on.)

In [2]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [3]:
import numpy as np
import torch
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from datasets import load_metric
import pandas as pd

In this example notebook, we use the 20 newsgroups dataset, which can be downloaded through sklearn via below:

In [4]:
from sklearn.datasets import fetch_20newsgroups

train_data = fetch_20newsgroups(subset='train')
test_data = fetch_20newsgroups(subset='test')

Some simplistic data cleaning, and putting all data into dataframes for the wrapper

In [5]:
def clean_text(text):
    text = text[text.index("\n\n")+2:]
    text = text.replace("\n", " ")
    text = text.replace("    ", " ")
    text = text.replace("   ", " ")
    text = text.replace("  ", " ")
    text = text.strip()
    return text

In [6]:
train_rows = []
for i in range(len(train_data["data"])):
    row = {}
    row["text"] = clean_text(train_data["data"][i])
    row["target"] = train_data['target'][i]
    if row["text"] == "" or row["text"] == " ": continue
    train_rows.append(row)
train_df = pd.DataFrame(train_rows)

In [7]:
test_rows = []
for i in range(len(test_data["data"])):
    row = {}
    row["text"] = clean_text(test_data["data"][i])
    row["target"] = test_data['target'][i]
    if row["text"] == "" or row["text"] == " ": continue
    test_rows.append(row)
test_df = pd.DataFrame(test_rows)

In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=20)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [9]:
class EncodedSet(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        print(self.len)

    def __getitem__(self, index):
        text = str(self.data.text[index])
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'input_ids': torch.tensor(ids, dtype=torch.long),
            'attention_mask': torch.tensor(mask, dtype=torch.long),
            'labels': torch.tensor(self.data.target[index], dtype=torch.long)
        }

    def __len__(self):
        return self.len
    
train_set = EncodedSet(train_df, tokenizer, 256)
test_set = EncodedSet(test_df[:1000], tokenizer, 256)

11296
1000


## Training

This section minimally trains the classification and language model - nothing fancy here, just to give the dashboard demo something to work with. Most of this is similar to the huggingface tutorial notebooks.

In [10]:
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(output_dir="data", no_cuda=False, num_train_epochs=1)
trainer = Trainer(model=model, args=training_args, train_dataset=train_set, compute_metrics=compute_metrics)
trainer.train()



Step,Training Loss
500,1.232681




TrainOutput(global_step=706, training_loss=1.0737276158319315)

The wrapper uses an `encodings` dictionary for various labels/visualizations, and can be set up with something similar to:

In [11]:
encodings = {}
for index, entry in enumerate(train_data["target_names"]):
    encodings[entry] = index
encodings

{'alt.atheism': 0,
 'comp.graphics': 1,
 'comp.os.ms-windows.misc': 2,
 'comp.sys.ibm.pc.hardware': 3,
 'comp.sys.mac.hardware': 4,
 'comp.windows.x': 5,
 'misc.forsale': 6,
 'rec.autos': 7,
 'rec.motorcycles': 8,
 'rec.sport.baseball': 9,
 'rec.sport.hockey': 10,
 'sci.crypt': 11,
 'sci.electronics': 12,
 'sci.med': 13,
 'sci.space': 14,
 'soc.religion.christian': 15,
 'talk.politics.guns': 16,
 'talk.politics.mideast': 17,
 'talk.politics.misc': 18,
 'talk.religion.misc': 19}

## TX2

This section shows how to put everything into the TX2 wrapper to get the dashboard widget displayed.

In [12]:
from tx2.wrapper import Wrapper
from tx2.dashboard import Dashboard

[nltk_data] Downloading package stopwords to /home/81n/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [13]:
wrapper = Wrapper(
    train_texts=train_df.text,
    train_labels=train_df.target,
    test_texts=test_df.text[:2000],
    test_labels=test_df.target[:2000],
    encodings=encodings, 
    classifier=model,  
    language_model=model.bert, 
    tokenizer=tokenizer,
    cache_path="data/example3",
    overwrite=True
)
wrapper.prepare()

INFO:root:Cache path found
INFO:root:Checking for cached predictions...
INFO:root:Running classifier...
INFO:root:Saving predictions...
INFO:root:Writing to data/example3/predictions.json
INFO:root:Done!
INFO:root:Checking for cached embeddings...
INFO:root:Embedding training and testing datasets
INFO:root:Saving embeddings...
INFO:root:Writing to data/example3/embedding_training.json
INFO:root:Writing to data/example3/embedding_testing.json
INFO:root:Done!
INFO:root:Checking for cached projections...
INFO:root:Training projector...
INFO:root:Applying projector to test dataset...
INFO:root:Saving projections...
INFO:root:Writing to data/example3/projections_training.json
INFO:root:Writing to data/example3/projections_testing.json
INFO:root:Writing to data/example3/projector.pkl.gz
INFO:root:Done!
INFO:root:Checking for cached salience maps...
INFO:root:Computing salience maps...


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

INFO:root:Saving salience maps...
INFO:root:Writing to data/example3/salience.pkl.gz





INFO:root:Done!
INFO:root:Checking for cached cluster profiles...
INFO:root:Clustering projections...
INFO:root:Saving cluster profiles...
INFO:root:Writing to data/example3/clusters.json
INFO:root:Writing to data/example3/cluster_profiles.pkl.gz
INFO:root:Done!
INFO:root:Saving cluster labels...
INFO:root:Writing to data/example3/cluster_labels.json
INFO:root:Done!




INFO:root:Saving cluster word counts...


INFO:root:Writing to data/example3/cluster_words.json
INFO:root:Writing to data/example3/cluster_class_words.json
INFO:root:Done!


In [14]:
%matplotlib agg
import matplotlib.pyplot as plt
dash = Dashboard(wrapper)
dash.render()

VBox(children=(VBox(children=(HTML(value='<h3>UMAP Embedding Graph</h3>'), HBox(children=(Output(), VBox(child…

To play with different UMAP and DBSCAN arguments without having to recompute the entire `prepare()` function, you can use `recompute_projections` (which will recompute both the projections and visual clusterings) or `recompute_visual_clusterings` (which will only redo the clustering)

In [15]:
# wrapper.recompute_visual_clusterings("KMeans", clustering_args=dict(n_clusters=18))
# wrapper.recompute_visual_clusterings("OPTICS", clustering_args=dict())
# wrapper.recompute_projections(umap_args=dict(min_dist=.2), dbscan_args=dict())

To test or debug the classification model/see what raw outputs the viusualizations are getting, or create your own visualization tools, you can manually call the `classify()`, `soft_classify()`, `embed()` functions, or get access to any of the cached data as seen in the bottom cell

In [16]:
wrapper.classify(["testing"])

[8]

In [17]:
wrapper.soft_classify(["testing"])

[[-0.7819971442222595,
  0.06718077510595322,
  -0.465032696723938,
  -0.8326480984687805,
  -0.41031721234321594,
  -0.7673346996307373,
  2.255911111831665,
  2.475999116897583,
  2.6724281311035156,
  -0.15834367275238037,
  -0.761745810508728,
  -1.1997004747390747,
  1.0040363073349,
  -0.19663360714912415,
  -0.2896929681301117,
  -0.516423761844635,
  0.4586598575115204,
  -0.7602152824401855,
  -0.20343555510044098,
  -0.5323857069015503]]

In [18]:
wrapper.embed(["testing"])

[[0.38975822925567627,
  0.23370453715324402,
  -0.6037890911102295,
  0.42801281809806824,
  0.7710248827934265,
  0.7473658323287964,
  0.466446191072464,
  -0.32463744282722473,
  -0.22400183975696564,
  -0.8100736737251282,
  0.14713317155838013,
  -0.1320997029542923,
  0.6231873035430908,
  0.4791969656944275,
  0.8242490887641907,
  0.5742146968841553,
  0.5950765013694763,
  0.9663211107254028,
  0.3941460847854614,
  -1.5301533937454224,
  -0.03778163343667984,
  -1.1597237586975098,
  0.6117329001426697,
  -0.25169265270233154,
  -0.6425670981407166,
  -0.2890930473804474,
  0.4472787082195282,
  -0.04857483506202698,
  0.9817265272140503,
  0.08757776767015457,
  -0.3973219394683838,
  0.6669486165046692,
  -0.07870160043239594,
  0.20381714403629303,
  -0.5368008017539978,
  0.25037604570388794,
  0.7944205403327942,
  0.40262216329574585,
  0.2738703191280365,
  0.5566569566726685,
  0.27378955483436584,
  0.1408196985721588,
  -0.4545992314815521,
  -0.1893201321363449,
 

In [19]:
# cached data:
# wrapper.embeddings_training
# wrapper.embeddings_testing
# wrapper.projector
# wrapper.projections_training
# wrapper.projections_testing
# wrapper.salience_maps
# wrapper.clusters
# wrapper.cluster_profiles
# wrapper.cluster_words_freqs
# wrapper.cluster_class_word_sets