In [8]:
# Imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import List

import torch
from torch.utils.data import DataLoader
from torch.jit import RecursiveScriptModule


# other libraries
from typing import Final

# own modules
from src.model_utils import set_seed
from src.model_utils import load_model
from src.model_utils import predict_single_text
from src.model_utils import predict_multiple_text
from src.model_utils import load_w2v_model

from lime.lime_text import LimeTextExplainer

%matplotlib inline

In [9]:
# static variables
DATA_PATH: Final[str] = "NLP_Data/data"
NUM_CLASSES: Final[int] = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set device
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)

In [10]:
# load the model
model: RecursiveScriptModule = load_model("best_model")
# Load the w2v model
w2vec_model = load_w2v_model()


Explain the model with LIME:

In [11]:
file_path = DATA_PATH + '/test.csv'
data: pd.DataFrame = pd.read_csv(file_path)

# replace the target column with a binary representation
data['tag'] = data['account.type'].replace('human', 0)
data['tag'] = data['tag'].replace('bot', 1)
# Only keep columns text and tag
data = data[['text', 'tag']]

print(data.head())

                                                text  tag
0  justin timberlake really one of the goats if y...    0
1  Thank you @PMBhutan for your gracious prayers ...    0
2  Theory: the number of red lights you will hit ...    0
3  Respects on the Upt of the I good with the peo...    1
4  Might give the BASIC #10Liner game contest ano...    0


  data['tag'] = data['tag'].replace('bot', 1)


In [12]:
# Test the model with a single text
text = data['text'][0]
print(text)
predicted = predict_single_text(text, model, device)
print(f"Predicted: {predicted}, Real: {data['tag'][0]}")

justin timberlake really one of the goats if you think about it
Predicted: 1, Real: 0


  return forward_call(*args, **kwargs)


In [17]:
"""
Classifier function for explainer
classifier_fn: classifier prediction probability function, which
                takes a list of d strings and outputs a (d, k) numpy array with
                prediction probabilities, where k is the number of classes.
                For ScikitClassifiers , this is classifier.predict_proba."""
def classifier_fn(text: str) -> int:
    predictions = predict_multiple_text(text, model, device, probability=True)
    print(predictions)
    # We have 2 classes. build the array (d, k) where d is the prediction and k is the number of classes
    array = []
    for prediction in predictions:
        array.append(np.array([1 - prediction, prediction]))
    print(array)
    return array

In [18]:
class_names = [0, 1]
print(f"Explaining result for: {text}")
explainer = LimeTextExplainer(class_names = class_names)
exp = explainer.explain_instance(text, classifier_fn, num_features=6)


Explaining result for: justin timberlake really one of the goats if you think about it


  return forward_call(*args, **kwargs)


[1.5288556814193726, 0.29712703824043274, 0.9638805389404297, 0.4464068114757538, 0.8467997908592224, -0.32819750905036926, 1.2383854389190674, -0.01186206005513668, 0.8331743478775024, 0.17885102331638336, 0.3994818925857544, 0, 0.8579382300376892, 0.3084229528903961, 0.10798435658216476, 0.5788082480430603, -0.11164500564336777, 0.20918747782707214, 0.34083694219589233, 0.896324634552002, 0.9042654633522034, 0.26531508564949036, 0, 0.7254881858825684, -0.45282235741615295, 0.6463339924812317, -0.06431842595338821, 0.5688446164131165, 0.6918022036552429, 0.8579382300376892, -1.1861658096313477, -0.6639578938484192, 0.3106795847415924, -0.4869260787963867, -0.8817122578620911, -0.018795231357216835, -1.1861658096313477, 0.9291799068450928, 0.5269882082939148, 0.3855453133583069, 0.12833794951438904, 0, 0.6162332892417908, -0.4589240849018097, 0, 1.5288556814193726, 0.12823528051376343, 0.010411364957690239, 0.9739956259727478, 0.1404787003993988, 0.11748095601797104, 1.2113405466079712

TypeError: list indices must be integers or slices, not tuple