In [4]:
# go up a directory
import os
os.chdir('..')

In [5]:
from openxai.Explainer import Explainer
from openxai.evaluator import Evaluator
from openxai.explainers.catalog.perturbation_methods import NormalPerturbation
from openxai.LoadModel import DefineModel
import torch
from utils import get_model_names, get_model_architecture
from openxai.dataloader import return_loaders

In [6]:
data_name = 'beauty'
model_name = 'ann_l'
base_model_dir = './models/ClassWeighted/'
model_dir, model_file_name = get_model_names(model_name, data_name, base_model_dir)

In [7]:
# Load dataset
loader_train, loader_val, loader_test = return_loaders(data_name=data_name, download=False)

X_train, y_train = loader_train.dataset.data, loader_train.dataset.targets.to_numpy()
X_val, y_val = loader_val.dataset.data, loader_val.dataset.targets.to_numpy()
X_test, y_test = loader_test.dataset.data, loader_test.dataset.targets.to_numpy()

# load sentences
X_train_sentences = loader_train.dataset.sentences
X_val_sentences = loader_val.dataset.sentences
X_test_sentences = loader_test.dataset.sentences

num_features = X_train.shape[1]

In [8]:
# Load model
input_size = loader_train.dataset.get_number_of_features()
dim_per_layer_per_MLP, activation_per_layer_per_MLP = get_model_architecture(model_name)
model = DefineModel(model_name, input_size,
                    dim_per_layer_per_MLP,
                    activation_per_layer_per_MLP)
model.load_state_dict(torch.load(model_dir + model_file_name))
model.eval()

# Store test predictions
preds = model.predict(torch.tensor(X_test).float())
model

MLP(
  (layers): ModuleList(
    (0): Linear(in_features=384, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
    (5): ReLU()
    (6): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [9]:
from sentence_transformers import SentenceTransformer
# Generate embeddings
def generate_embeddings(texts):
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    batch_size = 256  # Adjust based on your system's memory capacity
    embeddings = []
    print('transforming data...')
    num_batches = len(texts) // batch_size
    for i in range(0, len(texts), batch_size):
        if i % 100 == 0:
            print(f'Processing batch {i // batch_size + 1}/{num_batches}')
        batch = texts[i:i + batch_size]
        batch_embeddings = embedding_model.encode(batch)
        embeddings.extend(batch_embeddings)
    return embeddings

In [10]:
def classifier_fn(sentences):
    embeddings = generate_embeddings(sentences)
    with torch.no_grad():
        preds = model(torch.Tensor(embeddings))
    # return the class argmax
    # preds = torch.argmax(preds, dim=1)
    return preds

In [30]:
kernel_width           = 0.75
std_LIME               = 0.1
mode                   = 'text'
sample_around_instance = True
n_samples_LIME         = 16#1000
discretize_continuous  = False

param_dict_lime = dict()
param_dict_lime['dataset_tensor']         = X_train
param_dict_lime['kernel_width']           = kernel_width
param_dict_lime['std']                    = std_LIME
param_dict_lime['mode']                   = mode
param_dict_lime['sample_around_instance'] = sample_around_instance
param_dict_lime['n_samples']              = n_samples_LIME
param_dict_lime['discretize_continuous']  = discretize_continuous
param_dict_lime['categorical_features']   = None

explainer = Explainer(method='lime', model=classifier_fn, dataset_tensor=X_train, param_dict=param_dict_lime)

In [31]:
explanations = explainer.get_explanation(X_test_sentences[:100], seed=0)

  0%|          | 0/100 [00:00<?, ?it/s]

transforming data...
Processing batch 1/0


  1%|          | 1/100 [00:01<02:25,  1.47s/it]

transforming data...
Processing batch 1/0


  2%|▏         | 2/100 [00:02<02:03,  1.26s/it]

transforming data...
Processing batch 1/0


  3%|▎         | 3/100 [00:03<02:02,  1.26s/it]

transforming data...
Processing batch 1/0


  5%|▌         | 5/100 [00:05<01:43,  1.09s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


  7%|▋         | 7/100 [00:08<01:42,  1.10s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


  8%|▊         | 8/100 [00:09<01:45,  1.15s/it]

transforming data...
Processing batch 1/0


  9%|▉         | 9/100 [00:10<01:42,  1.12s/it]

transforming data...
Processing batch 1/0


 10%|█         | 10/100 [00:11<01:47,  1.20s/it]

transforming data...
Processing batch 1/0


 12%|█▏        | 12/100 [00:13<01:39,  1.13s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 14%|█▍        | 14/100 [00:15<01:30,  1.05s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 15%|█▌        | 15/100 [00:16<01:29,  1.05s/it]

transforming data...
Processing batch 1/0


 17%|█▋        | 17/100 [00:19<01:33,  1.13s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 18%|█▊        | 18/100 [00:20<01:32,  1.13s/it]

transforming data...
Processing batch 1/0


 19%|█▉        | 19/100 [00:21<01:33,  1.15s/it]

transforming data...
Processing batch 1/0


 20%|██        | 20/100 [00:23<01:34,  1.18s/it]

transforming data...
Processing batch 1/0


 21%|██        | 21/100 [00:24<01:40,  1.28s/it]

transforming data...
Processing batch 1/0


 23%|██▎       | 23/100 [00:26<01:34,  1.23s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 24%|██▍       | 24/100 [00:28<01:33,  1.23s/it]

transforming data...
Processing batch 1/0


 25%|██▌       | 25/100 [00:29<01:34,  1.26s/it]

transforming data...
Processing batch 1/0


 27%|██▋       | 27/100 [00:32<01:33,  1.28s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 29%|██▉       | 29/100 [00:34<01:26,  1.22s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 30%|███       | 30/100 [00:35<01:29,  1.28s/it]

transforming data...
Processing batch 1/0


 31%|███       | 31/100 [00:37<01:28,  1.29s/it]

transforming data...
Processing batch 1/0


 32%|███▏      | 32/100 [00:38<01:26,  1.28s/it]

transforming data...
Processing batch 1/0


 33%|███▎      | 33/100 [00:39<01:27,  1.30s/it]

transforming data...
Processing batch 1/0


 34%|███▍      | 34/100 [00:41<01:24,  1.29s/it]

transforming data...
Processing batch 1/0


 35%|███▌      | 35/100 [00:42<01:22,  1.28s/it]

transforming data...
Processing batch 1/0


 36%|███▌      | 36/100 [00:43<01:20,  1.26s/it]

transforming data...
Processing batch 1/0


 38%|███▊      | 38/100 [00:46<01:15,  1.22s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 39%|███▉      | 39/100 [00:47<01:16,  1.26s/it]

transforming data...
Processing batch 1/0


 41%|████      | 41/100 [00:49<01:09,  1.18s/it]

transforming data...
Processing batch 1/0


 42%|████▏     | 42/100 [00:50<01:04,  1.12s/it]

transforming data...
Processing batch 1/0


 43%|████▎     | 43/100 [00:51<00:59,  1.04s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 45%|████▌     | 45/100 [00:53<00:56,  1.03s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 46%|████▌     | 46/100 [00:54<00:59,  1.11s/it]

transforming data...
Processing batch 1/0


 48%|████▊     | 48/100 [00:57<00:57,  1.10s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 49%|████▉     | 49/100 [00:58<01:05,  1.29s/it]

transforming data...
Processing batch 1/0


 51%|█████     | 51/100 [01:01<01:04,  1.31s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 52%|█████▏    | 52/100 [01:02<01:03,  1.32s/it]

transforming data...
Processing batch 1/0


 54%|█████▍    | 54/100 [01:05<00:55,  1.20s/it]

transforming data...
Processing batch 1/0


 55%|█████▌    | 55/100 [01:06<00:50,  1.13s/it]

transforming data...
Processing batch 1/0


 56%|█████▌    | 56/100 [01:07<00:47,  1.08s/it]

transforming data...
Processing batch 1/0


 57%|█████▋    | 57/100 [01:08<00:45,  1.07s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 59%|█████▉    | 59/100 [01:11<00:54,  1.32s/it]

transforming data...
Processing batch 1/0


 60%|██████    | 60/100 [01:12<00:50,  1.27s/it]

transforming data...
Processing batch 1/0


 61%|██████    | 61/100 [01:13<00:45,  1.18s/it]

transforming data...
Processing batch 1/0


 62%|██████▏   | 62/100 [01:14<00:44,  1.17s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 64%|██████▍   | 64/100 [01:16<00:41,  1.15s/it]

transforming data...
Processing batch 1/0


 65%|██████▌   | 65/100 [01:17<00:38,  1.11s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 66%|██████▌   | 66/100 [01:19<00:40,  1.18s/it]

transforming data...
Processing batch 1/0


 68%|██████▊   | 68/100 [01:21<00:37,  1.17s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 69%|██████▉   | 69/100 [01:22<00:37,  1.21s/it]

transforming data...
Processing batch 1/0


 70%|███████   | 70/100 [01:24<00:36,  1.22s/it]

transforming data...
Processing batch 1/0


 71%|███████   | 71/100 [01:25<00:38,  1.33s/it]

transforming data...
Processing batch 1/0


 73%|███████▎  | 73/100 [01:28<00:35,  1.31s/it]

transforming data...
Processing batch 1/0


 74%|███████▍  | 74/100 [01:29<00:30,  1.18s/it]

transforming data...
Processing batch 1/0


 75%|███████▌  | 75/100 [01:30<00:27,  1.10s/it]

transforming data...
Processing batch 1/0


 76%|███████▌  | 76/100 [01:31<00:26,  1.09s/it]

transforming data...
Processing batch 1/0


 77%|███████▋  | 77/100 [01:32<00:25,  1.12s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 79%|███████▉  | 79/100 [01:34<00:23,  1.12s/it]

transforming data...
Processing batch 1/0


 80%|████████  | 80/100 [01:35<00:22,  1.10s/it]

transforming data...
Processing batch 1/0


 81%|████████  | 81/100 [01:37<00:21,  1.13s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 83%|████████▎ | 83/100 [01:39<00:19,  1.12s/it]

transforming data...
Processing batch 1/0


 84%|████████▍ | 84/100 [01:40<00:16,  1.05s/it]

transforming data...
Processing batch 1/0


 85%|████████▌ | 85/100 [01:41<00:15,  1.02s/it]

transforming data...
Processing batch 1/0


 86%|████████▌ | 86/100 [01:42<00:14,  1.00s/it]

transforming data...
Processing batch 1/0


 87%|████████▋ | 87/100 [01:43<00:14,  1.11s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 89%|████████▉ | 89/100 [01:45<00:12,  1.11s/it]

transforming data...
Processing batch 1/0


 90%|█████████ | 90/100 [01:47<00:11,  1.15s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 92%|█████████▏| 92/100 [01:49<00:09,  1.15s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


 94%|█████████▍| 94/100 [01:51<00:06,  1.12s/it]

transforming data...
Processing batch 1/0


 95%|█████████▌| 95/100 [01:52<00:05,  1.15s/it]

transforming data...
Processing batch 1/0


 96%|█████████▌| 96/100 [01:54<00:04,  1.19s/it]

transforming data...
Processing batch 1/0


 97%|█████████▋| 97/100 [01:55<00:03,  1.13s/it]

transforming data...
Processing batch 1/0


 98%|█████████▊| 98/100 [01:56<00:02,  1.07s/it]

transforming data...
Processing batch 1/0


 99%|█████████▉| 99/100 [01:57<00:01,  1.06s/it]

transforming data...
Processing batch 1/0
transforming data...
Processing batch 1/0


100%|██████████| 100/100 [01:58<00:00,  1.19s/it]


In [32]:
# for exp in explanations:
#     print(exp.as_list()[:3])
#     
k = 3
# keep the first entry of each tuple in the list
LIME_exps = [[exp.as_list()[i][0] for i in range(len(exp.as_list()[:k]))] for exp in explanations]
LIME_exps

[['faster', 'and', 'process'],
 ['works', 'stop', 'nothing'],
 ['my', 'for', 'and'],
 ['VERY', 'hydrating', 'again'],
 ['glad', 'until', 'recently'],
 ['off', 'inches', 'thin'],
 ['at', 'not', 'did'],
 ['in', 'non', 'best'],
 ['helped', 'has', 'rough'],
 ['t', 's', 'Lasting'],
 ['close', 'not', 'case'],
 ['the', 't', 'The'],
 ['sleep', 'and', 'other'],
 ['well', 'really', 'without'],
 ['torn', 'had', 'purchase'],
 ['lemongrass', 'formula', 'Works'],
 ['dyed', 'fresh', 'slight'],
 ['roomy', 'pattern', 'owned'],
 ['lot', 'sink', 'off'],
 ['great', 'it', 'love'],
 ['nope', 'think', 'easily'],
 ['lightweight', 'make', 'beautiful'],
 ['Not', 'head', 'the'],
 ['Bought', 'cuts', 'wish'],
 ['use', 'oils', 'pipes'],
 ['00', 'to', 'that'],
 ['say', 'Nice', 's'],
 ['with', 'o', 'wash'],
 ['sensitive', 'skin', 'love'],
 ['designs', 'Although', 'random'],
 ['this', 'sister', 'came'],
 ['just', 'I', 'love'],
 ['Great', 'We', 'Fast'],
 ['makeup', 'There', 'one'],
 ['EXCELLENT', 'everyone', 'i'],
 ['r

In [33]:
import numpy as np
def PGI_words(sentence, embedding, topk, text_classifier, random_baseline=False):
    # PGI = []
    # for sentence, embedding, topk in zip(input_sentences, input_embeddings, topks):
    print('original_sentence:', sentence)

    with torch.no_grad():
        pred_original = text_classifier(torch.tensor(embedding))

    if random_baseline:
        #randomly select k words to remove from sentence
        inds = np.random.choice(len(sentence.split()), len(topk), replace=False)
        topk = [sentence.split()[i] for i in inds]

    for word in topk:
        print(word)
        sentence = sentence.replace(word, '')

    print('new_sentence:', sentence)
    # get embeddings for the new sentence
    new_sentence_embedding = generate_embeddings([sentence])
    with torch.no_grad():
        pred_removed = text_classifier(torch.tensor(new_sentence_embedding))
    PGI = torch.abs(pred_original[1] - pred_removed[0][1]).item()
    print(PGI)
    # PGI.append(pgi_word)

    return PGI

In [34]:
PGI_scores = PGI_words(X_test_sentences[0], X_test[0], LIME_exps[0], model)

original_sentence: I bought this for my wife.  She says that it makes the process easier and much faster.  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
faster
and
process
new_sentence: I bought this for my wife.  She says that it makes the  easier  much .  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
transforming data...
Processing batch 1/0
0.0004265904426574707


In [35]:
from sklearn.metrics import auc
def calculateFaithfulnessAUC_text(input_sentences, input_embeddings, explanations, text_classifier, min_idx, max_idx,
                                  max_k, random_baseline=False):
    PGI_AUC = []
    for index in range(min_idx, max_idx):
        # if index == bad_reply_indices:
        #     continue
        print(index)
        if max_k > 1:
            auc_x = np.arange(max_k) / (max_k - 1)
        PGI = []
        for top_k in range(1, max_k + 1):
            print('top_k', top_k)
            PGI.append(
                PGI_words(input_sentences[index], input_embeddings[index], explanations[index][:top_k], text_classifier,
                          random_baseline))  #[exp[:top_k] for exp in explanations], text_classifier))
        if max_k > 1:
            PGI_AUC.append(auc(auc_x, PGI))
        else:
            PGI_AUC.append(PGI)

    return PGI_AUC


PGI_AUC = calculateFaithfulnessAUC_text(X_test_sentences, X_test, LIME_exps, model, 0, 100, 3)

0
top_k 1
original_sentence: I bought this for my wife.  She says that it makes the process easier and much faster.  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
faster
new_sentence: I bought this for my wife.  She says that it makes the process easier and much .  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
transforming data...
Processing batch 1/0
0.00020802021026611328
top_k 2
original_sentence: I bought this for my wife.  She says that it makes the process easier and much faster.  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
faster
and
new_sentence: I bought this for my wife.  She says that it makes the process easier  much .  I say that the results are great, a very noticeable difference. My daughter says shes wants one too.
transforming data...
Processing batch 1/0
0.0002614259719848633
top_k 3
original_sentence: I

In [36]:
# LIME 1000 = 
str(round(np.mean(PGI_AUC), 3)) + '+/-' + str(round(np.std(PGI_AUC)/np.sqrt(100), 3))

'0.16+/-0.031'

In [ ]:
str(round(np.mean(PGI_AUC), 3)) + '+/-' + str(round(np.std(PGI_AUC)/np.sqrt(100), 3))