In [None]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import random

from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
from keras.regularizers import l2
from keras.models import Sequential
from keras.utils import to_categorical
from keras.models import model_from_json
import keras.backend as K

from sklearn.metrics import confusion_matrix
from sklearn.utils import resample

import shap

In [None]:
json_file = open('model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights('model.h5')
print('Loaded model from disk')

In [None]:
input_files = glob.glob('arrays/*.npz')
print(len(input_files))

Xs = []
ys = []

for file in input_files:
    loaded = np.load(file)
    Xs.append(loaded['X'])
    ys.append(loaded['y'])
    
X_ = np.array(Xs)
y_ = np.array(ys)

In [None]:
samples = 3
Xs_r = []
ys_r = []
for class_name in np.unique(y_):
    Xc_ = X_[y_==class_name]
    yc_ = y_[y_==class_name]
    for sample in range(samples):
        idx = random.randint(0,len(Xc_)-1)
        Xs_r.append(Xc_[idx])
        ys_r.append(yc_[idx])
        
X = np.array(Xs_r)
y = np.array(ys_r)
y = to_categorical(y)

In [None]:
y_pred = model.predict(X)
cm = confusion_matrix(np.argmax(y, axis=1), np.argmax(y_pred, axis=1))
print(cm)
plt.imshow(cm)

In [None]:
state_map = {
    -99:'nan',
    0:'normal',
    1:'fluid_pound',
    2:'standing_valve',
    3:'sticking',
    4:'barrel_leak',
    5:'gas',
    6:'bad_data',
}

def map2layer(x, layer):
    feed_dict = dict(zip([model.layers[0].input], [x.copy()]))
    return K.get_session().run(model.layers[layer].input, feed_dict)

In [None]:
layer = 0
X_shap = X[:,:,:,:]
ranked_outputs = 2

In [None]:
explainer = shap.GradientExplainer(
    (model.layers[layer].input, model.layers[-1].output), 
    map2layer(X_shap.copy(), layer)
)

In [None]:
X_explain = X[::3,:,:,:]
y_explain = np.argmax(y[:], axis=1)
shap_values, indexes = explainer.shap_values(
    map2layer(X_explain, layer), 
    ranked_outputs=ranked_outputs
)
index_names = np.vectorize(lambda x: state_map[x])(indexes)

In [None]:
for x in range(len(X_explain)):
    fig = make_subplots(
        rows=3, 
        cols=ranked_outputs+1, 
        horizontal_spacing = 0.01,
        vertical_spacing = 0.01,
        shared_xaxes=True,
        shared_yaxes=True,
#         subplot_titles=("Plot 1", "Plot 2", "Plot 3", "Plot 4")
    )
    for ch in range(3):
        fig.add_trace(
            go.Heatmap(
                z=np.rot90(np.rot90(np.rot90(X_explain[x][:,:,ch]))),
                colorscale='Viridis',
                showscale=False,
                xaxis='x2',
                yaxis='y2',
            ),
            row=ch+1, col=1
        )
        for ro in range(ranked_outputs):
            endpt = np.quantile(np.abs(shap_values[ro][x][:,:,ch]), 0.99)
            fig.add_trace(
                go.Heatmap(
                    z=np.rot90(np.rot90(np.rot90(shap_values[ro][x][:,:,ch]))),
                    zmin=-endpt,
                    zmax=endpt,
                    colorscale='Picnic',
                    showscale=False,
                    xaxis='x2',
                    yaxis='y2',
                ),
                row=ch+1, col=ro+2
            )
            
            
    fig.update_layout(
#         title="Plot Title",
#         xaxis_title="x Axis Title",
#         yaxis_title="y Axis Title",
        margin=dict(l=10, r=10, t=10, b=10),
        xaxis2=dict(
            autorange=True,
            showgrid=False,
            ticks='',
            showticklabels=False
        ),
        yaxis2=dict(
            autorange=True,
            showgrid=False,
            ticks='',
            showticklabels=False
        ),
#         font=dict(
#             family="Courier New, monospace",
#             size=18,
#             color="#7f7f7f"
#         )
    )
    
    fig.show()