# Import

In [1]:
from run import *
import tint, gc, os
import pandas as pd
from exp.config import FeatureFiles
from utils.interpreter import *

from tint.attr import (
    AugmentedOcclusion,
    DynaMask,
    Occlusion, 
    FeatureAblation
)

# Arguments

In [2]:
parser = get_parser()
argv = """
  --model DLinear --result_path scratch --data_path Top_20.csv
""".split()
args = parser.parse_args(argv)

initial_setup(args)

# Experiment 

In [3]:
# Disable cudnn if using cuda accelerator.
# Please see https://captum.ai/docs/faq#how-can-i-resolve-cudnn-rnn-backward-error-for-rnn-or-lstm-network
# args.no_gpu = True

setting = stringify_setting(args)

In [4]:
exp = Exp_Forecast(args, setting)  # set experiments
exp.load_model()
model = exp.model.eval()

Starting experiment. Result folder scratch\DLinear_Top_20.
Use GPU: cuda:0
adding time index columns TimeFromStart
added time encoded known reals ['month', 'day', 'weekday'].

Train samples 12740, validation samples 560, test samples 560
637 days of training, 14 days of validation data, 14 days of test data.

Fitting scalers on train data
Loading dataset from ./dataset/processed\Top_20\train.pt
Loading dataset from ./dataset/processed\Top_20\val.pt
Loading dataset from ./dataset/processed\Top_20\test.pt
loading best model from scratch\DLinear_Top_20\checkpoint.pth


In [5]:
# align importance along their time axis with the input data
features = DataConfig.static_reals + DataConfig.observed_reals
age_features = DataConfig.static_reals

# Interpret

## Calculate Attribute

In [6]:
flag = 'train'
dataset, dataloader = exp.get_data(flag)

In [7]:
explainer = FeatureAblation(model)
attr = batch_compute_attr(dataloader, exp, explainer)

100%|██████████| 382/382 [00:23<00:00, 15.99it/s]


In [9]:
df = exp.data_map[flag]
df.sort_values(by=['Date', 'FIPS'], inplace=True)
df.head(3)

Unnamed: 0,Date,FIPS,UNDER5,AGE517,AGE1829,AGE3039,AGE4049,AGE5064,AGE6574,AGE75PLUS,VaccinationFull,Cases,month,day,weekday,TimeFromStart
746,2020-03-01,2261,0.0062,0.016,0.014,0.0146,0.0117,0.0235,0.0103,0.0004,0.0,0.0,3,1,6,0
1782,2020-03-01,4013,0.0601,0.1717,0.1678,0.1392,0.1254,0.1771,0.0912,0.0675,0.0,1.0,3,1,6,0
2818,2020-03-01,6037,0.056,0.1558,0.1738,0.1525,0.1311,0.1865,0.0826,0.0617,0.0,130.0,3,1,6,0


## Get ground truth

In [10]:
# Load ground truth
group_cases = pd.read_csv(
    os.path.join(FeatureFiles.root_folder, 'Cases by age groups.csv')
)
group_cases['end_of_week'] = pd.to_datetime(group_cases['end_of_week'])

population = pd.read_csv('dataset/raw/Population.csv')
population = population[['FIPS', 'POPESTIMATE']]

## Aggregate

In [11]:
attr_numpy = attr.detach().cpu().numpy()
np.save(os.path.join(exp.output_folder, f'{flag}_{explainer.get_name()}.npy'), attr_numpy)
# taking absolute since we want the magnitude of feature importance only
attr_numpy  = np.abs(attr_numpy)

attr_df = align_interpretation(
    ranges=dataset.ranges,
    attr=attr_numpy,
    features=features,
    min_date=df['Date'].min(),
    seq_len=args.seq_len, pred_len=args.pred_len
)
print(attr_df.describe())

2020-03-15 00:00:00 2021-11-27 00:00:00
               FIPS   UNDER5   AGE517  AGE1829  AGE3039  AGE4049  AGE5064  \
count  12460.000000  12460.0  12460.0  12460.0  12460.0  12460.0  12460.0   
mean   22995.050000      0.0      0.0      0.0      0.0      0.0      0.0   
std    18249.847559      0.0      0.0      0.0      0.0      0.0      0.0   
min     2261.000000      0.0      0.0      0.0      0.0      0.0      0.0   
25%     6069.500000      0.0      0.0      0.0      0.0      0.0      0.0   
50%    14558.500000      0.0      0.0      0.0      0.0      0.0      0.0   
75%    39068.000000      0.0      0.0      0.0      0.0      0.0      0.0   
max    53033.000000      0.0      0.0      0.0      0.0      0.0      0.0   

       AGE6574  AGE75PLUS  VaccinationFull    Cases  
count  12460.0    12460.0          12460.0  12460.0  
mean       0.0        0.0              0.0      1.0  
std        0.0        0.0              0.0      0.0  
min        0.0        0.0              0.0      1.

In [15]:
attr_by_date = attr_df.groupby('Date')[
    age_features
].aggregate('sum').reset_index()

### Weighted

In [13]:
weights = df.groupby('FIPS').first()[age_features].reset_index()
groups = []

for FIPS, group_df in attr_df.groupby('FIPS'):
    county_age_weights = weights[weights['FIPS']==FIPS][age_features].values
    total_population = population[
        population['FIPS']==FIPS]['POPESTIMATE'].values[0]
    group_df[age_features] *= county_age_weights * total_population
    # group_df[age_features] *= total_population
    groups.append(group_df)
    
groups = pd.concat(groups, axis=0)
weighted_attr_df = groups[['FIPS', 'Date']+age_features].reset_index(drop=True)

weighted_attr_by_date = weighted_attr_df.groupby('Date')[
    age_features].aggregate('sum').reset_index()

## Evaluate globally

In [16]:
dates = attr_by_date['Date'].values
first_common_date = find_first_common_date(group_cases, dates)
last_common_date = find_last_common_date(group_cases, dates)

summed_ground_truth = group_cases[
    (group_cases['end_of_week']>=first_common_date) &
    (group_cases['end_of_week']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_ground_truth.columns = ['age_group', 'cases']
summed_ground_truth

Found first common date 2020-03-21T00:00:00.000000000.
Found last common date 2021-11-27T00:00:00.000000000.


Unnamed: 0,age_group,cases
0,UNDER5,14060.865169
1,AGE517,69605.303371
2,AGE1829,112643.910112
3,AGE3039,87244.47191
4,AGE4049,76051.707865
5,AGE5064,99080.303371
6,AGE6574,36918.449438
7,AGE75PLUS,28115.640449


### Unweighted

In [17]:
summed_attr = attr_df[
    (attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_attr.columns = ['age_group', 'attr']
summed_attr

Unnamed: 0,age_group,attr
0,UNDER5,0.0
1,AGE517,0.0
2,AGE1829,0.0
3,AGE3039,0.0
4,AGE4049,0.0
5,AGE5064,0.0
6,AGE6574,0.0
7,AGE75PLUS,0.0


In [24]:
merged = summed_ground_truth.merge(
    summed_attr, on='age_group', how='inner'
) 
merged[['cases', 'attr']] = merged[['cases', 'attr']].truediv(
    merged[['cases', 'attr']].sum(axis=0)/100, 
    axis=1).fillna(0) 
merged

Unnamed: 0,age_group,cases,attr
0,UNDER5,2.684803,0.0
1,AGE517,13.29054,0.0
2,AGE1829,21.508396,0.0
3,AGE3039,16.658589,0.0
4,AGE4049,14.521426,0.0
5,AGE5064,18.91854,0.0
6,AGE6574,7.049264,0.0
7,AGE75PLUS,5.368442,0.0


### Weighted

In [25]:
summed_weighted_attr = weighted_attr_df[
    (weighted_attr_df['Date']>=(first_common_date-pd.to_timedelta(6, unit='D'))) &
    (weighted_attr_df['Date']<=last_common_date)
][age_features].mean(axis=0).T.reset_index()
summed_weighted_attr.columns = ['age_group', 'attr']
summed_weighted_attr

Unnamed: 0,age_group,attr
0,UNDER5,0.0
1,AGE517,0.0
2,AGE1829,0.0
3,AGE3039,0.0
4,AGE4049,0.0
5,AGE5064,0.0
6,AGE6574,0.0
7,AGE75PLUS,0.0


In [26]:
global_rank = summed_ground_truth.merge(
    summed_weighted_attr, on='age_group', how='inner'
) 
global_rank[['cases', 'attr']] = global_rank[['cases', 'attr']].div(
    global_rank[['cases', 'attr']].sum(axis=0)/100, axis=1).fillna(0)

global_rank['cases_rank'] = global_rank['cases'].rank(
    axis=0, ascending=False
)
global_rank['attr_rank'] = global_rank['attr'].rank(
    axis=0, ascending=False
)
print(global_rank)

   age_group      cases  attr  cases_rank  attr_rank
0     UNDER5   2.684803   0.0         8.0        4.5
1     AGE517  13.290540   0.0         5.0        4.5
2    AGE1829  21.508396   0.0         1.0        4.5
3    AGE3039  16.658589   0.0         3.0        4.5
4    AGE4049  14.521426   0.0         4.0        4.5
5    AGE5064  18.918540   0.0         2.0        4.5
6    AGE6574   7.049264   0.0         6.0        4.5
7  AGE75PLUS   5.368442   0.0         7.0        4.5


In [122]:
global_rank.to_csv(
    os.path.join(
        exp.output_folder, 
        f'{flag}_global_rank_{explainer.get_name()}.csv'
    ), 
    index=False
)

## Evaluate Local Interpretation

### Unweighted

In [136]:
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores = aggregate_importance_by_window(
    attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores, age_features
)

Found first common date 2021-12-18T00:00:00.000000000.
Rank mae: 0.32812, rmse: 0.4239, ndcg: 0.82107
Normalized mae: 0.061398, rmse: 0.074441, ndcg: 0.82411


### Weighted

In [137]:
# since age group ground truth is weekly aggregated
# do the same for predicted importance
weekly_agg_scores_df = aggregate_importance_by_window(
    weighted_attr_by_date, age_features, first_common_date
)
result_df = evaluate_interpretation(
    group_cases, weekly_agg_scores_df, age_features
)
# result_df.to_csv(
#     os.path.join(
#         exp.output_folder, 
#         f'{flag}_int_metrics_{explainer.get_name()}.csv'
#     ), 
#     index=False
# )

Found first common date 2021-12-18T00:00:00.000000000.
Rank mae: 0.125, rmse: 0.1875, ndcg: 0.98915
Normalized mae: 0.039509, rmse: 0.047049, ndcg: 0.9155
