In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from utils import *
from model_utils import *
from data import *
from c2d_models import *

import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
from deepexplain.tensorflow import DeepExplain




In [3]:
# c2d_model = C2D_AE_128_3x3(isTrain = False, max_value=1000)
c2d_model = C2D_AE_128_3x3(isTrain = True)
c2d_model.model.load_weights("trained_models/C2D_AE_128_3x3_IR_DISTRACTION/model.h5")
model = c2d_model.model

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [4]:
IR_TEST_DIR = "TEST_FOR_PAPER/"

test_dict = dict()
for class_dir in read_directory_contents(IR_TEST_DIR):
    test_dict[class_dir.split("/")[-1]] = np.array([read_image(file, resize_to=(128,128))/255. for file in read_directory_contents(class_dir)])

In [5]:
out = model.predict(test_dict["gps"])




In [14]:
# 1 -> High Loss -> Anomaly
# 0 -> Low Loss -> Normal
X_test = list()
y_test = list()
for key in set(test_dict.keys()):
    X_test.append(test_dict[key])
    y_test += [0 if "normal" in key else 1]*len(test_dict[key])
X_test = np.concatenate(X_test)
y_test = np.array(y_test)

In [25]:
xai_methods = {
    "Gradient Inputs": "grad*input",
    "Saliency": "saliency",
    "Integrated Gradients": "intgrad",
    "Deep LIFT": "deeplift",
    "e-LRP": "elrp",
    "Occlusion": "occlusion"
    }

xai_results = dict()

with DeepExplain(session=K.get_session()) as de:
    input_tensor = model.layers[0].input
    fModel = Model(inputs=input_tensor, outputs = model.layers[-2].output)
    target_tensor = fModel(input_tensor)

    for method_name, method_tag in xai_methods.items():
        print("-"*40)
        print("Method: %s"%method_name)
        print("-"*40)
        try: attributions = de.explain(method_tag, target_tensor, input_tensor, X_test, ys=X_test)
        except Exception as e: 
            print("ERROR:", e)
            continue
        xai_results[method_name] = attributions

    print("-"*40)
    print("Method: %s"%method_name)
    print("-"*40)
    try: xai_results["Shapely"] = de.explain('shapley_sampling', target_tensor, input_tensor, X_test, ys=X_test, samples=20)
    except Exception as e: print("ERROR:", e)

----------------------------------------
Method: Gradient Inputs
----------------------------------------
----------------------------------------
Method: Saliency
----------------------------------------
----------------------------------------
Method: Integrated Gradients
----------------------------------------
----------------------------------------
Method: Deep LIFT
----------------------------------------
ERROR: 'NoneType' object cannot be interpreted as an integer
----------------------------------------
Method: e-LRP
----------------------------------------
----------------------------------------
Method: Occlusion
----------------------------------------
----------------------------------------
Method: Occlusion
----------------------------------------


KeyboardInterrupt: 

In [27]:
dump_pickle(xai_results, "deep_explainer_results.pkl")

True

In [54]:
xai_results.keys()

dict_keys(['Gradient Inputs', 'Saliency', 'Integrated Gradients', 'e-LRP', 'Occlusion'])

In [57]:
from torchvision.utils import save_image, make_grid

In [58]:
import torch

In [62]:
sample = xai_results["Gradient Inputs"][:16]

In [64]:
sample = (sample - sample.min()) / (sample.max()-sample.min())

In [67]:
save_image(make_grid(torch.tensor(sample).transpose(2,3).transpose(1,2), nrow=4), "results/graident_inputs.png")

# LIME

In [28]:
from lime import lime_image

In [29]:
from skimage.segmentation import mark_boundaries

In [30]:
c2d_model = C2D_AE_128_3x3(isTrain = False, max_value=1000)
c2d_model.model.load_weights("trained_models/C2D_AE_128_3x3_IR_DISTRACTION/model.h5")
model = c2d_model.model

In [31]:
explainer = lime_image.LimeImageExplainer()

In [35]:
lime_results = list()
for xs in X_test:
    explanation = explainer.explain_instance(xs.astype('double'), model.predict, num_features=128*128)
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=32, hide_rest=True)
    result_mask = mark_boundaries(temp / 2 + 0.5, mask)
    lime_results.append(result_mask)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [38]:
dump_pickle(lime_results, "lime_results.pkl")

True

In [36]:
from keras_explain.lrp import LRP
from keras_explain.grad_cam import GradCam
from keras_explain.guided_bp import GuidedBP

In [42]:
lrp_explainer = LRP(model)
guidedbp_explainer = GuidedBP(model)
gradcam_explainer = GradCam(model, layer = -2)

k_explainer_results = list()

for xs in X_test:
    exp_dict = dict()
    for explainer, type_ in zip([lrp_explainer, guidedbp_explainer, gradcam_explainer], ["LRP", "Guided_BP", "GradCAM"]):
        try:
            exp = explainer.explain(xs, 1)
            exp_dict[type_] = exp
        except:
            print(type_)
    k_explainer_results.append(exp_dict)



LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Guided_BP
LRP
Gu