# Imports

In [2]:
#standard libraries
import copy
import filecmp
import gc
import inspect
import logging
import json
import os
import os.path
os.makedirs('results', exist_ok=True)
import pickle
import random
import string
import time
import unicodedata

#third-party libraries
from datasets import load_dataset, load_from_disk
from itertools import combinations
import lime
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from scipy.stats import spearmanr
import seaborn as sns
import shap
import sklearn
from sklearn.cluster import DBSCAN
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, roc_auc_score, classification_report,roc_curve, auc
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments,TrainerCallback,AutoModel,BertConfig as BertConfig, AutoModelForSequenceClassification

#const
labels_list = ['chinese', 'haitian', 'roma', 'muslim', 'nigerian', 'indian', 'christian']

In [None]:
class Comparator:
    @staticmethod
    def are_files_equal(file1,file2):
        return filecmp.cmp(file1,file2,shallow=False)
    @staticmethod
    def are_dir_trees_equal(dir1, dir2):
        """
        Compare two directories recursively. Files in each directory are
        assumed to be equal if their names and contents are equal.

        @param dir1: First directory path
        @param dir2: Second directory path

        @return: True if the directory trees are the same and
            there were no errors while accessing the directories or files,
            False otherwise.
        """

        dirs_cmp = filecmp.dircmp(dir1, dir2)
        if len(dirs_cmp.left_only)>0 or len(dirs_cmp.right_only)>0 or \
            len(dirs_cmp.funny_files)>0:
            return False
        (_, mismatch, errors) =  filecmp.cmpfiles(
            dir1, dir2, dirs_cmp.common_files, shallow=False)
        if len(mismatch)>0 or len(errors)>0:
            return False
        for common_dir in dirs_cmp.common_dirs:
            new_dir1 = os.path.join(dir1, common_dir)
            new_dir2 = os.path.join(dir2, common_dir)
            if not Comparator.are_dir_trees_equal(new_dir1, new_dir2):
                return False
        return True
class Loader:
    @staticmethod
    def import_dataset_and_model(dataset_name,model_name,paragraph_selection_strategy):
        dataset_path, model_path, num_labels = Loader.import_paths_and_nlabels(dataset_name,model_name,paragraph_selection_strategy)
        train_set = load_from_disk(dataset_path+'train_set')
        test_set = load_from_disk(dataset_path+'test_set')
        validation_set = load_from_disk(dataset_path+'validation_set')
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSequenceClassification.from_pretrained(model_path,num_labels = num_labels)
        return train_set,test_set,validation_set,tokenizer,model

    @staticmethod
    def import_paths_and_nlabels(dataset_name,model_name,paragraph_selection_strategy):
        if dataset_name not in ['asylex-outcome','asylex-norp','sentiment1','sentiment2']:
            raise ValueError('dataset name not found')
        if model_name not in ['bert','roberta']:
            raise ValueError('model name not found')
        if paragraph_selection_strategy not in ['first','last','rand','cas','']:
            raise ValueError('long_text technique not found')

        base_dataset_path = 'datasets/' + model_name + '/'
        base_model_path = 'Models/'  + model_name + '/'

        dataset_filename = {'asylex-norp':'norp_','asylex-outcome':'outcome_','sentiment1':'sentiment1','sentiment2':'sentiment2'}
        model_filename = {'bert' : 'BERT512-', 'roberta' : 'RoBERTa512-'}
        model_filename2 = {'asylex-norp':'norp_','asylex-outcome':'out_','sentiment1':'sentiment1','sentiment2':'sentiment2'}

        if dataset_name == 'asylex-norp':
            num_labels = 7
        else:
            num_labels = 2

        dataset_path = base_dataset_path+ dataset_name+ '/' + dataset_filename[dataset_name] + paragraph_selection_strategy + '_'
        model_path = base_model_path + dataset_name + '/' + model_filename[model_name] + model_filename2[dataset_name] +paragraph_selection_strategy
        return dataset_path, model_path, num_labels
    @staticmethod
    def create_embeddings_path(dataset_name,model_name,paragraph_selection_strategy):
        if dataset_name not in ['asylex-outcome','asylex-norp','sentiment1','sentiment2']:
            raise ValueError('dataset name not found')
        if model_name not in ['bert','roberta']:
            raise ValueError('model name not found')
        if paragraph_selection_strategy not in ['first','last','rand','cas','']:
            raise ValueError('long_text technique not found')
        embeddings_path = 'embeddings/' + model_name + '/' + dataset_name + '/'
        if paragraph_selection_strategy != '':
            embeddings_path += paragraph_selection_strategy + '/'
        return embeddings_path

class method_list_data():
    def __init__(self,method_names_list):
        self.method_data_list = []
        for method_name in method_names_list:
            self.method_data_list.append(DataObject(['counts','ground_truth_list','input_ids_list','p_org_list','pred_class_list','sentence','sentences_list','tokens_lists','undecided_threshold','visual_explanations_lists'],method_name))
    def get_method_data(self,method_name):
        return next(method_data for method_data in self.method_data_list if method_data.name == method_name)
    def get_counts(self):
        return self.method_data_list[0].df['counts'].tolist()

class DataObject:
    def __init__(self,columns,name=''):
        self.columns = columns
        self.df = pd.DataFrame(columns=columns)
        self.name = name
    def put(self,row):
        if len(row) != len(self.columns):
            raise ValueError('Row length doesnt match with column number')
        new_row = dict(zip(self.columns,row))
        self.df.loc[len(self.df)] = new_row
    def write(self,path,format):
        if format=='csv':
            self.df.to_csv(path+'.csv',index=False)
        elif format=='pickle':
            with open(path+'.pkl', 'wb') as f:
                pickle.dump(self, f)
        else:
            raise ValueError('non supported format')
    def get(self,column,index):
        return self.df[column].iloc[index]
    def get_unpacked_row(self,index):
        return tuple(self.df.loc[index,self.columns])
    def __repr__(self):
        return repr(self.df)
    def read(self,path,format):
        if format=='csv':
            self.df = pd.read_csv(path+'.csv')
        elif format == 'pickle':
            with open(path+'.pkl','rb') as f:
                readed_object = pickle.load(f)
                self.df = readed_object.df
                self.name = readed_object.name
        else:
            raise ValueError('not supported format')

class embeddings_manager:
    @staticmethod
    def _calculate_embeddings(dataset,tokenizer,model,n_embeddings):
        data = DataObject(['predicted_classes','ground_truth','probs','embeddings','sentences'])
        for i in range(n_embeddings):
            inputs = {k: torch.tensor(v) for k, v in dataset.select_columns(['input_ids','attention_mask'])[i].items()} #removed token_type_ids, should be useless
            sentence = tokenizer.decode(dataset['input_ids'][i], skip_special_tokens=True)
            # inputs.to('cuda')
            #model.to('cuda')
            with torch.no_grad():
                # Check if input_ids has the expected shape
                if len(inputs['input_ids'].shape) == 1:
                    # Reshape input_ids to have a batch size of 1
                    inputs['input_ids'] = inputs['input_ids'].unsqueeze(0)
                # Check and reshape attention_mask and token_type_ids as well
                if len(inputs['attention_mask'].shape) == 1:
                    inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(0)
                # if len(inputs['token_type_ids'].shape) == 1:
                #     inputs['token_type_ids'] = inputs['token_type_ids'].unsqueeze(0)
                outputs = model(**inputs,output_hidden_states = True)
            print(outputs.logits)
            data.put([torch.argmax(outputs.logits, dim=1).item(),dataset['labels'][i],max(F.softmax(outputs.logits,dim = 1).tolist()[0]),outputs.hidden_states[12].squeeze(0),sentence])
        return data
    
    @staticmethod
    def generate_embeddings(dataset_name,model_name,paragraph_selection_strategy,n_embeddings):
        train_set,test_set,validation_set,tokenizer,model = Loader.import_dataset_and_model(dataset_name,model_name,paragraph_selection_strategy)
        dataset = test_set
        embs = embeddings_manager._calculate_embeddings(dataset,tokenizer,model,n_embeddings)
        embs.write(Loader.create_embeddings_path(dataset_name,model_name,paragraph_selection_strategy)+'embeddings','pickle')
    
    @staticmethod
    def generate_all_embeddings():
        def todo_function(dataset_name,model_name,paragraph_selection_strategy):
            train_set,test_set,validation_set,tokenizer,model = Loader.import_dataset_and_model(dataset_name,model_name,paragraph_selection_strategy)
            dataset = test_set
            embs = embeddings_manager._calculate_embeddings(dataset,tokenizer,model,min(500,len(dataset)-1))
            embs.write(Loader.create_embeddings_path(dataset_name,model_name,paragraph_selection_strategy)+'embeddings','pickle')
        repeat_for_all(todo_function)


In [5]:
class AnalysisCreator:
    def __init__(self,analysis_name):
        self.output_dir = 'results/'+analysis_name
        os.makedirs(self.output_dir, exist_ok=True)
        self.run_number = 0

    def _setup_logging(self,run_dir):
        log_file = os.path.join(run_dir, "pipeline.log")
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        self.logger.handlers.clear()  # Rimuovi handler precedenti (utile in Jupyter o test)
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
        self.logger.addHandler(file_handler)

    def _create_new_run(self,method_list,n_sentences,undecided_threshold):
        self.run_number += 1
        run_dir = os.path.join(self.output_dir,'run_'+ str(self.run_number))
        os.makedirs(run_dir, exist_ok=True)
        self._setup_logging(run_dir)

        global output_plot_path
        output_plot_path = os.path.join(run_dir,'plots')
        os.makedirs(output_plot_path, exist_ok=True)

        os.makedirs(output_plot_path+'/conf', exist_ok=True)
        os.makedirs(output_plot_path+'/hist', exist_ok=True)
        os.makedirs(output_plot_path+'/roc', exist_ok=True)

        global data_path
        data_path = os.path.join(run_dir,'data.pkl')

        global analytics_csv_path
        analytics_csv_path = os.path.join(run_dir,'analytics.csv')

        global resume_path
        resume_path = os.path.join(run_dir,'resume.txt')
        with open(resume_path,'w') as f:
            f.write('RESUME\n'+model_name+' '+dataset_name+' '+paragraph_selection_strategy+'\n')
            f.write('run number '+str(self.run_number)+'\n')
            f.write('n_sentences: '+str(n_sentences)+' undecided_threshold: '+str(undecided_threshold)+'\n')
            for method in method_list:
                f.write(method.name +' '+ method.input_args +'\n')

    def load_embeddings(self,n_embeddings):
        embeddings_data = DataObject(['generic'])
        embeddings_data.read(Loader.create_embeddings_path(dataset_name,model_name,paragraph_selection_strategy)+'embeddings','pickle')
        embeddings_data = embeddings_data.df.head(n_embeddings)
        self.predicted_classes = embeddings_data['predicted_classes'].tolist()
        self.ground_truth = embeddings_data['ground_truth'].tolist()
        self.probs = embeddings_data['probs'].tolist()
        self.embeddings = embeddings_data['embeddings'].tolist()
        self.sentences = embeddings_data['sentences'].tolist()
        if len(self.predicted_classes) != n_embeddings:
            print('actual embeddings number:'+str(len(self.predicted_classes)))

    def create_scores(self,method_list,n_sentences='max',undecided_threshold=0.4):
        if n_sentences == 'max':
            n_sentences = len(self.predicted_classes)
        self._create_new_run(method_list,n_sentences,undecided_threshold)
        self.data = method_list_data([method.name for method in method_list])
        for count in range(n_sentences):
            if self.to_skip(count):
                # self.visual_explanations.append(0)
                continue
            if not is_sentence_included(dataset['input_ids'][count],clean_input_sentences(dataset['all_sentences'][count]),tokenizer.tokenize(self.sentences[count])):
                # self.visual_explanations.append(0)
                continue
            for my_method in method_list:
                sentence = self.sentences[count]
                embedding = self.embeddings[count]
                p_org = self.probs[count]
                pred_class = self.predicted_classes[count]
                tokens = tokenizer.tokenize(sentence)
                ground_truth = self.ground_truth[count]
                visual_explanation = my_method.method_pipeline(sentence,embedding,tokens,p_org,pred_class)
                if len(visual_explanation) != len(tokens):
                    print(my_method.name)
                    print(visual_explanation)
                    print(tokens)
                    print('ERROR IN LENGTH')
                self.data.get_method_data(my_method.name).put([count,ground_truth,dataset['input_ids'][count],p_org,pred_class,sentence,clean_input_sentences(dataset['all_sentences'][count]),tokens,undecided_threshold,visual_explanation])
        self.save_scores()

    def to_skip(self,count):
        if model_name =='roberta' and dataset_name=='asylex-outcome' and paragraph_selection_strategy == 'last' and count==20:
            return True
        elif model_name =='bert' and dataset_name=='asylex-norp' and paragraph_selection_strategy == 'first' and count==92:
            return True
        if count != 0 and self.sentences[count] == self.sentences[count-1]: #avoid printing duplicates
            return True
        if self.sentences[count].startswith('##') or self.sentences[count].startswith('Ġ'):
            print('started with ##, so skipped')
            return True
        return False

    def create_comparison(self,method_list):
        # my_lime = LIME(['negative','positive'],510,1000)
        to_print = {}
        to_print['true_positive'] = 'Positive: '
        to_print['false_positive'] = 'Negative '
        to_print['true_negative'] = 'Negative '
        to_print['false_negative'] = 'Positive '
        to_print['undecided_positive'] = 'Positive '
        to_print['undecided_negative'] = 'Negative '

        for count in range(len(self.data.get_counts())):
            print(count)
            for my_method in method_list:
                _,ground_truth,input_ids,p_org,pred_class,sentence,all_sentences,tokens,undecided_threshold,visual_explanation = self.data.get_method_data(my_method.name).get_unpacked_row(count)
                plot_category = classification_analysis(pred_class,p_org,ground_truth,undecided_threshold)

                indexes_list = None
                indexes_list = my_method.run_analytics(visual_explanation,input_ids,all_sentences,tokens,count,my_method.name)
                print(indexes_list)
                if indexes_list == None:
                    indexes_list = [[-1,-1],[-1,-1]]

                if dataset_name == 'asylex-norp':
                    outcome_info =  'g_truth: '+labels_list[ground_truth] + 'prediction:' + labels_list[pred_class]
                else:
                    outcome_info = ''

                to_print[plot_category] += '<br>' + create_heatmap(tokens,visual_explanation,pred_class,p_org,indexes_list,undecided_threshold, my_method.clipped_heatmap) +'  '+ my_method.name + outcome_info
            to_print[plot_category] += '<br>'

        printHTML('Correctly classified:<br>'+to_print['true_positive'] + '<br><br>' +to_print['true_negative'] + '<br><br>' +
                'Undecided: <br>'+to_print['undecided_positive'] +'<br><br>'+to_print['undecided_negative'] +'<br><br>'+
                'Misclassified: <br>'+to_print['false_negative'] +'<br><br>' + to_print['false_positive'],'_'+dataset_name+'_'+model_name+'_'+paragraph_selection_strategy)
        analytics_dfs = []
        for my_method in method_list:
            to_resume(my_method.name + ' RESUME')
            my_method.stat_resume()
            to_resume(len(my_method.tokens_lists))
            to_resume(len(my_method.visual_explanations_lists))
            my_method.analytics.stats.df['method_name'] = my_method.name
            analytics_dfs.append(my_method.analytics.stats.df)

            if my_lime.visual_explanations_lists != []:
                to_resume(my_method.name + ' rank corr')
                rank_corr_scores = []
                p_values = []
                for index,v_expl in enumerate(my_lime.visual_explanations_lists):
                    corr, p_value = spearmanr(v_expl,my_method.visual_explanations_lists[index])
                    rank_corr_scores.append(corr)
                    p_values.append(p_value)
                to_resume(sum(rank_corr_scores)/len(rank_corr_scores))
                to_resume(sum(p_values)/len(p_values))
                to_resume()
        analytics_dfs = pd.concat(analytics_dfs,ignore_index=True)
        analytics_dfs.to_csv(analytics_csv_path,index=False)

    def save_scores(self):
        with open(data_path,'wb') as f:
            pickle.dump(self.data,f)
    def load_scores(self):
        with open(data_path,'rb') as f:
            self.data = pickle.load(f)


# XAI methods

In [6]:
class Generic_xai_method():
    def __init__(self,clipped_heatmap,uniqueness,analytics=False):
        self.name = 'generic_method'
        self.clipped_heatmap = clipped_heatmap
        if uniqueness == False:
            self.calculate_uniqueness = self.no_uniqueness
        self.analytics = analytics
        self.tokens_lists = []
        self.visual_explanations_lists = []
        self.input_args = get_input_args()
    #utils
    def run_analytics(self,similarities,text_ids,sentences,tokens,count,name):
        if self.analytics==False:
            pass
        else:
            indexes = self.analytics.run_analysis(similarities,text_ids,sentences,tokens,count,name)
            return(indexes)

    def stat_resume(self): #check if work
        if self.analytics==False:
            pass
        else:
            self.analytics.stat_resume()

    def calculate_uniqueness(self,probabilities):
        probabilities = np.array(probabilities)
        if len(probabilities) == 1:
            uniqueness_vector = np.array([1])
        else:
            uniqueness_vector = np.sum(np.abs(probabilities[:, None] - probabilities), axis=1)

        return uniqueness_vector
    def no_uniqueness(self,probabilities):
        return np.ones(len(probabilities))

    def calculate_SD(self,probabilities,p_org,sigma):
        probabilities = np.array(probabilities)
        return np.exp(-np.abs(p_org - probabilities)/2*(sigma**2))

    def method_pipeline(self,sentence,embedding,tokens,p_org,pred_class):
        filters = self.generate_filters(sentence,embedding)
        #probabilities = my_method.filter_prob_masked_attention(filters,sentence,predicted_classes[count])
        probabilities = self.filters_to_probabilities(filters,sentence,pred_class)
        visual_explanation = np.ones((len(filters),len(tokens)))
        uniqueness = self.calculate_uniqueness(probabilities)
        similarity_difference = self.calculate_SD(probabilities,p_org,self.sdparam)
        weights = (uniqueness*similarity_difference)
        for index,filter in enumerate(filters):
            visual_explanation[index,filter] = 0
            visual_explanation[index] = visual_explanation[index]*weights[index]
        visual_explanation = np.mean(visual_explanation,axis = 0)

        self.tokens_lists.append(tokens)
        self.visual_explanations_lists.append(visual_explanation)
        return visual_explanation


In [7]:
#utils

def create_heatmap(tokens, similarities,pred_class,p_org,underlined_text_indices,undecided_threshold,clipped=True):  #se lo togli toglilo anche di qui. vedi punto dopo
    highlighted_text = ""
    if clipped:
        norm = plt.Normalize(vmin=np.min(similarities[similarities>0]), vmax=similarities.max(),clip = True)
    else:
        norm = plt.Normalize(vmin=similarities.min(), vmax=similarities.max())
    if p_org>undecided_threshold:
        if pred_class == 0:
            cmap = plt.get_cmap('Blues')
        else:
            cmap = plt.get_cmap('Greens')
    else:
        cmap = plt.get_cmap('Greys')
    prediction = 'Prob:'+str(round(p_org,2))+ '  -'
    highlighted_text += f"<span style='color: black;'>{prediction}</span> "


    i = 0
    for token, similarity in zip(tokens, similarities):
        color = cmap(norm(similarity))  # Ottieni il colore in base alla similarity
        rgb = (int(color[0] * 220), int(color[1] * 220), int(color[2] * 220))  # Converte in RGB
        hex_color = '#{:02x}{:02x}{:02x}'.format(*rgb)  # Converti in formato hex

        if i in underlined_text_indices[0]:
            highlighted_text += f'<div class="box" style="background-color: yellow; border: 1px solid black;">start</div>\n'
        if token.startswith('##'):
            highlighted_text += f'<div class="box" style="background-color: {hex_color};">{token[2:]}</div>\n'
        elif token.startswith('Ġ'):
            highlighted_text += f'<div class="box" style="background-color: {hex_color};">{token[1:]}</div>\n'
        else:
            highlighted_text += f'<div class="box" style="background-color: {hex_color};">{token}</div>\n'
        if i in underlined_text_indices[1]:
            highlighted_text += f'<div class="box" style="background-color: yellow; border: 1px solid black;">end</div>\n'
        i += 1

    return highlighted_text

def printHTML(highlighted_texts,epoch):
    text_html = '''
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <style>
        body {
            font-size: 16px; /* Riduce la dimensione generale del testo */
            line-height: 40px; /* Riduce l'altezza delle righe */
        }
            .box {
                display: inline-block;
                margin: 1px;
                padding: 4px;
                border-radius: 3px;
                color: white;
                font-family: Arial, sans_serif;
                font-size: 16px;
                line-height: normal;
            }

        </style>
    </head>
    <body>'''+highlighted_texts+'''

    </body>
    </html>
    '''
    with open(output_plot_path+"/heatmap"+str(epoch)+".html", "w", encoding="utf-8") as file:
            file.write(text_html)

def load_data(data_folder, embedding_file):
    embedding_path = os.path.join(data_folder, embedding_file)
    data = torch.load(embedding_path, map_location='cpu') # dict
    keys, values = list(data.keys()), list(data.values())
    return keys, values

def predict(sentence,tokenizer,model):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
    with torch.no_grad():
        outputs = model(**inputs)
    print(outputs.logits)
    probabilities = F.softmax(outputs.logits,dim = 1).tolist()[0]
    pred_class = torch.argmax(outputs.logits, dim=1).item()
    return max(probabilities),pred_class

def classification_analysis(pred_class,prob,true_class,undecided_threshold):
    if prob>undecided_threshold:
        if true_class == 1:
            if pred_class == 1:
                return 'true_positive'
            else:
                return 'false_negative'
        else:
            if pred_class == 1:
                return 'false_positive'
            else:
                return 'true_negative'
    else:
        if true_class == 1:
            return 'undecided_positive'
        else:
            return 'undecided_negative'
def classify_embedding(embedding):
    pooler_output = model.bert.pooler.dense(embedding)
    pooler_output = model.bert.pooler.activation(pooler_output)
    with torch.no_grad():
        logits = model.classifier(pooler_output)
    return logits

def find_sublist_in_list(my_list,my_sublist):
    my_list_str = ','.join(map(str, my_list))
    my_sublist_str = ','.join(map(str, my_sublist))
    sublist_index = my_list_str.find(my_sublist_str)
    if sublist_index == -1:
        return [-1, -1]
    my_sublist_start_index = my_list_str[:sublist_index].count(',')
    my_sublist_end_index = my_sublist_start_index + len(my_sublist)
    return [my_sublist_start_index, my_sublist_end_index]

def find_sublists_in_list(my_list, my_list_of_sublists):
    list_of_indexes = []
    for my_sublist in my_list_of_sublists:
        index_entry = find_sublist_in_list(my_list,my_sublist)
        list_of_indexes.append(index_entry)
    return list_of_indexes

def clean_input_sentences(sentences):
    if type(sentences) != list:
        sentences = [sentences]
    sentences = list(dict.fromkeys(sentences)) #remove duplicates
    sorted_sentences = sorted(sentences, key=lambda x: -len(x))
    risultato = []
    for s in sorted_sentences:
        if not any(s in altro for altro in risultato):
            risultato.append(s)
    risultato_finale = [x for x in sentences if x in risultato]
    return risultato_finale

def clean_roberta_tokens(token):
    token = unicodedata.normalize('NFKD',token)
    token = token.encode('ascii','ignore').decode('ascii')
    return token

def get_input_args():
    frame = inspect.currentframe().f_back  # Passa al frame della funzione chiamante
    args, _, _, values = inspect.getargvalues(frame)

    def safe_repr(val):
        if isinstance(val, (int, float, str, bool, type(None))):
            return repr(val)
        else:
            return f"<{type(val).__name__} object>"

    inputs = ', '.join(
        f"{arg}={safe_repr(values[arg])}"
        for arg in args if arg != 'self'
    )
    return inputs + '\n'

def to_resume(to_write=''):
    with open(resume_path,'a') as f:
        f.write(str(to_write)+'\n')

def is_sentence_included(text_ids,sentences,tokens):
    if type(text_ids) != list: #if it is pytorch, it should be converted to list
        text_ids = text_ids.tolist()
    if type(sentences) == str:
        sentences = [sentences]
    if model_name == 'roberta':
        sentences = [elem for elem in sentences if elem.strip()]

    text_ids = text_ids[1:len(tokens)+1]
    sentence_included = False
    for sentence in sentences:
        if model_name == 'roberta':
            tokenized_sentence = tokenizer(' '+sentence)['input_ids'][1:-1]
        else:
            tokenized_sentence = tokenizer(sentence)['input_ids'][1:-1]
        indexes = [-1,-1]
        for i in range(0,len(text_ids)-len(tokenized_sentence)+1):
            if text_ids[i:i+len(tokenized_sentence)] == tokenized_sentence:
                indexes = [i,i+len(tokenized_sentence)]
                i = i + len(tokenized_sentence)
                sentence_included = True
    return sentence_included

def repeat_for_all(todo_function):
    for dataset_name in ['asylex-norp','asylex-outcome','sentiment1','sentiment2']:
        for model_name in ['bert','roberta']:
            if dataset_name == 'asylex-outcome' or dataset_name == 'asylex-norp':
                for paragraph_selection_strategy in ['first','last','cas','rand']:
                    print(dataset_name + '_' + model_name +'_'+paragraph_selection_strategy)
                    todo_function(dataset_name,model_name,paragraph_selection_strategy)
            else:
                paragraph_selection_strategy = ''
                print(dataset_name + '_' + model_name +'_'+paragraph_selection_strategy)
                todo_function(dataset_name,model_name,paragraph_selection_strategy)

In [8]:
class TokenizerWrapper:
    def __init__(self,tokenizer):
        self.tokenizer = tokenizer
    def __call__(self,text):
        tokens = self.tokenizer.tokenize(text)
        tokens = [token.removeprefix('##') for token in tokens]
        return tokens

class LIME(Generic_xai_method):
    def __init__(self, class_names, num_features, num_samples,clipped_heatmap = False, uniqueness = True, analytics = False):
        super().__init__(clipped_heatmap,uniqueness,analytics)
        self.name = 'LIME'
        self.num_features = num_features
        self.num_samples = num_samples
        self.tokens_lists = []
        self.visual_explanations_lists = []
        self.input_args = self.input_args + get_input_args() + (self.analytics.input_args if self.analytics != False else '')
        self.tokenizer_for_lime = TokenizerWrapper(tokenizer)
        # model.to(device).eval()  # Sposta il modello e imposta eval()

    def predictor(self, texts, batch_size=64):
        probas = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device)
            with torch.no_grad():
                outputs = model(**inputs)
                probas.extend(F.softmax(outputs.logits, dim=-1).cpu().numpy())
            del inputs, outputs
        return np.array(probas)

    def create_heatmap(self,input,pred_class,p_org,undecided_threshold):
        token_weights = self.exp.as_list()
        print(token_weights)
        for token, weight in token_weights:
            print(f"{token}: {weight:.4f}")
        return create_heatmap(self.shap_values.data[0].tolist(),vals,pred_class,p_org,undecided_threshold)

    def method_pipeline(self,sentence,embedding,tokens,p_org,pred_class):

        input_sentence = tokenizer.tokenize(sentence)
        # input_sentence = [t[2:] if t.startswith('##') else t for t in input_sentence]

        # if model_name == 'roberta':
        #     input_sentence = ['_' if t=='Ġ' else t for t in input_sentence]
        #     input_sentence = [t[1:] if t.startswith('Ġ') else t for t in input_sentence]
        #     input_sentence = [clean_roberta_tokens(t) for t in input_sentence]
        #     input_sentence = ['_' if t == '' else t for t in input_sentence]

        # input_sentence = ['_' if unicodedata.category(t[0]).startswith('P') else t for t in input_sentence]

        self.num_features = len(input_sentence)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        model.to(device).eval()  # Sposta il modello e imposta eval()
        if dataset_name == 'asylex-norp':
            class_names = [0,1,2,3,4,5,6]
        else:
            class_names = [0,1]
        explainer = LimeTextExplainer(class_names=class_names, bow=False, split_expression = ' ',random_state=1)
        exp = explainer.explain_instance(
            ' '.join(input_sentence),
            self.predictor,
            labels=[pred_class],
            num_features=self.num_features,
            num_samples=self.num_samples  # Ora gestiti a batch dentro predictor
        )

        idxs = []
        weights = []
        for idx, weight in exp.local_exp[pred_class]:  # Assumendo classe 1
            idxs.append(int(idx))
            weights.append(float(weight))

        idxs = np.array(idxs)
        weights = np.array(weights)
        correct_order = np.argsort(idxs)
        visual_explanation = weights[correct_order]

        self.tokens_lists.append(tokens)
        self.visual_explanations_lists.append(visual_explanation)

        return visual_explanation


In [9]:
class DOA(Generic_xai_method):
    def __init__(self,clipped_heatmap = True,uniqueness = True,analytics = False):
        super().__init__(clipped_heatmap,uniqueness,analytics)
        self.name = 'Persistent homology Masking (Angular distance)'
        self.sdparam = 1
        self.input_args = self.input_args + get_input_args() + (self.analytics.input_args if self.analytics != False else '')

    def compute_distance_matrix(self,cls_embedding, embeddings, tempeature=0.2):
        distances = torch.nn.functional.cosine_similarity(cls_embedding, embeddings) # cosine similarity between cls_embedding and each token embedding
        distances = torch.acos(distances)
        d_matrix = torch.exp(-torch.abs(distances.unsqueeze(1) - distances) / tempeature) # D[i,j] = exp(-|θ_i - θ_j| / τ)
        return d_matrix

    def find_connected_components(self,adj_matrix):
        boolean_distance_matrix = np.where(adj_matrix == 1, 0, 1)
        dbscan = DBSCAN(eps=0.5, min_samples=1, metric='precomputed')
        labels = dbscan.fit_predict(boolean_distance_matrix)
        clusters = []
        for cluster_id in set(labels):
            if cluster_id == -1:
                continue  # Ignora i punti rumorosi (opzionale)
            cluster_indices = [i for i, label in enumerate(labels) if label == cluster_id]
            clusters.append(cluster_indices)

        return clusters #should be a list of lists. each list is a component = a list of nodes that are aggregated together

    def generate_filters(self,sentence,embedding):
        filters = [[]]
        if sentence == '':
            return 'empty'

        tokens = tokenizer.tokenize(sentence)
        cls_embedding = embedding[0]
        embedding = embedding[1:len(tokens)+1]

        d_matrix = self.compute_distance_matrix(cls_embedding, embedding)

        components_over_time = {}

        for threshold in np.arange(0.9998,-0.0001,-0.0001):  #np.sort(d_matrix[np.triu_indices(len(d_matrix), k=1)].flatten())[::-1].tolist()

            adj_matrix = (d_matrix > threshold).int()

            components = self.find_connected_components(adj_matrix)
            filter = []
            single_components_counter = 0
            for c in components:
                if len(c) == 1:
                    single_components_counter += 1
                    filter.append(c[0])

            if single_components_counter == 0:
                break
            #filter.sort() does nothing
            if filter != filters[-1]:
                filters.append(filter)

                # print(f"\nThreshold: {threshold:.2f}, Number of components: {len(components)}")
                # print(components)

        filters = filters[2:] #take away the null string and the trivial first one
        # print(filters)

        return filters

    def filters_to_probabilities(self,filters,sentence,ground_truth): #cambiare nome da ground truth a p_class
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")

            device = model.device  # Get the model's device (CPU or GPU)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            embedding_layer = model.get_input_embeddings()
            word_embeddings = embedding_layer(input_ids)
            masked_embeddings = word_embeddings.clone()
            masked_embeddings[0,[i + 1 for i in filter],:] = 0 #required that list comprension intead of simply "filters" because of sep cls tokens

            output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])
            output.logits.to('cpu')
            probabilities.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])

        return probabilities

    def filter_prob_masked_attention(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
            inputs['attention_mask'][0][[i + 1 for i in filter]] = 0 #in the attention mask there is also the cls token, so 1 should be added to each element in filter to correctly refer to the right entry
            probabilities.append(F.softmax(model(**inputs).logits, dim=1).tolist()[0][ground_truth])
        return probabilities


In [10]:
class SHA(Generic_xai_method):
    def __init__(self,clipped_heatmap=True, uniqueness=True, analytics=False):
        super().__init__(clipped_heatmap,uniqueness,analytics)
        self.name = 'SHAP'
        self.explainer_bert = shap.Explainer(self.f_batch, tokenizer)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_args = self.input_args + get_input_args() + (self.analytics.input_args if self.analytics != False else '')
    def predictor(self,x,pred_class):
        x = tokenizer(x, return_tensors="pt", padding=True, truncation=True).to(self.device)
        # Ottieni l'output del modello
        outputs = model(**x)
        # Estrai i logits
        logits = outputs.logits
        # Applica softmax ai logits
        probas = F.softmax(logits, dim=1).cpu().detach().numpy()
        # Calcola il logit della probabilità della seconda classe
        val = sp.special.logit(probas[:, pred_class])
        return val

    def f_batch(self,x):
        val = np.array([])
        for i in x:
            val = np.append(val, self.predictor(i,self.pred_class))
        return val

    def create_heatmap(self,input,pred_class,p_org,indexes_list,undecided_threshold):

        #return self.shap_heatmap(self.shap_values.data[0].tolist(),self.shap_values.values.squeeze(),1,0.9)
        vals = self.shap_values.values.squeeze()
        if pred_class == 0:
            vals = [-x for x in vals]
        vals = np.array(vals)
        print(self.shap_values.data[0].tolist())
        return create_heatmap(self.shap_values.data[0].tolist(),vals,pred_class,p_org,indexes_list,undecided_threshold)

    def method_pipeline(self,sentence,embedding,tokens,p_org,pred_class):
        self.pred_class = pred_class

        self.shap_values = self.explainer_bert([sentence], fixed_context=1)
        # print(self.shap_values)
        # print(self.shap_values.values[0,:,0])
        vals = self.shap_values.values.squeeze()


        vals = np.array(vals[1:-1])
        # print(vals)

        self.tokens_lists.append(tokens)
        self.visual_explanations_lists.append(vals)
        return vals

In [11]:
class cosine_similarity(Generic_xai_method):
    def __init__(self,analytics= False):
        super().__init__(False,True,analytics)
        self.name = 'cosine similarity'
        self.input_args = self.input_args + get_input_args() + (self.analytics.input_args if self.analytics != False else '')

    def method_pipeline(self,sentence,embedding,tokens,p_org,pred_class):
        cls_embedding = embedding[0]
        sentence_embeddings = embedding[1:len(tokens)+1]
        similarity = F.cosine_similarity(sentence_embeddings,cls_embedding.unsqueeze(0),dim= 1)
        self.tokens_lists.append(tokens)
        self.visual_explanations_lists.append(similarity)
        return similarity

In [12]:
class rangedCSM(Generic_xai_method):
    def __init__(self,clipped_heatmap = True,uniqueness = True,analytics= False):
        super().__init__(clipped_heatmap,uniqueness,analytics)
        self.name = 'Thresholded Cosine Similarity Masking with stepsizes'
        self.sdparam = 1
        self.input_args = self.input_args + get_input_args() + (self.analytics.input_args if self.analytics != False else '')
    def generate_filters(self,sentence,embedding):

        tokens = tokenizer.tokenize(sentence)
        cls_embedding = embedding[0]
        embedding = embedding[1:len(tokens)+1]

        similarities = torch.nn.functional.cosine_similarity(cls_embedding,embedding)

        sorted_indices = torch.argsort(similarities, descending = False).tolist()
        #random.shuffle(sorted_indices) #the previous is sorted. if i shuffle them around, I will obtain random filters
        # print(similarities)
        # print(sorted_indices)
        filters = [[]]
        step = 0.03
        thresholds = np.arange(0.0,1.0 + step,step)
        for i in thresholds:

            thresholded_subset = [x for x in sorted_indices if similarities[x] <= i]
            thresholded_subset.sort()

            if thresholded_subset != [] and thresholded_subset != filters[-1]:
                filters.append(thresholded_subset)

        filters = filters[1:]
        if len(filters) > 1:
            filters = filters[:-1] #delete last
        filters = filters[::-1] #reverse order

        return filters

    def filters_to_probabilities(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")

            device = model.device  # Get the model's device (CPU or GPU)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            embedding_layer = model.get_input_embeddings()
            word_embeddings = embedding_layer(input_ids)
            masked_embeddings = word_embeddings.clone()
            masked_embeddings[0,[i + 1 for i in filter],:] = 0 #required that list comprension intead of simply "filters" because of sep cls tokens
            output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])
            output.logits.to('cpu')
            probabilities.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])
        return probabilities

    def filter_prob_masked_attention(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
            inputs['attention_mask'][0][[i + 1 for i in filter]] = 0 #TODO check also in trial. forse contraddice il commento sopra e andrebbero ri incrementati di 1 tutti
            probabilities.append(F.softmax(model(**inputs).logits, dim=1).tolist()[0][ground_truth])

# Output data analysis

In [13]:
class DataAnalysis():
    def __init__(self,method_name,create_plot=False,create_stats=False,relevance_threshold=0.7):
        self.method_name = method_name
        self.create_plot = create_plot
        self.create_stats = create_stats
        self.relevance_threshold = relevance_threshold
        self.input_args = get_input_args()
        self.stats = DataObject(['positive_percentage','tp','tn','fp','fn','precision','recall','f1','f1_micro','f1_macro','f1_weighted','accuracy','sent_rel_mean','oth_rel_mean'])

    def plot_distributions(self,sent_rel,oth_rel,count,name):
        sns.kdeplot(sent_rel, color='blue', fill=True, alpha=0.5, label='tokens in relevant sentences')
        media = np.mean(sent_rel)
        plt.axvline(media, color='blue', linestyle='--', label=f'Mean: {media:.2f}')
        sns.kdeplot(oth_rel, color='green', fill=True, alpha=0.3, label='Other tokens')
        media = np.mean(oth_rel)
        plt.axvline(media, color='green', linestyle='--', label=f'Mean: {media:.2f}')
        plt.legend()
        plt.savefig(output_plot_path+'/hist/hist_'+str(count)+name+'.png')
        plt.close()

    def classify_relevance(self,similarities,text_ids,sentences,tokens):
        if type(text_ids) != list: #if it is pytorch, it should be converted to list
            text_ids = text_ids.tolist()
        if type(sentences) == str:
            sentences = [sentences]
        if model_name == 'roberta':
            sentences = [elem for elem in sentences if elem.strip()]

        text_ids = text_ids[1:len(tokens)+1]
        similarities = similarities.tolist()
        other_text_indices = [1] * len(similarities)
        sentences_relevances = []
        indexes_list = []
        if len(similarities) != len(text_ids):
            raise Exception("similarities and text_ids length are different")
        for sentence in sentences:
            if model_name == 'roberta':
                tokenized_sentence = tokenizer(' '+sentence)['input_ids'][1:-1]
            else:
                tokenized_sentence = tokenizer(sentence)['input_ids'][1:-1]
            indexes = [-1,-1]
            for i in range(0,len(similarities)-len(tokenized_sentence)+1):
                if text_ids[i:i+len(tokenized_sentence)] == tokenized_sentence:
                    indexes = [i,i+len(tokenized_sentence)]
                    sentences_relevances += similarities[i:i+len(tokenized_sentence)]
                    other_text_indices[i:i+len(tokenized_sentence)] = [0] * len(tokenized_sentence)
                    i = i + len(tokenized_sentence)
            indexes_list.append(indexes)
        other_text_relevances = [elem for elem, m in zip(similarities, other_text_indices) if m == 1]
        if len(other_text_relevances)+len(sentences_relevances) != len(similarities):
            print(len(other_text_relevances))
            print(len(sentences_relevances))
            print(len(similarities))
            print(sentences)
            print(tokens)
            raise Exception("not all the similarities have been added")

        if indexes_list == []:
            indexes_list = [[-1,-1],[-1,-1]]
        else:
            indexes_list =  list(map(list, zip(*indexes_list)))
        return sentences_relevances,other_text_relevances, indexes_list # trasposition

    def save_stats(self,positive_percentage,cm,precision,recall,f1,f1_micro,f1_macro,f1_weighted,accuracy,sent_rel_mean,oth_rel_mean):
        self.stats.put([positive_percentage,cm[1,1],cm[0,0],cm[0,1],cm[1,0],precision,recall,f1,f1_micro,f1_macro,f1_weighted,accuracy,sent_rel_mean,oth_rel_mean])

    def calculate_stats(self,sent_rel,oth_rel,count,name):
        data = sent_rel + oth_rel
        ground_truth = [1] * len(sent_rel) + [0] * len(oth_rel)
        positive_percentage = len(sent_rel)/(len(sent_rel)+len(oth_rel))

        predictions = [1 if p>=self.relevance_threshold else 0 for p in data]
        sent_rel_mean = sum(sent_rel)/len(sent_rel)
        oth_rel_mean = sum(oth_rel)/len(oth_rel)

        cm = confusion_matrix(ground_truth, predictions)

        precision = precision_score(ground_truth, predictions)
        recall = recall_score(ground_truth, predictions)
        f1 = f1_score(ground_truth, predictions)
        f1_micro = f1_score(ground_truth, predictions, average='micro')
        f1_macro = f1_score(ground_truth, predictions, average='macro')
        f1_weighted = f1_score(ground_truth, predictions, average='weighted')
        accuracy = accuracy_score(ground_truth, predictions)

        self.save_stats(positive_percentage,cm,precision,recall,f1,f1_micro,f1_macro,f1_weighted,accuracy,sent_rel_mean,oth_rel_mean)

        plt.figure(figsize=(6, 4))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Negativo', 'Positivo'], yticklabels=['Negativo', 'Positivo'])
        plt.xlabel('Predicted')
        plt.ylabel('Real')
        plt.title('Confusion Matrix')
        plt.savefig(output_plot_path+'/conf/conf'+str(count)+name+'.png')
        plt.close()

        data_min = min(data)
        data_max = max(data)
        probs = [((d-data_min) / (data_max-data_min)) for d in data]

        fpr, tpr, thresholds = roc_curve(ground_truth, probs)
        roc_auc = auc(fpr, tpr)

        # Disegna la curva ROC
        plt.figure(figsize=(6, 6))
        plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
        plt.xlabel('False Positive Rate (FPR)')
        plt.ylabel('True Positive Rate (TPR)')
        plt.title('Curva ROC')
        plt.legend()
        plt.savefig(output_plot_path+'/roc/roc'+str(count)+name+'.png')
        plt.close()

    def run_analysis(self,similarities,text_ids,sentences,tokens,count,name):
        sentences_relevances,other_text_relevances, indexes_list = self.classify_relevance(similarities,text_ids,sentences,tokens)
        if sentences_relevances == []:
            return indexes_list
        if self.create_plot:
            self.plot_distributions(sentences_relevances,other_text_relevances,count,name)
        if self.create_stats:
            self.calculate_stats(sentences_relevances,other_text_relevances,count,name)
        return indexes_list

    def plot_hist(self,x,y,x_name,y_name,title):
        # print(x)
        # print(y)
        plt.hist(x,bins=20,alpha=0.5,label=x_name,color='blue',density=False)
        plt.hist(y,bins=20,alpha=0.5,label=y_name,color='red',density=False)
        plt.title(title)
        plt.xlabel('values')
        plt.ylabel('count')
        plt.legend()
        plt.grid(True)
        plt.savefig(output_plot_path+'/'+title+'.png',dpi=300,bbox_inches='tight')
        plt.close()

    def stat_resume(self):
        if not self.stats.df.empty:
            for column in self.stats.columns:
                to_resume(column)
                to_resume(self.stats.df[column].mean())
            self.plot_hist(self.stats.df['sent_rel_mean'],self.stats.df['oth_rel_mean'],'Relevant words avg score','NonRelevant words avg score','relevance_distribution_'+self.method_name)

    @staticmethod
    def void_function():
        pass


# Old methods

In [14]:
class CSM(Generic_xai_method):
    def __init__(self,clipped_heatmap = True,uniqueness = True,analysis= False):
        super().__init__(clipped_heatmap,uniqueness,analysis)
        self.name = 'Cosine Similarity Masking'
        self.sdparam = 1
    def generate_filters(self,sentence,embedding):

        tokens = tokenizer.tokenize(sentence)
        cls_embedding = embedding[0]
        embedding = embedding[1:len(tokens)+1]

        similarities = torch.nn.functional.cosine_similarity(cls_embedding,embedding)

        sorted_indices = torch.argsort(similarities, descending = False).tolist()
        #random.shuffle(sorted_indices) #the previous is sorted. if i shuffle them around, I will obtain random filters
        filters = [sorted_indices[:i+1] for i in range(len(sorted_indices))]
        filters = filters[:-1]
        filters = filters[::-1]
        return filters

    def filters_to_probabilities(self,filters,sentence,ground_truth):
        probabilities = []

        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")

            device = model.device  # Get the model's device (CPU or GPU)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            embedding_layer = model.get_input_embeddings()
            word_embeddings = embedding_layer(input_ids)
            masked_embeddings = word_embeddings.clone()
            masked_embeddings[0,[i + 1 for i in filter],:] = 0 #required that list comprension intead of simply "filters" because of sep cls tokens
            output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])

            output.logits.to('cpu')

            probabilities.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])

        return probabilities

    # def filters_to_probabilities(self, filters, sentence, ground_truth):
    #     """
    #     Given a set of filters, this method creates the corresponding masked sentences
    #     and returns the probabilities assigned by the model to the original class for
    #     each masked sentence.
    #     """

    #     inputs = tokenizer(sentence, return_tensors="pt")

    #     # Make sure all inputs are on the same device
    #     device = model.device  # Get the model's device (CPU or GPU)
    #     inputs = {k: v.to(device) for k, v in inputs.items()}

    #     masked_sentences = []
    #     probs = []

    #     for filter in filters:
    #         masked_sentence = self.mask_sentence(sentence, filter)
    #         masked_sentences.append(masked_sentence)
    #         # masked_input = tokenizer(masked_sentence, return_tensors="pt")
    #         masked_embeddings = model.get_input_embeddings()(inputs['input_ids'])

    #         # Check that filter is a sequence of integers
    #         if all(isinstance(x, int) for x in filter):
    #             for token_index in filter:
    #                 masked_embeddings[0][token_index] = torch.zeros_like(masked_embeddings[0][token_index])
    #             masked_embeddings.to(device)  # Move masked_embeddings to the same device
    #             # inputs.to(device)  # Already moved inputs to the device
    #             output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])
    #             probs.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])

    #     return probs



    def filter_prob_masked_attention(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
            inputs['attention_mask'][0][[i + 1 for i in filter]] = 0 #TODO check also in trial. forse contraddice il commento sopra e andrebbero ri incrementati di 1 tutti
            probabilities.append(F.softmax(model(**inputs).logits, dim=1).tolist()[0][ground_truth])
        return probabilities

In [15]:
class NED(Generic_xai_method):
    def __init__(self,clipped_heatmap = True,uniqueness = True,analysis= False):
        super().__init__(clipped_heatmap,uniqueness,analysis)
        self.name = 'Persistent Homology Masking (Euclidean distance)'
        self.sdparam = 1
    def compute_d_matrix_normalized_emb(self,embeddings):
        embeddings = torch.nn.functional.normalize(embeddings, p = 2.0, dim = 1)
        d_matrix = torch.cdist(embeddings,embeddings)
        return d_matrix

    def find_connected_components(self,adj_matrix):
        boolean_distance_matrix = np.where(adj_matrix == 1, 0, 1)
        dbscan = DBSCAN(eps=0.5, min_samples=1, metric='precomputed')
        labels = dbscan.fit_predict(boolean_distance_matrix)
        clusters = []
        for cluster_id in set(labels):
            if cluster_id == -1:
                continue  # Ignora i punti rumorosi (opzionale)
            cluster_indices = [i for i, label in enumerate(labels) if label == cluster_id]
            clusters.append(cluster_indices)

        return clusters #should be a list of lists. each list is a component = a list of nodes that are aggregated together

    def generate_filters(self,sentence,embedding):
        filters = [[]]
        if sentence == '':
            return 'empty'

        tokens = tokenizer.tokenize(sentence)

        embedding = embedding[0:len(tokens)+1]
        d_matrix = self.compute_d_matrix_normalized_emb(embedding)


        components_over_time = {}

        for threshold in np.sort(d_matrix[np.triu_indices(len(d_matrix), k=1)].flatten()).tolist():  #was 0 , 1+step, step
            adj_matrix = (d_matrix < threshold).int()

            components = self.find_connected_components(adj_matrix)

            filter = []
            if len(components) == 1:
                break

            for c in components:
                if 0 not in c:
                    filter.extend(c)
            filter.sort()
            if filter != filters[-1]:
                filters.append(filter)

                # print(f"\nThreshold: {threshold:.2f}, Number of components: {len(components)}")
                # print(components)

        filters = filters[2:] #take away the null string and the trivial first one #se dà errore rimettere 3 invece che 2
        filters_with_corrected_indices = [[i -1 for i in f] for f in filters] #all the values are decremented by one so the filters map correctly with the token indices

        return filters_with_corrected_indices

    def filters_to_probabilities(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")

            device = model.device  # Get the model's device (CPU or GPU)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            embedding_layer = model.get_input_embeddings()
            word_embeddings = embedding_layer(input_ids)
            masked_embeddings = word_embeddings.clone()
            masked_embeddings[0,[i + 1 for i in filter],:] = 0 #required that list comprension intead of simply "filters" because of sep cls tokens

            output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])
            output.logits.to('cpu')
            probabilities.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])

        return probabilities

    def filter_prob_masked_attention(filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
            inputs['attention_mask'][0][[i + 1 for i in filter]] = 0 #TODO check also in trial. forse contraddice il commento sopra e andrebbero ri incrementati di 1 tutti
            probabilities.append(F.softmax(model(**inputs).logits, dim=1).tolist()[0][ground_truth])
        print(probabilities)
        print(filters)
        return probabilities

In [16]:
class embedding_classification(Generic_xai_method):
    def __init__(self,analytics= False):
        super().__init__(False,True,analytics) #stabilire quale sia meglio, forse True anche sul primo
        self.name = 'Embedding classification'
    def method_pipeline(self,sentence,embedding,tokens,p_org,pred_class):
        cls_embedding = embedding[0]
        sentence_embeddings = embedding[1:len(tokens)+1]
        similarity = []
        for t in sentence_embeddings:
            similarity.append(classify_embedding(t)[pred_class].tolist())
            #similarity.append(F.softmax(classify_embedding(t))[predicted_classes[count]].tolist()) #same result
        similarity = np.array(similarity)
        return similarity

In [17]:
class newCSM(Generic_xai_method): #csm with threshold
    def __init__(self,clipped_heatmap = True,uniqueness = True,analysis= False):
        super().__init__(clipped_heatmap,uniqueness,analysis)
        self.name = 'Thresholded Cosine Similarity Masking with Combinations'
        self.sdparam = 1
    def get_combinations(self,indexes,values):
        thresholded_subset = indexes[-7:] #limit to 7, otherwise computational time grows to much (it's exponential)
        thresholded_subset = [x for x in thresholded_subset if values[x] > 0.7]
        indexes_set = set(indexes)
        powerset = []
        filters = []
        for i in range(len(thresholded_subset)):
            powerset.extend([list(e) for e in combinations(thresholded_subset, i+1)])  # Genera tutte le combinazioni di 2 elementi
        for i in powerset:
            filters.extend([list(indexes_set - set(i))])
        return filters
    def generate_filters(self,sentence,embedding):

        tokens = tokenizer.tokenize(sentence)
        cls_embedding = embedding[0]
        embedding = embedding[1:len(tokens)+1]

        similarities = torch.nn.functional.cosine_similarity(cls_embedding,embedding)

        sorted_indices = torch.argsort(similarities, descending = False).tolist()

        filters = self.get_combinations(sorted_indices,similarities)
        return filters

    def filters_to_probabilities(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")

            device = model.device  # Get the model's device (CPU or GPU)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            embedding_layer = model.get_input_embeddings()
            word_embeddings = embedding_layer(input_ids)
            masked_embeddings = word_embeddings.clone()
            masked_embeddings[0,[i + 1 for i in filter],:] = 0 #required that list comprension intead of simply "filters" because of sep cls tokens
            output = model(inputs_embeds = masked_embeddings,attention_mask = inputs['attention_mask'])
            output.logits.to('cpu')
            probabilities.append(F.softmax(output.logits, dim=1).tolist()[0][ground_truth])
        return probabilities

    def filter_prob_masked_attention(self,filters,sentence,ground_truth):
        probabilities = []
        for filter in filters:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length")
            inputs['attention_mask'][0][[i + 1 for i in filter]] = 0 #TODO check also in trial. forse contraddice il commento sopra e andrebbero ri incrementati di 1 tutti
            probabilities.append(F.softmax(model(**inputs).logits, dim=1).tolist()[0][ground_truth])
        return probabilities