In [101]:
from run import *
from tint.metrics import mse, mae
import tint, captum
from tqdm import tqdm
import pandas as pd

In [14]:
from captum.attr import (
    DeepLift,
    GradientShap,
    IntegratedGradients,
    Lime
)

from tint.attr import (
    AugmentedOcclusion,
    DynaMask,
    Occlusion, 
    Fit, FeatureAblation
)

from tint.attr import (
    TemporalAugmentedOcclusion,
    TemporalIntegratedGradients,
    TemporalOcclusion,
    TimeForwardTunnel,
)

In [237]:
getattr(tint.attr, 'DeepLift')

AttributeError: module 'tint.attr' has no attribute 'DeepLift'

In [4]:
parser = get_parser()
argv = """
  --root_path ./dataset/illness/ \
  --data_path national_illness.csv \
  --model_id ili_36_24 \
  --model Transformer \
  --data custom \
    --use_gpu
  --features MS \
  --seq_len 36 \
  --label_len 18 \
  --pred_len 24 \
  --e_layers 2 \
  --d_layers 1 \
  --factor 3 \
  --enc_in 7 \
  --dec_in 7 \
  --c_out 7 \
  --des Exp \
  --itr 1
""".split()
args = parser.parse_args(argv)

In [5]:
args.explainers = [
    "deep_lift",
    "gradient_shap",
    "integrated_gradients",
    "lime",
    "occlusion",
    "augmented_occlusion",
    "dyna_mask"
]

args.areas = [
    0.2, 0.5
]

In [97]:
explainer_name_map = {"deep_lift":DeepLift,
    "gradient_shap":GradientShap,
    "integrated_gradients":IntegratedGradients,
    "lime":Lime,
    "occlusion":Occlusion,
    # "augmented_occlusion":AugmentedOcclusion,
    "dyna_mask":DynaMask,
    "feature_ablation":FeatureAblation}

In [7]:
set_random_seed(args.seed)
# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
args.use_gpu = False

assert args.task_name == 'long_term_forecast', "Only long_term_forecast is supported for now"
Exp = Exp_Long_Term_Forecast

setting = stringify_setting(args, 0)

In [8]:
exp = Exp(args)  # set experiments
_, dataloader = exp._get_data('test')
exp.model.load_state_dict(
    torch.load(os.path.join('checkpoints/' + setting, 'checkpoint.pth'))
)
result_folder = './results/' + setting + '/'

Use CPU
test 170


<All keys matched successfully>

In [11]:
model = exp.model
model.eval()
model.zero_grad()

# only need to output targets, sinec interpretation is based on outputs
assert not exp.args.output_attention

In [267]:
def compute_attr(
    inputs, baselines, explainer,
    additional_forward_args, args
):
    assert isinstance(inputs, tuple), 'inputs not generalized for single tensor yet'
    name = explainer.get_name()
    if name in ['Deep Lift', 'Lime', 'Integrated Gradients']:
        attr_list = []
        for target in range(args.pred_len):
            score = explainer.attribute(
                inputs=inputs, baselines=baselines, target=target,
                additional_forward_args=additional_forward_args
            )
            attr_list.append(score)
        
        attr = []
        for input_index in range(len(inputs)):
            attr_per_input = torch.stack([score[input_index] for score in attr_list])
            # pred_len x batch x seq_len x features -> batch x pred_len x seq_len x features
            attr_per_input = attr_per_input.permute(1, 0, 2, 3)
            attr.append(attr_per_input)
            
        attr = tuple(attr)
        
    elif name == 'Feature Ablation':
        attr = explainer.attribute(
            inputs=inputs, baselines=baselines,
            additional_forward_args=additional_forward_args
        )
    elif name == 'Occlusion' or name=='Augmented Occlusion':
        attr = explainer.attribute(
            inputs=inputs,
            baselines=baselines,
            sliding_window_shapes = tuple([(1,1) for _ in inputs]),
            additional_forward_args=additional_forward_args
        )
    else:
        raise NotImplementedError
    
    # batch x seq_len x features
    attr = tuple([
        score.reshape(
            # batch x pred_len x seq_len x features
            (inputs[0].shape[0], args.pred_len, args.seq_len, score.shape[-1])
        # take mean over the output horizon
        ).mean(axis=1) for score in attr
    ])
    
    return attr

In [262]:
def get_explainer_by_name(name:str):
    try: explainer_class = getattr(tint.attr, name)
    except:
        try: explainer_class = getattr(captum.attr, name)
        except: raise
        
    return explainer_class

In [264]:
expl_metrics = ['mae', 'mse']
expl_metrics = [getattr(tint.metrics, metric_name) for metric_name in expl_metrics]
areas = [0.1, 0.2, 0.5]

explainers = ['deep_lift', 'feature_ablation'] # explainers = args.explainers
explainers_map = dict()
for name in explainers:
    explainers_map[name] = explainer_name_map[name](model)

In [280]:
results = []
baseline_mode = "aug" # "zeros", "aug"
result_columns = ['batch_index', 'explainer', 'metric', 'area', 'comp', 'suff']
output_file = open(os.path.join(result_folder, "batch_interpretation_results.csv"), 'w')
output_file.write(','.join(result_columns))

progress_bar = tqdm(
    enumerate(dataloader), total=len(dataloader), disable=False
)
for batch_index, (batch_x, batch_y, batch_x_mark, batch_y_mark) in progress_bar:
    batch_x = batch_x.float().to(exp.device)
    batch_y = batch_y.float().to(exp.device)

    batch_x_mark = batch_x_mark.float().to(exp.device)
    batch_y_mark = batch_y_mark.float().to(exp.device)
    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float()
    # outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    
    inputs = (batch_x, batch_x_mark)
    # baseline must be a scaler or tuple of tensors with same dimension as input
    if baseline_mode=='zeros': baselines = tuple([0]*len(inputs))
    else: baselines = tuple([
        torch.mean(input, axis=0).repeat(input.shape[0], 1, 1).float() \
            for input in inputs
    ])
    additional_forward_args = (dec_inp, batch_y_mark)

    # get attributions
    for name in explainers:
        explainer = explainers_map[name]
        attr = compute_attr(
            inputs, baselines, explainer, additional_forward_args, args
        )
    
        # get scores
        for area in areas:
            for metric_name in ['mae', 'mse']:
                metric = getattr(tint.metrics, metric_name)
                error_comp = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=True
                )
                
                error_suff = metric(
                    model, inputs=inputs, 
                    attributions=attr, baselines=baselines, 
                    additional_forward_args=additional_forward_args,
                    topk=area, mask_largest=False
                )
           
                result_row = [batch_index, name, metric_name, area, error_comp, error_suff]
                # print(result_row)
                output_file.write("\n" + ','.join([str(r) for r in result_row]))
                results.append(result_row)
        
        output_file.flush()
    # break
output_file.close()

6it [10:02, 100.44s/it]


In [281]:
results_df = pd.DataFrame(results, columns=result_columns)
results_df = results_df.groupby(['explainer', 'metric', 'area'])[['comp', 'suff']].aggregate('mean').reset_index()
results_df.round(6).to_csv(os.path.join(result_folder, 'interpretation_results.csv'), index=False)
print(results_df)

Unnamed: 0,batch_index,explainer,metric,area,comp,suff
0,0,deep_lift,mae,0.1,3.041492,4.536799
1,0,deep_lift,mse,0.1,0.424379,0.956420
2,0,deep_lift,mae,0.2,4.301575,5.090987
3,0,deep_lift,mse,0.2,0.844737,1.179293
4,0,deep_lift,mae,0.5,4.672506,4.955289
...,...,...,...,...,...,...
67,5,feature_ablation,mse,0.1,0.069005,0.139158
68,5,feature_ablation,mae,0.2,1.515319,1.857422
69,5,feature_ablation,mse,0.2,0.115939,0.162400
70,5,feature_ablation,mae,0.5,1.672927,2.036666


In [144]:
attr = compute_attr(
    inputs, baselines, explainer, additional_forward_args, args
)

4.333446979522705

In [162]:
explainer = IntegratedGradients(model)

attr = []
for target in tqdm(range(args.pred_len)):
    score = explainer.attribute(
        inputs=inputs, baselines=baselines, target=target,
        additional_forward_args=additional_forward_args
    )
    attr.append(score)

100%|██████████| 24/24 [05:13<00:00, 13.05s/it]


In [213]:
explainer = FeatureAblation(model)
scores = explainer.attribute(
    inputs=inputs, baselines=baselines, 
    additional_forward_args=additional_forward_args
)

In [116]:
mae_error = mae(
    model, inputs=inputs, 
    attributions=attr, baselines=(0,0), 
    additional_forward_args=(dec_inp, batch_y_mark),
    topk=0.2
)

In [29]:
# mse_error = mse(
#     model, inputs=batch_x, 
#     attributions=temp, baselines=0, 
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     topk=0.2
# )
# print(mse_error)

# mae_error = mae(
#     model, inputs=batch_x, 
#     attributions=temp, baselines=0, 
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     # target=0,
#     topk=0.2
# )
# print(mae_error)

In [30]:
# temporal_mask = torch.zeros_like(batch_x, dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t] = t

# explainer = FeatureAblation(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     target=0,
#     feature_mask=temporal_mask
# )
# print(score.shape)

In [31]:
# from tint.attr import Occlusion

# temporal_mask = torch.zeros(size=(1, *batch_x.shape[1:]), dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t, :] = t

# explainer = Occlusion(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     sliding_window_shapes=(1, 1)
#     # feature_mask=temporal_mask.to(exp.device)
# )
# print(time_score.shape)

In [32]:
# from typing import Any

# class Unify_Output:
#     def __init__(self, model, agg='mean') -> None:
#         self.model = model
#         self.agg = agg
#         assert self.agg in ['mean', 'sum'], 'Aggregation must be either mean or sum'
        
#     def __call__(self, *args: Any, **kwds: Any) -> Any:
#         outputs = model(*args, **kwds)
#         # batch size x target
#         if self.agg == 'mean':
#             return torch.mean(outputs, axis=1)
#         else:
#             return torch.sum(outputs, axis=1)