In [None]:
# LIME Analysis - Insitu Surface Porosity Predictions

import types
from lime.utils.generic_utils import has_arg
from skimage.segmentation import felzenszwalb, slic, quickshift
import copy
from functools import partial

import sklearn
import sklearn.preprocessing
from sklearn.utils import check_random_state
from skimage.color import gray2rgb
from tqdm.auto import tqdm

import scipy.ndimage as ndi
from skimage.segmentation._quickshift_cy import _quickshift_cython

from lime import lime_base
from lime.wrappers.scikit_image import SegmentationAlgorithm

import skimage
from matplotlib import colors
from skimage.segmentation import mark_boundaries, find_boundaries
from skimage.morphology import dilation,square
from collections import Counter


class BaseWrapper(object):
    """Base class for LIME Scikit-Image wrapper
    Args:
        target_fn: callable function or class instance
        target_params: dict, parameters to pass to the target_fn
    'target_params' takes parameters required to instanciate the
        desired Scikit-Image class/model
    """

    def __init__(self, target_fn=None, **target_params):
        self.target_fn = target_fn
        self.target_params = target_params

        self.target_fn = target_fn
        self.target_params = target_params

    def _check_params(self, parameters):
        """Checks for mistakes in 'parameters'
        Args :
            parameters: dict, parameters to be checked
        Raises :
            ValueError: if any parameter is not a valid argument for the target function
                or the target function is not defined
            TypeError: if argument parameters is not iterable
         """
        a_valid_fn = []
        if self.target_fn is None:
            if callable(self):
                a_valid_fn.append(self.__call__)
            else:
                raise TypeError('invalid argument: tested object is not callable,\
                 please provide a valid target_fn')
        elif isinstance(self.target_fn, types.FunctionType) \
                or isinstance(self.target_fn, types.MethodType):
            a_valid_fn.append(self.target_fn)
        else:
            a_valid_fn.append(self.target_fn.__call__)

        if not isinstance(parameters, str):
            for p in parameters:
                for fn in a_valid_fn:
                    if has_arg(fn, p):
                        pass
                    else:
                        raise ValueError('{} is not a valid parameter'.format(p))
        else:
            raise TypeError('invalid argument: list or dictionnary expected')

    def set_params(self, **params):
        """Sets the parameters of this estimator.
        Args:
            **params: Dictionary of parameter names mapped to their values.
        Raises :
            ValueError: if any parameter is not a valid argument
                for the target function
        """
        self._check_params(params)
        self.target_params = params

    def filter_params(self, fn, override=None):
        """Filters `target_params` and return those in `fn`'s arguments.
        Args:
            fn : arbitrary function
            override: dict, values to override target_params
        Returns:
            result : dict, dictionary containing variables
            in both target_params and fn's arguments.
        """
        override = override or {}
        result = {}
        for name, value in self.target_params.items():
            if has_arg(fn, name):
                result.update({name: value})
        result.update(override)
        return result


class SegmentationAlgorithm(BaseWrapper):
      """ Define the image segmentation function based on Scikit-Image
           implementation and a set of provided parameters
          Args:
             algo_type: string, segmentation algorithm among the following:
                 'quickshift', 'slic', 'felzenszwalb'
             target_params: dict, algorithm parameters (valid model paramters
                 as define in Scikit-Image documentation)
      """

def __init__(self, algo_type, **target_params):
       
    self.algo_type = algo_type
    if (self.algo_type == 'quickshift'):
            BaseWrapper.__init__(self, quickshift, **target_params)
            kwargs = self.filter_params(quickshift)
            self.set_params(**kwargs)
    elif (self.algo_type == 'felzenszwalb'):
            BaseWrapper.__init__(self, felzenszwalb, **target_params)
            kwargs = self.filter_params(felzenszwalb)
            self.set_params(**kwargs)
    elif (self.algo_type == 'slic'):
            BaseWrapper.__init__(self, slic, **target_params)
            kwargs = self.filter_params(slic)
            self.set_params(**kwargs)

def __call__(self, *args):
    return self.target_fn(args[0], **self.target_params)



class ImageExplanation(object):
    def __init__(self, image, segments):
        """Init function.
        Args:
            image: 3d numpy array
            segments: 2d numpy array, with the output from skimage.segmentation
        """
        self.image = image
        self.segments = segments
        self.intercept = {}
        self.local_exp = {}
        self.local_pred = None

    def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
                           num_features=5, min_weight=0.):
        """Init function.
        Args:
            label: label to explain
            positive_only: if True, only take superpixels that positively contribute to
                the prediction of the label.
            negative_only: if True, only take superpixels that negatively contribute to
                the prediction of the label. If false, and so is positive_only, then both
                negativey and positively contributions will be taken.
                Both can't be True at the same time
            hide_rest: if True, make the non-explanation part of the return
                image gray
            num_features: number of superpixels to include in explanation
            min_weight: minimum weight of the superpixels to include in explanation
        Returns:
            (image, mask), where image is a 3d numpy array and mask is a 2d
            numpy array that can be used with
            skimage.segmentation.mark_boundaries
        """
        if label not in self.local_exp:
            raise KeyError('Label not in explanation')
        if positive_only & negative_only:
            raise ValueError("Positive_only and negative_only cannot be true at the same time.")
        segments = self.segments
        image = self.image
        exp = self.local_exp[label]
        mask = np.zeros(segments.shape, segments.dtype)
        if hide_rest:
            temp = np.zeros(self.image.shape)
        else:
            temp = self.image.copy()
        if positive_only:
            fs = [x[0] for x in exp
                  if x[1] > 0 and x[1] > min_weight][:num_features]
        if negative_only:
            fs = [x[0] for x in exp
                  if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
        if positive_only or negative_only:
            for f in fs:
                temp[segments == f] = image[segments == f].copy()
                mask[segments == f] = 1
            return temp, mask
        else:
            for f, w in exp[:num_features]:
                if np.abs(w) < min_weight:
                    continue
                c = 0 if w < 0 else 1
                mask[segments == f] = -1 if w < 0 else 1
                temp[segments == f] = image[segments == f].copy()
                temp[segments == f, c] = np.max(image)
            return temp, mask


class LimeImageExplainer(object):
    """Explains predictions on Image (i.e. matrix) data.
    For numerical features, perturb them by sampling from a Normal(0,1) and
    doing the inverse operation of mean-centering and scaling, according to the
    means and stds in the training data. For categorical features, perturb by
    sampling according to the training distribution, and making a binary
    feature that is 1 when the value is the same as the instance being
    explained."""

    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
                 feature_selection='auto', random_state=None):
        """Init function.
        Args:
            kernel_width: kernel width for the exponential kernel.
            If None, defaults to sqrt(number of columns) * 0.75.
            kernel: similarity kernel that takes euclidean distances and kernel
                width as input and outputs weights in (0,1). If None, defaults to
                an exponential kernel.
            verbose: if true, print local prediction values from linear model
            feature_selection: feature selection method. can be
                'forward_selection', 'lasso_path', 'none' or 'auto'.
                See function 'explain_instance_with_data' in lime_base.py for
                details on what each of the options does.
            random_state: an integer or numpy.RandomState that will be used to
                generate random numbers. If None, the random state will be
                initialized using the internal numpy seed.
        """
        kernel_width = float(kernel_width)

        if kernel is None:
            def kernel(d, kernel_width):
                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

        kernel_fn = partial(kernel, kernel_width=kernel_width)

        self.random_state = check_random_state(random_state)
        self.feature_selection = feature_selection
        self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)

    def explain_instance(self, image, classifier_fn, labels=(1,),
                         hide_color=0,
                         top_labels=5, num_features=10000, num_samples=1000,
                         batch_size=10,
                         segmentation_fn=None, k_size = 100,
                         stack = 20,
                         distance_metric='cosine',
                         model_regressor=None,
                         random_seed=None):
        global  yuhao_exp 
        global  yuhao_score 
        global  yuhao_origpred 
        global  yuhao_label
        global  yuhao_distance
        global  yuhao_segment
        global  yuhao_data
        """Generates explanations for a prediction.
        First, we generate neighborhood data by randomly perturbing features
        from the instance (see __data_inverse). We then learn locally weighted
        linear models on this neighborhood data to explain each of the classes
        in an interpretable way (see lime_base.py).
        Args:
            image: 3 dimension RGB image. If this is only two dimensional,
                we will assume it's a grayscale image and call gray2rgb.
            classifier_fn: classifier prediction probability function, which
                takes a numpy array and outputs prediction probabilities.  For
                ScikitClassifiers , this is classifier.predict_proba.
            labels: iterable with labels to be explained.
            hide_color: TODO
            top_labels: if not None, ignore labels and produce explanations for
                the K labels with highest prediction probabilities, where K is
                this parameter.
            num_features: maximum number of features present in explanation
            num_samples: size of the neighborhood to learn the linear model
            batch_size: TODO
            distance_metric: the distance metric to use for weights.
            model_regressor: sklearn regressor to use in explanation. Defaults
            to Ridge regression in LimeBase. Must have model_regressor.coef_
            and 'sample_weight' as a parameter to model_regressor.fit()
            segmentation_fn: SegmentationAlgorithm, wrapped skimage
            segmentation function
            random_seed: integer used as random seed for the segmentation
                algorithm. If None, a random integer, between 0 and 1000,
                will be generated using the internal random number generator.
        Returns:
            An ImageExplanation object (see lime_image.py) with the corresponding
            explanations.
        """
        if len(image.shape) == 2:
            image = gray2rgb(image)
        if random_seed is None:
            random_seed = self.random_state.randint(0, high=1000)

        if segmentation_fn is None:
             segmentation_fn = SegmentationAlgorithm('felzenszwalb', scale=50, sigma=0.8, min_size = 2, multichannel=True)
#             segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=1,
#                                                    max_dist=5, ratio=0.2,
#                                                    random_seed=random_seed)
            # segmentation_fn = SegmentationAlgorithm('slic',n_segments=k_size, compactness=1000, max_iter=5, sigma=0.8, channel_axis = -1)
        # try:
          #  segmentation_fn = SegmentationAlgorithm('slic',n_segments=k_size, compactness=1000, max_iter=5, sigma=0.8, channel_axis = -1)
          #  segments = segmentation_fn(image)
             # segments = np.arange(0,k_size).reshape((int(np.sqrt(k_size)),int(np.sqrt(k_size))))
             segments = np.arange(0,k_size).reshape(C1, C2)
             yuhao_segment = segments 
        # except ValueError as e:
        #     raise e

        fudged_image = image.copy()
        print("fudged image shape", fudged_image.shape)
        if hide_color is None:
            for x in np.unique(segments):
                fudged_image[segments == x] = (
                    np.mean(image[segments == x][:, 0]),
                    np.mean(image[segments == x][:, 1]),
                    np.mean(image[segments == x][:, 2]))
        else:
            fudged_image[:] = hide_color

        top = labels

        data, labels = self.data_labels(image, fudged_image, segments, stack,
                                        classifier_fn, num_samples,
                                        batch_size=batch_size)
        yuhao_data = data

        distances = sklearn.metrics.pairwise_distances(
            data,
            data[0].reshape(1, -1),
            metric=distance_metric
        ).ravel()
        yuhao_distance = distances

        ret_exp = ImageExplanation(image, segments)
        if top_labels:
            top = np.argsort(labels[0])[-top_labels:]
            ret_exp.top_labels = list(top)
            ret_exp.top_labels.reverse()
        for label in top:
            (ret_exp.intercept[label],
             ret_exp.local_exp[label],
             ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
                data, labels, distances, label, num_features,
                model_regressor=model_regressor,
                feature_selection='auto')
#            print(top)
            yuhao_label = labels
            yuhao_exp = ret_exp.local_exp
            yuhao_score = ret_exp.score
            yuhao_origpred = ret_exp.local_pred
#            print(ret_exp.local_exp)
#            print(ret_exp.score)
#            print(ret_exp.local_pred)
            #self.feature_selection
        return ret_exp


    def data_labels(self,
                    image,
                    fudged_image,
                    segments,
                    stack,
                    classifier_fn,
                    num_samples,
                    batch_size=10):
        """Generates images and predictions in the neighborhood of this image.
        Args:
            image: 3d numpy array, the image
            fudged_image: 3d numpy array, image to replace original image when
                superpixel is turned off
            segments: segmentation of the image
            classifier_fn: function that takes a list of images and returns a
                matrix of prediction probabilities
            num_samples: size of the neighborhood to learn the linear model
            batch_size: classifier_fn will be called on batches of this size.
        Returns:
            A tuple (data, labels), where:
                data: dense num_samples * num_superpixels
                labels: prediction probabilities matrix
        """
        global yuhao_imgs
        
        yuhao_imgs = []
        n_features = np.unique(segments).shape[0] * stack
        data = self.random_state.randint(0, 2, num_samples * n_features)\
            .reshape((num_samples, n_features))
        labels = []
        data[0, :] = 1
        imgs = []
      
        for row in tqdm(data):
            temp = copy.deepcopy(image)
            zeros = np.where(row == 0)[0]
            mask = np.zeros((stack, segments.shape[0],segments.shape[1])).astype(bool)
            print("mask shape", mask.shape)
            for z in zeros:
                mask[int(z/1935), segments == z%1935] = True
            temp[mask] = fudged_image[mask]
            imgs.append(temp)
            if len(imgs) == batch_size:
                preds = classifier_fn(np.array(imgs))
                labels.extend(preds)
                yuhao_imgs.append(np.array(imgs))
                imgs = []
            
        if len(imgs) > 0:
            preds = classifier_fn(np.array(imgs))
            labels.extend(preds)
            yuhao_imgs.append(np.array(imgs))
        return data, np.array(labels)


In [None]:
NUMPY_INPUT_DATA = TENSOR_INPUT_DATA_PRINTING.numpy()

# seg_fn = 'slic'
# LX = np.load('cwm_markers_dataset/LX.npy')
# LY = np.load('cwm_markers_dataset/LXY.npy')

LX = NUMPY_INPUT_DATA[:,:,:,:]
LY = y_target[:]
print("LX shape", LX.shape)
print("LY shape", LY.shape)
model = load_model('porosity_bin_model_printing_fold_no_4.h5')
pred = model.predict(LX)

# LY = np.array([1 if a == 0 else 0 for a in Y])
pred_class = [0 if a < 0.5 else 1 for a in pred]
correct_X_ind = list(np.where(pred_class==LY)[0])
print(len(correct_X_ind))

import lime
from lime import lime_image
from matplotlib import pyplot as plt

sam_l = 1
stack = 4  # Tensor size
C1 = 129  # NUmber of frequency bands in spectrogram
C2 = 15   # Number of time intervals in spectrogram
stack_shift = 1
window_length = 300
train_size = 0.9
n_class = 2
k_size = C1*C2
num_s = 3500   # Number of perturbations

channel = 4    # Number of data channels 
inter_period = 5400
lag = 0
fs = 2000
w_shift = int(window_length/2)
batch_size = 32
epochs = 2000   # Number of epochs
lr = 0.0002     # Learning rate

# Rename file directories 
fname_out = 'C:\\Users\.............\...'
data_f = fname_out +'/'+str(sam_l)+'s_continuous/inter_'+str(inter_period)+'s/with_seizure/DPGMM_C'+str(C)+'/window'+str(window_length)+'/shift'+str(w_shift)
save_data_f = data_f +'/stack'+str(stack)+'_shift'+str(stack_shift)+'/scaled01'        #############
save_f = fname_out +'/'+str(n_class)+'class/CNN_best_train_lr'+str(lr)+'_epoch'+str(epochs) 


exp = LimeImageExplainer()

score = []
f_name1 = save_f+'/detailed_lime/numseg'+str(k_size)+'_nums'+str(num_s)+'/local results'
if not os.path.exists(f_name1):
    os.makedirs(f_name1)

for sample_ind in range(len(correct_X_ind)):
# for sample_ind in range(0,1):

    # Explain a single sample
    sample = LX[sample_ind, :, :, :].astype('double')
    y_samp = LY[sample_ind]       
    explanation = exp.explain_instance(sample, model, top_labels=n_class, hide_color = 0, num_samples=num_s, k_size = k_size, stack = stack) 
    score.append(explanation.score)
  
    # only study the good quality explanations
    if explanation.score >= 0.75 and explanation.score <= 1:

        # save weights and lime ratios
        weightRatio = []
        for i, e in explanation.local_exp[explanation.top_labels[0]]:
            weightRatio.append((i, e, explanation.local_pred, explanation.intercept, explanation.top_labels[0]))
  
        WR = pd.DataFrame(weightRatio)
        WR.to_csv(f_name1+'/'+str(sample_ind)+'_weightRatioPred_score'+str(round(explanation.score,2))+'_numseg'+str(len(explanation.local_exp[explanation.top_labels[0]]))+'.csv', index = False)
        
score = np.array(score)
np.save(save_f+'/detailed_lime/numseg'+str(k_size)+'_nums'+str(num_s)+'/score.npy',score)