# Import 

In [1]:
from utils.utils import seed_torch, get_best_model_path
from dataclasses import dataclass
from pytorch_forecasting import TemporalFusionTransformer
from experiment.tft import Experiment_TFT
from explainers import *
from experiment.config import Split, DataConfig, FeatureFiles
from utils.interpreter import *
from utils.plotter import PlotResults

import numpy as np
from tqdm import tqdm
import gc, os
import pandas as pd

# Arguments

In [2]:
@dataclass
class args:
    result_folder = 'scratch/no_scale/' # 'results/Top_100'
    input_folder = 'dataset/processed/'
    input = 'Top_100.csv'
    explainer = 'FO'
    
    disable_progress = False
    seed = 7
    
seed_torch(args.seed)

Global seed set to 7


In [3]:
def explainer_factory(
    args, model, dataloader:AgeDataLoader, 
)-> BaseExplainer:
    # only interpreting static reals for now
    features = dataloader.static_reals
    
    if args.explainer == 'FO':
        explainer = FeatureOcclusion(model, dataloader, features)
    elif args.explainer == 'AFO':
        explainer = AugmentedFeatureOcclusion(model, dataloader, features, n_samples=2)
    elif args.explainer == 'FA':
        explainer = FeatureAblation(model, dataloader, features, method='global')
    else:
        raise ValueError(f'{args.explainer} isn\'t supported.')
    return explainer

# Input

In [4]:
data_path = os.path.join(args.input_folder, args.input)
experiment = Experiment_TFT(
    data_path, args.result_folder, not args.disable_progress
)

total_data = experiment.age_dataloader.read_df()
print(total_data.shape)
print(total_data.head(3))

train_data, val_data, test_data = experiment.age_dataloader.split_data(
    total_data, Split.primary()
)

(103600, 13)
        Date  FIPS  UNDER5  AGE517  AGE1829  AGE3039  AGE4049  AGE5064  \
0 2020-12-13  2261  0.0062   0.016    0.014   0.0146   0.0117   0.0235   
1 2020-12-14  2261  0.0062   0.016    0.014   0.0146   0.0117   0.0235   
2 2020-12-15  2261  0.0062   0.016    0.014   0.0146   0.0117   0.0235   

   AGE6574  AGE75PLUS  VaccinationFull  Cases  SinWeekly  
0   0.0103     0.0004              0.0    2.0    -0.7818  
1   0.0103     0.0004              0.0    1.0     0.0000  
2   0.0103     0.0004              0.0    1.0     0.7818  

Train samples 63700, validation samples 2800, test samples 2800
637 days of training, 14 days of validation data, 14 days of test data.



## Config

In [6]:
dataloader = experiment.age_dataloader
time_index = dataloader.time_index
features = dataloader.static_reals

# Interpret

## Load Model

In [7]:
model_path = get_best_model_path(args.result_folder)
model = TemporalFusionTransformer.load_from_checkpoint(model_path)


Found best checkpoint model best-epoch=5.ckpt.


  rank_zero_warn(
  rank_zero_warn(


## Select Data

In [8]:
data = train_data.copy()

## Calculate Importance

In [11]:
explainer = explainer_factory(args, model, dataloader)

# train any baseline or parameters
explainer.train_generators(train_data)
all_scores = explainer.attribute(train_data, args.disable_progress)

100%|██████████| 610/610 [12:12<00:00,  1.20s/it]


In [17]:
score_file = os.path.join(args.result_folder, 'scores.npy.gz')
np.savez_compressed(score_file, all_scores)

In [None]:
time_index = dataloader.time_index
features = dataloader.static_reals 

In [None]:
time_range = explainer.time_range(train_data)
df = data[
    (data[time_index]>=time_range[0]) & 
    (data[time_index]<=time_range[-1])
][['Date', 'FIPS']]

global_rank = calculate_global_rank(
    df, all_scores, features
)

In [None]:
group_agg_scores_df = align_interpretation(df, all_scores, features)

# plot local interpretations
plotter = PlotResults(
    figPath=args.result_folder, targets=dataloader.targets, 
    show=not args.disable_progress
)
plotter.local_interpretation(
    group_agg_scores_df, dataloader.static_reals
)

# Evaluate

## Load ground truth

In [None]:
# Load ground truth
group_cases = pd.read_csv(
    os.path.join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])


## Calculate rank score

In [None]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, group_agg_scores_df['Date'].values
)

# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
    group_agg_scores_df, dataloader.static_reals, first_common_date
)
evaluate_interpretation(
    group_cases, weekly_agg_scores_df, dataloader.static_reals
)