In [16]:
from run import *
from tint.attr import FeatureAblation, Occlusion, Fit
from tint.metrics import mse, mae

In [17]:
parser = get_parser()
argv = """
  --root_path ./dataset/illness/ \
  --data_path national_illness.csv \
  --model_id ili_36_24 \
  --model Transformer \
  --data custom \
    --use_gpu
  --features MS \
  --seq_len 36 \
  --label_len 18 \
  --pred_len 24 \
  --e_layers 2 \
  --d_layers 1 \
  --factor 3 \
  --enc_in 7 \
  --dec_in 7 \
  --c_out 7 \
  --des Exp \
  --itr 1
""".split()
args = parser.parse_args(argv)

In [18]:
set_random_seed(args.seed)
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ', '')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]
    
if args.task_name == 'classification':
    Exp = Exp_Classification
else:
    Exp = Exp_Long_Term_Forecast

In [19]:
setting = stringify_setting(args, 0)

In [20]:
exp = Exp(args)  # set experiments
_, dataloader = exp._get_data('test')
exp.model.load_state_dict(
    torch.load(os.path.join('checkpoints/' + setting, 'checkpoint.pth'))
)

Use GPU: cuda:0
test 170


<All keys matched successfully>

In [21]:
model = exp.model
model.eval()
model.zero_grad()

# only need to output targets, sinec interpretation is based on outputs
assert not exp.args.output_attention

explainer = FeatureAblation(model)

In [35]:
from tqdm import tqdm

results = {
    'mae':[], 'mse':[]
}

for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in tqdm(enumerate(dataloader)):
    batch_x = batch_x.float().to(exp.device)
    batch_y = batch_y.float().to(exp.device)

    batch_x_mark = batch_x_mark.float().to(exp.device)
    batch_y_mark = batch_y_mark.float().to(exp.device)

    # decoder input
    dec_inp = torch.zeros_like(batch_y[:, -exp.args.pred_len:, :]).float()
    dec_inp = torch.cat([batch_y[:, :exp.args.label_len, :], dec_inp], dim=1).float().to(exp.device)
    
    mean_score = None
    for target in tqdm(range(args.pred_len)):
        score = explainer.attribute(
            inputs=(batch_x),
            baselines=0,
            target=target,
            additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark)
        )
        if target==0: mean_score = score
        else: mean_score += score
    mean_score /= args.pred_len
            
    # temp = score.reshape(
    #     (batch_x.shape[0], args.pred_len, args.seq_len, -1)
    # ).mean(axis=1).float().to(exp.device)
    
    mae_error = mae(
        model, inputs=batch_x, 
        attributions=mean_score, baselines=0, 
        additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
        topk=0.2
    )
    mse_error = mse(
        model, inputs=batch_x, 
        attributions=mean_score, baselines=0, 
        additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
        topk=0.2
    )
    results['mae'].append(mae_error)
    results['mse'].append(mse_error)

 25%|██▌       | 6/24 [00:55<02:46,  9.22s/it]
0it [00:55, ?it/s]


KeyboardInterrupt: 

In [23]:
batch_x.shape, score.shape

(torch.Size([10, 36, 7]), torch.Size([10, 36, 7]))

In [32]:
for key in results.keys():
    results[key] = np.mean(results[key])
print(results)

{'mae': 10.65552536646525, 'mse': 4.753488858540853}


In [25]:
# outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# outputs.shape, outputs.numel(), outputs[0].numel()

In [26]:
# score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=0,
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark)
# )
# print(batch_x.shape, score.shape)

In [27]:
# score_targeted = explainer.attribute(
#     inputs=(batch_x),
#     baselines=0,
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     target=0
# )
# print(batch_x.shape, score_targeted.shape)

In [28]:
# temp = score.reshape((batch_x.shape[0], args.pred_len, args.seq_len, -1)).mean(axis=1)

In [29]:
# mse_error = mse(
#     model, inputs=batch_x, 
#     attributions=temp, baselines=0, 
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     topk=0.2
# )
# print(mse_error)

# mae_error = mae(
#     model, inputs=batch_x, 
#     attributions=temp, baselines=0, 
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     # target=0,
#     topk=0.2
# )
# print(mae_error)

In [30]:
# temporal_mask = torch.zeros_like(batch_x, dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t] = t

# explainer = FeatureAblation(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     target=0,
#     feature_mask=temporal_mask
# )
# print(score.shape)

In [31]:
# from tint.attr import Occlusion

# temporal_mask = torch.zeros(size=(1, *batch_x.shape[1:]), dtype=int)
# for t in range(batch_x.shape[1]):
#     temporal_mask[:, t, :] = t

# explainer = Occlusion(model)
# time_score = explainer.attribute(
#     inputs=(batch_x),
#     baselines=(batch_x*0),
#     additional_forward_args=(batch_x_mark, dec_inp, batch_y_mark),
#     sliding_window_shapes=(1, 1)
#     # feature_mask=temporal_mask.to(exp.device)
# )
# print(time_score.shape)