In [None]:
%load_ext autoreload
%autoreload 2

# Performing perturbation analysis on NLP datasets

## Load model and Data

In [None]:
from symbxai.lrp.symbolic_xai import BERTSymbXAI, ViTSymbolicXAI
from symbxai.model.transformer import bert_base_uncased_model
from symbxai.dataset.utils import load_sst_treebank, load_imdb_dataset
import transformers


sample_range = [1]
data_mode = 'sst'

if data_mode == 'sst': # Load SST data and model
    model = bert_base_uncased_model(
            pretrained_model_name_or_path='textattack/bert-base-uncased-SST-2' )
    
    model.eval()
    # pretrained_embeddings = model.bert.embeddings
    tokenizer = transformers.BertTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
    
    dataset = load_sst_treebank(sample_range, verbose=False)['train']
    print('got', len(dataset['label']), 'samples from sst')
    input_type = 'sentence'
elif data_mode == 'imdb': # Load IMDB data and model
    # Load the dataset
    dataset = load_imdb_dataset(sample_range)

    # Load the model and tokenizer
    model = bert_base_uncased_model(
            pretrained_model_name_or_path='textattack/bert-base-uncased-imdb' )
    
    model.eval()
    # pretrained_embeddings = model.bert.embeddings
    tokenizer = transformers.BertTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
    input_type = 'sentence'
    
    
    

## Define perturbation order by XAI method

In [None]:
import numpy as np
from tqdm import tqdm


## Please run the script `script/perform_perturbation.py` and then load the results by the following:

In [None]:
import pickle
load_data = True
data_mode = 'sst'
sample_range = range(201)

attribution_methods = ['SymbXAI', 'LRP', 'PredDiff','random' ]
optimize_parameter = [('minimize', 'removal'), ('maximize', 'removal') , ('minimize', 'generation'), ('maximize', 'generation')]

if load_data:
    all_output_sequences = {param: {attribution_method: {} for attribution_method in attribution_methods} for param in optimize_parameter }
    
    for i in sample_range:
        if data_mode == 'sst':
            file_name = f'/Users/thomasschnake/Downloads/temp/perturbation_results_{data_mode}_{i}.pkl' # 'intermediate_results/perturbation_results_sst.pkl'
            try:
                with open(file_name, 'rb+') as f:
                    # print(f.seek(0))
                    output_sequences = pickle.load(f)
            except:
                continue
                
        elif data_mode in ['imdb']:
            output_sequences = {}
            for par1, par2 in optimize_parameter:
                try:
                    file_name = f'/Users/thomasschnake/Downloads/temp/perturbation_results_{data_mode}_{i}_{par1}_{par2}_curves.pkl'
                    with open(file_name, 'rb+') as f:
                        # print(f.seek(0))
                        this_output = pickle.load(f)
                        output_sequences.update(this_output)
                except:
                    # print('skipped', i, par1, par2)
                    continue
                
        # print()        
        for attribution_method in attribution_methods:
            for param in output_sequences.keys():
                all_output_sequences[param][attribution_method].update(output_sequences[param][attribution_method])
            

for attribution_method in attribution_methods:
    for auc_task, perturbation_type in optimize_parameter:
        print(attribution_method, auc_task, perturbation_type, '\t\t', len(all_output_sequences[(auc_task,perturbation_type)][attribution_method].keys()))
        

## Plot the perturbation curves and their integral

In [None]:
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import numpy as np
def resample_sequence(sequence, k):
    n = len(sequence)
    x_original = np.linspace(0, 1, n)
    x_new = np.linspace(0, 1, k)
    
    if k < n:
        # Downsampling: averaging within bins
        new_sequence = []
        bin_edges = np.linspace(0, n, k+1)
        for i in range(k):
            bin_start = int(bin_edges[i])
            bin_end = int(bin_edges[i+1])
            if bin_end > bin_start:
                bin_values = sequence[bin_start:bin_end]
                bin_average = np.mean(bin_values)
                new_sequence.append(bin_average)
            else:
                new_sequence.append(sequence[bin_start])
        return np.array(new_sequence)
    else:
        # Upsampling: interpolation
        interpolator = interp1d(x_original, sequence, kind='linear')
        new_sequence = interpolator(x_new)
        return new_sequence

text_parser = { 'removal' : 'AURC',
                'generation': 'AUGC',
                'minimize' : '$\\min$',
                'maximize' : '$\\max$'}

grid_resolution = 50 #max(len(sublist) for sublist in all_output_sequences[attribution_method])
averages = {param: {} for param in optimize_parameter}

for i, attribution_method in enumerate(attribution_methods):
    for auc_task, perturbation_type in optimize_parameter:
        # find for each sequence a averaged sequence of length grid_resolution
        sequences_in_percent = []
        for sublist in all_output_sequences[(auc_task, perturbation_type)][attribution_method].values():
            sequences_in_percent.append(resample_sequence(sublist[1:-1], grid_resolution))
    
        # take the mean
        averages[(auc_task, perturbation_type)][attribution_method] = np.mean(sequences_in_percent, axis = 0)

fig, axs = plt.subplots(1,4, figsize=(20,5))
for i, (auc_task, perturbation_type) in enumerate(optimize_parameter):
    for attribution_method in attribution_methods:
        axs[i].plot(averages[(auc_task, perturbation_type)][attribution_method], label=attribution_method, lw=3)
    
    axs[i].set_title( text_parser[auc_task] + ' ' + text_parser[perturbation_type] )
    axs[i].set_xticks(ticks=[0, .25 * grid_resolution, .5 * grid_resolution, .75 * grid_resolution, grid_resolution],
               labels=['0%', '25%', '50%', '75%', '100%'])
plt.legend()
plt.savefig(f'pics/perturbation_curves_{data_mode}.png')
plt.show()

In [None]:
from functools import reduce 

print( ' & ' + reduce( lambda x,y : x + ' & '  +y, [text_parser[param[0]] + ' ' + text_parser[param[1]] for param in optimize_parameter]) + ' \\\ ')
print('\hline \hline')
for attribution_method in attribution_methods:
    outstring = attribution_method
    for param in optimize_parameter:
        outstring+=  f' & {round(sum(averages[param][attribution_method])/grid_resolution, 2)} '
    outstring += ' \\\ '
    print(outstring)
    print('\hline')

# Look at the feature orderings

In [None]:

data_mode = 'fer'
sample_range = range(201)

attribution_methods = ['SymbXAI', 'LRP', 'PredDiff','random' ]
optimize_parameter = [('minimize', 'removal'), ('maximize', 'removal') , ('minimize', 'generation'), ('maximize', 'generation')]

all_output_orderings = {param: {attribution_method: {} for attribution_method in attribution_methods} for param in optimize_parameter }

for i in sample_range:
    if data_mode == 'sst':
        file_name = f'/Users/thomasschnake/Downloads/temp/perturbation_results_{data_mode}_{i}.pkl' # 'intermediate_results/perturbation_results_sst.pkl'
        try:
            with open(file_name, 'rb+') as f:
                # print(f.seek(0))
                output_orderings = pickle.load(f)
        except:
            continue
            
    elif data_mode in ['imdb', 'fer']:
        output_orderings = {}
        for par1, par2 in optimize_parameter:
            try:
                file_name = f'/Users/thomasschnake/Downloads/temp/perturbation_results_{data_mode}_{i}_{par1}_{par2}_orderings.pkl'
                with open(file_name, 'rb+') as f:
                    # print(f.seek(0))
                    this_output = pickle.load(f)
                    output_orderings.update(this_output)
            except:
                # print('skipped', i, par1, par2)
                continue
            
    # print()        
    for attribution_method in attribution_methods:
        for param in output_orderings.keys():
            all_output_orderings[param][attribution_method].update(output_orderings[param][attribution_method])
        

for attribution_method in attribution_methods:
    for auc_task, perturbation_type in optimize_parameter:
        print(attribution_method, auc_task, perturbation_type, '\t\t', len(all_output_orderings[(auc_task,perturbation_type)][attribution_method].keys()))
        