# Import 

In [None]:
from utils.utils import seed_torch, get_best_model_path
from pytorch_forecasting import TemporalFusionTransformer
from exp.exp_tft import Experiment_TFT
from exp.config import Split, FeatureFiles
from utils.interpreter import *
from utils.plotter import PlotResults
from tqdm import tqdm

import numpy as np
import gc, os
import pandas as pd

# Arguments

In [None]:
from run_tft import get_argparser, stringify_setting

argv = """
--result_path scratch
--data_path Top_20.csv
--test
""".split()
args = get_argparser().parse_args(argv)

args.explainer = 'FO'
seed_torch(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device = torch.device('cpu')

# Input

In [None]:
setting = stringify_setting(args)
experiment = Experiment_TFT(args, setting)

total_data = experiment.age_dataloader.read_df()
print(total_data.shape)
print(total_data.head(3))

train_data, val_data, test_data = experiment.age_dataloader.split_data(
    total_data, Split.primary()
)

## Config

In [None]:
age_data = experiment.age_dataloader
time_index = age_data.time_index
features = age_data.static_reals

# Interpret

## Load Model

In [None]:
model_path = get_best_model_path(experiment.output_folder)
model = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=device)
_ = model.eval().to(device)

## Calculate Importance

### Utils

In [None]:
from typing import Union, Tuple, Dict

class OutputMixIn:
    """
    MixIn to give namedtuple some access capabilities of a dictionary
    """

    def __getitem__(self, k):
        if isinstance(k, str):
            return getattr(self, k)
        else:
            return super().__getitem__(k)

    def get(self, k, default=None):
        return getattr(self, k, default)

    def items(self):
        return zip(self._fields, self)

    def keys(self):
        return self._fields

    def iget(self, idx: Union[int, slice]):
        """Select item(s) row-wise.

        Args:
            idx ([int, slice]): item to select

        Returns:
            Output of single item.
        """
        return self.__class__(*(x[idx] for x in self))

def move_to_device(
    x: Union[
        Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]],
        torch.Tensor,
        List[torch.Tensor],
        Tuple[torch.Tensor],
    ],
    device: Union[str, torch.DeviceObjType],
) -> Union[
    Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]],
    torch.Tensor,
    List[torch.Tensor],
    Tuple[torch.Tensor],
]:
    """
    Move object to device.

    Args:
        x (dictionary of list of tensors): object (e.g. dictionary) of tensors to move to device
        device (Union[str, torch.DeviceObjType]): device, e.g. "cpu"

    Returns:
        x on targeted device
    """
    if isinstance(device, str):
        device = torch.device(device)
    if isinstance(x, dict):
        for name in x.keys():
            x[name] = move_to_device(x[name], device=device)
    elif isinstance(x, OutputMixIn):
        for xi in x:
            move_to_device(xi, device=device)
        return x
    elif isinstance(x, torch.Tensor) and x.device != device:
        x = x.to(device)
    elif isinstance(x, (list, tuple)) and x[0].device != device:
        x = [move_to_device(xi, device=device) for xi in x]
    return x

### Calculate

In [None]:
dataset, dataloader = age_data.create_timeseries(train_data)

In [None]:
attr_list = []
for (x, _) in tqdm(dataloader):
    x = move_to_device(x, device)
    
    # batch_size x seq_len x features
    inputs = x['encoder_cont']
    assignment = torch.randn(inputs.shape[0], device=inputs.device)
    
    # passing target name in a list during training 
    # returns prediction as a list despite having one target
    # list of batch_size x pred_len x 1
    y_pred = model(x)['prediction'][0]
    
    attr = torch.zeros_like(inputs, device=inputs.device)
    for t in range(args.seq_len):
        for f in range(len(features)):
            x_hat = inputs.clone()
            x_hat[:, t, f] = assignment
            x['encoder_cont'] = x_hat
            
            y_pred_hat = model(x)['prediction'][0]
            attr[:, t, f] = torch.sum(torch.abs(y_pred_hat - y_pred), dim=(1, 2))
            
    attr_list.append(attr)
    gc.collect()
    torch.cuda.empty_cache()
    
attr = torch.vstack(attr_list)

In [None]:
all_scores = attr.detach().cpu().numpy()
group_agg_scores_df = align_interpretation(df, all_scores, features)

# plot local interpretations
plotter = PlotResults(
    figPath=args.result_folder, targets=dataloader.targets, 
    show=not args.disable_progress
)
plotter.local_interpretation(
    group_agg_scores_df, dataloader.static_reals
)

# Evaluate

The white box evaluation is only available for age group features.

## Load ground truth

In [None]:
# Load ground truth
group_cases = pd.read_csv(
    os.path.join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])


## Calculate rank score

In [None]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, group_agg_scores_df['Date'].values
)

# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
    group_agg_scores_df, dataloader.static_reals, first_common_date
)

In [None]:
evaluate_interpretation(
    group_cases, weekly_agg_scores_df, dataloader.static_reals
)