# Example 2 - Custom Approach

This notebook demonstrates how to use the TX2 dashboard with a sequence classification transformer using the custom approach as described in the "Basic Usage" docs. To demonstrate, we simply take the default functions for each and manually define them.

In [1]:
%cd -q ..
%load_ext autoreload
%autoreload 2
%matplotlib inline

We enable logging to view the output from `wrapper.prepare()` further down in the notebook. (It's a long running function, and logs which step it's on.)

In [2]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [3]:
import numpy as np
import torch
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from torch import cuda
from transformers import AutoModel, AutoTokenizer, BertTokenizer
import pandas as pd

In this example notebook, we use the 20 newsgroups dataset, which can be downloaded through sklearn via below:

In [4]:
from sklearn.datasets import fetch_20newsgroups

train_data = fetch_20newsgroups(subset='train')
test_data = fetch_20newsgroups(subset='test')

device = 'cuda' if cuda.is_available() else 'cpu'

Defined below is a simple sequence classification model with a variable for the language model itself `l1`. Since it is a BERT model, we take the sequence embedding from the `[CLS]` token (via `output_1[0][:, 0, :])`) and pipe that into the linear layer.

In [5]:
class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.l1 = AutoModel.from_pretrained("bert-base-cased")
        self.l2 = torch.nn.Linear(768, 20)

    def forward(self, ids, mask):
        output_1= self.l1(ids, mask)
        output = self.l2(output_1[0][:, 0, :]) # use just the [CLS] output embedding
        return output
    
model = BERTClass()
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

Some simplistic data cleaning, and putting all data into dataframes for the wrapper

In [6]:
def clean_text(text):
    text = text[text.index("\n\n")+2:]
    text = text.replace("\n", " ")
    text = text.replace("    ", " ")
    text = text.replace("   ", " ")
    text = text.replace("  ", " ")
    text = text.strip()
    return text

In [7]:
train_rows = []
for i in range(len(train_data["data"])):
    row = {}
    row["text"] = clean_text(train_data["data"][i])
    row["target"] = train_data['target'][i]
    if row["text"] == "" or row["text"] == " ": continue
    train_rows.append(row)
train_df = pd.DataFrame(train_rows)

In [8]:
test_rows = []
for i in range(len(test_data["data"])):
    row = {}
    row["text"] = clean_text(test_data["data"][i])
    row["target"] = test_data['target'][i]
    if row["text"] == "" or row["text"] == " ": continue
    test_rows.append(row)
test_df = pd.DataFrame(test_rows)

## Training

This section minimally trains the classification and language model - nothing fancy here, just to give the dashboard demo something to work with. Most of this is similar to the huggingface tutorial notebooks.

In [9]:
# Creating the loss function and optimizer
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params =  model.parameters(), lr=1e-05)

In [10]:
class EncodedSet(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        print(self.len)

    def __getitem__(self, index):
        text = str(self.data.text[index])
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.target[index], dtype=torch.long)
        }

    def __len__(self):
        return self.len
    
train_set = EncodedSet(train_df, tokenizer, 256)
test_set = EncodedSet(test_df[:1000], tokenizer, 256)

train_params = {'batch_size': 16,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': 2,
                'shuffle': True,
                'num_workers': 0
                }

# put everything into data loaders
train_loader = DataLoader(train_set, **train_params)
test_loader = DataLoader(test_set, **test_params)

11296
1000


In [11]:
def train(epoch):
    model.train()

    
    loss_history = []
    for _,data in tqdm(enumerate(train_loader, 0), total=len(train_loader), desc=f"Epoch {epoch}"):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.long)

        outputs = model(ids, mask).squeeze()

        optimizer.zero_grad()
        loss = loss_function(outputs, targets)
        if _%100==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        loss_history.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         torch.cuda.empty_cache()
    return loss_history

In [12]:
losses = []
for epoch in range(1):
    losses.extend(train(epoch))

HBox(children=(HTML(value='Epoch 0'), FloatProgress(value=0.0, max=706.0), HTML(value='')))



Epoch: 0, Loss:  2.9316442012786865
Epoch: 0, Loss:  2.4549002647399902
Epoch: 0, Loss:  1.0166977643966675
Epoch: 0, Loss:  0.7657784223556519
Epoch: 0, Loss:  1.1686592102050781
Epoch: 0, Loss:  1.1292028427124023
Epoch: 0, Loss:  0.5788240432739258
Epoch: 0, Loss:  0.9440498352050781



The wrapper uses an `encodings` dictionary for various labels/visualizations, and can be set up with something similar to:

In [13]:
encodings = {}
for index, entry in enumerate(train_data["target_names"]):
    encodings[entry] = index
encodings

{'alt.atheism': 0,
 'comp.graphics': 1,
 'comp.os.ms-windows.misc': 2,
 'comp.sys.ibm.pc.hardware': 3,
 'comp.sys.mac.hardware': 4,
 'comp.windows.x': 5,
 'misc.forsale': 6,
 'rec.autos': 7,
 'rec.motorcycles': 8,
 'rec.sport.baseball': 9,
 'rec.sport.hockey': 10,
 'sci.crypt': 11,
 'sci.electronics': 12,
 'sci.med': 13,
 'sci.space': 14,
 'soc.religion.christian': 15,
 'talk.politics.guns': 16,
 'talk.politics.mideast': 17,
 'talk.politics.misc': 18,
 'talk.religion.misc': 19}

## TX2

This section shows how to put everything into the TX2 wrapper to get the dashboard widget displayed.

In [14]:
from tx2.wrapper import Wrapper
from tx2.dashboard import Dashboard

  from pkg_resources import resource_filename


In [15]:
from tx2 import utils

def custom_encoding_function(text):
    encoded = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=256,
        pad_to_max_length=True,
        truncation=True,
        return_token_type_ids=True,
    )
    return {
        "input_ids": torch.tensor(encoded["input_ids"], device=device),
        "attention_mask": torch.tensor(encoded["attention_mask"], device=device),
    }

def custom_classification_function(inputs):
    return torch.argmax(model(inputs["input_ids"], inputs["attention_mask"]), dim=1)

def custom_embedding_function(inputs):
    return model.l1(inputs["input_ids"], inputs["attention_mask"])[0][:, 0, :]  # [CLS] token embedding

def custom_soft_classification_function(inputs):
    return model(inputs["input_ids"], inputs["attention_mask"])


In [16]:
wrapper = Wrapper(
    train_texts=train_df.text,
    train_labels=train_df.target,
    test_texts=test_df.text[:2000],
    test_labels=test_df.target[:2000],
    encodings=encodings, 
    cache_path="data/custom_cache",
    overwrite=True
)
wrapper.encode_function = custom_encoding_function
wrapper.classification_function = custom_classification_function
wrapper.soft_classification_function = custom_soft_classification_function
wrapper.embedding_function = custom_embedding_function
wrapper.prepare()

INFO:root:Cache path not found, creating...
INFO:root:Checking for cached predictions...
INFO:root:Running classifier...
INFO:root:Saving predictions...
INFO:root:Writing to data/custom_cache/predictions.json
INFO:root:Done!
INFO:root:Checking for cached embeddings...
INFO:root:Embedding training and testing datasets
INFO:root:Saving embeddings...
INFO:root:Writing to data/custom_cache/embedding_training.json
INFO:root:Writing to data/custom_cache/embedding_testing.json
INFO:root:Done!
INFO:root:Checking for cached projections...
INFO:root:Training projector...
INFO:root:Applying projector to test dataset...
INFO:root:Saving projections...
INFO:root:Writing to data/custom_cache/projections_training.json
INFO:root:Writing to data/custom_cache/projections_testing.json
INFO:root:Writing to data/custom_cache/projector.pkl.gz
INFO:root:Done!
INFO:root:Checking for cached salience maps...
INFO:root:Computing salience maps...


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2000.0), HTML(value='')))

INFO:root:Saving salience maps...
INFO:root:Writing to data/custom_cache/salience.pkl.gz





INFO:root:Done!
INFO:root:Checking for cached cluster profiles...
INFO:root:Clustering projections...
INFO:root:Saving cluster profiles...
INFO:root:Writing to data/custom_cache/clusters.json
INFO:root:Writing to data/custom_cache/cluster_profiles.pkl.gz
INFO:root:Done!
INFO:root:Saving cluster labels...
INFO:root:Writing to data/custom_cache/cluster_labels.json
INFO:root:Done!
INFO:root:Saving cluster word counts...
INFO:root:Writing to data/custom_cache/cluster_words.json
INFO:root:Writing to data/custom_cache/cluster_class_words.json
INFO:root:Done!


In [17]:
%matplotlib agg
dash = Dashboard(wrapper, show_wordclouds=True)
dash.render()

VBox(children=(VBox(children=(HTML(value='<h3>UMAP Embedding Graph</h3>'), HBox(children=(Output(), VBox(child…



To play with different UMAP and DBSCAN arguments without having to recompute the entire `prepare()` function, you can use `recompute_projections` (which will recompute both the projections and visual clusterings) or `recompute_visual_clusterings` (which will only redo the clustering)

In [18]:
# wrapper.recompute_visual_clusterings(dbscan_args=dict())
# wrapper.recompute_projections(umap_args=dict(min_dist=.2), dbscan_args=dict())

To test or debug the classification model/see what raw outputs the viusualizations are getting, or create your own visualization tools, you can manually call the `classify()`, `soft_classify()`, `embed()` functions, or get access to any of the cached data as seen in the bottom cell

In [19]:
wrapper.classify(["testing"])

[7]

In [20]:
wrapper.soft_classify(["testing"])

[[-1.0488052368164062,
  -0.03285098820924759,
  -0.2580734193325043,
  -0.3668859302997589,
  0.15141808986663818,
  -0.15236662328243256,
  1.9831351041793823,
  1.2149289846420288,
  1.321704387664795,
  -0.19445110857486725,
  0.08170069754123688,
  -0.11487706750631332,
  1.019587516784668,
  -0.02104061096906662,
  -0.6935140490531921,
  -0.22049453854560852,
  -0.9581518173217773,
  0.4018412232398987,
  -0.6637047529220581,
  -0.3188351094722748]]

In [21]:
wrapper.embed(["testing"])

[[0.43734872341156006,
  0.16381454467773438,
  0.08185862749814987,
  -0.38637927174568176,
  -0.2747752070426941,
  0.29509079456329346,
  0.2484285533428192,
  -0.34395158290863037,
  -0.5298862457275391,
  -1.2994873523712158,
  0.03324988856911659,
  0.22103442251682281,
  -0.5836265087127686,
  0.8374025821685791,
  -1.2318155765533447,
  0.9985899329185486,
  0.2384842485189438,
  0.33898913860321045,
  0.40220141410827637,
  0.458391010761261,
  -0.5837693810462952,
  -0.9413780570030212,
  0.21607470512390137,
  -1.1177239418029785,
  0.49385640025138855,
  -0.10062618553638458,
  0.31760454177856445,
  1.296376347541809,
  0.18768341839313507,
  -0.25272172689437866,
  -0.0787934958934784,
  0.5208892226219177,
  0.9109460115432739,
  -0.030176391825079918,
  0.07182762026786804,
  0.11704669892787933,
  -0.11009922623634338,
  0.2647625207901001,
  -0.34480756521224976,
  -0.46474146842956543,
  -1.0423533916473389,
  -0.025282690301537514,
  0.38863736391067505,
  0.1714559

In [22]:
# cached data:
# wrapper.embeddings_training
# wrapper.embeddings_testing
# wrapper.projector
# wrapper.projections_training
# wrapper.projections_testing
# wrapper.salience_maps
# wrapper.clusters
# wrapper.cluster_profiles
# wrapper.cluster_words_freqs
# wrapper.cluster_class_word_sets