# Import Library

In [1]:
from run import *
from tint.metrics import mse, mae
import tint, gc
from tqdm import tqdm
import pandas as pd
from utils.explainer import *
from utils.tsr_tunnel import TSRTunnel
from exp.exp_interpret import *

In [2]:
from captum.attr import (
    DeepLift,
    GradientShap,
    IntegratedGradients,
    Lime
)

from tint.attr import (
    AugmentedOcclusion,
    DynaMask, Fit,
    Occlusion, 
    FeatureAblation
)

# Argument Parser

In [3]:
parser = get_parser()
argv = """
  --task_name long_term_forecast \
  --use_gpu \
  --result_path scratch \
  --root_path ./dataset/illness/ \
  --data_path national_illness.csv \
  --model LSTM \
  --features MS \
  --seq_len 36 \
  --label_len 12 \
  --pred_len 24 \
  --n_features 7
""".split()
args = parser.parse_args(argv)

# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
# args.use_gpu = False
initial_setup(args)

# Initialize Experiment

In [4]:
if args.task_name == 'classification': Exp = Exp_Classification
else: Exp = Exp_Long_Term_Forecast
exp = Exp(args)  # set experiments
_, dataloader = exp._get_data(args.flag)

exp.load_best_model()

Use GPU: cuda:0
Experiments will be saved in scratch\national_illness_LSTM
test 73
Loading model from scratch\national_illness_LSTM\checkpoint.pth


# Evaluation

## Model

In [5]:
model = exp.model
model.eval()

# only need to output targets, sinec interpretation is based on outputs
assert not exp.args.output_attention

## Evaluate

In [10]:
expl_metric_map = {
    'mae': mae, 'mse': mse
}
areas = [0.03, 0.05, 0.1, 0.2]

explainers = ['deep_lift', 'feature_ablation'] # explainers = args.explainers
explainers_map = dict()
for name in explainers:
    explainers_map[name] = explainer_name_map[name](model)

In [15]:
results = []
# "zero", "aug", "random"
# performance order random > zero > aug
baseline_mode = "random" 

result_columns = ['batch_index', 'explainer', 'metric', 'area', 'comp', 'suff']

progress_bar = tqdm(
    enumerate(dataloader), total=len(dataloader), disable=False
)
for batch_index, (batch_x, batch_y, batch_x_mark, batch_y_mark) in progress_bar:
    batch_x = batch_x.float().to(exp.device)
    batch_y = batch_y.float().to(exp.device)

    batch_x_mark = batch_x_mark.float().to(exp.device)
    batch_y_mark = batch_y_mark.float().to(exp.device)
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float()
    # outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    
    # baseline must be a scaler or tuple of tensors with same dimension as input
    baselines = get_baseline(batch_x, mode=baseline_mode)
    additional_forward_args = (batch_x_mark, dec_inp, batch_y_mark)

    # get attributions
    for name in explainers:
        explainer = explainers_map[name]
        attr = compute_regressor_attr(
            batch_x, baselines, explainer, additional_forward_args, args
        )
    
        # get scores
        for area in areas:
            for metric_name, metric in expl_metric_map.items():
                error_comp = metric(
                    model, inputs=batch_x, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=True
                )
                
                error_suff = metric(
                    model, inputs=batch_x, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=False
                )
           
                result_row = [batch_index, name, metric_name, area, error_comp, error_suff]
                results.append(result_row)

100%|██████████| 5/5 [00:15<00:00,  3.20s/it]


# Others

In [6]:
batch_x, batch_y, batch_x_mark, batch_y_mark = next(iter(dataloader))
batch_x = batch_x.float().to(exp.device)
batch_y = batch_y.float().to(exp.device)

batch_x_mark = batch_x_mark.float().to(exp.device)
batch_y_mark = batch_y_mark.float().to(exp.device)
# decoder input
dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float()

In [7]:
total_data = get_total_data(dataloader, exp.device)
inputs = (batch_x, batch_x_mark)
additional_forward_args = (dec_inp, batch_y_mark)

In [30]:
inputs = (batch_x, batch_x_mark)
additional_forward_args = (dec_inp, batch_y_mark)
explainer = FeatureAblation(exp.model)

In [33]:
attr = explainer.attribute(
    inputs=inputs, baselines=(0,0),
    additional_forward_args=additional_forward_args
)

In [None]:
# temporal_mask = torch.zeros_like(batch_x, dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t] = t

# explainer = FeatureAblation(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     target=0,
#     feature_mask=temporal_mask
# )
# print(score.shape)