# Import Library

In [1]:
from run import *
from tint.metrics import mse, mae
import tint, gc
from tqdm import tqdm
import pandas as pd
from utils.explainer import *
from exp.exp_interpret import *

In [2]:
from captum.attr import (
    DeepLift,
    GradientShap,
    IntegratedGradients,
    Lime
)

from tint.attr import (
    AugmentedOcclusion,
    DynaMask,
    Occlusion, 
    FeatureAblation
)

# Argument Parser

In [3]:
parser = get_parser()
argv = """
  --task_name long_term_forecast \
  --use_gpu \
  --root_path ./dataset/illness/ \
  --data_path national_illness.csv \
  --model Autoformer \
  --features MS \
  --seq_len 36 \
  --label_len 12 \
  --pred_len 24 \
  --enc_in 7 \
  --dec_in 7 \
  --c_out 7 --batch_size 16
""".split()
args = parser.parse_args(argv)

In [5]:
explainer_name_map = {
    "deep_lift":DeepLift,
    "gradient_shap":GradientShap,
    "integrated_gradients":IntegratedGradients,
    "lime":Lime,
    "occlusion":Occlusion,
    # "augmented_occlusion":AugmentedOcclusion,
    "dyna_mask":DynaMask,
    "feature_ablation":FeatureAblation
}

# Initialize Experiment

In [6]:
set_random_seed(args.seed)
# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
# args.use_gpu = False

assert args.task_name == 'long_term_forecast', "Only long_term_forecast is supported for now"
Exp = Exp_Long_Term_Forecast

setting = stringify_setting(args)

In [7]:
exp = Exp(args)  # set experiments
exp.model.load_state_dict(
    torch.load(os.path.join('checkpoints/' + setting, 'checkpoint.pth'))
)
result_folder = './results/' + setting + '/'

Use GPU: cuda:0


# Evaluation

## Model

In [8]:
model = exp.model
model.eval()
model.zero_grad()

# only need to output targets, sinec interpretation is based on outputs
assert not exp.args.output_attention

## Evaluate

In [9]:
flag = 'test'
_, dataloader = exp._get_data(flag)

test 73


In [10]:
expl_metric_map = {
    'mae': mae, 'mse': mse
}
areas = [0.03, 0.05, 0.1, 0.2]

explainers = ['deep_lift', 'feature_ablation'] # explainers = args.explainers
explainers_map = dict()
for name in explainers:
    explainers_map[name] = explainer_name_map[name](model)

In [15]:
results = []
# "zero", "aug", "random"
# performance order random > zero > aug
baseline_mode = "random" 

result_columns = ['batch_index', 'explainer', 'metric', 'area', 'comp', 'suff']

progress_bar = tqdm(
    enumerate(dataloader), total=len(dataloader), disable=False
)
for batch_index, (batch_x, batch_y, batch_x_mark, batch_y_mark) in progress_bar:
    batch_x = batch_x.float().to(exp.device)
    batch_y = batch_y.float().to(exp.device)

    batch_x_mark = batch_x_mark.float().to(exp.device)
    batch_y_mark = batch_y_mark.float().to(exp.device)
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float()
    # outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    
    inputs = batch_x
    # baseline must be a scaler or tuple of tensors with same dimension as input
    if baseline_mode=='zero': baselines = torch.zeros_like(inputs)
    elif baseline_mode == 'random': baselines = torch.randn_like(inputs)
    else: baselines = torch.mean(inputs, axis=0).repeat(inputs.shape[0], 1, 1).float()
    additional_forward_args = (batch_x_mark, dec_inp, batch_y_mark)

    # get attributions
    for name in explainers:
        explainer = explainers_map[name]
        attr = compute_attr(
            inputs, baselines, explainer, additional_forward_args, args
        )
    
        # get scores
        for area in areas:
            for metric_name, metric in expl_metric_map.items():
                error_comp = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=True
                )
                
                error_suff = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=False
                )
           
                result_row = [batch_index, name, metric_name, area, error_comp, error_suff]
                results.append(result_row)

100%|██████████| 5/5 [00:15<00:00,  3.20s/it]


## Output

In [16]:
results_df = pd.DataFrame(results, columns=result_columns)
results_df = results_df.groupby(['explainer', 'metric', 'area'])[['comp', 'suff']].aggregate('mean').reset_index()
# results_df.round(6).to_csv(os.path.join(result_folder, 'interpretation_results.csv'), index=False)
print(results_df)

           explainer metric  area        comp        suff
0          deep_lift    mae  0.03   22.387512   47.552436
1          deep_lift    mae  0.05   35.472929   34.969124
2          deep_lift    mae  0.10   62.132083   12.359752
3          deep_lift    mae  0.20   80.801300   15.408948
4          deep_lift    mse  0.03   22.275765  100.830368
5          deep_lift    mse  0.05   54.752081   56.914517
6          deep_lift    mse  0.10  165.784731    9.525279
7          deep_lift    mse  0.20  286.879962   17.974950
8   feature_ablation    mae  0.03   22.090495   48.157625
9   feature_ablation    mae  0.05   34.997378   35.562492
10  feature_ablation    mae  0.10   61.651758   12.793189
11  feature_ablation    mae  0.20   79.969749   13.870515
12  feature_ablation    mse  0.03   21.877092  102.763078
13  feature_ablation    mse  0.05   53.530500   58.256120
14  feature_ablation    mse  0.10  163.161661   10.077021
15  feature_ablation    mse  0.20  279.247916   15.032560


# TSR

In [20]:
def compute_tsr_attr(
    inputs, baselines, explainer, additional_forward_args, args, device, masking='point'
):
    actual_attr = compute_attr(inputs, baselines, explainer, additional_forward_args, args)
    # batch x seq_len
    time_attr = torch.zeros((inputs.shape[0], args.seq_len), dtype=float, device=device)
    new_inputs = inputs.clone() 
    
    assignment = torch.randn_like(new_inputs)
    for t in range(args.seq_len):
        if masking == 'point':
            prev_value = new_inputs[:, t]
            # batch x seq_len x features
            new_inputs[:, t] = assignment[:, t] # assignment # inputs[0, 0, -1] # test with new_inputs[:, :t+1] and other masking
        else:
            prev_value = new_inputs[:, :t]
            # batch x seq_len x features
            new_inputs[:, :t] = assignment[:, :t]

        new_attr_per_time = compute_attr(
            new_inputs, baselines, explainer, 
            additional_forward_args, args
        )
        
        # sum the attr difference for each input in the batch
        # batch x seq_len x features -> batch
        time_attr[:, t] = (actual_attr - new_attr_per_time
            ).abs().sum(axis=(1, 2))
        
        if masking == 'point':
            new_inputs[:, t] = prev_value
        else:
            new_inputs[:, :t] = prev_value
    
    # for each input in the batch, normalize along the time axis
    time_attr = min_max_scale(time_attr, dim=1)
    
    # new_attr = (time_attr.T * actual_attr.T).T
    # return new_attr

    # find median along the time axis
    # mean_time_importance = np.quantile(time_attr, .55, axis=1)   
    
    n_features = inputs.shape[-1]
    input_attr = torch.zeros((inputs.shape[0], n_features), dtype=float, device=device)
    time_scaled_attr = torch.zeros_like(actual_attr)

    assignment = torch.randn((inputs.shape[0],inputs.shape[1]), dtype=float)
    for t in range(args.seq_len):
        # if time_attr[t] < mean_time_importance:
        #     featureContibution = torch.ones(input_attr, dtype=float)/n_features
        for f in range(n_features):
             # batch x seq_len x features
            if masking == 'point':
                prev_value = new_inputs[:, t, f]
                new_inputs[:, t, f] = assignment[:, t] # inputs[0, 0, f] # assignment[:, f] # inputs[0, 0, f]
            else:
                prev_value = new_inputs[:, :t, f]
                new_inputs[:, :t, f] = assignment[:, :t]
            
            attr = compute_attr(
                new_inputs, baselines, explainer, 
                additional_forward_args, args
            )
            input_attr[:, f] = (actual_attr - attr).abs().sum(axis=(1, 2))
            if masking == 'point':
                new_inputs[:, t, f] = prev_value
            else:
                new_inputs[:, :t, f] = prev_value
        
        input_attr = min_max_scale(input_attr, dim=1)
        
        for f in range(n_features):
            time_scaled_attr[:, t, f] = time_attr[:, t] * input_attr[:, f]
            
    return time_scaled_attr

In [21]:
results = []
baseline_mode = 'random'
result_columns = ['batch_index', 'explainer', 'metric', 'area', 'comp', 'suff']

progress_bar = tqdm(
    enumerate(dataloader), total=len(dataloader), disable=False
)
for batch_index, (batch_x, batch_y, batch_x_mark, batch_y_mark) in progress_bar:
    batch_x = batch_x.float().to(exp.device)
    batch_y = batch_y.float().to(exp.device)

    batch_x_mark = batch_x_mark.float().to(exp.device)
    batch_y_mark = batch_y_mark.float().to(exp.device)
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float()
    # outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    
    inputs = batch_x
    # baseline must be a scaler or tuple of tensors with same dimension as input
    if baseline_mode=='zeros': baselines = torch.zeros_like(inputs)
    elif baseline_mode == 'random': baselines = torch.randn_like(inputs)
    else: baselines = torch.mean(inputs, axis=0).repeat(inputs.shape[0], 1, 1).float()
    additional_forward_args = (batch_x_mark, dec_inp, batch_y_mark)

    # get attributions
    for name in explainers:
        explainer = explainers_map[name]
        attr = compute_tsr_attr(
            inputs, baselines, explainer, additional_forward_args, args, exp.device
        )
    
        # get scores
        for area in areas:
            for metric_name, metric in expl_metric_map.items():
                error_comp = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=True
                )
                
                error_suff = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=False
                )
           
                result_row = [batch_index, name, metric_name, area, error_comp, error_suff]
                results.append(result_row)
    gc.collect()

  0%|          | 0/5 [00:00<?, ?it/s]

In [None]:
tsr_results_df = pd.DataFrame(results, columns=result_columns)
tsr_results_df = tsr_results_df.groupby(['explainer', 'metric', 'area'])[['comp', 'suff']].aggregate('mean').reset_index()
tsr_results_df.round(6).to_csv(os.path.join(result_folder, 'interpretation_results_tsr.csv'), index=False)
print(tsr_results_df)

# Others

In [None]:
# explainer = Occlusion(model)
# scores = explainer.attribute(
#     inputs=inputs, baselines=baselines, sliding_window_shapes = (1,1),
#     additional_forward_args=additional_forward_args
# )
# mae_error = mae(
#     model, inputs=inputs, 
#     attributions=attr, baselines=baselines, 
#     additional_forward_args=additional_forward_args,
#     topk=0.2
# )

In [None]:
# temporal_mask = torch.zeros_like(batch_x, dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t] = t

# explainer = FeatureAblation(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     target=0,
#     feature_mask=temporal_mask
# )
# print(score.shape)

In [None]:
# from tint.attr import Occlusion

# temporal_mask = torch.zeros(size=(1, *batch_x.shape[1:]), dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t, :] = t

# explainer = Occlusion(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     sliding_window_shapes=(1, 1)
#     # feature_mask=temporal_mask.to(exp.device)
# )
# print(time_score.shape)