# Load the Pretrained Model and the dataset
We use ernie-2.0-base-en as the model and SST-2 as the dataset for example. More models can be found in [PaddleNLP Model Zoo](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/index.html#transformer).

Obviously, PaddleNLP is needed to run this notebook, which is easy to install:
```bash
pip install setuptools_scm 
pip install --upgrade paddlenlp
```

In [2]:
import paddle
import paddlenlp
from assets.ernie import ErnieForSequenceClassification
from paddlenlp.transformers import ErnieTokenizer

MODEL_NAME = "ernie-2.0-base-en"

model = ErnieForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=2)
tokenizer = ErnieTokenizer.from_pretrained(MODEL_NAME)

  from .autonotebook import tqdm as notebook_tqdm
[32m[2022-07-06 16:15:30,885] [    INFO][0m - Already cached /root/.paddlenlp/models/ernie-2.0-base-en/ernie_v2_eng_base.pdparams[0m
W0706 16:15:30.888118 131170 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0706 16:15:30.892442 131170 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.
[32m[2022-07-06 16:15:41,836] [    INFO][0m - Already cached /root/.paddlenlp/models/ernie-2.0-base-en/vocab.txt[0m
[32m[2022-07-06 16:15:41,869] [    INFO][0m - tokenizer config file saved in /root/.paddlenlp/models/ernie-2.0-base-en/tokenizer_config.json[0m
[32m[2022-07-06 16:15:41,871] [    INFO][0m - Special tokens file saved in /root/.paddlenlp/models/ernie-2.0-base-en/special_tokens_map.json[0m


In [3]:
from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset(
    "glue", name='sst-2', splits=["train", "dev", "test"]
)

# Prepare the Model

## Train the model

In [5]:
# training the model and save to save_dir
# only needs to run once.
# total steps ~1700 (1 epoch)

from assets.utils import training_model
training_model(model, tokenizer, train_ds, dev_ds, save_dir=f'assets/sst-2-ernie-2.0-en')

dataset labels: ['0', '1']
dataset examples:
{'sentence': 'hide new secretions from the parental units ', 'labels': 0}
{'sentence': 'contains no wit , only labored gags ', 'labels': 0}
{'sentence': 'that loves its characters and communicates something rather beautiful about human nature ', 'labels': 1}
{'sentence': 'remains utterly satisfied to remain the same throughout ', 'labels': 0}
{'sentence': 'on the worst revenge-of-the-nerds clichés the filmmakers could dredge up ', 'labels': 0}
Training Starts:
global step 100, epoch: 1, batch: 100, loss: 0.31666, acc: 0.76844
global step 200, epoch: 1, batch: 200, loss: 0.28145, acc: 0.82578
global step 300, epoch: 1, batch: 300, loss: 0.28762, acc: 0.84729
global step 400, epoch: 1, batch: 400, loss: 0.39735, acc: 0.86125
global step 500, epoch: 1, batch: 500, loss: 0.07971, acc: 0.87100
global step 600, epoch: 1, batch: 600, loss: 0.17368, acc: 0.87906
global step 700, epoch: 1, batch: 700, loss: 0.19687, acc: 0.88571
global step 800, epoc

[32m[2022-07-06 14:23:45,598] [    INFO][0m - tokenizer config file saved in assets/sst-2-ernie-2.0-en/tokenizer_config.json[0m
[32m[2022-07-06 14:23:45,599] [    INFO][0m - Special tokens file saved in assets/sst-2-ernie-2.0-en/special_tokens_map.json[0m


## Or Load the trained model

In [4]:
# Load the trained model.
state_dict = paddle.load(f'assets/sst-2-ernie-2.0-en/model_state.pdparams')
model.set_dict(state_dict)

# Prepare for Interpretations

In [5]:
import interpretdl as it
import numpy as np
from assets.utils import convert_example, aggregate_subwords_and_importances
from paddlenlp.data import Stack, Tuple, Pad
from interpretdl.data_processor.visualizer import VisualizationTextRecord, visualize_text

def preprocess_fn(data):
    examples = []
    
    if not isinstance(data, list):
        data = [data]
    
    for text in data:
        input_ids, segment_ids = convert_example(
            text,
            tokenizer,
            max_seq_length=128,
            is_test=True
        )
        examples.append((input_ids, segment_ids))

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input id
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # segment id
    ): fn(samples)
    
    input_ids, segment_ids = batchify_fn(examples)
    return paddle.to_tensor(input_ids, stop_gradient=False), paddle.to_tensor(segment_ids, stop_gradient=False)

## BT Interpreter

### Token-wise

In [6]:
from assets.utils import predict

data = [
    {"text": "it 's a charming and often affecting journey . "},
    {"text":'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '},
    {"text":'this one is definitely one to skip , even for horror movie fanatics . '},
    {"text": 'in its best moments , resembles a bad high school production of grease , without benefit of song . '},
]

label_map = {0: 'negative', 1: 'positive'}

batch_size = 32

results = predict(
    model, data, tokenizer, label_map, batch_size=batch_size)

for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text, results[idx]))
    
pred = paddle.nn.functional.softmax(model(*preprocess_fn(data)), axis=1)

Data: {'text': "it 's a charming and often affecting journey . "} 	 Lable: positive
Data: {'text': 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '} 	 Lable: positive
Data: {'text': 'this one is definitely one to skip , even for horror movie fanatics . '} 	 Lable: negative
Data: {'text': 'in its best moments , resembles a bad high school production of grease , without benefit of song . '} 	 Lable: negative


In [7]:
bt = it.BTNLPInterpreter(model, device='gpu:0')
interp_class = [1, 1, 0, 0]
true_label = [1, 1, 0, 0]
recs = []

for idx, sentence in enumerate(data):
    subword_importances = bt.interpret(
        ap_mode="token",
        data=preprocess_fn(sentence),
        label=interp_class[idx],
        start_layer=9)

    subwords = tokenizer.convert_ids_to_tokens(preprocess_fn(sentence)[0][0])[1:-1]
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances[0])
    word_importances = np.array(word_importances) / np.linalg.norm(
            word_importances)
    
    if interp_class[idx] == 0:
        word_importances = -word_importances
    
    recs.append(
            VisualizationTextRecord(words, word_importances, true_label[idx],
                               np.argmax(pred[idx]), pred[idx, np.argmax(pred[idx])].item(), interp_class[idx])
        )

visualize_text(recs)
# The visualization is not available at github

True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.97),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.89),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (1.00),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,


### Head-wise

In [8]:
bt = it.BTNLPInterpreter(model, device='gpu:0')
interp_class = [1, 1, 0, 0]
true_label = [1, 1, 0, 0]
recs = []

for idx, sentence in enumerate(data):
    subword_importances = bt.interpret(
        data=preprocess_fn(sentence),
        label=interp_class[idx],
        start_layer=11)

    subwords = tokenizer.convert_ids_to_tokens(preprocess_fn(sentence)[0][0])[1:-1]
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances[0])
    word_importances = np.array(word_importances) / np.linalg.norm(
            word_importances)
    
    if interp_class[idx] == 0:
        word_importances = -word_importances
    
    recs.append(
            VisualizationTextRecord(words, word_importances, true_label[idx],
                               np.argmax(pred[idx]), pred[idx, np.argmax(pred[idx])].item(), interp_class[idx])
        )

visualize_text(recs)
# The visualization is not available at github

True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.97),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.89),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (1.00),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,


## GA Interpreter

In [9]:
bt = it.GANLPInterpreter(model, device='gpu:0')
interp_class = [1, 1, 0, 0]
true_label = [1, 1, 0, 0]
recs = []

for idx, sentence in enumerate(data):
    subword_importances = bt.interpret(
        data=preprocess_fn(sentence),
        label=interp_class[idx],
        start_layer=11)
    subwords = tokenizer.convert_ids_to_tokens(preprocess_fn(sentence)[0][0])[1:-1]
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances[0])
    word_importances = np.array(word_importances) / np.linalg.norm(
            word_importances)
    
    if interp_class[idx] == 0:
        word_importances = -word_importances
    
    recs.append(
            VisualizationTextRecord(words, word_importances, true_label[idx],
                               np.argmax(pred[idx]), pred[idx, np.argmax(pred[idx])].item(), interp_class[idx])
        )

visualize_text(recs)

True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.97),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.89),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (1.00),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,
