In [None]:
# Implementations from https://github.com/e-delaney/Instance-Based_CFE_TSC

In [1]:
cd ../

C:\Users\mrefoyo\Documents\Proyectos\counterfactuals_PFGs


In [2]:
import os
import sys
import pickle
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import preprocessing
from tslearn.neighbors import KNeighborsTimeSeries
import tensorflow as tf
from tensorflow import keras

from experiments.experiment_utils import local_data_loader, label_encoder

print(tf.__version__)

2.13.0


In [3]:
datasets = ['CBF', 'chinatown', 'coffee', 'gunpoint', 'ECG200']
# datasets = ['CBF']

# Load data and models

In [11]:
data_dict = {}
models_dict = {}
outlier_calculators_dict = {}
nuns_idx_dict = {}
desired_classes_dict = {}

for dataset in datasets:
    X_train, y_train, X_test, y_test = local_data_loader(str(dataset), data_path="./experiments/data")
    y_train, y_test = label_encoder(y_train, y_test)
    data_dict[dataset] = (X_train, y_train, X_test, y_test)

    # Load model
    model = keras.models.load_model(f'./experiments/models/{dataset}/{dataset}_best_model.hdf5')
    y_pred = np.argmax(model.predict(X_test), axis=1)
    models_dict[dataset] = model



# Native Guide counterfactuals

In [5]:
def native_guide_retrieval(query, predicted_label, distance, n_neighbors, X_train, y_train):
    df = pd.DataFrame(y_train, columns = ['label'])
    df.index.name = 'index'
    df[df['label'] == 1].index.values, df[df['label'] != 1].index.values
    ts_length = X_train.shape[1]
    
    knn = KNeighborsTimeSeries(n_neighbors=n_neighbors, metric = distance)
    knn.fit(X_train[list(df[df['label'] != predicted_label].index.values)])
    dist,ind = knn.kneighbors(query.reshape(1,ts_length), return_distance=True)
    
    return dist[0], df[df['label'] != predicted_label].index[ind[0][:]]

In [6]:
def findSubarray(a, k): #used to find the maximum contigious subarray of length k in the explanation weight vector
    n = len(a)
    vec=[] 

    # Iterate to find all the sub-arrays 
    for i in range(n-k+1): 
        temp=[] 
        # Store the sub-array elements in the array 
        for j in range(i,i+k): 
            temp.append(a[j]) 
        # Push the vector in the container 
        vec.append(temp) 

    sum_arr = []
    for v in vec:
        sum_arr.append(np.sum(v))

    return (vec[np.argmax(sum_arr)])

In [7]:
def counterfactual_generator_swap(instance, nun, subarray_length):
    
    most_influencial_array = findSubarray((cam_training_weights[nun]), subarray_length)
    starting_point = np.where(cam_training_weights[nun]==most_influencial_array[0])[0][0]
    X_example = np.concatenate((X_test[instance][:starting_point], (X_train[nun][starting_point:subarray_length+starting_point]), X_test[instance][subarray_length+starting_point:]))
    prob_target = model.predict(X_example.reshape(1,-1,1), verbose=0)[0][y_pred[instance]]
    
    while prob_target > 0.5:
        
        subarray_length +=1
        most_influencial_array=findSubarray((cam_training_weights[nun]), subarray_length)
        starting_point = np.where(cam_training_weights[nun]==most_influencial_array[0])[0][0]
        X_example = np.concatenate((X_test[instance][:starting_point], (X_train[nun][starting_point:subarray_length+starting_point]), X_test[instance][subarray_length+starting_point:]))
        prob_target = model.predict(X_example.reshape(1,-1,1), verbose=0)[0][y_pred[instance]]
        
    return X_example

In [8]:
for dataset in datasets:
    print(f'Generating counterfactuals for {dataset}...')
    # Load data and model
    X_train, y_train, X_test, y_test = data_dict[dataset]
    model = models_dict[dataset]
    y_pred = np.argmax(model.predict(X_test, verbose=0), axis=1)

    # Get the NUNs
    nuns_idx = []
    for instance_idx in range(len(X_test)):
        nuns_idx.append(native_guide_retrieval(X_test[instance_idx], y_pred[instance_idx], 'euclidean', 1, X_train, y_train)[1][0])
    nuns_idx = np.array(nuns_idx)

    # Get cam importances 
    cam_training_weights = np.load(f'./methods/NativeGuide/Class_Activation_Mapping/{dataset}_cam_train_weights.npy')
    cam_testing_weights = np.load(f'./methods/NativeGuide/Class_Activation_Mapping/{dataset}_cam_test_weights.npy')

    # Get the counterfactuals 
    ng_cfs = []
    test_instances_idx = np.array(range(len(X_test)))
    for test_instance_idx, nun_idx in tqdm(zip(test_instances_idx, nuns_idx)):
        ng_cfs.append(counterfactual_generator_swap(test_instance_idx, nun_idx, 1))

    # Store
    # Adapt counterfactual result to our format
    results = [{'cf': np.expand_dims(cf, axis=0), 'time': -1} for cf in ng_cfs]
    # Store concatenated file
    with open(f'./experiments/results/{dataset}/ng.pickle', 'wb') as f:
        pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)
    # with open(f'./experiments/results/{dataset}/ng.pickle', 'wb') as f:
    #     pickle.dump(ng_cfs, f, pickle.HIGHEST_PROTOCOL)

Generating counterfactuals for CBF...


900it [1:32:01,  6.14s/it]


In [9]:
ng_cfs[0].shape

(128, 1)

In [10]:
results = [{'cf': np.expand_dims(cf, axis=0), 'time': -1} for cf in ng_cfs]

In [11]:
results[0]['cf'].shape

(1, 128, 1)

# Watcher et al

In [4]:
import pickle

In [12]:
for dataset_name in datasets:
    X_train, _, _, _ = data_dict[dataset_name]
    ts_length, n_features = X_train.shape[1], X_train.shape[2]
    with open(f'./experiments/results/{dataset_name}/counterfactuals_wcf_ng.pickle', 'rb') as f:
        wcf_ng_cfs = pickle.load(f)
        
    results = [{'cf': cf.reshape(1 ,ts_length, n_features), 'time': -1} for cf in wcf_ng_cfs]
    # Store concatenated file
    with open(f'./experiments/results/{dataset_name}/wcf_ng.pickle', 'wb') as f:
        pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)
        

In [13]:
results

[{'cf': array([[[ 0.43507313],
          [ 1.45157486],
          [ 1.01714995],
          [ 3.37530732],
          [ 2.31719462],
          [ 0.16901917],
          [-1.31602197],
          [ 0.28923912],
          [ 0.9508433 ],
          [ 0.69934562],
          [ 0.93132386],
          [ 0.7393736 ],
          [ 0.11236398],
          [-0.71280211],
          [-0.67766258],
          [-0.85951063],
          [-0.95703887],
          [-0.91617486],
          [-0.94575104],
          [-2.31482965],
          [-1.13947045],
          [-1.15696395],
          [-1.22983356],
          [-1.57529557],
          [-1.57091424],
          [-1.5025508 ],
          [-2.12631635],
          [-1.8997038 ],
          [-2.0289613 ],
          [-2.08866051],
          [-1.83555961],
          [-2.03233135],
          [-1.99215935],
          [-1.74626332],
          [-1.3559454 ],
          [-1.17373044],
          [-0.88370404],
          [-0.36429085],
          [-0.2096524 ],
          [ 0.18726

In [9]:
from scipy.optimize import minimize
from scipy import stats

In [32]:
def target_(label):
    if label == 0:
        counter = 1
    elif label == 1:
        counter = 0
    return counter

def dist_mad(query, cf):
    manhat = np.abs(query-cf)
    mad = stats.median_abs_deviation(X_train)
    return np.sum((manhat/mad).flatten())

def loss_function_mad(x_dash):
    target = target_(example_label)
    L = lamda*(model.predict(x_dash.reshape(1,-1,1), verbose=0)[0][target] - 1)**2 + \
    dist_mad(x_dash.reshape(1,-1,1), query)
    return L

In [36]:
def Wachter_Counterfactual(instance, lambda_init):
    
    global lamda
    global dist_mad
    global loss_function_mad
    global example_label
    global query

    
    pred_threshold = 0.5

    # initial conditions
    lamda = lambda_init
    x0 = X_test[instance].reshape(1,-1,1) # initial guess for cf
    query = X_test[instance].reshape(1,-1,1)
    example_label = y_pred[instance]

    res = minimize(loss_function_mad, x0.reshape(1,-1), method='nelder-mead', options={'maxiter':10, 'xatol': 50, 'adaptive': True})
    cf = res.x.reshape(1,-1,1)

    target = target_(y_pred[instance])
    prob_target = model.predict(cf)[0][target]


    i=0
    while prob_target < pred_threshold:


        lamda = lambda_init*(1+0.5)**i
        x0 = cf
        res = minimize(loss_function_mad, x0.reshape(1,-1), method='nelder-mead', options={'maxiter':10, 'xatol': 50, 'adaptive': True})
        cf = res.x.reshape(1,-1,1)
        
        """figure = plt.Figure()
        plt.plot(cf.flatten())
        plt.show()"""
        
        prob_target = model.predict(cf, verbose=0)[0][target]
        i += 1
        if i == 500:
            print('Error condition not met after',i,'iterations')
            break
    
    return cf

In [37]:
for dataset in datasets[4:]:
    print(f'Generating counterfactuals for {dataset}...')
    # Load data and model
    X_train, y_train, X_test, y_test = data_dict[dataset]
    model = models_dict[dataset]
    y_pred = np.argmax(model.predict(X_test, verbose=0), axis=1)

    # Generate counterfactuals
    wcf_cfs = []
    for instance in tqdm(range(len(X_test))):
        wcf_cfs.append(Wachter_Counterfactual(instance,lambda_init=0.1)[0])
        print(wcf_cfs)
    
    # Store
    with open(f'./counterfactuals/results/{dataset}/counterfactuals_wcf_ng.pickle', 'wb') as f:
        pickle.dump(wcf_cfs, f, pickle.HIGHEST_PROTOCOL)

Generating counterfactuals for ECG200...


  res = minimize(loss_function_mad, x0.reshape(1,-1), method='nelder-mead', options={'maxiter':10, 'xatol': 50, 'adaptive': True})




  res = minimize(loss_function_mad, x0.reshape(1,-1), method='nelder-mead', options={'maxiter':10, 'xatol': 50, 'adaptive': True})
  1%|█                                                                                                               | 1/100 [1:51:01<183:11:15, 6661.36s/it]

[array([0.43892822])]


  2%|██▏                                                                                                             | 2/100 [3:00:22<141:18:15, 5190.77s/it]

[array([0.43892822]), array([0.6644174])]


  3%|███▎                                                                                                            | 3/100 [3:41:37<106:26:28, 3950.40s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742])]


  4%|████▌                                                                                                            | 4/100 [4:25:55<91:44:15, 3440.16s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883])]


  5%|█████▋                                                                                                           | 5/100 [4:49:46<71:39:59, 2715.78s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335])]


  6%|██████▊                                                                                                          | 6/100 [5:51:32<79:42:04, 3052.39s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129])]


  7%|███████▉                                                                                                         | 7/100 [5:57:02<55:51:58, 2162.56s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983])]


  8%|█████████                                                                                                        | 8/100 [6:53:34<65:15:57, 2553.89s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983]), array([1.7092211])]


  9%|██████████▏                                                                                                      | 9/100 [7:47:26<69:55:02, 2765.96s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983]), array([1.7092211]), array([1.6243869])]


 10%|███████████▏                                                                                                    | 10/100 [8:46:15<75:02:01, 3001.35s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983]), array([1.7092211]), array([1.6243869]), array([0.81976906])]


 11%|████████████▎                                                                                                   | 11/100 [9:48:51<79:54:40, 3232.37s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983]), array([1.7092211]), array([1.6243869]), array([0.81976906]), array([1.43271269])]


 12%|█████████████▎                                                                                                 | 12/100 [10:07:58<63:30:20, 2597.96s/it]

[array([0.43892822]), array([0.6644174]), array([0.4150742]), array([1.0880883]), array([0.44832335]), array([0.21277129]), array([1.0850983]), array([1.7092211]), array([1.6243869]), array([0.81976906]), array([1.43271269]), array([-0.28508899])]


 12%|█████████████▎                                                                                                 | 12/100 [10:10:49<74:39:21, 3054.10s/it]

KeyboardInterrupt

