# Import

In [3]:
from run import *
from tint.metrics import mse, mae
import tint, gc, os
from tqdm import tqdm
import pandas as pd
from exp.config import FeatureFiles
from utils.interpreter import *

from tint.attr import (
    AugmentedOcclusion,
    DynaMask,
    Occlusion, 
    FeatureAblation
)

# Arguments

In [4]:
parser = get_parser()
argv = """
  --model Autoformer --use_gpu --result_path scratch --data_path Top_20.csv
""".split()
args = parser.parse_args(argv)

args.n_features = len(set(DataConfig.static_reals+DataConfig.observed_reals+DataConfig.targets))
args.enc_in = args.dec_in = args.c_out = args.n_features
args.n_targets = len(DataConfig.targets)

# Experiment 

In [5]:
set_random_seed(args.seed)
# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
# args.use_gpu = False

Exp = Exp_Forecast

setting = stringify_setting(args)

In [6]:
exp = Exp(args, setting)  # set experiments
exp.load_model()
result_folder = exp.output_folder
model = exp.model.eval()

Starting experiment. Result folder scratch\Autoformer_Top_20.
Use GPU: cuda:0
adding time index columns TimeFromStart
added time encoded known reals ['month', 'day', 'weekday'].

Train samples 12740, validation samples 560, test samples 560
637 days of training, 14 days of validation data, 14 days of test data.

Fitting scalers on train data
Loading dataset from ./dataset/processed\Top_20\train.pt
Loading dataset from ./dataset/processed\Top_20\val.pt
Loading dataset from ./dataset/processed\Top_20\test.pt
loading best model from scratch\Autoformer_Top_20\checkpoint.pth


In [8]:
# align importance along their time axis with the input data
features = exp.age_data.static_reals + exp.age_data.observed_reals
age_features = DataConfig.static_reals

# Interpret

## Calculate Attribute

In [125]:
flag = 'test'
dataset, dataloader = exp._get_data(flag)

In [111]:
# from explainers import MorrisSensitivty

# def get_all_inputs(dataloader:torch.utils.data.dataloader.DataLoader):
#     data = [batch[0] for batch in dataloader]
#     return torch.vstack(data)

# data = get_all_inputs(dataloader)
# explainer = MorrisSensitivty(model, data, args.pred_len)
# attr = batch_compute_attr(dataloader, exp, explainer)

In [126]:
explainer = FeatureAblation(model)
attr = batch_compute_attr(dataloader, exp, explainer)

100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


In [127]:
df = exp.data_map[flag]
df.sort_values(by=['Date', 'FIPS'], inplace=True)
df.head(3)

Unnamed: 0,Date,FIPS,UNDER5,AGE517,AGE1829,AGE3039,AGE4049,AGE5064,AGE6574,AGE75PLUS,VaccinationFull,Cases,month,day,weekday,TimeFromStart
350,2021-11-28,2261,0.0062,0.016,0.014,0.0146,0.0117,0.0235,0.0103,0.0004,54.1,1.0,11,28,6,637
1386,2021-11-28,4013,0.0601,0.1717,0.1678,0.1392,0.1254,0.1771,0.0912,0.0675,51.3,0.0,11,28,6,637
2422,2021-11-28,6037,0.056,0.1558,0.1738,0.1525,0.1311,0.1865,0.0826,0.0617,63.7,1225.0,11,28,6,637


## Get ground truth

In [108]:
# Load ground truth
group_cases = pd.read_csv(
    os.path.join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])

population = pd.read_csv('dataset/raw/Population.csv')
population = population[['FIPS', 'POPESTIMATE']]

## Aggregate

In [128]:
attr_numpy = attr.detach().cpu().numpy()
# np.save(os.path.join(exp.output_folder, f'{flag}_{explainer.get_name()}.npy'), attr_numpy)
# taking absolute since we want the magnitude of feature importance only
attr_numpy  = np.abs(attr_numpy)

attr_df = align_interpretation(
    ranges=dataset.ranges,
    attr=attr_numpy,
    features=features,
    min_date=df['Date'].min(),
    seq_len=args.seq_len, pred_len=args.pred_len
)
print(attr_df.describe())

2021-12-12 00:00:00 2021-12-25 00:00:00
               FIPS      UNDER5      AGE517     AGE1829     AGE3039  \
count    280.000000  280.000000  280.000000  280.000000  280.000000   
mean   22995.050000    0.065495    0.069544    0.062922    0.068926   
std    18281.790462    0.042149    0.049840    0.045401    0.054660   
min     2261.000000    0.000188    0.000120    0.000243    0.000180   
25%     6069.500000    0.031076    0.030981    0.030946    0.030950   
50%    14558.500000    0.063317    0.067861    0.055310    0.056428   
75%    39068.000000    0.090895    0.095310    0.085310    0.094058   
max    53033.000000    0.201835    0.258352    0.284300    0.277420   

          AGE4049     AGE5064     AGE6574   AGE75PLUS  VaccinationFull  \
count  280.000000  280.000000  280.000000  280.000000       280.000000   
mean     0.067515    0.086001    0.060666    0.075937         0.119456   
std      0.044248    0.065430    0.040402    0.060236         0.079675   
min      0.000148    0.0

In [129]:
attr_by_date = attr_df.groupby('Date')[
    age_features
].aggregate('sum').reset_index()

### Weighted

In [130]:
weights = df.groupby('FIPS').first()[age_features].reset_index()
groups = []

for FIPS, group_df in attr_df.groupby('FIPS'):
    county_age_weights = weights[weights['FIPS']==FIPS][age_features].values
    total_population = population[
        population['FIPS']==FIPS]['POPESTIMATE'].values[0]
    group_df[age_features] *= county_age_weights * total_population
    # group_df[age_features] *= total_population
    groups.append(group_df)
    
groups = pd.concat(groups, axis=0)
weighted_attr_df = groups[['FIPS', 'Date']+age_features].reset_index(drop=True)

weighted_attr_by_date = weighted_attr_df.groupby('Date')[
    age_features].aggregate('sum').reset_index()

## Evaluate globally

In [131]:
dates = attr_by_date['Date'].values
first_common_date = find_first_common_date(group_cases, dates)
last_common_date = find_last_common_date(group_cases, dates)

summed_ground_truth = group_cases[
    (group_cases['end_of_week']>=first_common_date) &
    (group_cases['end_of_week']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_ground_truth.columns = ['age_group', 'cases']
summed_ground_truth

Found first common date 2021-12-18T00:00:00.000000000.
Found last common date 2021-12-25T00:00:00.000000000.


Unnamed: 0,age_group,cases
0,UNDER5,49827.0
1,AGE517,202210.0
2,AGE1829,343324.0
3,AGE3039,269842.0
4,AGE4049,196863.5
5,AGE5064,221850.5
6,AGE6574,70745.0
7,AGE75PLUS,41543.0


### Unweighted

In [132]:
summed_attr = attr_df[
    (attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_attr.columns = ['age_group', 'attr']
summed_attr

Unnamed: 0,age_group,attr
0,UNDER5,0.065495
1,AGE517,0.069544
2,AGE1829,0.062922
3,AGE3039,0.068926
4,AGE4049,0.067515
5,AGE5064,0.086001
6,AGE6574,0.060666
7,AGE75PLUS,0.075937


In [133]:
merged = summed_ground_truth.merge(
    summed_attr, on='age_group', how='inner'
) 
merged[['cases', 'attr']] = merged[['cases', 'attr']].div(merged[['cases', 'attr']].sum(axis=0)/100, axis=1)
merged

Unnamed: 0,age_group,cases,attr
0,UNDER5,3.568745,11.758368
1,AGE517,14.48283,12.485377
2,AGE1829,24.589799,11.296468
3,AGE3039,19.326818,12.374293
4,AGE4049,14.099899,12.121077
5,AGE5064,15.889536,15.439819
6,AGE6574,5.066949,10.891472
7,AGE75PLUS,2.975423,13.633129


### Weighted

In [134]:
summed_weighted_attr = weighted_attr_df[
    (weighted_attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (weighted_attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_weighted_attr.columns = ['age_group', 'attr']
summed_weighted_attr

Unnamed: 0,age_group,attr
0,UNDER5,14448.383363
1,AGE517,36690.836517
2,AGE1829,33880.776863
3,AGE3039,36642.269908
4,AGE4049,27780.851757
5,AGE5064,50591.372789
6,AGE6574,18890.060892
7,AGE75PLUS,13519.295254


In [135]:
global_rank = summed_ground_truth.merge(
    summed_weighted_attr, on='age_group', how='inner'
) 
global_rank[['cases', 'attr']] = global_rank[['cases', 'attr']].div(
    global_rank[['cases', 'attr']].sum(axis=0)/100, axis=1)

global_rank['cases_rank'] = global_rank['cases'].rank(
    axis=0, ascending=False
)
global_rank['attr_rank'] = global_rank['attr'].rank(
    axis=0, ascending=False
)
print(global_rank)

   age_group      cases       attr  cases_rank  attr_rank
0     UNDER5   3.568745   6.215860         7.0        7.0
1     AGE517  14.482830  15.784817         4.0        2.0
2    AGE1829  24.589799  14.575897         1.0        4.0
3    AGE3039  19.326818  15.763923         2.0        3.0
4    AGE4049  14.099899  11.951640         5.0        5.0
5    AGE5064  15.889536  21.764987         3.0        1.0
6    AGE6574   5.066949   8.126720         6.0        6.0
7  AGE75PLUS   2.975423   5.816155         8.0        8.0


In [122]:
global_rank.to_csv(
    os.path.join(
        exp.output_folder, 
        f'{flag}_global_rank_{explainer.get_name()}.csv'
    ), 
    index=False
)

## Evaluate Local Interpretation

### Unweighted

In [136]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, attr_by_date['Date'].values
)
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores = aggregate_importance_by_window(
    attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores, age_features
)

Found first common date 2021-12-18T00:00:00.000000000.
Rank mae: 0.32812, rmse: 0.4239, ndcg: 0.82107
Normalized mae: 0.061398, rmse: 0.074441, ndcg: 0.82411


### Weighted

In [137]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, weighted_attr_by_date['Date'].values
)
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
    weighted_attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores_df, age_features
)
# result_df.to_csv(
#     os.path.join(
#         exp.output_folder, 
#         f'{flag}_int_metrics_{explainer.get_name()}.csv'
#     ), 
#     index=False
# )

Found first common date 2021-12-18T00:00:00.000000000.
Rank mae: 0.125, rmse: 0.1875, ndcg: 0.98915
Normalized mae: 0.039509, rmse: 0.047049, ndcg: 0.9155
