In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.cuda.set_device(1)
torch.cuda.current_device()

1

In [3]:
import shap, pickle, datetime, xgboost
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.tree import DecisionTreeClassifier
from shap.plots._utils import convert_ordering
from shap import Explanation
from collections import defaultdict

# 1. Load data

In [4]:
df = pd.read_csv("../data/figlang_all.tsv", sep="\t", encoding="utf-8")
print(df.shape)
print(df["label"].value_counts())
df.head()

(10848, 4)
0    6506
1    2212
2     884
3     625
4     621
Name: label, dtype: int64


Unnamed: 0,text,label,label_binary,source
0,I can't believe my ex didn't pay his car note ...,0,0,Sarcasm_premise
1,But then the paper would not find out about yo...,0,0,Idiom_premise
2,Last week my kid said some really mean things ...,0,0,CreativeParaphrase_premise
3,"The gravy was so fatty, it made the meat taste...",0,0,Metaphor_premise
4,He pulls a giant disc out and flashes it like ...,3,1,Simile_hypothesis


# 2. Set the model to use

**Uncomment one of the lines below to set the model.**

In [5]:
#MODEL = "rf"
MODEL = "lr"

In [6]:
model_config_dict = {    
    "vectorizer": "../results/models/tfidf_vectorizer.sav",
    "rf": "../results/models/rf.sav",
    "lr": "../results/models/lr.sav",
}

**Convert the texts to Tf-idf vectors:**

In [7]:
tfidf = pickle.load(open(model_config_dict["vectorizer"], 'rb'))

df_multiclass = df[df["label"] != 0]
X = tfidf.transform(df_multiclass['text']).toarray()
X = pd.DataFrame(X, columns=tfidf.get_feature_names_out())    
X

Unnamed: 0,000,10,100,1000,10000,11,12,15,150,16,...,younger,your,yourself,youthful,youtube,zero,zeus,zoe,zombies,zone
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4338,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4339,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4340,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [8]:
model = pickle.load(open(model_config_dict[MODEL], 'rb'))
model

# 3. Compute SHAP

In [9]:
# compute SHAP values
explainer = shap.Explainer(model, X)
shap_values = explainer(X)

**Save explainer and SHAP values:** Keep the model name in the file name of the results.

In [10]:
with open('../results/shap/shap_values_' + MODEL + '.pkl', 'wb') as out_value:
    pickle.dump(shap_values, out_value, pickle.HIGHEST_PROTOCOL)

In [11]:
with open("../results/shap/shap_explainer_" + MODEL  + ".pkl", "wb") as out:
    pickle.dump(explainer, out, pickle.HIGHEST_PROTOCOL)

# 4. Write out results

Write out the top 50 most important features and their SHAP values.

In [12]:
label_dict = {0: "sarcasm", 
              1: "idiom",
              2: "simile",
              3: "metaphor"}

In [13]:
k = 50

for i in range(len(label_dict)):
    print("Process class No. {}...".format(i))
    top_features_idx = shap_values[:, :, i].mean(0).values.argsort()[::-1][:k]
    top_features = [tfidf.get_feature_names_out()[f] for f in top_features_idx]
    top_shap_values = [shap_values[:, :, i].mean(0).values[f] for f in top_features_idx]
        
    df_result = pd.DataFrame({
        "feature": top_features,
        "shap": top_shap_values,
    })
    
    df_result.to_csv("../results/shap/top_50_features_whitebox/" + MODEL + "_" + label_dict[i] + ".tsv", 
                     sep="\t", encoding="utf-8", index=False)
        
        
print("Done.")

Process class No. 0...
Process class No. 1...
Process class No. 2...
Process class No. 3...
Done.
