In [None]:
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

from data import load_env

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

from os import path
from datetime import datetime, timedelta

import matplotlib.pyplot as plt

# imports from captum library
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation

In [None]:
start_date = '2014-01-01'
model_path = '../runs/invest_runs/1/policy_final.pth'

os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Understanding Data

In [None]:
env = load_env(parentdir, norm_state=False)
env.reset(date=start_date)
raw_X = env.data_df[:365]

## Plot curves of each column in raw data

In [None]:
feature_names = list(raw_X.keys())

fig, axs = plt.subplots(nrows = 2, ncols=3, figsize=(30, 20))
for i, (ax, col) in enumerate(zip(axs.flat, feature_names)):
    data_point = raw_X[col]
    ax.plot(list(data_point))
    
    ax.set_title(col)
    ax.set_xlabel("")
    ax.set_ylabel('')

In [None]:
env = load_env(parentdir)
env.reset(date="2014-01-01")
X_norm = env.data_df[:365]

## Plot curves in each data column after preprocessing

In [None]:
feature_names = list(X_norm.keys())

fig, axs = plt.subplots(nrows = 2, ncols=3, figsize=(30, 20))
for i, (ax, col) in enumerate(zip(axs.flat, feature_names)):
    data_point = X_norm[col]
    ax.plot(list(data_point))
    
    ax.set_title(col)
    ax.set_xlabel("")
    ax.set_ylabel('')

In [None]:
model = torch.load(model_path)
model

In [None]:
X = torch.tensor(X_norm.to_numpy()).float()
action_set = {'sell': 0, 'no-op': 1, 'buy': 2}

class MaskedPolicy(nn.Module):
    def __init__(self, model, action):
        super(MaskedPolicy, self).__init__()
        self.action = action
        self.model = model
    
    def forward(self, x):
        x = self.model(x)
        return x.gather(1, torch.tensor([[self.action] for _ in range(len(x))]))

masked_policy = MaskedPolicy(model, action=action_set['sell'])

In [None]:
ig = IntegratedGradients(masked_policy)
ig_nt = NoiseTunnel(ig)
dl = DeepLift(masked_policy)
fa = FeatureAblation(masked_policy)

ig_attr_test = ig.attribute(X, n_steps=50)
ig_nt_attr_test = ig_nt.attribute(X)
dl_attr_test = dl.attribute(X)
fa_attr_test = fa.attribute(X)

In [None]:
# prepare attributions for visualization

x_axis_data = np.arange(X.shape[1])
x_axis_data_labels = list(map(lambda idx: feature_names[idx], x_axis_data))

ig_attr_test_sum = ig_attr_test.detach().numpy().sum(0)
ig_attr_test_norm_sum = ig_attr_test_sum / np.linalg.norm(ig_attr_test_sum, ord=1)

ig_nt_attr_test_sum = ig_nt_attr_test.detach().numpy().sum(0)
ig_nt_attr_test_norm_sum = ig_nt_attr_test_sum / np.linalg.norm(ig_nt_attr_test_sum, ord=1)

dl_attr_test_sum = dl_attr_test.detach().numpy().sum(0)
dl_attr_test_norm_sum = dl_attr_test_sum / np.linalg.norm(dl_attr_test_sum, ord=1)

fa_attr_test_sum = fa_attr_test.detach().numpy().sum(0)
fa_attr_test_norm_sum = fa_attr_test_sum / np.linalg.norm(fa_attr_test_sum, ord=1)

lin_weight = model.fc1.weight[0].detach().numpy()
y_axis_lin_weight = lin_weight / np.linalg.norm(lin_weight, ord=1)

width = 0.14
legends = ['Int Grads', 'Int Grads w/SmoothGrad','DeepLift', 'Feature Ablation', 'Weights']

plt.figure(figsize=(20, 10))

ax = plt.subplot()
ax.set_title('Comparing input feature importances across multiple algorithms and learned weights')
ax.set_ylabel('Attributions')

FONT_SIZE = 16
plt.rc('font', size=FONT_SIZE)            # fontsize of the text sizes
plt.rc('axes', titlesize=FONT_SIZE)       # fontsize of the axes title
plt.rc('axes', labelsize=FONT_SIZE)       # fontsize of the x and y labels
plt.rc('legend', fontsize=FONT_SIZE - 4)  # fontsize of the legend

ax.bar(x_axis_data, ig_attr_test_norm_sum, width, align='center', alpha=0.8, color='#eb5e7c')
ax.bar(x_axis_data + width, ig_nt_attr_test_norm_sum, width, align='center', alpha=0.7, color='#A90000')
ax.bar(x_axis_data + 2 * width, dl_attr_test_norm_sum, width, align='center', alpha=0.6, color='#34b8e0')
ax.bar(x_axis_data + 4 * width, fa_attr_test_norm_sum, width, align='center', alpha=1.0, color='#49ba81')
ax.bar(x_axis_data + 5 * width, y_axis_lin_weight, width, align='center', alpha=1.0, color='grey')
ax.autoscale_view()
plt.tight_layout()

ax.set_xticks(x_axis_data + 0.5)
ax.set_xticklabels(x_axis_data_labels)

plt.legend(legends, loc=3)
plt.show()