# HMM Experiment example

In this example, we will use the HMM dataset provided by this package, get some attributions,
and evaluate them.

### Imports

In [1]:
import torch as th

from pytorch_lightning import Trainer, seed_everything

from tint.attr import TemporalAugmentedOcclusion, TimeForwardTunnel
from tint.datasets import HMM
from tint.metrics.white_box import aur

from experiments.hmm.classifier import StateClassifierNet



### Make reproducible experiment 

For this example, we will make everything reproducible. With this aim, we use the 
tool from Pytorch-Lightning: seed_everything.

In [2]:
seed = 42
seed_everything(seed=seed, workers=True)

Global seed set to 42


42

### Data loading

We load the HMM (Hidden Markov Model) dataset, and eventually download it (since arma is a synthetic dataset, 
download actually generates the data).

In [3]:
hmm = HMM(seed=seed)

### Create and train a simple classifier

We will now train a simple classifier (a GRU followed by a MLP) over the HMM dataset.
This dataset provides indeed labels, generated given the hidden states of the HMM.

We use the Pytorch-Lightning framework to efficiently train this model. Here, the 
accelerator is set to ``cpu``, but feel free to change this to ``gpu``.

In [4]:
accelerator = "cpu"

In [5]:
classifier = StateClassifierNet(
    feature_size=3,
    n_state=2,
    hidden_size=200,
    regres=True,
    loss="cross_entropy",
    lr=0.0001,
    l2=1e-3,
)

In [6]:
trainer = Trainer(
    max_epochs=50, accelerator=accelerator, deterministic=True
)
trainer.fit(classifier, datamodule=hmm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name        | Type             | Params
--------------------------------------------------
0  | net         | StateClassifier  | 123 K 
1  | _loss       | CrossEntropyLoss | 0     
2  | train_acc   | Accuracy         | 0     
3  | train_pre   | Precision        | 0     
4  | train_rec   | Recall           | 0     
5  | train_auroc | AUROC            | 0     
6  | val_acc     | Accuracy         | 0     
7  | val_pre     | Precision        | 0     
8  | val_rec     | Recall           | 0     
9  | val_auroc   | AUROC            | 0     
10 | test_acc    | Accuracy         | 0     
11 | test_pre    | Precision        | 0     
12 | test_rec    | Recall           | 0     
13 | test_auroc  | AUROC            | 0     
--------------------------------------------------
123 K     Trainable params
0         Non-trainable params
123 K     To

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


### Get train and test data

We only compute the attributions over the test set. However, we will also need the train data here.

In [7]:
x_train = hmm.preprocess(split="train")["x"].to(accelerator)
x_test = hmm.preprocess(split="test")["x"].to(accelerator)

### Utils

We set the classifer to the ``evaluation`` mode, and we push it to the current accelerator.

We also disable ``cudnn`` when using ``cuda``, as it cannot backpropagate when set on evaluation.
Please refer to https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network 
for more information.

In [8]:
# Switch to eval
classifier.eval()

# Set model to accelerator
classifier.to(accelerator)

if accelerator == "cuda":
    th.backends.cudnn.enabled = False

### Create attributions using temporal augmented occlusion

In this example, we will use ``temporal_augmented_occlusion`` as an attribution method, first presented 
in this paper: https://arxiv.org/abs/2003.02821. This method hides some data like the ``Occlusion`` method,
however, instead of replacing the hidden data with a baseline, it samples this baseline
from a bootstrapped distribution. Moreover, unlike the regular ``augmented_occlusion``, this method only 
hides data from the last time, leaving the past data unchanged.

We also use a special tool: ``TimeForwardTunnel``. This method allows us to compute attributions
at each different time using only the past as information. The ``TimeForwardTunnel`` then loops over 
every time to compute every attributions.

In [9]:
explainer = TimeForwardTunnel(
    TemporalAugmentedOcclusion(
        classifier, data=x_train, n_sampling=10, is_temporal=True
    )
)

attr = explainer.attribute(
    x_test,
    sliding_window_shapes=(1,),
    attributions_fn=abs,
    task="binary",
    show_progress=True,
).abs()

Temporal Augmented Occlusion attribution:   0%|          | 0/20 [00:00<?, ?it/s]

### Attributions evaluation

Since we know the true attributions, we can evaluate our computed attributions 
using our white-box metrics. For instance, we compute here the ``aur`` (area under recall):

In [10]:
# Get true saliency
true_saliency = hmm.true_saliency(split="test").to(accelerator)

In [11]:
print(f"{aur(attr, true_saliency):.4}")

0.4847


This is slightly better than the results reported in https://arxiv.org/pdf/2106.05303.

There are however better methods than temporal_augmented_occlusion for this task. For more details, 
please refer to our ``experiments/hmm`` section.