# HMM Experiment example

In this example, we will use the HMM dataset provided by this package, get some attributions,
and evaluate them.

### Imports

In [4]:
import torch as th

from pytorch_lightning import Trainer, seed_everything

from tint.attr import TemporalAugmentedOcclusion, TimeForwardTunnel
from tint.datasets import HMM
from tint.metrics.white_box import aur

from main import main
from classifier import StateClassifierNet

In [5]:
m = main(['lime'], device='cuda')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


loaded


Lime attribution:   0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
m

In [None]:
m.detach().cpu().numpy().dtype

In [None]:
hmm = HMM(n_folds=5, fold=0, seed=42)
true_saliency = hmm.true_saliency(split="test")

In [None]:
information(m, true_saliency)

### Make reproducible experiment 

For this example, we will make everything reproducible. With this aim, we use the 
tool from Pytorch-Lightning: seed_everything.

In [None]:
seed = 42
seed_everything(seed=seed, workers=True)

### Data loading

We load the HMM (Hidden Markov Model) dataset, and eventually download it (since arma is a synthetic dataset, 
download actually generates the data).

In [None]:
hmm = HMM(seed=seed)

### Create and train a simple classifier

We will now train a simple classifier (a GRU followed by a MLP) over the HMM dataset.
This dataset provides indeed labels, generated given the hidden states of the HMM.

We use the Pytorch-Lightning framework to efficiently train this model. Here, the 
accelerator is set to ``cpu``, but feel free to change this to ``gpu``.

In [None]:
accelerator = "gpu"

In [None]:
classifier = StateClassifierNet(
    feature_size=3,
    n_state=2,
    hidden_size=200,
    regres=True,
    loss="cross_entropy",
    lr=0.0001,
    l2=1e-3,
)

In [None]:
th.use_deterministic_algorithms(True)
trainer = Trainer(
    max_epochs=50, accelerator=accelerator, deterministic=True
)
trainer.fit(classifier, datamodule=hmm)

### Get train and test data

We only compute the attributions over the test set. However, we will also need the train data here.

In [None]:
x_train = hmm.preprocess(split="train")["x"].to(accelerator)
x_test = hmm.preprocess(split="test")["x"].to(accelerator)

### Utils

We set the classifer to the ``evaluation`` mode, and we push it to the current accelerator.

We also disable ``cudnn`` when using ``cuda``, as it cannot backpropagate when set on evaluation.
Please refer to https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network 
for more information.

In [None]:
# Switch to eval
classifier.eval()

# Set model to accelerator
classifier.to(accelerator)

if accelerator == "cuda":
    th.backends.cudnn.enabled = False

### Create attributions using temporal augmented occlusion

In this example, we will use ``temporal_augmented_occlusion`` as an attribution method, first presented 
in this paper: https://arxiv.org/abs/2003.02821. This method hides some data like the ``Occlusion`` method,
however, instead of replacing the hidden data with a baseline, it samples this baseline
from a bootstrapped distribution. Moreover, unlike the regular ``augmented_occlusion``, this method only 
hides data from the last time, leaving the past data unchanged.

We also use a special tool: ``TimeForwardTunnel``. This method allows us to compute attributions
at each different time using only the past as information. The ``TimeForwardTunnel`` then loops over 
every time to compute every attributions.

In [None]:
explainer = TimeForwardTunnel(
    TemporalAugmentedOcclusion(
        classifier, data=x_train, n_sampling=10, is_temporal=True
    )
)

attr = explainer.attribute(
    x_test,
    sliding_window_shapes=(1,),
    attributions_fn=abs,
    task="binary",
    show_progress=True,
).abs()

### Attributions evaluation

Since we know the true attributions, we can evaluate our computed attributions 
using our white-box metrics. For instance, we compute here the ``aur`` (area under recall):

In [None]:
# Get true saliency
true_saliency = hmm.true_saliency(split="test").to(accelerator)

In [None]:
print(f"{aur(attr, true_saliency):.4}")

This is slightly better than the results reported in https://arxiv.org/pdf/2106.05303.

There are however better methods than temporal_augmented_occlusion for this task. For more details, 
please refer to our ``experiments/hmm`` section.

In [None]:
import multiprocessing as mp
import numpy as np
import random
import torch as th
import torch.nn as nn
import os

from argparse import ArgumentParser
from captum.attr import DeepLift, GradientShap, IntegratedGradients, Lime
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from typing import List
from classifier import StateClassifierNet

from tint.attr import (
    DynaMask,
    ExtremalMask,
    Fit,
    Retain,
    TemporalAugmentedOcclusion,
    TemporalOcclusion,
    TimeForwardTunnel,
)
from tint.attr.models import (
    ExtremalMaskNet,
    JointFeatureGeneratorNet,
    MaskNet,
    RetainNet,
)
from tint.datasets import HMM
from tint.metrics.white_box import (
    aup,
    aur,
    information,
    entropy,
    roc_auc,
    auprc,
)
from tint.models import MLP, RNN

In [None]:
def get_model(check_name, trainer, model, data_module, seed, rerun_all=False):
    checkpoint=str(seed)+"_"+check_name+'.ckpt'
    if os.path.exists(checkpoint) and not rerun_all:
        model.load_state_dict(th.load(checkpoint))
    else:
       trainer.fit(model, datamodule=data_module)
       th.save(model.state_dict(), checkpoint)
    return model

def main(
    explainers: List[str],
    device: str = "cpu",
    fold: int = 0,
    seed: int = 42,
    deterministic: bool = False,
    lambda_1: float = 1.0,
    lambda_2: float = 1.0,
    output_file: str = "results.csv",
    rerun_all=False
):
    # If deterministic, seed everything
    if deterministic:
        seed_everything(seed=seed, workers=True)

    # Get accelerator and device
    accelerator = device.split(":")[0]
    device_id = 1
    if len(device.split(":")) > 1:
        device_id = [int(device.split(":")[1])]

    # Create lock
    lock = mp.Lock()

    # Load data
    hmm = HMM(n_folds=5, fold=fold, seed=seed)

    # Create classifier
    classifier = StateClassifierNet(
        feature_size=3,
        n_state=2,
        hidden_size=200,
        regres=True,
        loss="cross_entropy",
        lr=0.0001,
        l2=1e-3,
    )

    # Train classifier
    trainer = Trainer(
        max_epochs=500,
        # max_epochs=3,
        accelerator=accelerator,
        devices=device_id,
        deterministic=deterministic
    )
    # trainer.fit(classifier, datamodule=hmm)
    classifier=get_model(check_name="classifier",trainer=trainer, model=classifier, data_module=hmm, seed=seed, rerun_all=rerun_all)
    # Get data for explainers
    with lock:
        x_train = hmm.preprocess(split="train")["x"].to(device)
        x_test = hmm.preprocess(split="test")["x"].to(device)
        y_test = hmm.preprocess(split="test")["y"].to(device)
        true_saliency = hmm.true_saliency(split="test").to(device)

    # Switch to eval
    classifier.eval()

    # Set model to device
    classifier.to(device)

    # Disable cudnn if using cuda accelerator.
    # Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
    # for more information.
    if accelerator == "cuda":
        th.backends.cudnn.enabled = False

    # Create dict of attributions
    attr = dict()

    if "deep_lift" in explainers:
        explainer = TimeForwardTunnel(DeepLift(classifier))
        attr["deep_lift"] = explainer.attribute(
            x_test,
            baselines=x_test * 0,
            task="binary",
            show_progress=True,
        ).abs()

    if "dyna_mask" in explainers:
        trainer = Trainer(
            max_epochs=1000,
            accelerator=accelerator,
            devices=device_id,
            log_every_n_steps=2,
            deterministic=deterministic,
        )
        mask = MaskNet(
            forward_func=classifier,
            perturbation="gaussian_blur",
            sigma_max=1,
            keep_ratio=list(np.arange(0.25, 0.35, 0.01)),
            size_reg_factor_init=0.1,
            size_reg_factor_dilation=100,
            time_reg_factor=1.0,
        )
        explainer = DynaMask(classifier)
        _attr = explainer.attribute(
            x_test,
            additional_forward_args=(True,),
            trainer=trainer,
            mask_net=mask,
            batch_size=100,
            return_best_ratio=True,
        )
        print(f"Best keep ratio is {_attr[1]}")
        attr["dyna_mask"] = _attr[0].to(device)

    if "extremal_mask" in explainers:
        trainer = Trainer(
            max_epochs=500,
            # max_epochs=3,
            accelerator=accelerator,
            devices=device_id,
            log_every_n_steps=2,
            deterministic=deterministic,
        )
        mask = ExtremalMaskNet(
            forward_func=classifier,
            model=nn.Sequential(
                RNN(
                    input_size=x_test.shape[-1],
                    rnn="gru",
                    hidden_size=x_test.shape[-1],
                    bidirectional=True,
                ),
                MLP([2 * x_test.shape[-1], x_test.shape[-1]]),
            ),
            lambda_1=lambda_1,
            lambda_2=lambda_2,
            optim="adam",
            lr=0.01,
        )
        explainer = ExtremalMask(classifier)
        _attr = explainer.attribute(
            x_test,
            additional_forward_args=(True,),
            trainer=trainer,
            mask_net=mask,
            batch_size=100,
        )
        attr["extremal_mask"] = _attr.to(device)

    if "fit" in explainers:
        generator = JointFeatureGeneratorNet(rnn_hidden_size=6)
        trainer = Trainer(
            max_epochs=300,
            accelerator=accelerator,
            devices=device_id,
            log_every_n_steps=10,
            deterministic=deterministic,
        )
        explainer = Fit(
            classifier,
            generator=generator,
            datamodule=hmm,
            trainer=trainer,
        )
        attr["fit"] = explainer.attribute(x_test, show_progress=True)

    if "gradient_shap" in explainers:
        explainer = TimeForwardTunnel(GradientShap(classifier.cpu()))
        attr["gradient_shap"] = explainer.attribute(
            x_test.cpu(),
            baselines=th.cat([x_test.cpu() * 0, x_test.cpu()]),
            n_samples=50,
            stdevs=0.0001,
            task="binary",
            show_progress=True,
        ).abs()
        classifier.to(device)

    if "integrated_gradients" in explainers:
        explainer = TimeForwardTunnel(IntegratedGradients(classifier))
        attr["integrated_gradients"] = explainer.attribute(
            x_test,
            baselines=x_test * 0,
            internal_batch_size=200,
            task="binary",
            show_progress=True,
        ).abs()

    if "lime" in explainers:
        explainer = TimeForwardTunnel(Lime(classifier))
        attr["lime"] = explainer.attribute(
            x_test,
            task="binary",
            show_progress=True,
        ).abs()

    if "augmented_occlusion" in explainers:
        explainer = TimeForwardTunnel(
            TemporalAugmentedOcclusion(
                classifier, data=x_train, n_sampling=10, is_temporal=True
            )
        )
        attr["augmented_occlusion"] = explainer.attribute(
            x_test,
            sliding_window_shapes=(1,),
            attributions_fn=abs,
            task="binary",
            show_progress=True,
        ).abs()

    if "occlusion" in explainers:
        explainer = TimeForwardTunnel(TemporalOcclusion(classifier))
        attr["occlusion"] = explainer.attribute(
            x_test,
            sliding_window_shapes=(1,),
            baselines=x_train.mean(0, keepdim=True),
            attributions_fn=abs,
            task="binary",
            show_progress=True,
        ).abs()

    if "retain" in explainers:
        retain = RetainNet(
            dim_emb=128,
            dropout_emb=0.4,
            dim_alpha=8,
            dim_beta=8,
            dropout_context=0.4,
            dim_output=2,
            loss="cross_entropy",
        )
        explainer = Retain(
            datamodule=hmm,
            retain=retain,
            trainer=Trainer(
                max_epochs=50,
                accelerator=accelerator,
                devices=device_id,
                deterministic=deterministic
            ),
        )
        attr["retain"] = (
            explainer.attribute(x_test, target=y_test).abs().to(device)
        )

    with open(output_file, "a") as fp, lock:
        for k, v in attr.items():
            fp.write(str(seed) + ",")
            fp.write(str(fold) + ",")
            fp.write(k + ",")
            fp.write(str(lambda_1) + ",")
            fp.write(str(lambda_2) + ",")
            fp.write(f"{aup(v, true_saliency):.4},")
            fp.write(f"{aur(v, true_saliency):.4},")
            fp.write(f"{information(v, true_saliency):.4},")
            fp.write(f"{entropy(v, true_saliency):.4},")
            fp.write(f"{roc_auc(v, true_saliency):.4},")
            fp.write(f"{auprc(v, true_saliency):.4}")
            fp.write("\n")

In [None]:
main(["extremal_mask"], 'cuda')