In [1]:
from nlp_utils.data_module import SemEvalDataModule
from nlp_utils.model import CustomDistilBertModel
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from glob import glob
import ipywidgets as widgets
from tqdm.notebook import tqdm
import pandas as pd
import torch
import seaborn as sb
import re
import random
import os

%load_ext tensorboard
%load_ext autoreload
%autoreload 2
seed_everything(42)

Global seed set to 42


42

In [2]:
if not 'notebookDir' in globals():
    notebookDir = os.getcwd()
print('notebookDir: ' + notebookDir)
os.chdir(notebookDir)

notebookDir: /home/user/Documents/Github/Uni/Master/TUM_Praktikum_NLP_Explainability/understanding-opinions-on-social-media/tasks


In [4]:
# Start tensorboard
! pkill tensorboard
! rm -r /tmp/.tensorboard-info
%tensorboard --logdir ../logs/lightning_logs --bind_all

In [27]:
# get folder where logs are stored
save_folder = "../logs/StancePrediction_SemEval/lightning_logs/"
save_folder = os.path.join(notebookDir, save_folder)

# Load model

In [28]:
# Select a model
w = widgets.Dropdown(
    options=glob(os.path.join(save_folder, '*/checkpoints/*.ckpt')),
    description='Select a checkpoint:'
)
w

Dropdown(description='Select a checkpoint:', options=('/home/user/Documents/Github/Uni/Master/TUM_Praktikum_NL…

In [23]:
model_version = re.findall("version_[0-9]+", w.value)[0]
model = CustomDistilBertModel.load_from_checkpoint(w.value)
data_module = SemEvalDataModule(num_workers=4, config=model.config)

model.config, model_version

({'dataset_path': '../../data/raw/SemEval/',
  'learning_rate': 0.0010739076860714453,
  'batch_size': 32,
  'epochs': 20,
  'num_trials': 50,
  'vocab_size': 30522,
  'target_encoding': {0: 'Atheism',
   1: 'Climate Change is a Real Concern',
   2: 'Feminist Movement',
   3: 'Hillary Clinton',
   4: 'Legalization of Abortion'},
  'stance_encoding': {0: 'AGAINST', 1: 'FAVOR', 2: 'NONE', 3: 'UNKNOWN'}},
 'version_22')

In [24]:
# check performance
trainer = pl.Trainer(deterministic=True)
trainer.test(model, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Testing:  98%|█████████▊| 39/40 [00:36<00:00,  1.24it/s]

Results				 
FAVOR     precision: 0.5906 recall: 0.4934 f-score: 0.5376
AGAINST   precision: 0.7110 recall: 0.7776 f-score: 0.7428
------------
Macro F: 0.6402

Testing: 100%|██████████| 40/40 [00:36<00:00,  1.10it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_epoch_F1': 0.7349879741668701,
 'test_epoch_target_F1': 0.7349879741668701,
 'test_loss': 0.7136555314064026}
--------------------------------------------------------------------------------


[{'test_loss': 0.7136555314064026,
  'test_epoch_target_F1': 0.7349879741668701,
  'test_epoch_F1': 0.7349879741668701}]

In [110]:
# get best config
def get_best_config(path):
    scores = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if ".ckpt" in name:
                score = name.split("val")
                score = [re.findall(r"[-+]?\d*\.\d+|\d+", s)[-1] for s in score]
                score.append(os.path.join(root, name))

                scores.append(score)

    # filter scores for best version
    df = pd.DataFrame(scores, columns=["epoch", "loss", "F1", "path"])
    ckpt = df.sort_values(by=["F1", "loss"], ascending=[False,True]).head(1).path.values[0]
    return ckpt

In [111]:
# take best model
best_model_path = get_best_config(save_folder)
best_model = CustomDistilBertModel.load_from_checkpoint(best_model_path)
best_data_module = SemEvalDataModule(num_workers=4, config=model.config)

best_model.config, best_model_version

({'dataset_path': '../../data/raw/SemEval/',
  'learning_rate': 0.0013774978663536918,
  'batch_size': 32,
  'epochs': 20,
  'num_trials': 50,
  'vocab_size': 30522,
  'target_encoding': {0: 'Atheism',
   1: 'Climate Change is a Real Concern',
   2: 'Feminist Movement',
   3: 'Hillary Clinton',
   4: 'Legalization of Abortion'},
  'stance_encoding': {0: 'AGAINST', 1: 'FAVOR', 2: 'NONE', 3: 'UNKNOWN'}},
 'version_22')

In [112]:
# check performance
trainer = pl.Trainer(deterministic=True)
trainer.test(best_model, datamodule=best_data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Testing:  98%|█████████▊| 39/40 [00:35<00:00,  1.21it/s]

Results				 
FAVOR     precision: 0.6468 recall: 0.5000 f-score: 0.5640
AGAINST   precision: 0.7426 recall: 0.7385 f-score: 0.7405
------------
Macro F: 0.6523

Testing: 100%|██████████| 40/40 [00:35<00:00,  1.12it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_epoch_F1': 0.7502001523971558,
 'test_epoch_target_F1': 0.7502001523971558,
 'test_loss': 0.7219581007957458}
--------------------------------------------------------------------------------


[{'test_loss': 0.7219581007957458,
  'test_epoch_target_F1': 0.7502001523971558,
  'test_epoch_F1': 0.7502001523971558}]

### Decide for model

In [None]:
model = model
model = best_model

# Explain model

## Sage

In [115]:
import numpy as np
import scipy as sp
import spacy
import pickle
import json
import sage
import nltk
import string
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering
from transformers import DistilBertTokenizer
from numpy.random import default_rng
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [146]:
!pip install "git+https://github.com/iancovert/sage.git"  &> /dev/null
!python -m spacy download en &> /dev/null
!python -m spacy download en_core_web_lg &> /dev/null

In [132]:
data_module.stance_encoding, data_module.target_encoding

({0: 'AGAINST', 1: 'FAVOR', 2: 'NONE', 3: 'UNKNOWN'},
 {0: 'Atheism',
  1: 'Climate Change is a Real Concern',
  2: 'Feminist Movement',
  3: 'Hillary Clinton',
  4: 'Legalization of Abortion'})

In [133]:
text, label = data_module.trainset.texts, data_module.trainset.labels
df = pd.DataFrame(data=(text), columns=["Text"])
df["Stance"] = label[0]
df["Target"] = label[1]
df.head()

Unnamed: 0,Text,Stance,Target
0,Don't get it twisted. A major presidential can...,0,3
1,The Dukes of Hazzard has been on tv for 36 yea...,0,3
2,#BlackLivesMatter unless they are pre born bla...,0,4
3,Of mothers advising their daughter's to abort ...,0,4
4,"If you want to empower women, you need to dise...",1,2


In [144]:
df1 = df.copy()
df1["Stance"] = df1["Stance"].transform(lambda x: data_module.stance_encoding[x])
df1["Target"] = df1["Target"].transform(lambda x: data_module.target_encoding[x])
df1.head()

Unnamed: 0,Text,Stance,Target
0,Don't get it twisted. A major presidential can...,AGAINST,Hillary Clinton
1,The Dukes of Hazzard has been on tv for 36 yea...,AGAINST,Hillary Clinton
2,#BlackLivesMatter unless they are pre born bla...,AGAINST,Legalization of Abortion
3,Of mothers advising their daughter's to abort ...,AGAINST,Legalization of Abortion
4,"If you want to empower women, you need to dise...",FAVOR,Feminist Movement


In [147]:
nlp_bigger = spacy.load('en_core_web_lg')
list_of_list = [tokenizer.tokenize(x) for x in df["Text"].values]
flat_list = [item for sublist in list_of_list for item in sublist]
small_vocab = list(set(flat_list))
spacy_vocab = [nlp_bigger(x) for x in small_vocab] 
vocab_2d = [[x] for x in spacy_vocab]

In [148]:
nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
filtered_vocab = [w for w in spacy_vocab if not str(w) in stop_words]
filtered_stopwords = [str(w) for w in spacy_vocab if str(w) in stop_words]

#removing punctuation from vocab
print(string.punctuation)
stripped_vocab = [w for w in filtered_vocab if not str(w) in string.punctuation] 
stripped_punctuation = [str(w) for w in filtered_vocab if str(w) in string.punctuation]

#removing digits from vocab
stripped_of_digits_vocab = [w for w in stripped_vocab if not str(w).isdigit()] 
stripped_digits = [str(w) for w in stripped_vocab if str(w).isdigit()]

#removing syllables from vocab
stripped_of_syllables_vocab = [w for w in stripped_of_digits_vocab if not "#" in str(w)] 
stripped_syllables = [str(w) for w in stripped_of_digits_vocab if "#" in str(w)]
 

[nltk_data] Downloading package stopwords to /home/user/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~


In [149]:
stripped_of_syllables_vocab_2d = [[x] for x in stripped_of_syllables_vocab]
print(len(stripped_of_syllables_vocab_2d))

5201


In [150]:
def similarity_func(u, v):

    if u[0].text == "#":
        token_one = u[len(u)-1]
    else:
        token_one = u[0]
    if v[0].text == "#":
        token_two = v[len(v)-1]
    else:
        token_two = v[0]
    return token_one.similarity(token_two)

In [151]:
dists = pdist(np.array(stripped_of_syllables_vocab_2d), similarity_func) 
similarity_matrix = squareform(dists)

  """Entry point for launching an IPython kernel.


In [152]:
NUMBER_OF_CLUSTERS = 13
#important to plug in the "1-similarity_matrix" !!!!
cluster_model = AgglomerativeClustering(affinity='precomputed', n_clusters=NUMBER_OF_CLUSTERS, linkage='complete').fit(1-similarity_matrix)

In [153]:
clustered_vocab = [[] for x in range(NUMBER_OF_CLUSTERS)]
_ = [clustered_vocab[cluster_model.labels_[index]].append(str(x)) for index, x in enumerate(stripped_of_syllables_vocab)]

In [154]:
token_groupings = {str(index) : element for index, element in enumerate(clustered_vocab)} 

token_groupings["stop_words"] = filtered_stopwords
token_groupings["punctuation"] = stripped_punctuation
token_groupings["digits"] = stripped_digits
token_groupings["syllables"] = stripped_syllables

In [161]:
dictionary_of_vocab = {
    "description": "",
    "text": (df["Text"].values,df["Stance"].values),
    "NUMBER_OF_CLUSTERS": NUMBER_OF_CLUSTERS,
    "similarity_matrix": similarity_matrix,
    "cluster_model": cluster_model,
    "clustered_vocab": clustered_vocab,
    "token_groupings": token_groupings
}

# !!!!!!!give a new name always to never overwrite existing data!!!!!!!!!
unique_named_file = 'dictionary_of_vocab_' + '13cluster_cleaned_test_samples32' + '.pickle'

file = open(os.path.join(notebookDir, "vocab/" + unique_named_file), 'wb')
pickle.dump(dictionary_of_vocab, file)
file.close()

In [162]:
dictionary_of_vocab_cluster_cleaned_test_samples = pickle.load(open(os.path.join(notebookDir, "vocab/" + unique_named_file), 'rb'))

_ = [print(key,':',value) for key, value in dictionary_of_vocab_cluster_cleaned_test_samples.items()]

is', 'we', 'other', 'own', 'weren', 'if', 'over', 'up', 'in', 'these', 'same', 'do', 'here', 'few', 'hers', 'don', 'ourselves', 'shouldn', 'then', 'wasn', 'just', 'because', 'myself', 'll', 'yours', 'was', 'be', 'wouldn', 'has', 'that', 'nor', 'can', 'so', 'with', 'now', 'yourself', 'her', 'your', 'any', 'my', 'for', 'such', 'himself', 'their', 'and', 'where', 'while', 'you', 'd', 'ma', 'until', 'me', 'above', 'o', 'why', 'herself', 'been', 't', 'a', 'doesn', 'hasn', 'more', 'm', 'about', 'who', 'isn', 'very'], 'punctuation': ['%', '$', '~', '!', '@', '.', '=', "'", '/', '"', '^', '*', ')', '>', '#', '+', ';', '?', ']', ':', '-', '_', '[', '&', '<', '(', ','], 'digits': ['322', '82', '92', '28', '35', '9', '2010', '68', '42', '26', '143', '205', '2', '300', '0', '4', '130', '45', '58', '25', '37', '1791', '52', '86', '67', '89', '3000', '324', '47', '99', '51', '160', '41', '500', '6', '40', '19', '12', '91', '1500', '140', '44', '61', '14', '16', '32', '297', '36', '57', '100', '30', 

In [208]:
number_of_groups = len(dictionary_of_vocab_cluster_cleaned_test_samples["token_groupings"])

# Set up imputer object
class Imputer_groupings:
    def __init__(self, model, number_of_groups):
        self.model = model
        self.num_groups = number_of_groups
    
    def __call__(self, input_array, S):
        max_length = 0
        reconstructed_array = []
        for index, sentence in enumerate(input_array):
            length_of_sentence = max(x[0] for group in sentence for x in group)
            if length_of_sentence > max_length:
                max_length = length_of_sentence

            original_input_ids = [None] * (length_of_sentence+1)
            for sub_index, group in enumerate(sentence):
                
                if S[index][sub_index]: #put in '[MASK]' elements if needed
                    for x in group:
                        original_input_ids[x[0]]=x[1]
                else: 
                    for x in group:
                        original_input_ids[x[0]]=103 #id of '[MASK]'

            original_input_ids.append(102)
            original_input_ids.insert(0, 101)
            reconstructed_array.append(original_input_ids)

        max_length+=3

        for index, sentence in enumerate(reconstructed_array):
            if len(sentence) < max_length:
                reconstructed_array[index].extend([0 for i in range(max_length-len(sentence))])

        input_ids = np.array(reconstructed_array)

        # "handmade" attention mask -> basically just set everything to one, except '[PAD]'s which are zero
        am = np.ones(input_ids.shape)
        am[input_ids == 0] = 0 #id of '[PAD]'
        tensor_attention_mask = torch.tensor(am)
        tensor_attention_mask = tensor_attention_mask.to(model.device)


        tensor_input_ids = torch.tensor(input_ids)
        tensor_input_ids = tensor_input_ids.to(model.device)


        #predict with model
        outputs = self.model(tensor_input_ids, tensor_attention_mask)
        outputs = outputs.detach().cpu().numpy()
        
        score_most_prob = [max(x) for x in outputs]

        return np.array(score_most_prob)

In [176]:
import textwrap
current_groupings = dictionary_of_vocab_cluster_cleaned_test_samples["token_groupings"]
feature_names = list(current_groupings.keys())
sage_input_all_instances = []
for text_element in dictionary_of_vocab_cluster_cleaned_test_samples["text"][0]:
    sage_input_groupings = [[] for x in range(number_of_groups)]
    input_text = tokenizer.tokenize(text_element)
    input_ids = tokenizer(text_element, add_special_tokens=False)["input_ids"]
    groups_to_text = [next((key for key, value in current_groupings.items() if x in value)) for x in input_text]
    print("\n".join(textwrap.wrap(text_element, 170)))
    print("\n".join(textwrap.wrap("--".join(groups_to_text), 170)))
    print("-"*170)
    positioned_ids = [(i, x) for i, x in enumerate(input_ids)]
    _ = [sage_input_groupings[list(current_groupings.keys()).index(groups_to_text[index])].append(x)  for index, x in enumerate(positioned_ids)]
    sage_input_all_instances.append(sage_input_groupings)

sage_input_all_instances_array = np.array(sage_input_all_instances, dtype=object)

es--syllables--0--stop_words--stop_words--8--stop_words--7--stop_words--stop_words--0--stop_words--punctuation--3--syllables--syllables
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Politics, Religion, Race; we are so busy killing & hurting each other, Earth dries up & we will all die anyway. #drought #SemST
4--punctuation--5--punctuation--1--punctuation--stop_words--stop_words--stop_words--0--6--punctuation--7--stop_words--stop_words--punctuation--5--2--syllables--stop_words
--punctuation--stop_words--stop_words--stop_words--10--0--punctuation--punctuation--6--punctuation--3--syllables--syllables
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@sassy_gramma Good point but  Our heart starts beating 22 days after conception we are alive at con

In [209]:
imputer = Imputer_groupings(model, number_of_groups)
groupings_estimator = sage.PermutationEstimator(imputer) 

In [166]:
# sage_values = estimator(x, y)
sage_values_groupings = groupings_estimator(sage_input_all_instances_array, np.array(dictionary_of_vocab_cluster_cleaned_test_samples["text"][1]), batch_size=64, verbose=True,thresh=0.10) 

TypeError: forward() takes 2 positional arguments but 3 were given

In [210]:
np.ones((64, imputer.num_groups), dtype=bool)

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

In [205]:
sage_input_all_instances_array[:64]

array([[list([(13, 2034)]), list([(3, 2131), (8, 2350), (11, 2478)]),
        list([]), ...,
        list([(1, 1005), (6, 1012), (24, 1012), (25, 1001)]), list([]),
        list([(21, 19170), (22, 13320), (23, 3258), (27, 5244), (28, 2102)])],
       [list([(24, 2409)]), list([(9, 2694), (12, 2086)]),
        list([(1, 16606)]), ...,
        list([(13, 1010), (17, 1005), (22, 1001), (28, 1012), (29, 1001), (33, 1001)]),
        list([(11, 4029)]),
        list([(4, 20715), (5, 4103), (31, 16558), (32, 5657), (35, 5244), (36, 2102)])],
       [list([]), list([]), list([]), ...,
        list([(0, 1001), (16, 1005), (19, 1001), (24, 1001), (27, 1001)]),
        list([]),
        list([(2, 3669), (3, 6961), (4, 18900), (5, 3334), (21, 2964), (22, 3126), (23, 4063), (26, 27179), (29, 5244), (30, 2102)])],
       ...,
       [list([(10, 2467), (16, 2409), (27, 2879)]), list([]), list([]),
        ...,
        list([(3, 1006), (7, 1011), (14, 1007), (20, 1005), (28, 1013), (31, 1012), (32, 10

In [None]:
sign_estimator = sage.SignEstimator(Imputer_groupings())
sage_sign_groupings = sign_estimator(sage_input_all_instances_array, np.array(dictionary_of_vocab_cluster_cleaned_test_samples["text"][1]), batch_size=64,sign_confidence=0.95,narrow_thresh=0.10)


In [None]:
sage_sign_groupings.plot_sign(feature_names) 

In [None]:
sage_values_groupings.plot(feature_names) 