In [1]:
import numpy as np
import pandas as pd
import shap
import xgboost as xgb
from scipy.spatial import distance
from skmultiflow.data.data_stream import DataStream
from skmultiflow.drift_detection import PageHinkley
from skmultiflow.drift_detection.adwin import ADWIN
from skmultiflow.drift_detection import KSWIN 
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, matthews_corrcoef, accuracy_score
from scipy.stats import entropy as e

In [2]:
class ShapDetector:
    def __init__(self, base_detector_type = None, base_detector_object = None, base_detector_config = None):
        # default base detector
        self.base_detector = None
        self.base_detector_type = None
        
        # ph default params
        self.alpha = None
        self.delta = None
        self.min_instances = None
        self.threshold = None
        
        # adwin standard params
        self.ad_delta = None
        
        # kswin standard params
        self.ks_alpha = None
        
        # performance metrics
        self.auc = None
        self.acc_score = None
        self.weighted_f1_score = None
        self.mcc_score = None
        self.score_list = []
        self.weighted_f1_list = []


        # change statistics
        self.shap_list = []
        self.drift_detections = []
        self.distances = []
        self.p_values = []
        self.prob_predictions = []
        self.predictions = []
        self.true_label_list = []
        
        # statitics
        self.true_drift_points = []
        self.al_percentage = None
        self.sparsity = None
        
        # shap detector
        self.retrainsize = None
        self.samplesize = None
        self.initial_batch_size = None
        self.approach = None
        
        # error rate based detector
        self.err_based = None
        
        # model type
        self.model_type = None
                
        # active learning
        self.samping = None
            
        # define statistical detector
        if base_detector_object is not None:
            self.base_detector = base_detector_object
            print('Parameters of base detector object are not listed individually in the evaluation. For this purpose the configuration must be specified at the function call')

        else: 
            if base_detector_type is None: 
                self.base_detector = self.create_base_detector()
                    
            elif base_detector_type == "ph":
                self.base_detector = self.create_base_detector(base_detector_type = 'ph', base_detector_config = base_detector_config)
            
            elif base_detector_type == "adwin":
                self.base_detector = self.create_base_detector(base_detector_type = 'adwin', base_detector_config = base_detector_config)
                
            elif base_detector_type == "kswin":
                self.base_detector = self.create_base_detector(base_detector_type = 'kswin', base_detector_config = base_detector_config)
                
            else:
                raise ValueError('You have to provide either a detector object or a detector type ("ph", "adwin" or "kswin") with appropriate configuration!')

                
    def active_learning(self, x_labeled, x_unlabeled, y_labeled, columns, data_full, uncertainty_threshold_in, amount_classes, clf, sampling, al_percentage_in): 
        # initialize
        cols = list(range(1, columns+1))
        cols = [str(x) for x in cols]
        cols_full = cols + ['label']
        data_full.columns = list(cols_full)
        
        updated_threshold = False
        uncertainty_threshold = uncertainty_threshold_in
        x_unlabeled = pd.DataFrame(x_unlabeled, columns = cols)  
        num_ask_instances = int(np.around((len(x_unlabeled)*al_percentage_in/100), decimals=0)) #total amount of unlabeled instances to query
        num_ask_instances_subset = int(np.ceil(num_ask_instances/10)) #batch size to query per AL iteration
        num_already_ask = 0
        cols_probas = ['probas_class'+str(c) for c in range(1,amount_classes+1)]
        cols_probas_conf = cols_probas[:]+['Confidence']
        
        it = 0
        less_classes = False
        
        while num_already_ask < num_ask_instances:              
            # sort by least confident instances from unlabeled data
            probabilities = clf.predict_proba(x_unlabeled.values)    
                
            # in case that fewer classes are predicted than are present in the data set
            if probabilities.shape[1] != len(cols_probas):
                diff = np.abs(len(cols_probas)-probabilities.shape[1])
                for i in range(diff):
                    probabilities = [np.concatenate((p, [0]),axis=0) for p in probabilities]
                probabilities = np.vstack(probabilities)
                less_classes = True    
            
            probas = pd.DataFrame(probabilities, columns=cols_probas)
            
            if sampling == 'margin':
                probas['Confidence'] = self.get_confidence(np.array(clf.predict_proba(x_unlabeled.values)), sampling)
            else:
                probas['Confidence'] = self.get_confidence(np.array(clf.predict_proba(x_unlabeled.values)), sampling)
                
            x_unlabeled.reset_index(drop=True, inplace=True)
            probas.reset_index(drop=True, inplace=True)            
            x_unlabeled = pd.concat([x_unlabeled, probas], axis=1)
            
            # determine the most uncertain instances that have a confidence level lower than threshold
            if sampling == 'margin':
                x_unlabeled.sort_values(by=['Confidence'], ascending=True, inplace=True)
                instances_ask = x_unlabeled[x_unlabeled['Confidence']< uncertainty_threshold].iloc[:num_ask_instances_subset,:]
                num_already_ask += len(instances_ask)
            else:
                x_unlabeled.sort_values(by=['Confidence'], ascending=False, inplace=True)
                instances_ask = x_unlabeled.iloc[:num_ask_instances_subset,:].copy() 
                num_already_ask += len(instances_ask)

            # if too many instances were requested, overwrite the batch with exactly the right number of instances to meet the limit
            if num_already_ask > num_ask_instances and sampling == 'margin':
                num_already_ask = num_already_ask-len(instances_ask)
                instances_ask = x_unlabeled[x_unlabeled['Confidence']< uncertainty_threshold].iloc[:num_ask_instances-num_already_ask,:]
                num_already_ask += len(instances_ask)
            elif num_already_ask > num_ask_instances and sampling != 'margin':
                num_already_ask = num_already_ask-len(instances_ask)
                instances_ask = x_unlabeled.iloc[:num_ask_instances-num_already_ask,:].copy()
                num_already_ask += len(instances_ask)

            # if no instances fall below threshold, update the threshold
            if len(instances_ask) == 0 and num_already_ask < num_ask_instances and sampling == 'margin': 
                uncertainty_threshold += 0.05
                updated_threshold = True
                x_unlabeled.drop(labels = cols_probas_conf, axis=1, inplace=True)
                continue
            
            # if threshold was updated and therefore labels were found, reset the threshold to its initial value
            if len(instances_ask) !=0 and updated_threshold==True and sampling == 'margin':
                updated_threshold = False
                uncertainty_threshold = uncertainty_threshold_in
  
            x_unlabeled.drop(labels = cols_probas_conf, axis=1, inplace=True)  
            instances_ask.drop(labels = cols_probas_conf, axis=1, inplace=True) 
            
            # remove the requested instances from x_unlabeled
            x_unlabeled = x_unlabeled.iloc[len(instances_ask):,:]
            
            # get labels for least confident instances
            new_labels = self.get_labels(data_full, instances_ask.values)

            # attach labeled instances to labeled data (x_labeled)
            x_labeled = np.concatenate((x_labeled, instances_ask)) 
            y_labeled = np.concatenate((y_labeled, new_labels))
            
            if it == 0:
                if amount_classes > 2: 
                    clf = xgb.XGBClassifier(objective = 'multi:softprob', num_class = amount_classes)
                    clf.fit(x_labeled, y_labeled)
                else:
                    clf = xgb.XGBClassifier()
                    clf.fit(x_labeled, y_labeled) 
                    
            elif less_classes == False:
                clf.fit(x_labeled, y_labeled)
                
            else:
                clf.fit(x_labeled, y_labeled)
                less_classes = False

            it += 1
        return x_labeled, y_labeled
    
    
    def detect_drift(self, fix_drifts, data_sparse, data_full, sparsity, initial_batch_size, initial_batch_sample, samplesize, retrainsize, distance_measure, approach, al_percentage, uncertainty_threshold, true_drift_points, err_based, multiclass, amount_classes, real_world, sampling, clf=None):
        self.reset_statistics()
        self.al_percentage = al_percentage
        self.true_drift_points = true_drift_points
        self.retrainsize = retrainsize
        self.samplesize = samplesize
        self.sparsity = sparsity
        self.initial_batch_size = initial_batch_size
        self.err_based = err_based
        self.approach = approach
        self.sampling = sampling
                
        data_full = data_full
        data = data_sparse

        # initialize model and prepare data
        if multiclass:
            clf = xgb.XGBClassifier(objective = 'multi:softprob', num_class = amount_classes)
            self.model_type = 'Xg'   
        elif multiclass == False and clf is None:
            clf = xgb.XGBClassifier()
            self.model_type = 'Xg'            
        else:
            self.model_type = 'Input'
            
        if sparsity == 100:
            stream = DataStream(data = data, allow_nan = True)
        else:
            stream = DataStream(data = data_full)

        # get initial trainig data
        x_train, y_train = stream.next_sample(initial_batch_size)
        rows, columns = x_train.shape
        
        # only for initial model training: overwrite unlabaled data with labeled data for initial model training
        if sparsity == 100:
            x_train = data_full.iloc[:initial_batch_size, :-1].values
            y_train = data_full.iloc[:initial_batch_size, -1].values

        clf.fit(x_train, y_train)
        
        # in real world scenarios, we keep the initial training batch for later retraining 
        if real_world:
            x_initial_batch_sample = initial_batch_sample.iloc[:, :-1].values 
            y_initial_batch_sample = initial_batch_sample.iloc[:, -1].values 
        
        # initial creation of explainer object and calculation of model explanation
        if err_based == False:
            explainer = self.create_explainer(classifier=clf, data = x_train[-100:], approach = approach) 
            if not multiclass:
                model_shap_distr = explainer.shap_values(x_train).tolist()
            else:
                model_shap_distr = explainer.shap_values(x_train)

        # Set initial values
        sample_count = 1
        x_storage = np.zeros((0,columns))
        y_storage = np.zeros((0,1))
        
        # detect drift and retrain model 
        print('start')
        for i in range(stream.n_remaining_samples()):
            x_test, y_test = stream.next_sample(samplesize)
            
            #if i%5000 == 0:
                #print('it', i)
            
            if len(x_test) == 0: 
                break
            
            # store all batches (labeled and unlabeled instances from stream)
            x_storage = np.append(x_storage, x_test, axis=0)
            y_storage = np.append(y_storage, y_test)         

            # collect explanations
            if err_based==False:
                if not multiclass:
                    shap_values = explainer.shap_values(x_test).tolist()
                    if samplesize > 1:
                        shap_values = self.aggregate_shap_values(shap_values)
                else:
                    shap_values = explainer.shap_values(x_test)
                    shap_values = self.aggregate_shap_values_multiclass(shap_values)

            # collect results
            predictions = clf.predict(x_test)
            y_test = self.get_labels(data_full, x_test)
            
            if not multiclass:                
                self.prob_predictions.extend(clf.predict_proba(x_test)[:,1]) 
            else:
                self.prob_predictions.extend(clf.predict_proba(x_test)) 
                
                
            self.predictions.extend(predictions) 
            self.true_label_list.extend(y_test)
            self.score_list.append(accuracy_score(y_test, predictions)) 
            self.weighted_f1_list.append(precision_recall_fscore_support(y_test, predictions, average = 'weighted', warn_for=tuple())[2])
            
            if err_based==False:
                self.shap_list.append(shap_values[0])
                # compute distance
                dist = self.compute_distance(shap_values, model_shap_distr, distance_measure, multiclass, columns)
                # use distance
                self.base_detector.add_element(dist)
                self.distances.append(dist)
                
            elif self.err_based:
                #print(accuracy_score(y_test, predictions))
                self.base_detector.add_element(accuracy_score(y_test, predictions))
                
            sample_count += 1

            # trigger retraining because of concept drift
            if i in fix_drifts:
                self.drift_detections.append(i)
                print('Drift, No. of iterations:',  i, 'Samples: ', i*samplesize)

                # get labels for unlabeled instances by active learning
                x_labeled, x_unlabeled, y_labeled, y_unlabeled = self.filter_missing(x_storage[-retrainsize:], y_storage[-retrainsize:])       
                
                # In case of the KSWIN scenario with fix predetermined drift points, corresponding parameters
                # are assigned a corresponding value, which prevents this method from being started
                # since KSWIN is only evaluated in a scenario with 100 % label availability
                if len(x_unlabeled) != 0 or sparsity != 0:
                    x_labeled, y_labeled = self.active_learning(x_labeled, x_unlabeled, y_labeled, columns, data_full, uncertainty_threshold, amount_classes, clf, sampling, al_percentage) 

                # in real world scenarios both, 10% of the initial batch and the last x-samples are used together
                if real_world:
                    x_labeled = np.append(x_labeled, x_initial_batch_sample, axis=0)
                    y_labeled = np.append(y_labeled, y_initial_batch_sample)
                    
                #retrain the model
                clf.fit(x_labeled, y_labeled)
                
                # compute global explanation 
                if err_based==False:
                    explainer = self.create_explainer(classifier=clf, data = x_labeled, approach = approach)
                
                    if not multiclass:
                        model_shap_distr = explainer.shap_values(x_labeled).tolist()
                    else:
                        model_shap_distr = explainer.shap_values(x_labeled)

                sample_count = self.reset_change_parameters() 
        
        # compute performance metrics
        self.acc_score = accuracy_score(self.true_label_list, self.predictions)
        self.weighted_f1_score = precision_recall_fscore_support(self.true_label_list, self.predictions, average = 'weighted', warn_for=tuple())[2]
        self.mcc_score = matthews_corrcoef(self.true_label_list, self.predictions)                
        if not multiclass:
            self.auc = roc_auc_score(self.true_label_list, self.prob_predictions, average='weighted') 
        else:
            sorted_labels = clf.classes_
            self.auc = roc_auc_score(self.true_label_list, self.prob_predictions, average='weighted', multi_class = 'ovo', labels = sorted_labels)

        
        self.create_export()
        
        return self
    
    
    def get_statistics(self, drift_range):
        statistics = {}
        
        statistics["Model"] = self.model_type
        
        # shap detector
        statistics["Retrainsize"] = self.retrainsize
        statistics["Samplesize"] = self.samplesize
        statistics["Initial Instances"] = self.initial_batch_size
        statistics["Approach"] = 'Proba' if self.approach == 2 else 'Standard'
        
        # base detector
        statistics["Ph Alpha"] = self.alpha
        statistics["Ph Delta"] = self.delta
        statistics["Ph Min Inst"] = self.min_instances
        statistics["Ph Threshold"] = self.threshold
        statistics["Ad Delta"] = self.ad_delta
        statistics["Ks Alpha"] = self.ks_alpha
        statistics["Base Detector"] = self.base_detector_type
                
        # scores
        statistics["Weighted F1"] = np.round(self.weighted_f1_score, decimals = 3)
        statistics["Acc"] = np.round(self.acc_score, decimals = 3)
        statistics["ROC_AUC"] = np.round(self.auc, decimals = 3) if self.auc is not None else '-'      
        statistics["Mcc"] = np.round(self.mcc_score, decimals = 3)
        statistics["Detections Count"] = len(self.drift_detections)
        
        # drifts        
        statistics["FAC"], statistics["MDC"], statistics["MDR"], statistics["MTD"], statistics["MTFA"], statistics["MTR"] = np.round(self.get_drift_metrics(self.drift_detections, self.true_drift_points, drift_range), decimals = 3)
        statistics["True Drift Points"] = self.true_drift_points
        statistics["Triggered Drifts"] = self.drift_detections
                
        # error-rate based
        statistics["Error Based"] = str(self.err_based)
        
        #active learning
        statistics["Sampling"] = self.sampling

        # labels
        if self.err_based == True or self.sparsity == 0: # measured by the metric ("retrainsize") of the batch for retraining 
            statistics["Labels Retraining %"] = 100
        else:
            statistics["Labels Retraining %"] = self.al_percentage

        if self.err_based: 
            statistics["Labels Detection %"] = 100 # Percentage of labels available after the initial training
        else:
            statistics["Labels Detection %"] = 0

        return statistics
    
    
    @staticmethod
    def get_drift_metrics(drift_detections, true_drift_points, drift_range):
        drift_detections.sort()
        true_drift_points.sort()
        diffs = []
        
        # MDC params
        mdc = 0
        
        # FAC params
        detections = []
        false_detections = []
        
        # FAC - false alarm count
        if len(drift_detections) > 0 and len(true_drift_points) > 0:
            for idx, td in enumerate(true_drift_points):
                if idx+1 < len(true_drift_points):
                    n_td = true_drift_points[idx+1]
                    # check if drift is detected
                    for dd in drift_detections:
                        if td <= dd < n_td:
                            detections.append(dd)  
             
                    # check if more than one drift triggered for one true drift and collect false detections if detections are outside of drift_range
                    if len(detections) > 1:
                        detections.sort()    
                        detections.pop(0)
                        if len(detections) > 0:
                            for dd in detections:
                                if dd>(td+drift_range):
                                    false_detections.append(dd)   
                    detections.clear()
        
            # detections of first and last true drift 
            if drift_detections:
                # last                                            
                for dd in drift_detections:
                    if dd > max(true_drift_points):
                        detections.append(dd) 
                if len(detections) > 1:
                    detections.sort()    
                    detections.pop(0)
                    if len(detections) > 0:
                        for dd in detections:
                            if dd>(max(true_drift_points)+drift_range):
                                false_detections.append(dd)                                    
                # first                                            
                for dd in drift_detections:
                    if dd < min(true_drift_points):
                        false_detections.append(dd)
                        
            fac = len(false_detections)
            
        else:
            fac = 0

        
        # MDC - missed detection count
        for idx, td in enumerate(true_drift_points):
            found = False
            if idx+1 < len(true_drift_points):
                n_td = true_drift_points[idx+1]
                # check if drift is detected
                for dd in drift_detections:
                    if td <= dd < n_td:
                        found = True
                if not found:
                    mdc += 1
        # did not find any drift    
        if len(drift_detections) == 0:
            mdc = len(true_drift_points)
        # did not find last drift
        elif true_drift_points and (max(true_drift_points) > max(drift_detections)):
            mdc += 1
        
        
        # MTD - mean time to detection
        if len(drift_detections) > 0 and len(true_drift_points) > 0:
            for idx, td in enumerate(true_drift_points):
                if idx+1 < len(true_drift_points):
                    n_td = true_drift_points[idx+1]
                    for dd in drift_detections:
                        # check if drift is detected
                        if td <= dd < n_td:
                            diffs.append(dd-td)
                            break
            # diff last drift
            if drift_detections:
                for dd in drift_detections:
                    if dd > max(true_drift_points):
                        diffs.append(dd-td)       
                        break
            mtd = np.round(np.mean(diffs), decimals = 3) 
        else:
            mtd = 0        
        
        
        # MDR - missed detection rate
        if len(true_drift_points) >= 1 and mdc is not None :
            mdr = mdc/len(true_drift_points)
        else:
            mdr = 0

            
        # MTFA - mean time between false alarms
        false_detections.sort()
        if false_detections:
            mtfa = np.round( np.mean([false_detections[i + 1] - false_detections[i] for i in range(len(false_detections)-1)]), decimals = 3 )
        else:
            mtfa = 0
            

        # MTR
        if mtd != 0:
            mtr = np.round((mtfa/mtd)*(1-mdr), decimals = 3)
        else:
            mtr = 0
                
        return fac,mdc,mdr,mtd,mtfa,mtr 
    
    
    def create_export(self):
        # reduce file size by removing unnecessary class attributes
        self.base_detector = None
        return
              
        
    def create_base_detector(self, base_detector_type=None, base_detector_config=None):
        if base_detector_type is None: 
            self.base_detector = PageHinkley()
            self.base_detector_type = 'ph'
            
            self.alpha = self.base_detector.alpha
            self.delta = self.base_detector.delta
            self.min_instances = self.base_detector.min_instances
            self.threshold = self.base_detector.threshold
            
        elif base_detector_config is not None:
            if base_detector_type == 'adwin':
                self.base_detector = ADWIN()
                self.base_detector_type = 'adwin'
                self.ad_delta=self.base_detector.delta = base_detector_config["ad_delta"]
            
            elif base_detector_type == 'kswin':
                self.base_detector = KSWIN()
                self.base_detector_type = 'kswin'
                self.ks_alpha=self.base_detector.alpha = base_detector_config["ks_alpha"]

            elif base_detector_type == 'ph':
                self.base_detector = PageHinkley()
                self.base_detector_type = 'ph'
                self.alpha=self.base_detector.alpha = base_detector_config["alpha"]
                self.delta=self.base_detector.delta = base_detector_config["delta"]
                self.min_instances=self.base_detector.min_instances = base_detector_config["min_instances"]
                self.threshold=self.base_detector.threshold = base_detector_config["threshold"]
        else:
            raise ValueError('If you provide a detector type, you have to provide an appropriate configuration as well!')
        
        
        return self.base_detector 
        
        
    def compute_distance(self, shap_values_inst, shap_values_mod, distance_measure, multiclass, columns):
        #compute distance between local explanation of an instance and the global explanation of the model                 
        if distance_measure == 'euclidean':
            if not multiclass:
                if len(shap_values_mod) > 1:
                    shap_values_mod = self.aggregate_shap_values(shap_values_mod)
                diff = self.euclidean_distance(shap_values_inst[0], shap_values_mod[0], columns=columns)
            else:
                if len(shap_values_mod) > 1:
                    shap_values_mod = self.aggregate_shap_values_multiclass(shap_values_mod)
                diff = self.euclidean_distance(shap_values_inst[0], shap_values_mod[0], columns=columns)
                
        elif distance_measure == 'manhattan':
            if not multiclass:
                if len(shap_values_mod) > 1:
                    shap_values_mod = self.aggregate_shap_values(shap_values_mod)
                diff = self.manhattan_distance(shap_values_inst[0], shap_values_mod[0]) 
            else:
                if len(shap_values_mod) > 1:
                    shap_values_mod = self.aggregate_shap_values_multiclass(shap_values_mod)
                diff = self.manhattan_distance(shap_values_inst[0], shap_values_mod[0]) 
        else:
            raise ValueError('Please assign distance measure "manhattan" or "euclidean"!')
        
        return diff
  

    def reset_statistics(self):
        self.shap_list = []
        self.score_list = []
        self.weighted_f1_list = []
        self.distances = []
        self.base_detector.reset()
        self.drift_detections = []  
        self.p_values = []
        self.prob_predictions = []
        self.predictions = []
        self.true_label_list = []
        
        self.true_drift_points = []
        self.retrainsize = None
        self.samplesize = None
        self.al_percentage = None
        self.sparsity = None
        self.initial_batch_size = None
        self.err_based = None
        self.approach = None
        self.model_type = None
        
        self.auc = None
        self.acc_score = None
        self.weighted_f1_score = None
        self.mcc_score = None
        
        self.sampling = None
    
    def reset_change_parameters(self):
        self.base_detector.reset()
        sample_count = 1 
        
        return sample_count
    
    
    @staticmethod
    def get_indices(data_full, x_labeled): 
        data_v = data_full.iloc[:,:-1].values
        pairwise_compare = data_v == x_labeled[:, np.newaxis, :]
        result = pairwise_compare.all(axis=2)
        indices = [data_full[l].index[0] for l in result]
        
        return indices

    @staticmethod
    def get_labels(data_full, x_unlabeled):
        data_v = data_full.iloc[:,:-1].values
        pairwise_compare = data_v == x_unlabeled[:, np.newaxis, :]
        result = pairwise_compare.all(axis=2)
        labels = [data_full[l].iloc[0,-1] for l in result]
        
        return labels
    
    
    @staticmethod
    def filter_missing(data, labels):
        nan = np.isnan(labels)
        x_labeled = data[~nan]
        x_unlabeled = data[nan]
        y_labeled = labels[~nan]
        y_unlabeled = labels[nan]
        
        return x_labeled, x_unlabeled, y_labeled, y_unlabeled
    
    
    @staticmethod
    def get_confidence(probas, sampling):
        # compute difference between 1st and 2nd highest probability outputs
        if sampling == 'margin':
            part = np.partition(-probas, 1, axis=1)
            margin = - part[:, 0] + part[:, 1]
            return margin
        # compute entropy of probability outputs
        else:
            entropy = e(probas.T)
            return entropy
    
    
    @staticmethod
    def create_explainer(classifier, approach, data=None):
        if approach == 1:
            explainer = shap.TreeExplainer(classifier, feature_perturbation = "tree_path_dependent")
        elif approach == 2:
            explainer = shap.TreeExplainer(classifier, data=data, feature_perturbation = 'interventional', model_output='probability')
        else:
            raise ValueError('Please assign an approach!')
        
        return explainer
    
    
    @staticmethod
    def aggregate_shap_values(shap_values):
        aggr_shap_vector = []
        sv_df = pd.DataFrame(shap_values)
        rows, columns = sv_df.shape

        for i in range(0,columns):
            aggr_shap_vector.append(np.sum(np.abs(sv_df.iloc[:,i]))/len(sv_df))

        return [aggr_shap_vector]
    
    @staticmethod
    def aggregate_shap_values_multiclass(shap_values):
        shap_values_transposed = list(map(list, zip(*shap_values)))
        l = [list(map(list, zip(*s))) for s in shap_values_transposed]
        l2 = list(map(list, zip(*l)))
        avg = [np.mean(col) for col in l2]
        
        return [avg]
                
    @staticmethod
    def manhattan_distance(x,y):
        
        return sum(abs(a-b) for a,b in zip(x,y))
    
    
    @staticmethod
    def euclidean_distance(x,y, columns):
        a = [1]*columns
        
        return distance.seuclidean(x, y, a)
    
    
    @staticmethod
    def make_sparse(label, sparsity):
        i = np.random.randint(1,10000)
        p = sparsity*100
        if i<p:
            label = np.nan
            return label
        else:
            return label[0]
        
    
    def random_sample_data(self,x,y, percentage):
        df = pd.DataFrame(x)
        df['y'] = y

        df_elements = df.sample(frac = percentage)  
        x = df_elements.iloc[:, :-1].values
        y = df_elements.iloc[:, -1].values

        return x,y