In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from functools import partial
from tqdm.notebook import tqdm
import tensorflow as tf
from transformers import TFDistilBertForSequenceClassification, DistilBertTokenizerFast
from transformers_gradients import (
    text_classification,
    html_heatmap,
    SmoothGradConfing,
    NoiseGradConfig,
    NoiseGradPlusPlusConfig,
    PlottingConfig,
)


tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
model = TFDistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)

tokenizer = DistilBertTokenizerFast.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)

Metal device set to: Apple M1 Pro


All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.

All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.


In [4]:
x = ["Like four times a year I rediscover Björk and listen to her full discography"]
y = [1]

In [7]:
a_batch = [
    f(model, x, y, tokenizer=tokenizer)[0]
    for f in tqdm(
        [
            text_classification.gradient_norm,
            text_classification.gradient_x_input,
            text_classification.integrated_gradients,
            partial(
                text_classification.smooth_grad,
                config=SmoothGradConfing(explain_fn="GradNorm"),
            ),
            partial(
                text_classification.noise_grad,
                config=NoiseGradConfig(explain_fn="GradNorm"),
            ),
            partial(
                text_classification.noise_grad_plus_plus,
                config=NoiseGradPlusPlusConfig(explain_fn="GradNorm"),
            ),
            partial(
                text_classification.smooth_grad,
                config=SmoothGradConfing(explain_fn="GradXInput"),
            ),
            partial(
                text_classification.noise_grad,
                config=NoiseGradConfig(explain_fn="GradXInput"),
            ),
            partial(
                text_classification.noise_grad_plus_plus,
                config=NoiseGradPlusPlusConfig(explain_fn="GradXInput"),
            ),
            text_classification.smooth_grad,
            text_classification.noise_grad,
            text_classification.noise_grad_plus_plus,
        ]
    )
]

  0%|          | 0/12 [00:00<?, ?it/s]

2023-04-25 18:04:57.968997: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [17]:
html_heatmap(
    a_batch,
    labels=[
        "Gradient Norm",
        "Gradient X Input",
        "Integrated Gradients",
        "Smooth Grad + Gradient Norm",
        "Noise Grad + Gradient Norm",
        "NoiseGrad++ + Gradient Norm",
        "Smooth Grad + Gradient X Input",
        "Noise Grad + Gradient X Input",
        "NoiseGrad++ + Gradient X Input",
        "Smooth Grad + Integrated Gradients",
        "Noise Grad + Integrated Gradients",
        "NoiseGrad++ + Integrated Gradients",
    ],
    config=PlottingConfig(
        ignore_special_tokens=True, color_mapping_strategy="row-wise"
    ),
)