# Import

In [1]:
from run import *
from tint.metrics import mse, mae
import tint, gc, os
from tqdm import tqdm
import pandas as pd
from exp.config import FeatureFiles
from utils.interpreter import *

from tint.attr import (
    AugmentedOcclusion,
    DynaMask,
    Occlusion, 
    FeatureAblation
)

# Arguments

In [4]:
parser = get_parser()
argv = """
  --model Autoformer --use_gpu --result_path scratch --data_path Top_20.csv
""".split()
args = parser.parse_args(argv)

args.n_features = len(set(DataConfig.static_reals+DataConfig.observed_reals+DataConfig.targets))
args.enc_in = args.dec_in = args.c_out = args.n_features
args.n_targets = len(DataConfig.targets)

# Experiment 

In [5]:
set_random_seed(args.seed)
# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
# args.use_gpu = False

Exp = Exp_Forecast

setting = stringify_setting(args)

In [6]:
exp = Exp(args, setting)  # set experiments
exp.load_model()
result_folder = exp.output_folder
model = exp.model
model.eval()

Use GPU: cuda:0

Train samples 12740, validation samples 560, test samples 560
637 days of training, 14 days of validation data, 14 days of test data.

Fitting scalers on train data
Loading dataset from ./dataset/processed\Top_20\train.pt
Loading dataset from ./dataset/processed\Top_20\val.pt
Loading dataset from ./dataset/processed\Top_20\test.pt
loading best model from scratch\Autoformer_Top_20\checkpoint.pth


# Interpret

## Calculate Attribute

In [94]:
flag = 'train'
dataset, dataloader = exp._get_data(flag)

explainer = FeatureAblation(model)
attr = batch_compute_attr(dataloader, exp, explainer)

In [211]:
df = exp.data_map[flag]
df.sort_values(by=['Date', 'FIPS'], inplace=True)
df.head(3)

Unnamed: 0,Date,FIPS,UNDER5,AGE517,AGE1829,AGE3039,AGE4049,AGE5064,AGE6574,AGE75PLUS,VaccinationFull,Cases,SinWeekly,TimeFromStart
746,2020-03-01,2261,0.0062,0.016,0.014,0.0146,0.0117,0.0235,0.0103,0.0004,0.0,0.0,-0.7818,0
1782,2020-03-01,4013,0.0601,0.1717,0.1678,0.1392,0.1254,0.1771,0.0912,0.0675,0.0,1.0,-0.7818,0
2818,2020-03-01,6037,0.056,0.1558,0.1738,0.1525,0.1311,0.1865,0.0826,0.0617,0.0,130.0,-0.7818,0


## Get ground truth

In [None]:
# Load ground truth
group_cases = pd.read_csv(
    os.path.join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])

population = pd.read_csv('dataset/raw/Population.csv')
population = population[['FIPS', 'POPESTIMATE']]

## Aggregate

### Unweighted

In [None]:
attr_numpy = attr.detach().cpu().numpy()
np.save(os.path.join(exp.output_folder, f'{flag}_{explainer.get_name()}.npy'), attr_numpy)
# taking absolute since we want the magnitude of feature importance only
attr_numpy  = np.abs(attr_numpy)

# align importance along their time axis with the input data
features = exp.age_data.static_reals + exp.age_data.observed_reals
age_features = DataConfig.static_reals

attr_df = align_interpretation(
    ranges=dataset.ranges,
    attr=attr_numpy,
    features=features,
    min_date=df['Date'].min(),
    seq_len=args.seq_len, pred_len=args.pred_len
)
print(attr_df.describe())

In [220]:
attr_by_date = attr_df.groupby('Date')[
    age_features
].aggregate('sum').reset_index()

### Weighted

In [257]:
weights = df.groupby('FIPS').first()[age_features].reset_index()
groups = []

for FIPS, group_df in attr_df.groupby('FIPS'):
    county_age_weights = weights[weights['FIPS']==FIPS][age_features].values
    total_population = population[
        population['FIPS']==FIPS]['POPESTIMATE'].values[0]
    group_df[age_features] *= county_age_weights * total_population
    # group_df[age_features] *= total_population
    groups.append(group_df)
    
groups = pd.concat(groups, axis=0)
weighted_attr_df = groups[['FIPS', 'Date']+age_features].reset_index(drop=True)

weighted_attr_by_date = weighted_attr_df.groupby('Date')[
    age_features].aggregate('sum').reset_index()

## Evaluate globally

In [None]:
dates = attr_by_date['Date'].values
first_common_date = find_first_common_date(group_cases, dates)
last_common_date = find_last_common_date(group_cases, dates)

summed_ground_truth = group_cases[
    (group_cases['end_of_week']>=first_common_date) &
    (group_cases['end_of_week']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_ground_truth.columns = ['age_group', 'cases']
summed_ground_truth

### Unweighted

In [241]:
summed_attr = attr_df[
    (attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_attr.columns = ['age_group', 'attr']
summed_attr

Unnamed: 0,age_group,attr
0,UNDER5,0.066876
1,AGE517,0.059743
2,AGE1829,0.067455
3,AGE3039,0.069482
4,AGE4049,0.065127
5,AGE5064,0.075087
6,AGE6574,0.05569
7,AGE75PLUS,0.085042


In [247]:
merged = summed_ground_truth.merge(
    summed_attr, on='age_group', how='inner'
) 
merged[['cases', 'attr']] = merged[['cases', 'attr']].div(merged[['cases', 'attr']].sum(axis=0)/100, axis=1)
merged

Unnamed: 0,age_group,cases,attr
0,UNDER5,2.684803,12.281976
1,AGE517,13.29054,10.972135
2,AGE1829,21.508396,12.388451
3,AGE3039,16.658589,12.760715
4,AGE4049,14.521426,11.960767
5,AGE5064,18.91854,13.789963
6,AGE6574,7.049264,10.227716
7,AGE75PLUS,5.368442,15.618274


### Weighted

In [258]:
summed_weighted_attr = weighted_attr_df[
    (weighted_attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (weighted_attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_weighted_attr.columns = ['age_group', 'attr']
summed_weighted_attr

Unnamed: 0,age_group,attr
0,UNDER5,13760.914048
1,AGE517,32547.453866
2,AGE1829,36913.976632
3,AGE3039,34282.368886
4,AGE4049,27710.678503
5,AGE5064,46890.374127
6,AGE6574,16021.508204
7,AGE75PLUS,14940.071591


In [259]:
global_rank = summed_ground_truth.merge(
    summed_weighted_attr, on='age_group', how='inner'
) 
global_rank[['cases', 'attr']] = global_rank[['cases', 'attr']].div(
    global_rank[['cases', 'attr']].sum(axis=0)/100, axis=1)

global_rank['cases_rank'] = global_rank['cases'].rank(
    axis=0, ascending=False
)
global_rank['attr_rank'] = global_rank['attr'].rank(
    axis=0, ascending=False
)
print(global_rank)

Unnamed: 0,age_group,cases,attr,cases_rank,attr_rank
0,UNDER5,2.684803,6.16895,8.0,8.0
1,AGE517,13.29054,14.590864,5.0,4.0
2,AGE1829,21.508396,16.548355,1.0,2.0
3,AGE3039,16.658589,15.368618,3.0,3.0
4,AGE4049,14.521426,12.422562,4.0,5.0
5,AGE5064,18.91854,21.020725,2.0,1.0
6,AGE6574,7.049264,7.182364,6.0,6.0
7,AGE75PLUS,5.368442,6.697561,7.0,7.0


In [250]:
global_rank.to_csv(
    os.path.join(
        exp.output_folder, 
        f'{flag}_global_rank_{explainer.get_name()}.csv'
    ), 
    index=False
)

   age_group      cases       attr  cases_rank  attr_rank
0     UNDER5   2.684803   6.168950         8.0        8.0
1     AGE517  13.290540  14.590864         5.0        4.0
2    AGE1829  21.508396  16.548355         1.0        2.0
3    AGE3039  16.658589  15.368618         3.0        3.0
4    AGE4049  14.521426  12.422562         4.0        5.0
5    AGE5064  18.918540  21.020725         2.0        1.0
6    AGE6574   7.049264   7.182364         6.0        6.0
7  AGE75PLUS   5.368442   6.697561         7.0        7.0


## Evaluate Local Interpretation

### Unweighted

In [260]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, attr_by_date['Date'].values
)
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores = aggregate_importance_by_window(
    attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores, age_features
)

Found first common date 2020-03-21T00:00:00.000000000.

        Rank mae: 0.30372, rmse: 0.38069, ndcg: 0.88171

        Normalized mae: 0.061155, rmse: 0.071744, ndcg: 0.80684
    


### Weighted

In [262]:
# find a common start point
first_common_date = find_first_common_date(
    group_cases, weighted_attr_by_date['Date'].values
)
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
    weighted_attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores_df, age_features
)
result_df.to_csv(
    os.path.join(
        exp.output_folder, 
        f'{flag}_int_metrics_{explainer.get_name()}.csv'
    ), 
    index=False
)

Found first common date 2020-03-21T00:00:00.000000000.

        Rank mae: 0.10218, rmse: 0.14873, ndcg: 0.98545

        Normalized mae: 0.030308, rmse: 0.039382, ndcg: 0.96357
    
