In [3]:
import os
import sys
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Add module tsmule to syspath
print("Current working dir:", os.getcwd())
sys.path.insert(0, '../../')

# Filter out all RuntimeWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)

import logging
logging.getLogger("stumpy").setLevel(logging.ERROR)


Current working dir: /content


In [2]:
# Add module tsmule to system path
from tsmule.xai.lime import LimeTS
from tsmule.xai.evaluation import PerturbationAnalysis
from tsmule.sampling.segment import MatrixProfileSegmentation, SAXSegmentation
from tsmule.xai.viz import visualize_segmentation_mask


ModuleNotFoundError: No module named 'tsmule'

In [None]:
import dill
from tensorflow import keras

data_dir = "."
cnn_model = keras.models.load_model(f'{data_dir}/beijing_air_multi_site_cnn_model.h5')
with open(f'{data_dir}/beijing_air_multi_site_test_data.dill', 'rb') as f:
    dataset_test = dill.load(f)

# Define a predict fn/model
def predict_fn(x):
    if len(x.shape) == 2:
        predictions = cnn_model.predict(x[np.newaxis]).ravel()
    if len(x.shape) == 3:
        predictions = cnn_model.predict(x).ravel()
    return predictions

In [None]:
# Get test set
n_instances = 100
X = dataset_test[0][:n_instances]
Y = dataset_test[1][:n_instances]


# Perturbation Analysis - overall

In [None]:
from sklearn import metrics

# Get relevance
explainer = LimeTS(n_samples=100)
relevance = [explainer.explain(x, predict_fn) for x in X]

{'original': 0.006235582040148577,
 'percentile': 0.0062537777804530445,
 'random': 0.006238465002214432}

In [None]:
# PerturbationAnalysis for Percentile
#   replacement_method = 'zeros|global_mean|local_mean|inverse_max|inverse_mean'
pa = PerturbationAnalysis()
scores = pa.analysis_relevance(X, Y, relevance,
                        predict_fn=predict_fn,
                        replace_method='zeros',
                        eval_fn=metrics.mean_squared_error,
                        percentile=90,
                        delta=0.1
                        )
print(scores)

print("Verfication: mse(original) <= mse(percentile) <= mse(random): ", \
    scores["original"] <= scores["percentile"] <= scores["random"])

{'original': 0.006235582040148577, 'percentile': 0.0062537777804530445, 'random': 0.006237656927009091}
Verfication: mse(original) <= mse(percentile) <= mse(random):  False


In [None]:
scores = pa.analysis_relevance(X, Y, relevance,
                        predict_fn=predict_fn,
                        replace_method='inverse_mean',
                        eval_fn=metrics.mean_squared_error,
                        percentile=90,
                        delta=0.1
                        )
scores
print("Verfication: mse(original) <= mse(percentile) <= mse(random): ", \
    scores["original"] <= scores["percentile"] <= scores["random"])

Verfication: mse(original) <= mse(percentile) <= mse(random):  False


In [None]:
scores = pa.analysis_relevance(X, Y, relevance,
                        predict_fn=predict_fn,
                        replace_method='inverse_max',
                        eval_fn=metrics.mean_squared_error,
                        percentile=90,
                        delta=0.1
                        )
scores
print("Verfication: mse(original) <= mse(percentile) <= mse(random): ", \
    scores["original"] <= scores["percentile"] <= scores["random"])

Verfication: mse(original) <= mse(percentile) <= mse(random):  False
