In [None]:
# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Import required libraries
import json
import os
import numpy as np
from collections import defaultdict
from tqdm.notebook import tqdm
import more_itertools as mit

# Import preprocessing utilities
from Preprocess.dataCollect import get_annotated_data
from Preprocess.spanMatcher import returnMask
from transformers import BertTokenizer

## 1. Load and Prepare Data

In [None]:
# Data configuration for 3-class (hatespeech, offensive, normal)
dict_data_folder = {
    '2': {'data_file': 'Data/dataset.json', 'class_label': 'Data/classes_two.npy'},
    '3': {'data_file': 'Data/dataset.json', 'class_label': 'Data/classes.npy'}
}

params = {}
params['num_classes'] = 3  # hatespeech, offensive, normal
params['data_file'] = dict_data_folder[str(params['num_classes'])]['data_file']
params['class_names'] = dict_data_folder[str(params['num_classes'])]['class_label']

# Load dataset
data_all_labelled = get_annotated_data(params)
print(f"Loaded {len(data_all_labelled)} samples")

In [None]:
# Preview dataset
data_all_labelled.head()

In [None]:
# Label distribution
print("\nLabel distribution:")
print(data_all_labelled['final_label'].value_counts())

## 2. Configure Tokenization Parameters

In [None]:
# Tokenization and attention configuration
params_data = {
    'include_special': False,      # Include special tokens in attention
    'bert_tokens': False,          # Use BERT tokenizer (set True for BERT models)
    'type_attention': 'softmax',   # Attention type
    'set_decay': 0.1,              # Decay parameter
    'majority': 2,                 # Majority threshold for rationales
    'max_length': 128,             # Maximum sequence length
    'variance': 10,                # Variance parameter
    'window': 4,                   # Window size
    'alpha': 0.5,                  # Alpha parameter
    'p_value': 0.8,                # P-value threshold
    'method': 'additive',          # Attention combination method
    'decay': False,                # Use decay
    'normalized': False,           # Normalize attention
    'not_recollect': True,         # Don't recollect data
}

# Initialize tokenizer
if params_data['bert_tokens']:
    print('Loading BERT tokenizer...')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=False)
else:
    print('Using standard (non-BERT) tokenizer...')
    tokenizer = None

## 3. Extract Rationales

In [None]:
def get_training_data(data, params_data, tokenizer):
    """
    Process dataset to extract token-wise rationales.
    
    Args:
        data: DataFrame with annotated data
        params_data: Configuration parameters
        tokenizer: BertTokenizer or None
    
    Returns:
        List of [post_id, annotation, tokens, attention_masks, annotation_list]
    """
    final_output = []
    print(f'Processing {len(data)} samples...')
    
    for _, row in tqdm(data.iterrows(), total=len(data)):
        annotation = row['final_label']
        post_id = row['post_id']
        annotation_list = [row['label1'], row['label2'], row['label3']]
        
        # Skip undecided samples
        if annotation != 'undecided':
            tokens_all, attention_masks = returnMask(row, params_data, tokenizer)
            final_output.append([post_id, annotation, tokens_all, attention_masks, annotation_list])
    
    return final_output

In [None]:
# Process training data
training_data = get_training_data(data_all_labelled, params_data, tokenizer)
print(f"\nProcessed {len(training_data)} valid samples")

## 4. Convert to ERASER Format

In [None]:
def find_ranges(iterable):
    """
    Find consecutive ranges in a sorted iterable.
    
    Args:
        iterable: Sorted list of integers
    
    Yields:
        Single integers or (start, end) tuples for consecutive ranges
    """
    for group in mit.consecutive_groups(iterable):
        group = list(group)
        if len(group) == 1:
            yield group[0]
        else:
            yield group[0], group[-1]

In [None]:
def get_evidence(post_id, anno_text, explanations):
    """
    Convert binary explanations to ERASER evidence format.
    
    Args:
        post_id: Document ID
        anno_text: List of tokens
        explanations: Binary list (1 = rationale token)
    
    Returns:
        List of evidence dictionaries with spans
    """
    output = []
    
    # Find indices where explanation = 1
    indexes = sorted([i for i, val in enumerate(explanations) if val == 1])
    
    # Convert to spans
    span_list = list(find_ranges(indexes))
    
    for span in span_list:
        if isinstance(span, int):
            start, end = span, span + 1
        elif len(span) == 2:
            start, end = span[0], span[1] + 1
        else:
            print(f'Error processing span: {span}')
            continue
        
        output.append({
            "docid": post_id,
            "end_sentence": -1,
            "end_token": end,
            "start_sentence": -1,
            "start_token": start,
            "text": ' '.join([str(t) for t in anno_text[start:end]])
        })
    
    return output

In [None]:
def convert_to_eraser_format(dataset, method, save_split, save_path, id_division):
    """
    Convert dataset to ERASER benchmark format.
    
    Args:
        dataset: Processed training data
        method: Rationale combination method ('union', 'intersection')
        save_split: Whether to save train/val/test splits
        save_path: Directory to save files
        id_division: Dict with train/val/test post IDs
    
    Returns:
        List of ERASER-formatted documents
    """
    final_output = []
    
    if save_split:
        os.makedirs(save_path, exist_ok=True)
        os.makedirs(os.path.join(save_path, 'docs'), exist_ok=True)
        
        train_fp = open(os.path.join(save_path, 'train.jsonl'), 'w')
        val_fp = open(os.path.join(save_path, 'val.jsonl'), 'w')
        test_fp = open(os.path.join(save_path, 'test.jsonl'), 'w')
    
    for row in dataset:
        post_id = row[0]
        post_class = row[1]
        anno_text_list = row[2]
        
        # Skip 'normal' class (no rationales expected)
        if post_class == 'normal':
            continue
        
        explanations = [list(exp) for exp in row[3]]
        
        # Combine rationales from annotators
        if method == 'union':
            # Token is rationale if ANY annotator marked it
            final_explanation = [int(any(tokens)) for tokens in zip(*explanations)]
        elif method == 'intersection':
            # Token is rationale if ALL annotators marked it
            final_explanation = [int(all(tokens)) for tokens in zip(*explanations)]
        elif method == 'majority':
            # Token is rationale if majority of annotators marked it
            final_explanation = [int(sum(tokens) >= 2) for tokens in zip(*explanations)]
        else:
            final_explanation = [int(any(tokens)) for tokens in zip(*explanations)]
        
        # Create ERASER format document
        doc = {
            'annotation_id': post_id,
            'classification': post_class,
            'evidences': [get_evidence(post_id, list(anno_text_list), final_explanation)],
            'query': "What is the class?",
            'query_type': None
        }
        final_output.append(doc)
        
        if save_split:
            # Save document text
            doc_path = os.path.join(save_path, 'docs', post_id)
            with open(doc_path, 'w') as fp:
                fp.write(' '.join([str(t) for t in list(anno_text_list)]))
            
            # Save to appropriate split
            if post_id in id_division['train']:
                train_fp.write(json.dumps(doc) + '\n')
            elif post_id in id_division['val']:
                val_fp.write(json.dumps(doc) + '\n')
            elif post_id in id_division['test']:
                test_fp.write(json.dumps(doc) + '\n')
    
    if save_split:
        train_fp.close()
        val_fp.close()
        test_fp.close()
        print(f"Saved ERASER format files to {save_path}")
    
    return final_output

In [None]:
# Load data splits
with open('./Data/post_id_divisions.json') as fp:
    id_division = json.load(fp)

print(f"Train: {len(id_division['train'])} samples")
print(f"Val:   {len(id_division['val'])} samples")
print(f"Test:  {len(id_division['test'])} samples")

In [None]:
# Convert to ERASER format
method = 'union'  # Options: 'union', 'intersection', 'majority'
save_split = True
save_path = './Data/Evaluation/Model_Eval/'

output_eraser = convert_to_eraser_format(
    training_data, 
    method, 
    save_split, 
    save_path, 
    id_division
)
print(f"\nConverted {len(output_eraser)} samples to ERASER format")

In [None]:
# Verify generated files
print("\nGenerated files:")
for item in os.listdir(save_path):
    item_path = os.path.join(save_path, item)
    if os.path.isfile(item_path):
        with open(item_path) as f:
            line_count = sum(1 for _ in f)
        print(f"  {item}: {line_count} samples")
    else:
        doc_count = len(os.listdir(item_path))
        print(f"  {item}/: {doc_count} documents")

## 5. Run ERASER Metrics

In [None]:
# Model explanation file mapping
explanation_file_mapping = {
    'BiRNN-Scrat': 'bestModel_birnnscrat_100_explanation_top5.json',
    'BiRNN-Attn': 'bestModel_birnnatt_100_explanation_top5.json',
    'CNN-GRU': 'bestModel_cnn_gru_100_explanation_top5.json',
    'BERT-Base': 'bestModel_bert_base_uncased_Attn_train_FALSE_100_explanation_top5.json',
    'BERT-HateXplain': 'bestModel_bert_base_uncased_Attn_train_TRUE_100_explanation_top5.json',
}

parent_path = './explanations_dicts/'

In [None]:
# Check which explanation files exist
print("Available explanation files:")
available_models = {}

for model_name, filename in explanation_file_mapping.items():
    filepath = os.path.join(parent_path, filename)
    if os.path.exists(filepath):
        print(f"  ✓ {model_name}: {filename}")
        available_models[model_name] = filename
    else:
        print(f"  ✗ {model_name}: {filename} (not found)")

if not available_models:
    print("\nNo explanation files found. Run testing_with_rational.py first.")

In [None]:
def run_eraser_metrics(model_name, explanation_file, data_dir, output_file):
    """
    Run ERASER benchmark metrics for a model.
    
    Args:
        model_name: Name of the model
        explanation_file: Path to model's explanation file
        data_dir: Path to ERASER format data directory
        output_file: Path to save results
    
    Returns:
        Dict with metric results or None if failed
    """
    import subprocess
    
    # Build command
    cmd = (
        f"cd eraserbenchmark && "
        f"PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py "
        f"--split test "
        f"--data_dir {data_dir} "
        f"--results {explanation_file} "
        f"--score_file {output_file}"
    )
    
    print(f"\nRunning ERASER metrics for {model_name}...")
    
    try:
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Error: {result.stderr}")
            return None
        
        # Load results
        with open(output_file) as fp:
            return json.load(fp)
    except Exception as e:
        print(f"Exception: {e}")
        return None

In [None]:
# Run metrics for available models
all_results = {}

for model_name, filename in available_models.items():
    explanation_file = f"../{parent_path}{filename}"
    data_dir = "../Data/Evaluation/Model_Eval"
    output_file = f"../explainability_results_{model_name.replace('-', '_').lower()}.json"
    
    results = run_eraser_metrics(model_name, explanation_file, data_dir, output_file)
    if results:
        all_results[model_name] = results
        print(f"  ✓ Completed {model_name}")

## 6. Display Results

In [None]:
def display_results(results, model_name):
    """
    Display ERASER benchmark results.
    
    Args:
        results: Dict with metric results
        model_name: Name of the model
    """
    print(f"\n{'='*60}")
    print(f"Model: {model_name}")
    print(f"{'='*60}")
    
    # Plausibility metrics
    print("\nPLAUSIBILITY (Agreement with human rationales):")
    print("-" * 40)
    
    if 'iou_scores' in results:
        iou_f1 = results['iou_scores'][0]['macro']['f1']
        print(f"  IOU F1:        {iou_f1:.4f}")
    
    if 'token_prf' in results:
        token_f1 = results['token_prf']['instance_macro']['f1']
        token_precision = results['token_prf']['instance_macro']['p']
        token_recall = results['token_prf']['instance_macro']['r']
        print(f"  Token F1:      {token_f1:.4f}")
        print(f"  Token Prec:    {token_precision:.4f}")
        print(f"  Token Recall:  {token_recall:.4f}")
    
    if 'token_soft_metrics' in results:
        auprc = results['token_soft_metrics']['auprc']
        print(f"  AUPRC:         {auprc:.4f}")
    
    # Faithfulness metrics
    print("\nFAITHFULNESS (Model reliance on rationales):")
    print("-" * 40)
    
    if 'classification_scores' in results:
        comp = results['classification_scores'].get('comprehensiveness', 'N/A')
        suff = results['classification_scores'].get('sufficiency', 'N/A')
        
        if isinstance(comp, (int, float)):
            print(f"  Comprehensiveness: {comp:.4f}")
        else:
            print(f"  Comprehensiveness: {comp}")
        
        if isinstance(suff, (int, float)):
            print(f"  Sufficiency:       {suff:.4f}")
        else:
            print(f"  Sufficiency:       {suff}")

In [None]:
# Display results for all models
for model_name, results in all_results.items():
    display_results(results, model_name)

In [None]:
# If no results were generated, try to load from existing files
if not all_results:
    print("Checking for existing result files...")
    
    result_files = [
        './model_explain_output.json',
        './explainability_results_birnn_scrat.json',
    ]
    
    for filepath in result_files:
        if os.path.exists(filepath):
            print(f"\nFound: {filepath}")
            with open(filepath) as fp:
                results = json.load(fp)
            display_results(results, os.path.basename(filepath))

## 7. Summary Table

In [None]:
# Create summary table
import pandas as pd

if all_results:
    summary_data = []
    
    for model_name, results in all_results.items():
        row = {'Model': model_name}
        
        # Plausibility
        if 'iou_scores' in results:
            row['IOU F1'] = results['iou_scores'][0]['macro']['f1']
        if 'token_prf' in results:
            row['Token F1'] = results['token_prf']['instance_macro']['f1']
        if 'token_soft_metrics' in results:
            row['AUPRC'] = results['token_soft_metrics']['auprc']
        
        # Faithfulness
        if 'classification_scores' in results:
            row['Comprehensiveness'] = results['classification_scores'].get('comprehensiveness', np.nan)
            row['Sufficiency'] = results['classification_scores'].get('sufficiency', np.nan)
        
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    summary_df = summary_df.set_index('Model')
    
    print("\n" + "="*80)
    print("EXPLAINABILITY METRICS SUMMARY")
    print("="*80)
    print(summary_df.round(4).to_string())
else:
    print("No results available. Run testing_with_rational.py first.")

## 8. Interpretation Guide

### Plausibility Metrics (Agreement with human rationales):

- **IOU F1**: Intersection over Union F1 score between predicted and ground truth rationale spans.
  - Higher is better (closer to 1.0)
  - Measures span-level agreement

- **Token F1**: Token-level F1 score for rationale identification.
  - Higher is better
  - Measures agreement at individual token level

- **AUPRC**: Area Under Precision-Recall Curve for soft rationale scores.
  - Higher is better
  - Measures ranking quality of attention scores

### Faithfulness Metrics (Model reliance on rationales):

- **Comprehensiveness**: How much prediction changes when rationales are removed.
  - Higher is better (rationales are necessary for prediction)
  - Range: 0 to 1

- **Sufficiency**: How well rationales alone predict the label.
  - Lower is better (rationales are sufficient for prediction)
  - Range: 0 to 1

### References:
- DeYoung et al. (2020) - "ERASER: A Benchmark to Evaluate Rationalized NLP Models"
- Mathew et al. (2021) - "HateXplain: A Benchmark Dataset for Explainable Hate Speech Detection"