In [None]:
# Install required packages
!pip install transformers torch pandas numpy scikit-learn matplotlib seaborn

import os
import json
import pandas as pd
import numpy as np
from typing import Dict, List, Any
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW

from google.colab import files
import io

In [None]:
# Mount Google Drive in Colab:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Path to your folder in Google Drive
route = ''

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Utils

## **utils_basic**

In [None]:
!pip install scikit-multilearn
!pip install pyevall

from typing import Dict, List, Any
import pandas as pd
import json
import nltk
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from datetime import datetime
from sklearn import metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, multilabel_confusion_matrix
from skmultilearn.problem_transform import BinaryRelevance, LabelPowerset
import matplotlib.pyplot as plt
import numpy as np
import os
import csv

# Import for PyEvALL

from pyevall.evaluation import PyEvALLEvaluation
from pyevall.utils.utils import PyEvALLUtils



class NumpyEncoder(json.JSONEncoder):
    """Custom JSON encoder to handle NumPy arrays"""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)
        

### Evaluation metrics
def check_matrix(confusion_matrix, gold_labels, predicted_labels):
    """
    Check and adjust confusion matrix dimensions for proper evaluation.

    Parameters:
    - confusion_matrix (numpy.ndarray): The confusion matrix to check and potentially adjust
    - gold_labels (list or array): The gold standard (true) labels
    - predicted_labels (list or array): The predicted labels

    Returns:
    - numpy.ndarray: Properly dimensioned confusion matrix
    """
    if confusion_matrix.size == 1:
        tmp = confusion_matrix[0][0]
        confusion_matrix = np.zeros((2, 2))
        if (predicted_labels[1] == 0):
            # true negative
            if gold_labels[1] == 0:
                confusion_matrix[0][0] = tmp
            # false negative
            else:
                confusion_matrix[1][0] = tmp
        else:
            # false positive
            if gold_labels[1] == 0:
                confusion_matrix[0][1] = tmp
            # true positive
            else:
                confusion_matrix[1][1] = tmp
    return confusion_matrix

def compute_f1(predicted_values, gold_values):
    """
    Compute F1 score based on predicted and gold values.

    Parameters:
    - predicted_values (list or array): Predicted label values
    - gold_values (list or array): Gold standard (true) label values

    Returns:
    - float: Macro-averaged F1 score (average of positive and negative class F1 scores)
    """
    matrix = metrics.confusion_matrix(gold_values, predicted_values)
    matrix = check_matrix(matrix, gold_values, predicted_values)

    # Calculate precision and recall for positive label
    if matrix[0][0] == 0:
        pos_precision = 0.0
        pos_recall = 0.0
    else:
        pos_precision = matrix[0][0] / (matrix[0][0] + matrix[0][1])
        pos_recall = matrix[0][0] / (matrix[0][0] + matrix[1][0])

    # Calculate F1 for positive label
    if (pos_precision + pos_recall) != 0:
        pos_F1 = 2 * (pos_precision * pos_recall) / (pos_precision + pos_recall)
    else:
        pos_F1 = 0.0

    # Calculate precision and recall for negative label
    neg_matrix = [[matrix[1][1], matrix[1][0]], [matrix[0][1], matrix[0][0]]]

    if neg_matrix[0][0] == 0:
        neg_precision = 0.0
        neg_recall = 0.0
    else:
        neg_precision = neg_matrix[0][0] / (neg_matrix[0][0] + neg_matrix[0][1])
        neg_recall = neg_matrix[0][0] / (neg_matrix[0][0] + neg_matrix[1][0])

    # Calculate F1 for negative label
    if (neg_precision + neg_recall) != 0:
        neg_F1 = 2 * (neg_precision * neg_recall) / (neg_precision + neg_recall)
    else:
        neg_F1 = 0.0

    # Return macro-averaged F1 (average of positive and negative F1 scores)
    f1 = (pos_F1 + neg_F1) / 2
    return f1

def retrieve_label_values(ground_truth, model_submission, field_index):
    """
    Extract specific field values from ground truth and submission dictionaries.

    Parameters:
    - ground_truth (dict): Dictionary containing ground truth values
    - model_submission (dict): Dictionary containing model submission values
    - field_index (int): Index of the field to extract

    Returns:
    - tuple: Lists of extracted ground truth and model submission values
    """
    gold = []
    pred = []

    for k, v in ground_truth.items():
        gold.append(v[field_index])
        pred.append(model_submission[k][field_index])

    return gold, pred

def compute_binary_f1(ground_truth, model_submission):
    """
    Compute F1 score for binary classification task.

    Parameters:
    - ground_truth (dict): Dictionary containing ground truth values
    - model_submission (dict): Dictionary containing model submission values

    Returns:
    - float: F1 score for binary classification
    """
    gold, pred = retrieve_label_values(ground_truth, model_submission, 0)
    score = compute_f1(pred, gold)
    return score

def compute_multilabel_f1(truth_data, prediction_data, label_count):
    """
    Compute weighted F1 score for multi-label classification task.

    Parameters:
    - truth_data (dict): Dictionary containing ground truth values
    - prediction_data (dict): Dictionary containing model prediction values
    - label_count (int): Total number of labels including binary classification label

    Returns:
    - float: Weighted F1 score for multi-label classification
    """
    score_components = []
    occurrence_sum = 0

    # Skip first column (index 0) which contains binary classification labels
    for label_idx in range(1, label_count):
        true_values, predicted_values = retrieve_label_values(truth_data, prediction_data, label_idx)
        class_f1 = compute_f1(predicted_values, true_values)
        class_weight = true_values.count(True)
        occurrence_sum += class_weight
        score_components.append(class_f1 * class_weight)

    # Return weighted average, handling zero division case
    return sum(score_components) / occurrence_sum if occurrence_sum != 0 else 0.0

def load_data(filepath):
    """
    Load data from a tab-separated file and convert labels to boolean values.

    Parameters:
    - filepath (str): Path to the tab-separated data file

    Returns:
    - dict: Dictionary where keys are the first column values, and values are lists of boolean labels

    Raises:
    - ValueError: If file has inconsistent or incorrect format
    """
    result_dict = {}
    expected_columns = None

    with open(filepath) as input_file:
        csv_reader = csv.reader(input_file, delimiter='\t')
        line_num = 1

        for entry in csv_reader:
            if len(entry) < 2:  # ensure at least one label column is present
                raise ValueError(f'Wrong number of columns in line {line_num}, expected at least 2.')

            if expected_columns and len(entry) != expected_columns:
                raise ValueError(f'Inconsistent number of columns in line {line_num}.')

            expected_columns = len(entry)
            result_dict[entry[0]] = [bool(float(val)) for val in entry[1:]]
            line_num += 1

    return result_dict

def evaluate_f1_scores(gold_label_path, prediction_path, num_labels):
    """
    Evaluate scores for binary and multi-label classification tasks.

    Parameters:
    - gold_label_path (str): Path to the file containing gold standard labels
    - prediction_path (str): Path to the file containing model predictions
    - num_labels (int): Total number of labels (2 for binary classification, >2 for multi-label)

    Returns:
    - float or tuple: Binary score if num_labels=2, otherwise (binary_score, multilabel_score) tuple

    Raises:
    - ValueError: If submission is missing required keys
    """

    truth = load_data(gold_label_path)
    submission = load_data(prediction_path)

    # Ensure submission contains all necessary keys
    for key in truth.keys():
        if key not in submission:
            raise ValueError(f'Missing element {key} in submission')

    # Compute F1 metric for binary classification
    if num_labels == 2:
        binary_score = compute_binary_f1(truth, submission)
        return binary_score

    # Compute F1 for both binary classification and multi-label classification
    if num_labels > 2:
        binary_score = compute_binary_f1(truth, submission)
        multilabel_score = compute_multilabel_f1(truth, submission, num_labels)
        return binary_score, multilabel_score

### Binary classification

def build_bin_classifier(X_train, y_train):
    """
    Create and train an SVM classifier with TF-IDF features.

    Parameters:
    - X_train (list): List of text documents for training
    - y_train (list): Binary labels for training documents

    Returns:
    - tuple: (svm_model, vectorizer) - Trained model and fitted vectorizer
    """
    # Create TF-IDF features
    vectorizer = TfidfVectorizer() # min_df=5,
    X_train_features = vectorizer.fit_transform(X_train)

    # Train SVM model
    svm_model = LinearSVC(max_iter=10000)
    svm_model.fit(X_train_features, y_train)

    return svm_model, vectorizer

def classify_data(X_test, model, vectorizer):
    """
    Classify test data using the trained model.

    Parameters:
    - X_test (list): List of text documents to classify
    - model: Trained classifier model
    - vectorizer: Fitted vectorizer for feature extraction

    Returns:
    - numpy.ndarray: Predicted labels for test documents
    """
    # Transform test data
    X_test_features = vectorizer.transform(X_test)

    # Make predictions
    y_pred = model.predict(X_test_features)

    return y_pred


### Binary evaluation

def evaluate_binary_classification(gold_label_json, predictions_json,
                                   y_true, y_pred,
                                   gold_labels_txt, predictions_txt,
                                   label_names,
                                   model_name="Model"):
    """
    Generate and print comprehensive evaluation metrics for binary classification.

    Parameters:
    - gold_label_json (str): Path to the file with gold labels in PyEvALL format
    - predictions_json (str): Path to the file with predicted labels in PyEvALL format
    - y_true (list or array): Gold labels
    - y_pred (list or array): Predicted labels
    - gold_labels_txt (str): Path to the file with gold labels in txt format for f1 metric
    - predictions_txt (str): Path to the file with predictions in txt format for f1 metric
    - label_names (list of str): The list of label names corresponding to the binary problem
    - model_name (str): Name of the model (default: "Model")

    Returns:
    - None: Results are printed to console and displayed as plots
    """

    y_pred = y_pred.tolist()

    # Print classification report with precision, recall, f1, and support metrics
    report = classification_report(y_true, y_pred, target_names=label_names, digits=3)
    print(f"{'-'*100}\nClassification Report for {model_name}:\n{report}\n{'-'*100}")

    # Generate confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    print(f"{'-'*100}\nConfusion matrix for {model_name}:")

    # # Print confusion matrix with 3 decimal places (normalized)
    # cf_matrix_normalized = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis]
    # print("Normalized Confusion Matrix (3 decimal places):")
    # for i, row in enumerate(cf_matrix_normalized):
    #     formatted_row = [f"{val:.3f}" for val in row]
    #     print(f"{label_names[i]:>15}: {formatted_row}")
    # print()

    # Print raw confusion matrix
    print("Raw Confusion Matrix:")
    for i, row in enumerate(cf_matrix):
        print(f"{label_names[i]:>15}: {row.tolist()}")
    print()

    # Display confusion matrix as a heatmap with custom formatting
    disp = ConfusionMatrixDisplay(cf_matrix, display_labels=label_names)

    # Create figure with custom formatting for 3 decimal places
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot with custom text formatting
    disp.plot(cmap=plt.cm.Greens, ax=ax, values_format='.3f' if np.any(cf_matrix < 1) else 'd')

    # If you want to show both raw counts and percentages with 3 decimals
    # Calculate percentages
    cf_matrix_percent = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis] * 100


    plt.title(f"Confusion Matrix - {model_name}")
    plt.tight_layout()
    plt.show()
    print(f"{'-'*100}")

    # PyEvALL evaluation metrics
    print(f"{'-'*100}\nPyEvaLL Metrics for {model_name}:\n")
    evaluator = PyEvALLEvaluation()
    evaluation_params = dict()
    evaluation_params[PyEvALLUtils.PARAM_REPORT] = PyEvALLUtils.PARAM_OPTION_REPORT_DATAFRAME
    metric_list = ["ICM", "ICMNorm", "FMeasure"]
    evaluation_report = evaluator.evaluate(predictions_json, gold_label_json, metric_list, **evaluation_params)
    evaluation_report.print_report()
    print(f"{'-'*100}")

    # MAMI F1 metric (macro-f1 for binary classification)
    print(f"{'-'*100}\n F1 Metrics for {model_name}:\n")
    n_labels = 2

    # Compute binary classification macro-F1 score
    score_bin = evaluate_f1_scores(gold_labels_txt, predictions_txt, n_labels)
    print(f"Binary classification macro-F1 score: {score_bin:.3f}")
    print(f"{'-'*100}")

    # Get structured classification report
    class_report_dict = classification_report(y_true, y_pred,
                                            target_names=label_names,
                                            zero_division=0,
                                            digits=3,
                                            output_dict=True)

    # Extract confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)

    # Calculate binary F1 score
    score_bin = evaluate_f1_scores(gold_labels_txt, predictions_txt, n_labels)

    # Return structured results
    results = {
        'binary_f1': score_bin,
        'macro_f1': class_report_dict['macro avg']['f1-score'],
        'per_label_metrics': class_report_dict,
        'confusion_matrix': cf_matrix,
        'label_names': label_names
    }

    return results








### Multi-label classification

def build_multilabel_classifier(X_train, y_train, transform_strategy):
    """
    Create and train a multi-label text classification model using SVM with the specified strategy.

    Parameters:
    - X_train (list of str): Training text data
    - y_train (list): Multi-label training labels
    - transform_strategy (skmultilearn model wrapper): Multi-label classification strategy (BinaryRelevance/LabelPowerset)

    Returns:
    - tuple: (ml_model, vectorizer) - Trained multi-label model and fitted vectorizer
    """

    # Create TF-IDF vectorizer with NLTK tokenization
    vectorizer = TfidfVectorizer() # min_df=5
    X_train_features = vectorizer.fit_transform(X_train)

    # Configure multi-label model with LinearSVC base classifier
    ml_model = transform_strategy(LinearSVC(max_iter=10000))
    # Train the model
    ml_model.fit(X_train_features, y_train)

    return ml_model, vectorizer




def build_hierarchical_multilabel_classifier(all_training_texts, binary_labels,
                                             positive_subset_texts,category_labels,
                                             test_texts, binary_label_name,
                                             category_label_names, strategy_class=BinaryRelevance):
    """
    Train a hierarchical classification model with two stages: binary and multi-label classification.

    This approach first classifies instances as positive/negative (e.g., misogynous/non-misogynous),
    then applies fine-grained classification only to positive instances.

    Parameters:
    - all_training_texts (list): All training text instances
    - binary_labels (list): Binary labels for all training instances
    - positive_subset_texts (list): Text instances with positive binary labels only
    - category_labels (list): Fine-grained category labels for positive instances
    - test_texts (list): Texts for evaluation
    - binary_label_name (str): Name of the binary label column
    - category_label_names (list): Names of fine-grained category labels
    - strategy_class: Multi-label classification strategy (default: BinaryRelevance)

    Returns:
    - tuple: (predictions_df, bin_model, bin_vec, ml_model, ml_vec)
            where:
            - predictions_df: DataFrame containing predictions for both binary and fine-grained labels
            - bin_model: Trained binary classification model
            - bin_vec: Text vectorizer for binary classification
            - ml_model: Trained multi-label classification model (None if no positive instances)
            - ml_vec: Text vectorizer for multi-label classification (None if no positive instances)

    """

    # First build binary model to predict positive instances (misogynous/sexist)
    bin_model, bin_vec  = build_bin_classifier(all_training_texts, binary_labels)
    binary_predictions = classify_data(test_texts, bin_model, bin_vec )

    # Filter positive instances （misogynous/sexist） for fine-grained classification
    positive_test_texts = pd.DataFrame(test_texts)[binary_predictions == 1][0].tolist()

    # Initialize predictions DataFrame with binary labels
    # Default all fine-grained labels to 0
    pred_df = pd.DataFrame({binary_label_name: binary_predictions})
    pred_df[category_label_names] = 0

    # Build multi-label model for fine-grained classification if there are positive instances
    if len(positive_test_texts) > 0:
        ml_model, ml_vec = build_multilabel_classifier(
            positive_subset_texts, category_labels, strategy_class)

        # Apply fine-grained classification only to positive instances
        multilabel_predictions = classify_data(
            positive_test_texts, ml_model, ml_vec)

        # Add fine-grained labels to positive instances in the predictions DataFrame
        pred_df.loc[binary_predictions == 1, category_label_names] = multilabel_predictions.toarray()
    else:
        # If no positive instances, create empty models (to maintain return structure)
        ml_model, ml_vec = None, None

    return pred_df, bin_model, bin_vec , ml_model, ml_vec



### Multi-label evaluation

def evaluate_multilabel_classification(gold_label_json, predictions_json,
                                       y_true, y_pred,
                                       gold_labels_txt, predictions_txt,
                                       label_names,
                                       hierarchy=True):
    """
    Evaluate the performance of a multi-label classification model with comprehensive metrics.

    Parameters:
    - gold_label_json (str): Path to the file with gold labels in PyEvALL format
    - predictions_json (str): Path to the file with predicted labels in PyEvALL format
    - y_true (array-like): Gold binary labels (multi-label binary matrix) for the test set
    - y_pred (array-like): Predicted binary labels (multi-label binary matrix) for the test set
    - label_names (list of str): The list of label names corresponding to the multi-label problem
    - gold_labels_txt (str): Path to the file with gold labels in txt format for MAMI f1 metric
    - predictions_txt (str): Path to the file with predictions in txt format for MAMI f1 metric
    - hierarchy (bool): Whether the evaluation considers hierarchical evaluation (first binary
                       classification and then multi-label). Default is True.

    Returns:
    - dict: Evaluation results including F1 scores and per-label metrics
    """


    # Convert inputs to numpy arrays if they aren't already
    y_pred_array = y_pred.toarray() if not isinstance(y_pred, np.ndarray) else y_pred
    y_true_array = np.array(y_true) if not isinstance(y_true, np.ndarray) else y_true

    # Process binary classification data (first column)
    binary_true = y_true_array[:, 0]
    binary_pred = y_pred_array[:, 0]

    # Create negative class representation matrices
    negative_true = np.zeros((len(y_true_array), 1))
    negative_pred = np.zeros((len(y_pred_array), 1))

    # Set values for negative class (inverse of binary classification)
    negative_true[:, 0] = (binary_true == 0) #~binary_true.astype(bool)
    negative_pred[:, 0] = (binary_pred == 0) #~binary_pred.astype(bool)

    # Get multi-label classification data (all columns except first)
    multilabel_true = y_true_array[:, 1:]
    multilabel_pred = y_pred_array[:, 1:]

    # Create combined representation with negative class and multi-labels
    combined_true = np.hstack((negative_true, multilabel_true))
    combined_pred = np.hstack((negative_pred, multilabel_pred))

    # Update label names to include negative class
    updated_labels = [f"non-{label_names[0]}"] + label_names[1:]

    total_labels = len(updated_labels)

    # Print classification metrics with 3 decimal places
    print(f"{'-'*100}\nClassification Report:")
    class_report = classification_report(combined_true, combined_pred,
                                        target_names=updated_labels,
                                        zero_division=0,
                                        digits=3)
    print(f"{class_report}\n{'-'*100}")

    # GET STRUCTURED DATA: classification report as dictionary
    class_report_dict = classification_report(combined_true, combined_pred,
                                            target_names=updated_labels,
                                            zero_division=0,
                                            digits=3,
                                            output_dict=True)  # Return as dictionary

    # Generate and display confusion matrices (text only)
    print(f"{'-'*100}\nConfusion matrices:")
    confusion_matrices = multilabel_confusion_matrix(combined_true, combined_pred)



    # Print each confusion matrix with 3 decimal places
    for idx, matrix in enumerate(confusion_matrices):
        # Print numeric values with 3 decimal places for normalized matrix
        print(f"Confusion Matrix for '{updated_labels[idx]}' label:")
        print("Raw Matrix:")
        print(matrix)

        # Calculate and print normalized matrix with 3 decimal places
        if matrix.sum() > 0:
            matrix_normalized = matrix.astype('float') / matrix.sum(axis=1)[:, np.newaxis]
            print("Normalized Matrix (3 decimal places):")
            for i, row in enumerate(matrix_normalized):
                formatted_row = [f"{val:.3f}" for val in row]
                print(f"  {['False', 'True'][i]:>5}: {formatted_row}")
        print()

    # Run PyEvALL evaluation
    print(f"{'-'*100}\nPyEvaLL Metrics:\n")
    evaluator = PyEvALLEvaluation()
    evaluation_params = dict()

    # Configure hierarchical evaluation if needed
    if hierarchy:
        label_hierarchy = {"yes": label_names, "no":[]}
        evaluation_params[PyEvALLUtils.PARAM_HIERARCHY] = label_hierarchy
        metric_list = ["ICM", "ICMNorm", "FMeasure"]

    else:
        metric_list = ["FMeasure"]


    # Set report format
    evaluation_params[PyEvALLUtils.PARAM_REPORT] = PyEvALLUtils.PARAM_OPTION_REPORT_DATAFRAME

    # Run evaluation and print results
    evaluation_report = evaluator.evaluate(gold_label_json, predictions_json, metric_list, **evaluation_params)
    evaluation_report.print_report()
    print(f"{'-'*100}")


    # Calculate F1 metrics with 3 decimal places
    bin_score, ml_score = evaluate_f1_scores(gold_labels_txt, predictions_txt, total_labels)

    print(f"Binary classification macro-F1 score: {bin_score:.3f}")
    print(f"Multi-label classification weighted-F1 score: {ml_score:.3f}")
    print(f"{'-'*100}")

    # RETURN STRUCTURED RESULTS
    results = {
        'binary_f1': bin_score,
        'macro_f1': class_report_dict['macro avg']['f1-score'],
        'multilabel_f1': ml_score,
        'per_label_metrics': class_report_dict,
        'confusion_matrices': confusion_matrices,
        'label_names': updated_labels,
        'original_label_names': label_names
    }

    return results












### Convert predictions to PyEvALL format

def format_pred_for_pyevall(df, binary_label, labels, test_case, eval_type, pred_label):
    """
    Convert a DataFrame of labeled meme dataset into a format suitable for PyEvALLEvaluation.

    Parameters:
    - df (pandas.DataFrame): A DataFrame containing meme id and associated labels
    - binary_label (str): The name of the binary label column in the DataFrame (sexist or misogynous)
    - labels (list): A list of column names representing the labels for evaluation in the dataset
    - test_case (str): The test case identifier to be added as a new column, e.g. "MAMI" or "EXIST2024"
    - eval_type (str): The type of evaluation format to be used:
        - "binary": For binary classification
        - "hierarchical": For multi-label classification
    - pred_label (np.ndarray): A NumPy array containing predicted labels

    Returns:
    - list: A list of dictionaries formatted for PyEvALLEvaluation with structure:
        - test_case: The name of the dataset
        - id: The meme id
        - value: Labels in format appropriate for the evaluation type
    """

    # Convert files to input required by PyEvALLEvaluation
    pred_labels = df[["meme id"]].copy()

    # Add the test_case column as per the library requirements
    pred_labels.insert(0, "test_case", [test_case] * (len(pred_labels)), True)

    if eval_type == "binary":

        # Format binary labels
        pred_labels["value"] = pred_label
        binary_labels = pred_labels

        # Convert values to "yes" and "no" as required by PyEvALL
        binary_labels = pred_labels.replace({"value":0}, "no").replace({"value":1}, "yes")

        # Rename the id column to match requirements
        binary_labels.rename(columns={"meme id": "id"}, inplace=True)

        # Convert "id" column to string values
        binary_labels["id"] = binary_labels["id"].astype(str)
        labels_df = binary_labels

    elif eval_type == "hierarchical":

        # Format hierarchical multi-label data
        multilabel_labels = pred_labels[["test_case", "meme id"]].reset_index(drop=True)

        # Concatenate with dataset df along columns (axis=1)
        multilabel_labels = pd.concat([multilabel_labels, pred_label], axis=1)

        # Extract only fine-grained category columns
        value_cols = multilabel_labels.columns[2:]

        # Create value column with list of labels where value is 1
        multilabel_labels["value"] = multilabel_labels[value_cols].apply(
            lambda row: [label for i, label in enumerate(labels[1:]) if row.iloc[i]] or ["no"], axis=1)
        multilabel_labels = multilabel_labels[["test_case", "meme id", "value"]]

        # Rename id column to match PyEvALL requirements
        multilabel_labels.rename(columns={"meme id": "id"}, inplace=True)

        # Convert "id" column to string values
        multilabel_labels["id"] = multilabel_labels["id"].astype(str)
        labels_df = multilabel_labels

    # Convert DataFrame to list of dictionaries as required by PyEvALL
    labels_list = labels_df.to_dict(orient="records")

    return labels_list


### Convert predictions to MAMI f1 evaluation format

def format_pred_for_mami_f1(df, eval_type, pred_label):
    """
    Transform a DataFrame containing meme data into the format required by MAMI evaluation framework.

    Parameters:
    - df (pandas.DataFrame): DataFrame with meme identifiers and attribute data
    - eval_type (str): Specifies the evaluation approach:
        - "binary": Used for simple positive/negative classification
        - "hierarchical": Used for multi-level category classification
    - pred_label (np.ndarray or pd.DataFrame): Classification results from model prediction

    Returns:
    - pandas.DataFrame: Properly formatted data structure with IDs and classification values
    """
    # Extract just the identifier column to a new DataFrame
    pred_labels = df[["meme id"]].copy()

    if eval_type == "binary":
        # Binary case: simply attach prediction vector as a value column
        pred_labels["value"] = pred_label

    else:
        # Handle multi-label scenario by first ensuring array format
        if not isinstance(pred_label, np.ndarray):
            pred_label = pred_label.toarray()
        # Transform array into structured DataFrame
        pred_df = pd.DataFrame(pred_label)
        # Normalize index sequencing for proper alignment
        pred_labels = pred_labels.reset_index(drop=True)
        # Merge identifier column with prediction matrix
        pred_labels = pd.concat([pred_labels, pred_df], axis=1)

    # Standardize identifier type to string format
    pred_labels["meme id"] = pred_labels["meme id"].astype(str)

    return pred_labels



### Save predictions for evaluation

def write_labels_to_json(label_list, output_file, dataset_name, split_name, eval_type):
    """
    Write labels to JSON file for PyEvALL Evaluation.

    Parameters:
    - label_list (list): List of dictionaries containing test case, meme id and labels
    - output_file (str): Path to the output JSON file
    - dataset_name (str): Name of the dataset (e.g., MAMI, EXIST2024)
    - split_name (str): Name of the data split (e.g., training, test)
    - eval_type (str): Type of evaluation (binary, flat, hierarchical)

    Returns:
    - None
    """
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(label_list, f, ensure_ascii=False, indent=4)
    print(f"Saved {dataset_name} {split_name} split {eval_type} evaluation to {output_file}")


def write_labels_to_txt(labels_df, output_path, dataset_name, split_name):
    """
    Write labels to tab-separated text file for MAMI Evaluation.

    Parameters:
    - labels_df (pandas.DataFrame): DataFrame containing meme id and associated labels
    - output_path (str): Path to the output text file
    - dataset_name (str): Name of the dataset (e.g., MAMI, EXIST2024)
    - split_name (str): Name of the data split (e.g., training, test)

    Returns:
    - None
    """
    labels_df.to_csv(output_path, index=False, sep='\t', header=False)

    print(f"Saved {dataset_name} {split_name} split to {output_path}")


def save_evaluation(df, pred_dir, dataset_name, split_name, eval_type, model_name, predictions, binary_label, labels):
    """
    Store model evaluation data in structured format files and return their paths.

    Exports prediction results to both JSON (PyEvALL compatible) and TXT (MAMI F1 compatible) formats.

    Parameter:
    - df (pandas.DataFrame): Source dataset containing meme identifiers and ground truth labels
    - pred_dir (str): Target location for storing output files
    - dataset_name (str): Identifier for the evaluation corpus (e.g., "MAMI", "EXIST2024")
    - split_name (str): Partition identifier (e.g., "train", "dev", "test")
    - eval_type (str): Classification approach used ("binary", "hierarchical", "flat")
    - model_name (str): Classifier identifier for file naming
    - predictions (np.ndarray): Model output predictions matrix
    - binary_label (str): Primary category field name in the dataset
                         Used for yes/no categorization in certain evaluation types
    - labels (list): Field identifiers for all classification dimensions
    """

    dataset_name = "EXIST2024" if dataset_name == "EXIST" else dataset_name

    # Construct nested directory path for this specific dataset
    output_directory = os.path.join(pred_dir, dataset_name)

    # Ensure storage location exists
    os.makedirs(output_directory, exist_ok=True)

    # Generate PyEvALL-compatible representation
    prediction_records = format_pred_for_pyevall(
        df, binary_label, labels, dataset_name, eval_type, predictions
    )

    # Construct JSON output path and persist data
    json_filepath = os.path.join(
        output_directory,
        f"{model_name}_{dataset_name}_{split_name}_{eval_type}.json"
    )
    write_labels_to_json(
        prediction_records, json_filepath, dataset_name, split_name, eval_type
    )

    # Handle DataFrame predictions by converting to numpy array
    prediction_data = predictions
    if isinstance(predictions, pd.DataFrame):
        prediction_data = predictions.to_numpy()

    # Generate MAMI F1 compatible representation
    mami_format_data = format_pred_for_mami_f1(df, eval_type, prediction_data)

    # Construct TXT output path and persist data
    txt_filepath = os.path.join(
        output_directory,
        f"{model_name}_{dataset_name}_{split_name}_answer.txt"
    )
    write_labels_to_txt(mami_format_data, txt_filepath, dataset_name, split_name)

    # Return both file paths for reference
    return json_filepath, txt_filepath

def save_predictions_csv(test_df, predictions, column_names, output_file):
    """
    Write predictions to a CSV file by adding new columns per predicted label(s).

    Parameters:
    - test_df (pd.DataFrame): DataFrame containing the original test instances and gold labels.
    - predictions (list or np.array): Model predictions.
    - column_names (list): A list of column names representing the labels for evaluation in the dataset.
    - output_file (str): Path to the output CSV file.
    """
    #convert predicted labels array to df with labels as column names + _prediction
    pred_df = pd.DataFrame(predictions, columns=[f"{col}_prediction" for col in column_names])

    #drop index to merge with predictions df
    test_df = test_df.reset_index(drop=True)

    #save predictions while keeping original labels
    pred_df = pd.concat([test_df, pred_df], axis=1)

    #save updated file
    pred_df.to_csv(output_file, index=False)

    print(f"Predictions saved to {output_file}")

## **utils_experiments**

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import re
import nltk
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import random
from collections import Counter
import nltk
import string


# Load NRC lexicon for emotion and sentiment analysis

def load_nrc_lexicon(path_nrc):
    """
    Load the NRC Emotion Lexicon from a file.

    Parameters:
    -----------
    path_nrc : str
        Path to the NRC Emotion Lexicon file (NRC-Emotion-Lexicon-Wordlevel-v0.92.txt)

    Returns:
    --------
    dict
        A dictionary where:
        - Keys are words from the lexicon
        - Values are lists of emotions/sentiments associated with each word
          (only those with a value of 1 in the lexicon)
    """

    lexicon = {}
    with open(path_nrc, 'r', encoding='utf-8') as f:
        for line in f:
            if len(line.strip().split('\t')) == 3:
                word, emotion, value = line.strip().split('\t')
                if int(value) == 1:
                    if word not in lexicon:
                        lexicon[word] = []
                    lexicon[word].append(emotion)
    return lexicon



# Load HurtLex for hate speech terms

def load_hurtlex(file_path):
    """
    Load HurtLex lexicon from TSV file.

    Parameters:
    file_path (str): Path to the HurtLex TSV file

    Returns:
    dict: Dictionary mapping words to their hate speech categories
    """
    hurtlex = {}

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            # Skip header
            header = next(f)

            for line in f:
                parts = line.strip().split('\t')
                # Based on your table: id, pos, category, stereotype, lemma, level
                if len(parts) >= 6:
                    id_val = parts[0]
                    pos = parts[1]
                    category = parts[2]  # This is the actual category code (om, qas, etc.)
                    stereotype = parts[3]
                    lemma = parts[4]
                    level = parts[5]

                    if lemma not in hurtlex:
                        hurtlex[lemma] = []

                    # Add the category to the word's list of categories
                    hurtlex[lemma].append(category)

        return hurtlex
    except FileNotFoundError:
        print(f"Error: Could not find HurtLex file at {file_path}")
        return {}

# Load function words dictionary

def load_function_words():
    """
    Load a comprehensive dictionary of function words based on linguistic closed class categories.

    Returns:
    --------
    dict
        A dictionary where:
        - Keys are closed class categories
        - Values are lists of words belonging to each category
    """
    function_words = {
        'determiners': ['the', 'a', 'an', 'this', 'that', 'these', 'those', 'my', 'your', 'his', 'her', 'its',
                        'our', 'their', 'any', 'each', 'every', 'some', 'all', 'both', 'either', 'neither',
                        'few', 'many', 'much', 'several', 'more', 'most', 'less', 'no', 'enough', 'which', 'what', 'whose'],

        'pronouns': ['i', 'me', 'my', 'mine', 'myself', 'you', 'your', 'yours', 'yourself', 'yourselves',
                    'he', 'him', 'his', 'himself', 'she', 'her', 'hers', 'herself', 'it', 'its', 'itself',
                    'we', 'us', 'our', 'ours', 'ourselves', 'they', 'them', 'their', 'theirs', 'themselves',
                    'who', 'whom', 'whose', 'which', 'what', 'this', 'that', 'these', 'those',
                    'anybody', 'somebody', 'nobody', 'everybody', 'anyone', 'someone', 'no one', 'everyone',
                    'each', 'either', 'neither', 'one', 'all', 'some', 'many', 'few', 'any', 'none', 'both'],

        'prepositions': ['of', 'at', 'in', 'on', 'without', 'between', 'under', 'over', 'beside', 'through',
                        'during', 'among', 'across', 'against', 'towards', 'around', 'before', 'after',
                        'along', 'behind', 'below', 'beyond', 'despite', 'except', 'from', 'inside', 'near',
                        'onto', 'outside', 'past', 'since', 'till', 'until', 'upon', 'within', 'about', 'above',
                        'beneath', 'beside', 'by', 'down', 'into', 'like', 'off', 'out', 'throughout', 'to',
                        'toward', 'underneath', 'unto', 'up', 'with', 'without', 'regarding', 'round'],

        'conjunctions': ['and', 'but', 'or', 'nor', 'for', 'yet', 'so', 'although', 'because', 'since',
                        'unless', 'until', 'when', 'while', 'whereas', 'after', 'before', 'if', 'then',
                        'even though', 'though', 'as long as', 'provided that', 'however', 'therefore',
                        'thus', 'moreover', 'nevertheless'],

        'auxiliary_verbs': ['am', 'is', 'are', 'was', 'were', 'be', 'being', 'been',
                           'have', 'has', 'had', 'having',
                           'do', 'does', 'did',
                           'will', 'would', 'shall', 'should', 'can', 'could', 'may', 'might', 'must', 'ought'],

        'enumerators': ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten',
                        'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
                        'once', 'twice', 'thrice'],

        'particles': ['no', 'not', 'nor', 'as', 'to', 'up', 'out', 'off', 'down', 'about',
                     'around', 'aside', 'away', 'back', 'apart'],

        'qualifiers': ['very', 'really', 'quite', 'somewhat', 'rather', 'too', 'pretty', 'fairly', 'slightly',
                       'almost', 'nearly', 'barely', 'hardly', 'scarcely', 'completely', 'absolutely', 'totally',
                       'utterly', 'extremely', 'especially', 'particularly', 'specifically'],


        'interjections': ['oh', 'ah', 'ugh', 'hey', 'oops', 'gadzooks', 'wow', 'ouch', 'eh', 'hmm']
    }

    return function_words


def get_all_categories(hurtlex_dict):
    """Get all unique categories in the lexicon"""
    categories = set()
    for word, cats in hurtlex_dict.items():
        for cat in cats:
            categories.add(cat)
    return sorted(list(categories))

def get_words_by_category(hurtlex_dict, category):
    """Get all words that belong to a specific category"""
    words = []
    for word, cats in hurtlex_dict.items():
        if category in cats:
            words.append(word)
    return words

def count_words_by_category(hurtlex_dict):
    """Count how many words are in each category"""
    categories = get_all_categories(hurtlex_dict)
    counts = {}
    for cat in categories:
        counts[cat] = len(get_words_by_category(hurtlex_dict, cat))
    return counts

def clean_tokens(tokens):
    return [token.strip(string.punctuation) for token in tokens]



def extract_emotion_words(text, nrc_lexicon, emotion_category):
    """
    Extract words in text belonging to a specific emotion category.

    Parameters:
    -----------
    text : str
        The text to analyze
    nrc_lexicon : dict
        The NRC emotion lexicon
    emotion_category : str
        Emotion category to extract (e.g., 'anger', 'joy')

    Returns:
    --------
    list
        List of words in text that belong to the emotion category
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]
    emotion_words = []

    for word in words:
        if word in nrc_lexicon and emotion_category in nrc_lexicon[word]:
            emotion_words.append(word)

    return emotion_words

def extract_sentiment_words(text, nrc_lexicon, sentiment_type):
    """
    Extract positive or negative sentiment words from text.

    Parameters:
    -----------
    text : str
        The text to analyze
    nrc_lexicon : dict
        The NRC emotion lexicon
    sentiment_type : str
        Either 'positive' or 'negative'

    Returns:
    --------
    list
        List of words in text that have the specified sentiment
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]
    sentiment_words = []

    for word in words:
        if word in nrc_lexicon and sentiment_type in nrc_lexicon[word]:
            sentiment_words.append(word)

    return sentiment_words

def extract_function_words(text, function_words_dict, function_category):
    """
    Extract function words from a specific category.

    Parameters:
    -----------
    text : str
        The text to analyze
    function_words_dict : dict
        Dictionary of function word categories
    function_category : str
        Category of function words to extract

    Returns:
    --------
    list
        List of function words found in the text from the specified category
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]
    category_words = []

    if function_category in function_words_dict:
        for word in words:
            if word in function_words_dict[function_category]:
                category_words.append(word)

    return category_words

def extract_hate_speech_terms(text, hurtlex_dict, hurtlex_category):
    """
    Extract hate speech terms from a specific HurtLex category.

    Parameters:
    -----------
    text : str
        The text to analyze
    hurtlex_dict : dict
        The HurtLex dictionary
    hurtlex_category : str
        HurtLex category to extract (e.g., 'ps', 'om')

    Returns:
    --------
    list
        List of words in text that belong to the specified HurtLex category
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]
    hate_terms = []

    for word in words:
        if word in hurtlex_dict and hurtlex_category in hurtlex_dict[word]:
            hate_terms.append(word)

    return hate_terms




def replace_mask_features(text, feature_words, mask_token="[MASK]"):
    """
    Replace specified features with a mask token.

    Parameters:
    -----------
    text : str
        The text to process
    feature_words : list
        List of words to mask
    mask_token : str
        Token to replace feature words with

    Returns:
    --------
    str
        Text with specified words masked
    tuple
        Count of words masked
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]

    feature_words_lower = set(word.lower() for word in feature_words)
    masked_count = 0
    masked_words = []

    for word in words:
        if word.lower() in feature_words_lower:
            masked_words.append(mask_token)
            masked_count += 1
        else:
            masked_words.append(word)

    return ' '.join(masked_words), masked_count

def remove_features(text, feature_words):
    """
    Remove specified features from text.

    Parameters:
    -----------
    text : str
        The text to process
    feature_words : list
        List of words to remove

    Returns:
    --------
    str
        Text with specified words removed
    int
        Count of words removed
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]
    removed_count = 0
    filtered_words = []

    for word in words:
        if word.lower() in feature_words:
            removed_count += 1
        else:
            filtered_words.append(word)

    return ' '.join(filtered_words), removed_count

##########################################################################################################

def replace_mask_random_words(text, num_words_to_mask, mask_token="[MASK]"):
    """
    Mask a specified number of random words in the text.

    Parameters:
    -----------
    text : str
        The text to process
    num_words_to_mask : int
        Number of words to mask
    mask_token : str
        Token to replace words with

    Returns:
    --------
    str
        Text with random words masked
    int
        Actual count of words masked (might be less than num_words_to_mask if text is shorter)
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]

    # Adjust number of words to mask if text is shorter
    num_words_to_mask = min(num_words_to_mask, len(words))

    # Select random indices to mask
    indices_to_mask = random.sample(range(len(words)), num_words_to_mask)

    # Apply masking
    for idx in indices_to_mask:
        words[idx] = mask_token

    return ' '.join(words), num_words_to_mask

def remove_random_words(text, num_words_to_remove):
    """
    Remove a specified number of random words from the text.

    Parameters:
    -----------
    text : str
        The text to process
    num_words_to_remove : int
        Number of words to remove

    Returns:
    --------
    str
        Text with random words removed
    int
        Actual count of words removed (might be less than num_words_to_remove if text is shorter)
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]

    # Adjust number of words to remove if text is shorter
    num_words_to_remove = min(num_words_to_remove, len(words))

    # Select random indices to remove
    indices_to_remove = random.sample(range(len(words)), num_words_to_remove)

    # Create a new list without the removed words
    filtered_words = [word for idx, word in enumerate(words) if idx not in indices_to_remove]

    return ' '.join(filtered_words), num_words_to_remove

def calculate_ablation_statistics(texts, feature_extraction_func, feature_data):
    """
    Calculate statistics for feature word ablation across a corpus.

    Parameters:
    -----------
    texts : list
        List of text documents
    feature_extraction_func : function
        Function that extracts feature words from a text
    feature_data : dict/object
        Data needed by the feature extraction function

    Returns:
    --------
    dict
        Dictionary containing:
        - total_ablated_words: Total number of words ablated
        - ablated_words_per_doc: List of counts for each document
        - avg_ablated_words: Average number of ablated words per document
        - most_common_ablated: Counter of most commonly ablated words
    """
    ablated_words_per_doc = []
    all_ablated_words = []

    for text in texts:
        # Extract feature words specific to this text
        feature_words = feature_extraction_func(text, feature_data)

        # Count how many words would be ablated
        words = text.lower().split()
        ablated_in_doc = [word for word in words if word in feature_words]

        ablated_words_per_doc.append(len(ablated_in_doc))
        all_ablated_words.extend(ablated_in_doc)

    # Calculate statistics
    total_ablated_words = sum(ablated_words_per_doc)
    avg_ablated_words = np.mean(ablated_words_per_doc) if ablated_words_per_doc else 0
    most_common_ablated = Counter(all_ablated_words).most_common(10)

    return {
        'total_ablated_words': total_ablated_words,
        'ablated_words_per_doc': ablated_words_per_doc,
        'avg_ablated_words': avg_ablated_words,
        'most_common_ablated': most_common_ablated
    }


def extract_features_from_texts(texts, nrc_lexicon, hurtlex_dict, function_words_dict):
    """
    Extract various linguistic features from a list of texts.

    Parameters:
    -----------
    texts : list
        List of texts to analyze
    nrc_lexicon : dict
        The NRC emotion lexicon
    hurtlex_dict : dict
        The HurtLex dictionary
    function_words_dict : dict
        Dictionary of function word categories

    Returns:
    --------
    dict
        Dictionary with extracted features for each text
    """
    features = {
        'emotion': {
            'anger': [], 'anticipation': [], 'disgust': [],
            'fear': [], 'joy': [], 'sadness': [],
            'surprise': [], 'trust': []
        },
        'sentiment': {'positive': [], 'negative': []},
        'function_words': {cat: [] for cat in function_words_dict.keys()},
        'hate_speech': {}
    }

    # Define hate speech categories to extract from HurtLex
    hurtlex_categories = ['ps', 'pa', 'ddf', 'ddp', 'asf', 'asp', 'om', 'qas',
                         'cds', 'rci', 'pr', 'pe', 'dmc', 'is', 'or', 'an']

    for category in hurtlex_categories:
        features['hate_speech'][category] = []

    for text in texts:
        # Extract emotion words
        for emotion in features['emotion'].keys():
            emotion_words = [word for word in text.lower().split()
                            if word in nrc_lexicon and emotion in nrc_lexicon[word]]
            features['emotion'][emotion].append(emotion_words)

        # Extract sentiment words
        for sentiment in features['sentiment'].keys():
            sentiment_words = [word for word in text.lower().split()
                              if word in nrc_lexicon and sentiment in nrc_lexicon[word]]
            features['sentiment'][sentiment].append(sentiment_words)

        # Extract function words
        for category in features['function_words'].keys():
            category_words = [word for word in text.lower().split()
                             if word in function_words_dict[category]]
            features['function_words'][category].append(category_words)

        # Extract hate speech terms
        for category in features['hate_speech'].keys():
            hate_terms = [word for word in text.lower().split()
                         if word in hurtlex_dict and category in hurtlex_dict[word]]
            features['hate_speech'][category].append(hate_terms)

    return features

##########################################################################################################


def replace_mask_random_words_excluding_features(text, num_words_to_mask, feature_words, mask_token="[MASK]"):
    """
    Mask a specified number of random words in the text, excluding feature words.

    Parameters:
    -----------
    text : str
        The text to process
    num_words_to_mask : int
        Number of words to mask
    feature_words : set
        Set of words to exclude from masking (feature words)
    mask_token : str
        Token to replace words with

    Returns:
    --------
    str
        Text with random words masked (excluding feature words)
    int
        Actual count of words masked (might be less than num_words_to_mask if text is shorter)
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]

    # Get indices of non-feature words
    non_feature_indices = [i for i, word in enumerate(words) if word.lower() not in feature_words]

    # Adjust number of words to mask if fewer non-feature words available
    num_words_to_mask = min(num_words_to_mask, len(non_feature_indices))

    if num_words_to_mask == 0:
        return text, 0

    # Select random indices from non-feature words
    indices_to_mask = random.sample(non_feature_indices, num_words_to_mask)

    # Apply masking
    for idx in indices_to_mask:
        words[idx] = mask_token

    return ' '.join(words), num_words_to_mask

def remove_random_words_excluding_features(text, num_words_to_remove, feature_words):
    """
    Remove a specified number of random words from the text, excluding feature words.

    Parameters:
    -----------
    text : str
        The text to process
    num_words_to_remove : int
        Number of words to remove
    feature_words : set
        Set of words to exclude from removal (feature words)

    Returns:
    --------
    str
        Text with random words removed (excluding feature words)
    int
        Actual count of words removed (might be less than num_words_to_remove if text is shorter)
    """
    tokens = nltk.word_tokenize(text.lower())
    tokens = clean_tokens(tokens)
    words = [w for w in tokens if w not in string.punctuation]

    # Get indices of non-feature words
    non_feature_indices = [i for i, word in enumerate(words) if word.lower() not in feature_words]

    # Adjust number of words to remove if fewer non-feature words available
    num_words_to_remove = min(num_words_to_remove, len(non_feature_indices))

    if num_words_to_remove == 0:
        return text, 0

    # Select random indices from non-feature words to remove
    indices_to_remove = random.sample(non_feature_indices, num_words_to_remove)

    # Create a new list without the removed words
    filtered_words = [word for idx, word in enumerate(words) if idx not in indices_to_remove]

    return ' '.join(filtered_words), num_words_to_remove



##########################################################################################################

try:
    import spacy
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Installing spaCy English model...")
    import subprocess
    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    import spacy
    nlp = spacy.load("en_core_web_sm")


def extract_pos_tags_batch(texts, batch_size=1000):
    """
    Extract POS tags for multiple texts using batch processing for speed.

    Parameters:
    - texts (list): List of input texts
    - batch_size (int): Number of texts to process in each batch

    Returns:
    - list: List of lists containing (word, pos_tag) tuples for each text
    """
    all_pos_tags = []

    print(f"Processing {len(texts)} texts in batches of {batch_size}...")

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]

        # Process batch with nlp.pipe for speed
        docs = list(nlp.pipe(batch, batch_size=batch_size))

        batch_pos_tags = []
        for doc in docs:
            pos_tags = [(token.text, token.pos_) for token in doc]
            batch_pos_tags.append(pos_tags)

        all_pos_tags.extend(batch_pos_tags)

        if (i + batch_size) % 5000 == 0 or (i + batch_size) >= len(texts):
            print(f"Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")

    return all_pos_tags

def remove_pos_category(text, pos_category):
    """
    Remove all words of a specific POS category from text.

    Parameters:
    - text (str): Input text
    - pos_category (str): POS category to remove (e.g., 'NOUN', 'VERB')

    Returns:
    - tuple: (modified_text, count_removed)
    """
    doc = nlp(text)
    filtered_words = []
    count_removed = 0

    for token in doc:
        if token.pos_ == pos_category:
            count_removed += 1
        else:
            filtered_words.append(token.text)

    return ' '.join(filtered_words), count_removed

def create_pos_ablated_dataset(texts, pos_category, precomputed_pos_tags=None):
    """
    Create a dataset with a specific POS category removed from all texts.
    Optimized version that can reuse precomputed POS tags.

    Parameters:
    - texts (list): List of text documents
    - pos_category (str): POS category to remove
    - precomputed_pos_tags (list): Precomputed POS tags to avoid re-tagging

    Returns:
    - tuple: (ablated_texts, removal_stats)
    """
    ablated_texts = []
    total_removed = 0
    documents_affected = 0
    removal_per_doc = []

    print(f"Removing POS category: {pos_category}")

    # Use precomputed tags if available, otherwise compute them
    if precomputed_pos_tags is None:
        print("Computing POS tags...")
        pos_tags_list = extract_pos_tags_batch(texts)
    else:
        pos_tags_list = precomputed_pos_tags

    for i, (text, pos_tags) in enumerate(zip(texts, pos_tags_list)):
        # Filter out words with the target POS category
        filtered_words = []
        count_removed = 0

        for word, pos_tag in pos_tags:
            if pos_tag == pos_category:
                count_removed += 1
            else:
                filtered_words.append(word)

        ablated_text = ' '.join(filtered_words)
        ablated_texts.append(ablated_text)

        total_removed += count_removed
        removal_per_doc.append(count_removed)

        if count_removed > 0:
            documents_affected += 1

        if (i + 1) % 2000 == 0:
            print(f"Processed {i + 1}/{len(texts)} documents")

    stats = {
        'pos_category': pos_category,
        'total_documents': len(texts),
        'documents_affected': documents_affected,
        'total_words_removed': total_removed,
        'avg_words_removed_per_doc': np.mean(removal_per_doc),
        'max_words_removed': max(removal_per_doc),
        'percentage_docs_affected': (documents_affected / len(texts)) * 100
    }

    return ablated_texts, stats


# Binary classification

In [None]:
# Define a custom binary dataset class for BERT
class MemeDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Define a BERT classifier model
class BertClassifier(torch.nn.Module):
    def __init__(self, n_classes, model_name="bert-base-uncased"):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.drop = torch.nn.Dropout(0.3)
        self.out = torch.nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        output = self.drop(pooled_output)
        return self.out(output)



In [None]:
# Training Function for binary classification
def train_epoch_bin(model, data_loader, optimizer, scheduler, device):
    model.train()
    losses = []
    correct_predictions = 0
    total_predictions = 0

    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        _, preds = torch.max(outputs, dim=1)
        loss = torch.nn.CrossEntropyLoss()(outputs, labels)

        correct_predictions += torch.sum(preds == labels)
        total_predictions += len(labels)

        losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

    return correct_predictions.double() / total_predictions, np.mean(losses)

# Evaluation Function for binary classification
def eval_model_bin(model, data_loader, device):
    model.eval()
    losses = []
    correct_predictions = 0
    total_predictions = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

            _, preds = torch.max(outputs, dim=1)
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)

            correct_predictions += torch.sum(preds == labels)
            total_predictions += len(labels)

            losses.append(loss.item())

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    return correct_predictions.double() / total_predictions, np.mean(losses), all_preds, all_labels

# Create a function to perform binary prediction
def predict(model, data_loader, device):
    model.eval()
    predictions = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())

    return predictions

# Multi-label classification

In [None]:
# Dataset for multi-label classification
class MemeMultiLabelDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

# Training function for multi-label classification
def train_epoch_multilabel(model, data_loader, optimizer, scheduler, device):
    model.train()
    losses = []

    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        loss = torch.nn.BCELoss()(outputs, labels)
        losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

    return np.mean(losses)


# Evaluation function for multi-label classification
def eval_model_multilabel(model, data_loader, device, threshold=0.5):
    model.eval()
    losses = []
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

            loss = torch.nn.BCELoss()(outputs, labels)
            losses.append(loss.item())

            # Convert probabilities to binary predictions
            preds = (outputs > threshold).int()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.mean(losses), all_preds, all_labels

In [None]:
# Create a multi-label BERT classifier
class BertMultiLabelClassifier(torch.nn.Module):
    def __init__(self, n_classes, model_name="bert-base-uncased"):
        super(BertMultiLabelClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.drop = torch.nn.Dropout(0.3)
        self.out = torch.nn.Linear(self.bert.config.hidden_size, n_classes)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        output = self.drop(pooled_output)
        return self.sigmoid(self.out(output))


In [None]:
# Create a function to perform hierarchical prediction
def predict_hierarchical(binary_model, ml_model, data_loader, ml_tokenizer, device,
                        fine_grained_labels, test_texts=None, threshold=0.5, batch_size=16):
    """
    Create a function to perform hierarchical prediction

    Parameters:
    -----------
    binary_model : torch.nn.Module
        Pre-trained binary classification model
    ml_model : torch.nn.Module
        Pre-trained multi-label classification model
    data_loader : DataLoader
        DataLoader for test data
    ml_tokenizer : transformers.Tokenizer
        Tokenizer for processing text
    device : torch.device
        Device to run models on
    fine_grained_labels : list
        List of fine-grained label names
    test_texts : list, optional
        List of test texts (if None, will extract from data_loader)
    threshold : float, default=0.5
        Threshold for multi-label predictions
    batch_size : int, default=16
        Batch size for multi-label predictions

    Returns:
    --------
    np.ndarray : Full prediction matrix (binary + multi-label)
    """
    # First, get binary predictions
    binary_model.eval()
    ml_model.eval()
    binary_preds = []
    all_ml_preds = np.zeros((len(data_loader.dataset), len(fine_grained_labels)))

    # Extract test texts if not provided
    if test_texts is None:
        test_texts = []
        for batch in data_loader:
            test_texts.extend(batch['text'])
        # Reset data_loader for predictions
        data_loader = DataLoader(data_loader.dataset, batch_size=data_loader.batch_size)

    # Get binary predictions
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)

            outputs = binary_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids
            )

            _, preds = torch.max(outputs, dim=1)
            binary_preds.extend(preds.cpu().tolist())

    # For positive binary predictions, get multi-label predictions
    positive_indices = [i for i, pred in enumerate(binary_preds) if pred == 1]

    if positive_indices:
        # Get texts for positive samples
        positive_texts = [test_texts[i] for i in positive_indices]

        # Create a dataset for positive samples
        positive_dataset = MemeDataset(
            texts=positive_texts,
            labels=[0] * len(positive_texts),  # Dummy labels
            tokenizer=ml_tokenizer
        )

        positive_loader = DataLoader(
            positive_dataset,
            batch_size=batch_size
        )

        # Predict multi-labels
        ml_preds = []
        with torch.no_grad():
            for batch in positive_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)

                outputs = ml_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )

                # Convert probabilities to binary predictions
                preds = (outputs > threshold).int()
                ml_preds.extend(preds.cpu().numpy())

        # Assign multi-label predictions to positive samples
        for i, idx in enumerate(positive_indices):
            all_ml_preds[idx] = ml_preds[i]

    # Create full prediction matrix (binary + multi-label)
    full_preds = np.zeros((len(binary_preds), 1 + len(fine_grained_labels)))
    full_preds[:, 0] = binary_preds  # Binary classification
    full_preds[:, 1:] = all_ml_preds  # Multi-label classification

    return full_preds

# MAMI

In [None]:
# Load MAMI dataset
print("Loading MAMI dataset...")
mami_training_data = route +  '' # Path to training data
mami_dev_data = route + '' # Path to development data
mami_test_data = route + '' # Path to test data

mami_training_df = pd.read_json(mami_training_data, orient='index')
mami_dev_df = pd.read_json(mami_dev_data, orient='index')
mami_test_df = pd.read_json(mami_test_data, orient='index')

# Combine training and validation sets
mami_training_df = pd.concat([mami_training_df, mami_dev_df]).sort_index()

# Check available columns
print("Training data columns:", mami_training_df.columns)
print("Test data columns:", mami_test_df.columns)

# Set up PyEvaLL evaluation gold files
gold_test_txt = route + '' # Path to MAMI text gold label (txt)
gold_test_bin = route + '' # Path to MAMI binary classification gold labels (json) 
gold_test_ml = route + '' # Path to MAMI multi-label classification gold labels (json)  

### Binary Classification: Misogynous vs. non-misogynous

In [None]:
def run_bert_baseline_bin(dataset_name, training_df, test_df, binary_label, label_names):
    """
    Run BERT baseline model and save results in the same format as ablation experiments.

    Parameters:
    -----------
    dataset_name : str
        Dataset name ('MAMI' or 'EXIST')
    training_df : DataFrame
        Training data
    test_df : DataFrame
        Test data
    binary_label : str
        Binary label column name
    label_names : list
        List of label names (e.g., ["non-misogynous", "misogynous"])

    Returns:
    --------
    dict : BERT baseline results dictionary
    """

    print(f"\n🤖 RUNNING BERT BASELINE FOR {dataset_name}")
    print("=" * 50)

    # Data preparation
    X_train = training_df["bert representation"].tolist()
    X_test = test_df["bert representation"].tolist()
    y_train = training_df[binary_label].tolist()
    y_test = test_df[binary_label].tolist()

    # Set model name
    model_name = "bert_baseline"

    # Initialize tokenizer and device
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create datasets
    train_dataset = MemeDataset(
        texts=X_train,
        labels=y_train,
        tokenizer=tokenizer
    )

    test_dataset = MemeDataset(
        texts=X_test,
        labels=y_test,
        tokenizer=tokenizer
    )

    # Create data loaders
    batch_size = 16
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size
    )

    # Initialize model, optimizer, and scheduler
    bert_model = BertClassifier(n_classes=2).to(device)
    optimizer = AdamW(bert_model.parameters(), lr=2e-5)
    total_steps = len(train_loader) * 3  # epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Train the model
    print("\n🚀 Training BERT binary classifier...")
    epochs = 3
    best_accuracy = 0

    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        train_acc, train_loss = train_epoch_bin(
            bert_model,
            train_loader,
            optimizer,
            scheduler,
            device
        )

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

        val_acc, val_loss, _, _ = eval_model_bin(
            bert_model,
            test_loader,
            device
        )

        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(bert_model.state_dict(), f'best_model_{model_name}_{dataset_name.lower()}.pt')
            print('✅ Best model saved!')

    # Load the best model
    bert_model.load_state_dict(torch.load(f'best_model_{model_name}_{dataset_name.lower()}.pt'))

    # Make predictions on test set
    print("\n📊 Evaluating model on test set...")
    _, _, all_preds, all_labels = eval_model_bin(
        bert_model,
        test_loader,
        device
    )

    # Set evaluation parameters
    evaluation_type = 'binary'

    # Set gold standard file paths
    if dataset_name == "MAMI":
        gold_test_bin = route + '' # Path to MAMI binary classification gold labels (json) 
        gold_test_txt = route + '' # Path to MAMI text gold label (txt)
    else:  # EXIST
        gold_test_bin = route + '' # Path to EXIST2024 binary classification gold labels (json) 
        gold_test_txt = route + ''# Path to EXIST2024 text gold label (txt)

    # Save prediction results
    test_pred_json, test_pred_txt = save_evaluation(
        test_df, "evaluation/predictions", dataset_name, "test",
        evaluation_type, model_name, np.array(all_preds), binary_label, []
    )

    # Calculate all metrics
    accuracy = accuracy_score(y_test, all_preds)
    precision_macro = precision_score(y_test, all_preds, average='macro')
    recall_macro = recall_score(y_test, all_preds, average='macro')
    f1_macro = f1_score(y_test, all_preds, average='macro')

    # Get detailed classification report
    class_report_dict = classification_report(
        y_test, all_preds,
        target_names=label_names,
        zero_division=0,
        digits=3,
        output_dict=True
    )

    # Calculate binary F1 score (MAMI evaluation metric)
    binary_f1 = evaluate_f1_scores(gold_test_txt, test_pred_txt, 2)

    # Save baseline results to JSON file
    baseline_results = {
        'binary_f1': binary_f1,
        'macro_f1': f1_macro,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'per_label_metrics': class_report_dict,
        'model_config': {
            'model_type': 'BERT',
            'model_name': 'bert-base-uncased',
            'epochs': epochs,
            'batch_size': batch_size,
            'learning_rate': 2e-5,
            'max_length': 128
        }
    }

    # Create results directory
    os.makedirs("evaluation/results/binary/BERT", exist_ok=True)

    # Save baseline results with updated filename format
    baseline_results_file = f"evaluation/results/binary/BERT/{model_name}_{dataset_name}_bin_baseline_results.json"

    with open(baseline_results_file, 'w') as f:
        json.dump(baseline_results, f, indent=2, cls=NumpyEncoder)

    print(f"✅ Baseline results saved to: {baseline_results_file}")

    # Display evaluation results (keep original functionality)
    evaluate_binary_classification(
        gold_test_bin, test_pred_json,
        y_test, np.array(all_preds),
        gold_test_txt, test_pred_txt,
        label_names,
        model_name="BERT baseline"
    )

    # Print saved metrics (for verification)
    print(f"\n📊 SAVED BERT BASELINE METRICS:")
    print(f"   • Binary F1: {binary_f1:.3f}")
    print(f"   • Macro F1: {f1_macro:.3f}")
    print(f"   • Accuracy: {accuracy:.3f}")
    print(f"   • Precision (macro): {precision_macro:.3f}")
    print(f"   • Recall (macro): {recall_macro:.3f}")

    # Save model state dict as well
    model_save_dir = f"models/bert_baseline/{dataset_name.lower()}"
    os.makedirs(model_save_dir, exist_ok=True)

    model_save_path = f"{model_save_dir}/{model_name}_final.pt"
    torch.save(bert_model.state_dict(), model_save_path)
    print(f"✅ Model saved to: {model_save_path}")

    return baseline_results

# Run MAMI BERT baseline
mami_bert_results = run_bert_baseline_bin(
    dataset_name="MAMI",
    training_df=mami_training_df,
    test_df=mami_test_df,
    binary_label="misogynous",
    label_names=["non-misogynous", "misogynous"]
)

## Multi-label Classification

In [None]:
def run_bert_baseline_ml(dataset_name, training_df, test_df, binary_model,
                                  binary_label, fine_grained_labels, tokenizer, device,
                                  batch_size=16, route=""):
    """
    Run BERT multi-label (hierarchical) model and save results.
    """

    print(f"\n🏷️ RUNNING BERT MULTI-LABEL FOR {dataset_name}")
    print("=" * 50)

    # Set evaluation parameters
    evaluation_type = "hierarchical"
    model_name = "bert_baseline"
    label_names = [binary_label] + fine_grained_labels

    print(f"Fine-grained labels: {fine_grained_labels}")
    print(f"All label names: {label_names}")

    # Filter data for positive binary instances (only train on positive examples for multi-label)
    train_df_bin_positive = training_df.loc[training_df[binary_label] == 1]
    X_train_bin_positive = train_df_bin_positive["bert representation"].tolist()
    y_train_categories = train_df_bin_positive[fine_grained_labels].values.tolist()

    print(f"Total training samples: {len(training_df)}")
    print(f"Positive binary samples for multi-label training: {len(train_df_bin_positive)}")

    # Prepare test data (all samples)
    X_test = test_df["bert representation"].tolist()
    y_test = test_df[binary_label].tolist()

    # Create datasets for multi-label classification
    ml_train_dataset = MemeMultiLabelDataset(
        texts=X_train_bin_positive,
        labels=y_train_categories,
        tokenizer=tokenizer
    )

    test_dataset = MemeDataset(
        texts=X_test,
        labels=y_test,
        tokenizer=tokenizer
    )

    # Create data loaders
    ml_train_loader = DataLoader(
        ml_train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size
    )

    # Initialize multi-label model
    ml_model = BertMultiLabelClassifier(n_classes=len(fine_grained_labels)).to(device)
    ml_optimizer = AdamW(ml_model.parameters(), lr=2e-5)
    ml_total_steps = len(ml_train_loader) * 3  # epochs
    ml_scheduler = get_linear_schedule_with_warmup(
        ml_optimizer,
        num_warmup_steps=0,
        num_training_steps=ml_total_steps
    )

    # Train multi-label model
    print("\n🚀 Training BERT multi-label classifier...")
    ml_epochs = 3

    for epoch in range(ml_epochs):
        print(f'Multi-label Epoch {epoch + 1}/{ml_epochs}')

        train_loss = train_epoch_multilabel(
            ml_model,
            ml_train_loader,
            ml_optimizer,
            ml_scheduler,
            device
        )

        print(f'Multi-label Train Loss: {train_loss:.4f}')

    # Save multi-label model
    torch.save(ml_model.state_dict(), f'best_model_bert_multilabel_{dataset_name.lower()}.pt')
    print('✅ Multi-label model saved!')

    # Make hierarchical predictions
    print("\n🔮 Making hierarchical predictions...")
    hierarchical_predictions = predict_hierarchical(
        binary_model,
        ml_model,
        test_loader,
        tokenizer,
        device,
        fine_grained_labels,
        test_texts=X_test,
        threshold=0.5,
        batch_size=batch_size
    )

    # Convert predictions to a DataFrame for evaluation
    test_pred_df = pd.DataFrame(
        hierarchical_predictions,
        columns=[binary_label] + fine_grained_labels
    )

    print(f"Hierarchical predictions shape: {hierarchical_predictions.shape}")
    print(f"Test predictions DataFrame shape: {test_pred_df.shape}")

    # Set gold standard file paths
    if dataset_name == "MAMI":
        gold_test_ml = route + '' # Path to MAMI multi-label classification gold labels (json)  
        gold_test_txt = route + '' # Path to MAMI text gold label (txt)
    else:  # EXIST
        gold_test_ml = route + '' # Path to EXIST2024 multi-label classification gold labels (json)  
        gold_test_txt = route + ''# Path to EXIST2024 text gold label (txt)

    # Save and evaluate hierarchical predictions
    test_pred_json_ml, test_pred_txt_ml = save_evaluation(
        test_df,
        "evaluation/predictions",
        dataset_name,
        "test",
        "hierarchical",
        model_name,
        test_pred_df,
        binary_label,
        label_names
    )

    # Calculate metrics for each label
    y_true_df = test_df[label_names]

    # Calculate macro F1 score for hierarchical evaluation
    hierarchical_scores = evaluate_f1_scores(gold_test_txt, test_pred_txt_ml, len(label_names))
    if isinstance(hierarchical_scores, tuple):
        binary_f1, multilabel_f1 = hierarchical_scores
    else:
        binary_f1 = hierarchical_scores
        multilabel_f1 = hierarchical_scores

    # Calculate individual label metrics
    binary_true = y_true_df[binary_label].values
    binary_pred = hierarchical_predictions[:, 0]

    # Create negative class representation
    negative_true = np.zeros((len(hierarchical_predictions), 1))
    negative_pred = np.zeros((len(hierarchical_predictions), 1))
    negative_true[:, 0] = (binary_true == 0)
    negative_pred[:, 0] = (binary_pred == 0)

    # Get multi-label data
    multilabel_true = y_true_df[fine_grained_labels].values
    multilabel_pred = hierarchical_predictions[:, 1:]

    # Create combined representation
    combined_true = np.hstack((negative_true, multilabel_true))
    combined_pred = np.hstack((negative_pred, multilabel_pred))

    # 更新标签名称
    updated_label_names = [f"non-{binary_label}"] + fine_grained_labels

    # 重新计算 individual metrics
    individual_metrics = {}
    for i, label in enumerate(updated_label_names):
        y_true_label = combined_true[:, i]
        y_pred_label = combined_pred[:, i]

        # Calculate metrics for this label
        accuracy = accuracy_score(y_true_label, y_pred_label)
        precision = precision_score(y_true_label, y_pred_label, zero_division=0)
        recall = recall_score(y_true_label, y_pred_label, zero_division=0)
        f1 = f1_score(y_true_label, y_pred_label, zero_division=0)

        individual_metrics[label] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

    # Calculate overall macro metrics
    all_f1_scores = [individual_metrics[label]['f1'] for label in updated_label_names]
    all_precision_scores = [individual_metrics[label]['precision'] for label in updated_label_names]
    all_recall_scores = [individual_metrics[label]['recall'] for label in updated_label_names]

    macro_f1 = np.mean(all_f1_scores)
    macro_precision = np.mean(all_precision_scores)
    macro_recall = np.mean(all_recall_scores)

    # Save hierarchical results to JSON file
    hierarchical_results = {
        'binary_f1': binary_f1,
        'hierarchical_f1': multilabel_f1,
        'macro_f1': macro_f1,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'individual_label_metrics': individual_metrics,
        'fine_grained_labels': fine_grained_labels,
        'all_labels': label_names,
        'model_config': {
            'model_type': 'BERT',
            'model_name': 'bert-base-uncased',
            'epochs': ml_epochs,
            'batch_size': batch_size,
            'learning_rate': 2e-5,
            'max_length': 128,
            'num_classes': len(fine_grained_labels),
            'hierarchical': True
        }
    }

    # Create results directory
    os.makedirs("evaluation/results/multi-label/BERT", exist_ok=True)

    # Save hierarchical results
    hierarchical_results_file = f"evaluation/results/multi-label/BERT/{model_name}_{dataset_name}_hierarchical_results.json"

    with open(hierarchical_results_file, 'w') as f:
        json.dump(hierarchical_results, f, indent=2, cls=NumpyEncoder)

    print(f"✅ Hierarchical results saved to: {hierarchical_results_file}")

    # Display evaluation results (keep original functionality)
    print("\n📊 Multi-label classification evaluation:")
    evaluate_multilabel_classification(
        gold_test_ml,
        test_pred_json_ml,
        y_true=test_df[label_names],
        y_pred=hierarchical_predictions,
        gold_labels_txt=gold_test_txt,
        predictions_txt=test_pred_txt_ml,
        label_names=label_names
    )

    # Print saved metrics (for verification)
    print(f"\n📊 SAVED BERT HIERARCHICAL METRICS:")
    print(f"   • Hierarchical F1: {multilabel_f1:.3f}")
    print(f"   • Macro F1: {macro_f1:.3f}")
    print(f"   • Macro Precision: {macro_precision:.3f}")
    print(f"   • Macro Recall: {macro_recall:.3f}")

    # Print individual label metrics
    print(f"\n📋 INDIVIDUAL LABEL METRICS:")
    for label, metrics in individual_metrics.items():
        print(f"   • {label}: F1={metrics['f1']:.3f}, P={metrics['precision']:.3f}, R={metrics['recall']:.3f}")

    # Save model state dict as well
    model_save_dir = f"models/bert_baseline/{dataset_name.lower()}"
    os.makedirs(model_save_dir, exist_ok=True)

    model_save_path = f"{model_save_dir}/{model_name}_multilabel_final.pt"
    torch.save(ml_model.state_dict(), model_save_path)
    print(f"✅ Multi-label model saved to: {model_save_path}")

    return hierarchical_results

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load binary model
mami_bin_model = BertClassifier(n_classes=2).to(device)
mami_bin_model.load_state_dict(torch.load('best_model_bert_baseline_mami.pt'))

# Run multi-label classification
mami_fine_grained_labels = ["shaming", "stereotype", "objectification", "violence"]

mami_bert_multilabel_results = run_bert_baseline_ml(
    dataset_name="MAMI",
    training_df=mami_training_df,
    test_df=mami_test_df,
    binary_model=mami_bin_model,
    binary_label="misogynous",
    fine_grained_labels=mami_fine_grained_labels,
    tokenizer=tokenizer,
    device=device,
    batch_size=16,
    route=route
)

# EXIST

In [None]:
# Implement EXIST dataset classification
print("\n\nLoading EXIST dataset...")
exist_training_data = route + '' # Path to training data
exist_dev_data = route + '' # Path to development data
exist_test_data = route + '' # Path to test data

exist_training_df = pd.read_json(exist_training_data, orient='index')
exist_dev_df = pd.read_json(exist_dev_data, orient='index')
exist_test_df = pd.read_json(exist_test_data, orient='index')

exist_training_df = pd.concat([exist_training_df, exist_dev_df]).sort_index()

# Check available columns
print("Training data columns:", exist_training_df.columns)
print("Test data columns:", exist_test_df.columns)

# Set up PyEvaLL evaluation gold files
gold_test_bin = route + '' # Path to EXIST2024 binary classification gold labels (json) 
gold_test_txt = route + '' # Path to EXIST2024 text gold label (txt)
gold_test_ml = route + '' # Path to EXIST2024 multi-label classification gold labels (json)  

## Binary Classification: Sexist vs. non-sexist

In [None]:
# Run EXIST BERT baseline
exist_bert_results = run_bert_baseline_bin(
    dataset_name="EXIST",
    training_df=exist_training_df,
    test_df=exist_test_df,
    binary_label="sexist",
    label_names=["non-sexist", "sexist"]
)

## Multi-label Classification

In [None]:


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

exist_fine_grained_labels = ["ideological-inequality", "stereotyping-dominance",
                            "objectification", "sexual-violence", "misogyny-non-sexual-violence"]

# Load binary model
exist_bin_model = BertClassifier(n_classes=2).to(device)
exist_bin_model.load_state_dict(torch.load('best_model_bert_baseline_exist.pt'))

# Run multi-label classification
exist_bert_multilabel_results = run_bert_baseline_ml(
    dataset_name="EXIST",
    training_df=exist_training_df,
    test_df=exist_test_df,
    binary_model=exist_bin_model,
    binary_label="sexist",
    fine_grained_labels=exist_fine_grained_labels,
    tokenizer=tokenizer,
    device=device,
    batch_size=16,
    route=route
)

# **Ablation Study**

## Coarse-grained ablation

### Binary

In [None]:
# clean text for coarse-grained ablation and POS ablation
def clean_text_sep_token(texts):
    """
    Clean texts for consistent feature extraction across SVM and BERT experiments.
    Removes BERT-specific tokens and standardizes text format.

    Parameters:
    -----------
    texts : list
        List of texts to clean

    Returns:
    --------
    list
        List of cleaned texts
    """
    cleaned = []
    for text in texts:
        # remove [SEP] token
        clean_text = text.replace('[SEP]', ' ')
        clean_text = ' '.join(clean_text.split())
        cleaned.append(clean_text)
    return cleaned

def run_coarse_binary_bert_ablation_experiment(dataset_df_train, dataset_df_test, feature_type,
                                               ablation_method, dataset_name, binary_label='misogynous'):
    """
    Run a coarse-grained ablation experiment with BERT for binary classification.

    Parameters:
    -----------
    dataset_df_train : DataFrame
        Training dataset
    dataset_df_test : DataFrame
        Test dataset
    feature_type : str
        Type of feature:
        - 'sentiment_pos': only positive sentiment words
        - 'sentiment_neg': only negative sentiment words
        - 'hate': all hate speech terms as a group
        - 'function': all function words as a group
    ablation_method : str
        'mask' or 'remove'
    dataset_name : str
        Name of dataset (e.g., 'MAMI', 'EXIST2024')
    binary_label : str
        Column for binary classification ('misogynous' or 'sexist')

    Returns:
    --------
    dict: Results of the ablation experiment
    """
    # Download required NLTK data
    try:
        nltk.data.find('tokenizers/punkt_tab')
    except LookupError:
        print("Downloading required NLTK data...")
        nltk.download('punkt_tab')
        nltk.download('punkt')

    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')

    # Load necessary lexicons
    # Note: Lexicon files are not included in this repository. 
    # Please download them from their respective sources before running this code.
    print(f"Loading lexicons for {feature_type} ablation...")
    nrc_lexicon = load_nrc_lexicon(route + "NRC-Emotion-Lexicon/NRC-Emotion-Lexicon-Wordlevel-v0.92.txt")
    function_words_dict = load_function_words()
    hurtlex_dict = load_hurtlex(route + "hurtlex-master/lexica/EN/1.2/hurtlex_EN.tsv")

    # Get original texts and labels
    X_train_orig_raw = dataset_df_train["bert representation"].tolist()
    X_test_orig_raw = dataset_df_test["bert representation"].tolist()
    y_train = dataset_df_train[binary_label].tolist()
    y_test = dataset_df_test[binary_label].tolist()

    # 🔥 KEY CHANGE: Clean texts for consistent feature extraction
    print("🧹 Cleaning texts for feature extraction (removing BERT tokens)...")
    X_train_orig = clean_text_sep_token(X_train_orig_raw)
    X_test_orig = clean_text_sep_token(X_test_orig_raw)

    # Prepare containers for ablated data
    X_train_feature_ablated = []
    X_test_feature_ablated = []

    print(f"Running coarse-grained BERT ablation experiment for {feature_type}...")

    # Training set statistics
    train_total_feature_words_processed = 0
    train_documents_with_features = 0
    train_feature_words_per_doc = []

    # Test set statistics
    test_total_feature_words_processed = 0
    test_documents_with_features = 0
    test_feature_words_per_doc = []

    # Process training data
    print("Processing training data...")
    for text in X_train_orig:
        # Extract feature words based on feature type
        feature_words = []
        if feature_type == 'sentiment_pos':
            feature_words = extract_sentiment_words(text, nrc_lexicon, 'positive')
        elif feature_type == 'sentiment_neg':
            feature_words = extract_sentiment_words(text, nrc_lexicon, 'negative')
        elif feature_type == 'hate':
            hate_words = []
            for word in text.lower().split():
                if word in hurtlex_dict:
                    hate_words.append(word)
            feature_words = list(set(hate_words))
        elif feature_type == 'function':
            function_words = []
            for category in function_words_dict.keys():
                category_words = extract_function_words(text, function_words_dict, category)
                function_words.extend(category_words)
            feature_words = list(set(function_words))

        # Convert to set for faster lookup
        feature_words_set = set(feature_words)

        # Count feature words in this text
        num_feature_words = len([word for word in text.lower().split() if word in feature_words_set])
        train_feature_words_per_doc.append(num_feature_words)

        if num_feature_words > 0:
            train_documents_with_features += 1

        # Apply feature ablation
        if ablation_method == 'mask':
            feature_ablated_text, count = replace_mask_features(text, feature_words, mask_token='[MASK]')
        else:  # remove
            feature_ablated_text, count = remove_features(text, feature_words)

        X_train_feature_ablated.append(feature_ablated_text)
        train_total_feature_words_processed += count

    # Process test data
    print("Processing test data...")
    for text in X_test_orig:
        # Extract feature words based on feature type
        feature_words = []
        if feature_type == 'sentiment_pos':
            feature_words = extract_sentiment_words(text, nrc_lexicon, 'positive')
        elif feature_type == 'sentiment_neg':
            feature_words = extract_sentiment_words(text, nrc_lexicon, 'negative')
        elif feature_type == 'hate':
            hate_words = []
            for word in text.lower().split():
                if word in hurtlex_dict:
                    hate_words.append(word)
            feature_words = list(set(hate_words))
        elif feature_type == 'function':
            function_words = []
            for category in function_words_dict.keys():
                category_words = extract_function_words(text, function_words_dict, category)
                function_words.extend(category_words)
            feature_words = list(set(function_words))

        # Convert to set for faster lookup
        feature_words_set = set(feature_words)

        # Count feature words in this text
        num_feature_words = len([word for word in text.lower().split() if word in feature_words_set])
        test_feature_words_per_doc.append(num_feature_words)

        if num_feature_words > 0:
            test_documents_with_features += 1

        # Apply feature ablation
        if ablation_method == 'mask':
            feature_ablated_text, count = replace_mask_features(text, feature_words, mask_token='[MASK]')
        else:  # remove
            feature_ablated_text, count = remove_features(text, feature_words)

        X_test_feature_ablated.append(feature_ablated_text)
        test_total_feature_words_processed += count

    # Summary statistics
    print(f"\n{'=' * 50}")
    print(f"Processing Summary:")
    print(f"Training set:")
    print(f"  Total documents: {len(X_train_feature_ablated)}")
    print(f"  Documents with features: {train_documents_with_features} ({train_documents_with_features/len(X_train_feature_ablated)*100:.2f}%)")
    print(f"  Total feature words processed: {train_total_feature_words_processed}")
    print(f"  Average feature words per document: {np.mean(train_feature_words_per_doc):.2f}")
    print(f"Test set:")
    print(f"  Total documents: {len(X_test_feature_ablated)}")
    print(f"  Documents with features: {test_documents_with_features} ({test_documents_with_features/len(X_test_feature_ablated)*100:.2f}%)")
    print(f"  Total feature words processed: {test_total_feature_words_processed}")
    print(f"  Average feature words per document: {np.mean(test_feature_words_per_doc):.2f}")
    print(f"{'=' * 50}")

    # Set up model name
    model_name = f"bert_{feature_type}_{ablation_method}_{dataset_name.lower()}"

    # Initialize BERT tokenizer and device
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create datasets
    print("Creating BERT datasets...")
    train_dataset = MemeDataset(
        texts=X_train_feature_ablated,
        labels=y_train,
        tokenizer=tokenizer,
        max_len=128
    )

    test_dataset = MemeDataset(
        texts=X_test_feature_ablated,
        labels=y_test,
        tokenizer=tokenizer,
        max_len=128
    )

    # Create data loaders
    batch_size = 16
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size
    )

    # Initialize BERT model
    print("Initializing BERT model...")
    bert_model = BertClassifier(n_classes=2).to(device)
    optimizer = AdamW(bert_model.parameters(), lr=2e-5)

    epochs = 3
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Train the model
    print(f"Training BERT model for {epochs} epochs...")
    best_accuracy = 0

    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        train_acc, train_loss = train_epoch_bin(
            bert_model,
            train_loader,
            optimizer,
            scheduler,
            device
        )

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

        val_acc, val_loss, _, _ = eval_model_bin(
            bert_model,
            test_loader,
            device
        )

        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        if val_acc > best_accuracy:
            best_accuracy = val_acc
            torch.save(bert_model.state_dict(), f'best_model_{model_name}.pt')
            print('Best model saved!')

    # Load the best model and make final predictions
    print("Loading best model and making final predictions...")
    bert_model.load_state_dict(torch.load(f'best_model_{model_name}.pt'))
    _, _, y_pred_feature, _ = eval_model_bin(bert_model, test_loader, device)

    # Convert predictions to numpy array
    y_pred_feature = np.array(y_pred_feature)

    # Set up gold standard file paths based on dataset
    if dataset_name == "MAMI":
        gold_test_bin = route + '' # Path to MAMI binary classification gold labels (json) 
        gold_test_txt = route + '' # Path to MAMI text gold label (txt)
        label_names = ["non-misogynous", "misogynous"]
    else:  # EXIST2024
        gold_test_bin = route + '' # Path to EXIST2024 binary classification gold labels (json) 
        gold_test_txt = route + '' # Path to EXIST2024 text gold label (txt)
        label_names = ["non-sexist", "sexist"]

    # Create files with predictions
    print(f"\n{'='*20} BERT Feature Ablation Results {'='*20}")
    test_pred_json_feature, test_pred_txt_feature = save_evaluation(
        dataset_df_test, "evaluation/predictions", dataset_name, "test", "binary",
        model_name, y_pred_feature, binary_label, []
    )

    # Evaluate the model
    feature_metrics = evaluate_binary_classification(
        gold_test_bin, test_pred_json_feature, y_test, y_pred_feature,
        gold_test_txt, test_pred_txt_feature, label_names, model_name=model_name
    )

    # Calculate additional metrics
    accuracy = accuracy_score(y_test, y_pred_feature)
    precision_macro = precision_score(y_test, y_pred_feature, average='macro')
    recall_macro = recall_score(y_test, y_pred_feature, average='macro')
    f1_macro = f1_score(y_test, y_pred_feature, average='macro')

    # Get classification report as dictionary
    class_report_dict = classification_report(y_test, y_pred_feature,
                                            target_names=label_names,
                                            zero_division=0, digits=3,
                                            output_dict=True)

    # Calculate binary F1 score (MAMI evaluation metric)
    binary_f1 = evaluate_f1_scores(gold_test_txt, test_pred_txt_feature, 2)

    # Create structured results dictionary
    results = {
        'feature_type': feature_type,
        'ablation_method': ablation_method,
        'dataset_name': dataset_name,
        'binary_label': binary_label,
        'model_type': 'BERT',
        'feature_ablation': {
            'model_name': model_name,
            'binary_f1': binary_f1,
            'macro_f1': f1_macro,
            'accuracy': accuracy,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'per_label_metrics': class_report_dict,
            'prediction_files': {
                'json': test_pred_json_feature,
                'txt': test_pred_txt_feature
            }
        },
        'train_statistics': {
            'total_documents': len(X_train_feature_ablated),
            'documents_with_features': train_documents_with_features,
            'total_feature_words_processed': train_total_feature_words_processed,
            'avg_feature_words_per_doc': np.mean(train_feature_words_per_doc),
            'max_feature_words_per_doc': max(train_feature_words_per_doc) if train_feature_words_per_doc else 0,
            'feature_coverage_percentage': train_documents_with_features/len(X_train_feature_ablated)*100
        },
        'test_statistics': {
            'total_documents': len(X_test_feature_ablated),
            'documents_with_features': test_documents_with_features,
            'total_feature_words_processed': test_total_feature_words_processed,
            'avg_feature_words_per_doc': np.mean(test_feature_words_per_doc),
            'max_feature_words_per_doc': max(test_feature_words_per_doc) if test_feature_words_per_doc else 0,
            'feature_coverage_percentage': test_documents_with_features/len(X_test_feature_ablated)*100
        }
    }

    # Save results to JSON file for later analysis
    os.makedirs("evaluation/results/binary/BERT", exist_ok=True)
    results_file = f"evaluation/results/binary/BERT/{model_name}_bert_results.json"

    # Create the format for saving results
    save_results = {
        'binary_f1': binary_f1,
        'macro_f1': f1_macro,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'per_label_metrics': class_report_dict,
        'predictions': y_pred_feature.tolist(),  # Save actual predictions
        'true_labels': y_test,  # Save true labels for reference
        'processing_stats': results['train_statistics'],
        'test_stats': results['test_statistics']
    }

    with open(results_file, 'w') as f:
        json.dump(save_results, f, indent=2, cls=NumpyEncoder)

    print(f"✅ BERT ablation results saved to: {results_file}")

    return results


def run_bert_coarse_grained_experiments():
    """
    Run all coarse-grained BERT ablation experiments for both datasets (MAMI and EXIST2024).
    This focuses only on binary classification experiments with BERT.
    """
    print("=" * 80)
    print("RUNNING COARSE-GRAINED BERT ABLATION EXPERIMENTS")
    print("=" * 80)

    # Define feature types and ablation methods
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['mask', 'remove']

    # Storage for results
    all_results = {
        'MAMI': {},
        'EXIST2024': {}
    }

    # MAMI Dataset Experiments
    print("\n" + "=" * 50)
    print("MAMI DATASET BERT EXPERIMENTS")
    print("=" * 50)

    for feature_type in feature_types:
        all_results['MAMI'][feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\n{'='*20} MAMI BERT: {feature_type.capitalize()} with {ablation_method} {'='*20}")
            results = run_coarse_binary_bert_ablation_experiment(
                mami_training_df, mami_test_df, feature_type, ablation_method, "MAMI"
            )
            all_results['MAMI'][feature_type][ablation_method] = results

    # EXIST2024 Dataset Experiments
    print("\n" + "=" * 50)
    print("EXIST2024 DATASET BERT EXPERIMENTS")
    print("=" * 50)

    for feature_type in feature_types:
        all_results['EXIST2024'][feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\n{'='*20} EXIST2024 BERT: {feature_type.capitalize()} with {ablation_method} {'='*20}")
            results = run_coarse_binary_bert_ablation_experiment(
                exist_training_df, exist_test_df, feature_type, ablation_method,
                "EXIST2024", binary_label='sexist'
            )
            all_results['EXIST2024'][feature_type][ablation_method] = results

    # Save overall results summary
    print("\n" + "=" * 50)
    print("SAVING OVERALL RESULTS SUMMARY")
    print("=" * 50)

    os.makedirs("evaluation/results/binary/BERT", exist_ok=True)
    summary_file = "evaluation/results/binary/BERT/bert_coarse_ablation_summary.json"

    # Create summary with key metrics
    summary = {}
    for dataset in all_results:
        summary[dataset] = {}
        for feature_type in all_results[dataset]:
            summary[dataset][feature_type] = {}
            for ablation_method in all_results[dataset][feature_type]:
                result = all_results[dataset][feature_type][ablation_method]
                summary[dataset][feature_type][ablation_method] = {
                    'binary_f1': result['feature_ablation']['binary_f1'],
                    'macro_f1': result['feature_ablation']['macro_f1'],
                    'accuracy': result['feature_ablation']['accuracy'],
                    'feature_coverage_percentage': result['train_statistics']['feature_coverage_percentage'],
                    'avg_feature_words_per_doc': result['train_statistics']['avg_feature_words_per_doc']
                }

    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2, cls=NumpyEncoder)

    print(f"✅ Overall BERT ablation summary saved to: {summary_file}")

    return all_results


# Function to run individual experiments
def run_mami_bert_ablation():
    """
    Run coarse-grained BERT ablation experiments for MAMI dataset only.
    """
    results = {}
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['mask', 'remove']

    for feature_type in feature_types:
        results[feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\nRunning BERT ablation for {feature_type} with {ablation_method} on MAMI...")
            result = run_coarse_binary_bert_ablation_experiment(
                mami_training_df, mami_test_df, feature_type, ablation_method, 'MAMI', 'misogynous'
            )
            results[feature_type][ablation_method] = result

    return results


def run_exist_bert_ablation():
    """
    Run coarse-grained BERT ablation experiments for EXIST2024 dataset only.
    """
    results = {}
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['mask', 'remove']

    for feature_type in feature_types:
        results[feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\nRunning BERT ablation for {feature_type} with {ablation_method} on EXIST2024...")
            result = run_coarse_binary_bert_ablation_experiment(
                exist_training_df, exist_test_df, feature_type, ablation_method, 'EXIST2024', 'sexist'
            )
            results[feature_type][ablation_method] = result

    return results

In [None]:
# Run all BERT ablation experiments
all_results = run_bert_coarse_grained_experiments()

### Multi-label


In [None]:
def run_bert_coarse_grained_multilabel_ablation_experiments(dataset_df_train, dataset_df_test, feature_type,
                                                         ablation_method, dataset_name, binary_label, fine_grained_labels):
    """
    Run BERT multilabel ablation experiment for coarse-grained feature categories.

    Parameters:
    -----------
    dataset_df_train : DataFrame
        Training dataset
    dataset_df_test : DataFrame
        Test dataset
    feature_type : str
        Type of feature to ablate:
        - 'sentiment_pos': Positive sentiment words
        - 'sentiment_neg': Negative sentiment words
        - 'hate': All hate speech terms
        - 'function': All function words
    ablation_method : str
        Method for ablation ('mask' or 'remove')
    dataset_name : str
        Name of the dataset ('MAMI' or 'EXIST2024')
    binary_label : str
        Name of the binary label column
    fine_grained_labels : list
        List of fine-grained category label names
    """

    print(f"\n🤖 RUNNING BERT COARSE-GRAINED MULTILABEL ABLATION")
    print(f"Dataset: {dataset_name}")
    print(f"Feature: {feature_type}")
    print(f"Method: {ablation_method}")
    print("=" * 60)

    # Download required NLTK data if needed
    try:
        nltk.data.find('tokenizers/punkt_tab')
    except LookupError:
        print("Downloading required NLTK data...")
        nltk.download('punkt_tab')
        nltk.download('punkt')

    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')

    # Load necessary lexicons
    print("Loading lexicons...")
    nrc_lexicon = load_nrc_lexicon(route + "NRC-Emotion-Lexicon/NRC-Emotion-Lexicon-Wordlevel-v0.92.txt")
    function_words_dict = load_function_words()
    hurtlex_dict = load_hurtlex(route + "hurtlex-master/lexica/EN/1.2/hurtlex_EN.tsv")

    # Get original texts and labels (with text cleaning)
    X_train_orig_raw = dataset_df_train["bert representation"].tolist()
    X_test_orig_raw = dataset_df_test["bert representation"].tolist()

    # Clean texts for consistent feature extraction
    print("🧹 Cleaning texts for feature extraction (removing BERT tokens)...")
    X_train_orig = clean_text_sep_token(X_train_orig_raw)
    X_test_orig = clean_text_sep_token(X_test_orig_raw)

    # Get labels
    y_train_binary = dataset_df_train[binary_label].tolist()
    train_df_bin_positive = dataset_df_train.loc[dataset_df_train[binary_label] == 1]
    X_train_bin_positive_raw = train_df_bin_positive["bert representation"].tolist()
    X_train_bin_positive = clean_text_sep_token(X_train_bin_positive_raw)
    y_train_categories = train_df_bin_positive[fine_grained_labels].values.tolist()

    # Prepare complete label list (binary + fine-grained)
    all_labels = [binary_label] + fine_grained_labels
    y_test_all = dataset_df_test[all_labels]

    # Create display name and model name for feature type
    display_name = f"Coarse-Grained: {feature_type.replace('_', ' ').title()}"
    feature_name_for_model = feature_type

    print(f"Running BERT multilabel coarse-grained ablation for {display_name} with {ablation_method}...")

    # Statistics tracking
    total_feature_words_processed = 0
    documents_with_features = 0
    feature_words_per_doc = []

    # Define function to extract coarse-grained features
    def extract_coarse_grained_features(text):
        """Helper function to extract coarse-grained feature words"""
        if feature_type == 'sentiment_pos':
            return extract_sentiment_words(text, nrc_lexicon, 'positive')
        elif feature_type == 'sentiment_neg':
            return extract_sentiment_words(text, nrc_lexicon, 'negative')
        elif feature_type == 'hate':
            hate_words = []
            for word in text.lower().split():
                if word in hurtlex_dict:
                    hate_words.append(word)
            return list(set(hate_words))
        elif feature_type == 'function':
            function_words = []
            for category in function_words_dict.keys():
                category_words = extract_function_words(text, function_words_dict, category)
                function_words.extend(category_words)
            return list(set(function_words))
        else:
            print(f"Warning: Unknown coarse-grained feature type '{feature_type}'. No features will be ablated.")
            return []

    # Process training data
    X_train_ablated = []
    X_train_bin_pos_ablated = []
    X_test_ablated = []

    print("Processing training data...")
    for text in X_train_orig:
        feature_words = extract_coarse_grained_features(text)
        feature_words_set = set(feature_words)

        # Count feature words in this text
        num_feature_words = len([word for word in text.lower().split() if word in feature_words_set])
        feature_words_per_doc.append(num_feature_words)

        if num_feature_words > 0:
            documents_with_features += 1

        # Apply feature ablation
        if ablation_method == 'mask':
            ablated_text, count = replace_mask_features(text, feature_words, mask_token='[MASK]')
        elif ablation_method == 'remove':
            ablated_text, count = remove_features(text, feature_words)

        X_train_ablated.append(ablated_text)
        total_feature_words_processed += count

    # Process positive-only training data
    print("Processing positive training data...")
    for text in X_train_bin_positive:
        feature_words = extract_coarse_grained_features(text)

        if ablation_method == 'mask':
            ablated_text, _ = replace_mask_features(text, feature_words, mask_token='[MASK]')
        elif ablation_method == 'remove':
            ablated_text, _ = remove_features(text, feature_words)

        X_train_bin_pos_ablated.append(ablated_text)

    # Process test data
    print("Processing test data...")
    for text in X_test_orig:
        feature_words = extract_coarse_grained_features(text)

        if ablation_method == 'mask':
            ablated_text, _ = replace_mask_features(text, feature_words, mask_token='[MASK]')
        elif ablation_method == 'remove':
            ablated_text, _ = remove_features(text, feature_words)

        X_test_ablated.append(ablated_text)

    # Summary statistics
    print(f"\n{'=' * 50}")
    print(f"Processing Summary:")
    print(f"Total documents: {len(X_train_ablated)}")
    print(f"Documents with features: {documents_with_features} ({documents_with_features/len(X_train_ablated)*100:.2f}%)")
    print(f"Total feature words processed: {total_feature_words_processed}")
    print(f"Average feature words per document: {np.mean(feature_words_per_doc):.2f}")
    print(f"Max words per document: {max(feature_words_per_doc) if feature_words_per_doc else 0}")
    print(f"{'=' * 50}")

    # Create model name
    model_name = f"{dataset_name}_bert_ablation_{feature_name_for_model}_{ablation_method}_multilabel"

    try:
        # Initialize BERT tokenizer and device
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # STEP 1: Train binary classifier on ablated data
        print(f"\n🚀 Training BERT binary classifier on ablated data...")

        # Create binary datasets
        train_dataset_bin = MemeDataset(
            texts=X_train_ablated,
            labels=y_train_binary,
            tokenizer=tokenizer,
            max_len=128
        )

        test_dataset_bin = MemeDataset(
            texts=X_test_ablated,
            labels=[0] * len(X_test_ablated),  # Dummy labels for test
            tokenizer=tokenizer,
            max_len=128
        )

        # Create data loaders
        batch_size = 16
        train_loader_bin = DataLoader(
            train_dataset_bin,
            batch_size=batch_size,
            shuffle=True
        )

        test_loader_bin = DataLoader(
            test_dataset_bin,
            batch_size=batch_size
        )

        # Train binary model
        binary_model = BertClassifier(n_classes=2).to(device)
        optimizer_bin = AdamW(binary_model.parameters(), lr=2e-5)

        epochs = 3
        total_steps = len(train_loader_bin) * epochs
        scheduler_bin = get_linear_schedule_with_warmup(
            optimizer_bin,
            num_warmup_steps=0,
            num_training_steps=total_steps
        )

        best_accuracy = 0
        for epoch in range(epochs):
            print(f'Binary Epoch {epoch + 1}/{epochs}')

            train_acc, train_loss = train_epoch_bin(
                binary_model,
                train_loader_bin,
                optimizer_bin,
                scheduler_bin,
                device
            )

            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

        # STEP 2: Train multi-label classifier on positive samples
        print(f"\n🏷️ Training BERT multi-label classifier...")

        # Create multi-label datasets
        ml_train_dataset = MemeMultiLabelDataset(
            texts=X_train_bin_pos_ablated,
            labels=y_train_categories,
            tokenizer=tokenizer,
            max_len=128
        )

        ml_train_loader = DataLoader(
            ml_train_dataset,
            batch_size=batch_size,
            shuffle=True
        )

        # Train multi-label model
        ml_model = BertMultiLabelClassifier(n_classes=len(fine_grained_labels)).to(device)
        ml_optimizer = AdamW(ml_model.parameters(), lr=2e-5)

        ml_total_steps = len(ml_train_loader) * epochs
        ml_scheduler = get_linear_schedule_with_warmup(
            ml_optimizer,
            num_warmup_steps=0,
            num_training_steps=ml_total_steps
        )

        for epoch in range(epochs):
            print(f'Multi-label Epoch {epoch + 1}/{epochs}')

            train_loss = train_epoch_multilabel(
                ml_model,
                ml_train_loader,
                ml_optimizer,
                ml_scheduler,
                device
            )

            print(f'Multi-label Train Loss: {train_loss:.4f}')

        # STEP 3: Make hierarchical predictions
        print(f"\n🔮 Making hierarchical predictions...")
        hierarchical_predictions = predict_hierarchical(
            binary_model,
            ml_model,
            test_loader_bin,
            tokenizer,
            device,
            fine_grained_labels,
            test_texts=X_test_ablated,
            threshold=0.5,
            batch_size=batch_size
        )

        # Convert predictions to DataFrame
        test_pred_df = pd.DataFrame(
            hierarchical_predictions,
            columns=all_labels
        )

        print(f"Hierarchical predictions shape: {hierarchical_predictions.shape}")
        print(f"Test predictions DataFrame shape: {test_pred_df.shape}")

        # Get gold file paths based on dataset
        if dataset_name == "MAMI":
            gold_test_ml = route + '' # Path to MAMI multi-label classification gold labels (json)  
            gold_test_txt = route + '' # Path to MAMI text gold label (txt)
        else:  # EXIST2024
            gold_test_ml = route + '' # Path to EXIST2024 multi-label classification gold labels (json)  
            gold_test_txt = route + ''# Path to EXIST2024 text gold label (txt)

        evaluation_type = "hierarchical"

        # Save evaluation results
        print(f"\nSaving evaluation results for: {model_name}")
        test_pred_json_ml, test_pred_txt_ml = save_evaluation(
            dataset_df_test, "evaluation/predictions", dataset_name, "test",
            evaluation_type, model_name, test_pred_df, binary_label, all_labels
        )

        # Evaluate the model
        print(f"\n{'='*20} Evaluation Results for {display_name} {'='*20}")
        evaluate_multilabel_classification(
            gold_test_ml, test_pred_json_ml,
            y_test_all, hierarchical_predictions,
            gold_test_txt, test_pred_txt_ml,
            all_labels, hierarchy=True
        )

        # Calculate detailed metrics for saving
        y_true_df = dataset_df_test[all_labels]

        # Calculate macro F1 score for hierarchical evaluation
        hierarchical_scores = evaluate_f1_scores(gold_test_txt, test_pred_txt_ml, len(all_labels))
        if isinstance(hierarchical_scores, tuple):
            binary_f1, multilabel_f1 = hierarchical_scores
        else:
            binary_f1 = hierarchical_scores
            multilabel_f1 = hierarchical_scores

        # Calculate individual label metrics
        binary_true = y_true_df[binary_label].values
        binary_pred = hierarchical_predictions[:, 0]

        # Create negative class representation
        negative_true = np.zeros((len(hierarchical_predictions), 1))
        negative_pred = np.zeros((len(hierarchical_predictions), 1))
        negative_true[:, 0] = (binary_true == 0)
        negative_pred[:, 0] = (binary_pred == 0)

        # Get multi-label data
        multilabel_true = y_true_df[fine_grained_labels].values
        multilabel_pred = hierarchical_predictions[:, 1:]

        # Create combined representation
        combined_true = np.hstack((negative_true, multilabel_true))
        combined_pred = np.hstack((negative_pred, multilabel_pred))

        # Updated label names
        updated_label_names = [f"non-{binary_label}"] + fine_grained_labels

        # Calculate individual metrics
        individual_metrics = {}
        for i, label in enumerate(updated_label_names):
            y_true_label = combined_true[:, i]
            y_pred_label = combined_pred[:, i]

            accuracy = accuracy_score(y_true_label, y_pred_label)
            precision = precision_score(y_true_label, y_pred_label, zero_division=0)
            recall = recall_score(y_true_label, y_pred_label, zero_division=0)
            f1 = f1_score(y_true_label, y_pred_label, zero_division=0)

            individual_metrics[label] = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

        # Calculate overall macro metrics
        all_f1_scores = [individual_metrics[label]['f1'] for label in updated_label_names]
        all_precision_scores = [individual_metrics[label]['precision'] for label in updated_label_names]
        all_recall_scores = [individual_metrics[label]['recall'] for label in updated_label_names]

        macro_f1 = np.mean(all_f1_scores)
        macro_precision = np.mean(all_precision_scores)
        macro_recall = np.mean(all_recall_scores)

        # Create comprehensive metrics dictionary
        metrics = {
            'binary_f1': binary_f1,
            'hierarchical_f1': multilabel_f1,
            'macro_f1': macro_f1,
            'macro_precision': macro_precision,
            'macro_recall': macro_recall,
            'individual_label_metrics': individual_metrics,
            'predictions': hierarchical_predictions.tolist(),
            'true_labels': y_test_all.values.tolist(),
            'prediction_files': {
                'json': test_pred_json_ml,
                'txt': test_pred_txt_ml
            },
            'all_labels': all_labels,
            'feature_type': feature_type,
            'ablation_method': ablation_method,
            'model_config': {
                'model_type': 'BERT',
                'model_name': 'bert-base-uncased',
                'epochs': epochs,
                'batch_size': batch_size,
                'learning_rate': 2e-5,
                'max_length': 128,
                'num_classes': len(fine_grained_labels),
                'hierarchical': True
            }
        }

        # Save results
        os.makedirs("evaluation/results/multi-label/BERT/coarse-grained", exist_ok=True)
        results_file = f"evaluation/results/multi-label/BERT/coarse-grained/{model_name}_results.json"

        with open(results_file, 'w') as f:
            json.dump(metrics, f, indent=2, cls=NumpyEncoder)

        print(f"✅ BERT multilabel results saved to: {results_file}")

        return {
            'model_name': model_name,
            'metrics': metrics,
            'ablation_stats': {
                'total_documents': len(X_train_ablated),
                'documents_with_features': documents_with_features,
                'percentage_with_features': documents_with_features/len(X_train_ablated)*100,
                'total_feature_words': total_feature_words_processed,
                'avg_feature_words': np.mean(feature_words_per_doc),
                'max_feature_words': max(feature_words_per_doc) if feature_words_per_doc else 0
            }
        }

    except Exception as e:
        print(f"❌ Error in BERT multilabel experiment: {e}")
        import traceback
        traceback.print_exc()
        print(f"Skipping BERT multilabel coarse-grained ablation for {feature_type} with {ablation_method}")
        return None


def run_all_bert_coarse_grained_multilabel_experiments():
    """
    Run all BERT coarse-grained multilabel ablation experiments for both datasets.
    """
    print("=" * 80)
    print("RUNNING ALL BERT COARSE-GRAINED MULTILABEL ABLATION EXPERIMENTS")
    print("=" * 80)

    # Define feature types and ablation methods
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['remove']

    # Storage for results
    all_results = {
        'MAMI': {},
        'EXIST2024': {}
    }

    # MAMI Dataset Experiments
    print("\n" + "=" * 50)
    print("MAMI DATASET BERT MULTILABEL EXPERIMENTS")
    print("=" * 50)

    mami_fine_grained_labels = ["shaming", "stereotype", "objectification", "violence"]

    for feature_type in feature_types:
        all_results['MAMI'][feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\n{'='*20} MAMI BERT Multilabel: {feature_type.capitalize()} with {ablation_method} {'='*20}")
            results = run_bert_coarse_grained_multilabel_ablation_experiments(
                mami_training_df, mami_test_df, feature_type, ablation_method,
                "MAMI", "misogynous", mami_fine_grained_labels
            )
            all_results['MAMI'][feature_type][ablation_method] = results

    # EXIST2024 Dataset Experiments
    print("\n" + "=" * 50)
    print("EXIST2024 DATASET BERT MULTILABEL EXPERIMENTS")
    print("=" * 50)

    exist_fine_grained_labels = ["ideological-inequality", "stereotyping-dominance",
                                "objectification", "sexual-violence", "misogyny-non-sexual-violence"]

    for feature_type in feature_types:
        all_results['EXIST2024'][feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\n{'='*20} EXIST2024 BERT Multilabel: {feature_type.capitalize()} with {ablation_method} {'='*20}")
            results = run_bert_coarse_grained_multilabel_ablation_experiments(
                exist_training_df, exist_test_df, feature_type, ablation_method,
                "EXIST2024", "sexist", exist_fine_grained_labels
            )
            all_results['EXIST2024'][feature_type][ablation_method] = results

    # Save overall results summary
    print("\n" + "=" * 50)
    print("SAVING OVERALL BERT MULTILABEL RESULTS SUMMARY")
    print("=" * 50)

    os.makedirs("evaluation/results/multi-label/BERT/coarse-grained", exist_ok=True)
    summary_file = "evaluation/results/multi-label/BERT/coarse-grained/bert_coarse_multilabel_summary.json"

    # Create summary with key metrics
    summary = {}
    for dataset in all_results:
        summary[dataset] = {}
        for feature_type in all_results[dataset]:
            summary[dataset][feature_type] = {}
            for ablation_method in all_results[dataset][feature_type]:
                result = all_results[dataset][feature_type][ablation_method]
                if result is not None:
                    summary[dataset][feature_type][ablation_method] = {
                        'hierarchical_f1': result['metrics']['hierarchical_f1'],
                        'macro_f1': result['metrics']['macro_f1'],
                        'binary_f1': result['metrics']['binary_f1'],
                        'feature_coverage_percentage': result['ablation_stats']['percentage_with_features'],
                        'avg_feature_words_per_doc': result['ablation_stats']['avg_feature_words']
                    }
                else:
                    summary[dataset][feature_type][ablation_method] = None

    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2, cls=NumpyEncoder)

    print(f"✅ Overall BERT multilabel summary saved to: {summary_file}")

    return all_results


def run_mami_bert_multilabel_ablation():
    """
    Run BERT coarse-grained multilabel ablation experiments for MAMI dataset only.
    """
    results = {}
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['remove']
    mami_fine_grained_labels = ["shaming", "stereotype", "objectification", "violence"]

    for feature_type in feature_types:
        results[feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\nRunning BERT multilabel ablation for {feature_type} with {ablation_method} on MAMI...")
            result = run_bert_coarse_grained_multilabel_ablation_experiments(
                mami_training_df, mami_test_df, feature_type, ablation_method,
                'MAMI', 'misogynous', mami_fine_grained_labels
            )
            results[feature_type][ablation_method] = result

    return results


def run_exist_bert_multilabel_ablation():
    """
    Run BERT coarse-grained multilabel ablation experiments for EXIST2024 dataset only.
    """
    results = {}
    feature_types = ['sentiment_pos', 'sentiment_neg', 'hate', 'function']
    ablation_methods = ['remove']
    exist_fine_grained_labels = ["ideological-inequality", "stereotyping-dominance",
                                "objectification", "sexual-violence", "misogyny-non-sexual-violence"]

    for feature_type in feature_types:
        results[feature_type] = {}
        for ablation_method in ablation_methods:
            print(f"\nRunning BERT multilabel ablation for {feature_type} with {ablation_method} on EXIST2024...")
            result = run_bert_coarse_grained_multilabel_ablation_experiments(
                exist_training_df, exist_test_df, feature_type, ablation_method,
                'EXIST2024', 'sexist', exist_fine_grained_labels
            )
            results[feature_type][ablation_method] = result

    return results

run_mami_bert_multilabel_ablation()
run_exist_bert_multilabel_ablation()

## POS tag ablation

In [None]:
def run_bert_pos_ablation_experiment(dataset_df_train, dataset_df_test, pos_category,
                                      dataset_name, binary_label='misogynous',
                                      train_pos_tags=None, test_pos_tags=None):
    """
    Run BERT POS ablation experiment for a specific POS category.
    NOW WITH PROPER TEXT CLEANING!
    """
    print(f"\n{'='*60}")
    print(f"BERT POS Ablation Experiment: {pos_category}")
    print(f"Dataset: {dataset_name}")
    print(f"{'='*60}")

    # Get original texts and labels
    X_train_orig_raw = dataset_df_train["bert representation"].tolist()
    X_test_orig_raw = dataset_df_test["bert representation"].tolist()
    y_train = dataset_df_train[binary_label].tolist()
    y_test = dataset_df_test[binary_label].tolist()

    # 🔥 KEY CHANGE: Clean texts for POS analysis
    print("🧹 Cleaning texts for POS analysis (removing [SEP] tokens)...")
    X_train_orig = clean_text_sep_token(X_train_orig_raw)
    X_test_orig = clean_text_sep_token(X_test_orig_raw)

    # Show cleaning effect
    print(f"Sample cleaning:")
    print(f"  Before: '{X_train_orig_raw[0]}'")
    print(f"  After:  '{X_train_orig[0]}'")

    # Create ablated datasets using precomputed POS tags
    print("\nCreating ablated training set...")
    X_train_ablated, train_stats = create_pos_ablated_dataset(
        X_train_orig, pos_category, train_pos_tags)

    print("\nCreating ablated test set...")
    X_test_ablated, test_stats = create_pos_ablated_dataset(
        X_test_orig, pos_category, test_pos_tags)

    # Print statistics
    print(f"\n{'-'*50}")
    print("ABLATION STATISTICS:")
    print(f"{'-'*50}")
    print(f"Training set:")
    print(f"  - Documents affected: {train_stats['documents_affected']}/{train_stats['total_documents']} ({train_stats['percentage_docs_affected']:.2f}%)")
    print(f"  - Total words removed: {train_stats['total_words_removed']}")
    print(f"  - Max words removed: {train_stats['max_words_removed']}")

    print(f"\nTest set:")
    print(f"  - Documents affected: {test_stats['documents_affected']}/{test_stats['total_documents']} ({test_stats['percentage_docs_affected']:.2f}%)")
    print(f"  - Total words removed: {test_stats['total_words_removed']}")
    print(f"  - Max words removed: {test_stats['max_words_removed']}")

    # Train BERT model on ablated data
    print(f"\n{'-'*50}")
    print("TRAINING BERT MODEL ON ABLATED DATA:")
    print(f"{'-'*50}")

    try:
        # Initialize BERT components
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        # Create BERT datasets with cleaned ablated texts
        train_dataset = MemeDataset(
            texts=X_train_ablated,
            labels=y_train,
            tokenizer=tokenizer,
            max_len=128
        )

        test_dataset = MemeDataset(
            texts=X_test_ablated,
            labels=y_test,
            tokenizer=tokenizer,
            max_len=128
        )

        # Create data loaders
        batch_size = 16
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size
        )

        # Initialize BERT model
        bert_model = BertClassifier(n_classes=2).to(device)
        optimizer = AdamW(bert_model.parameters(), lr=2e-5)

        epochs = 3
        total_steps = len(train_loader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=total_steps
        )

        # Train the model
        print(f"Training BERT model for {epochs} epochs...")
        best_accuracy = 0
        model_name = f"bert_pos_ablation_clean_{pos_category.lower()}_{dataset_name.lower()}"

        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}')

            train_acc, train_loss = train_epoch_bin(
                bert_model,
                train_loader,
                optimizer,
                scheduler,
                device
            )

            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

            val_acc, val_loss, _, _ = eval_model_bin(
                bert_model,
                test_loader,
                device
            )

            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

            if val_acc > best_accuracy:
                best_accuracy = val_acc
                torch.save(bert_model.state_dict(), f'best_model_{model_name}.pt')
                print('Best model saved!')

        # Load best model and make final predictions
        bert_model.load_state_dict(torch.load(f'best_model_{model_name}.pt'))
        _, _, y_pred, _ = eval_model_bin(bert_model, test_loader, device)
        y_pred = np.array(y_pred)

        # Set up evaluation parameters
        evaluation_type = "binary"

        # Set up gold standard file paths
        if dataset_name == "MAMI":
            gold_test_bin = route + '' # Path to MAMI binary classification gold labels (json) 
            gold_test_txt = route +'' # Path to MAMI text gold label (txt)
            label_names = ["non-misogynous", "misogynous"]
        else:  # EXIST2024
            gold_test_bin = route + '' # Path to EXIST2024 binary classification gold labels (json) 
            gold_test_txt = route + '' # Path to EXIST2024 text gold label (txt)
            label_names = ["non-sexist", "sexist"]

        # Save predictions and evaluate
        test_pred_json, test_pred_txt = save_evaluation(
            dataset_df_test, "evaluation/predictions", dataset_name, "test",
            evaluation_type, model_name, y_pred, binary_label, []
        )

        print(f"\n{'-'*50}")
        print("EVALUATION RESULTS:")
        print(f"{'-'*50}")

        # Use existing evaluation function
        evaluate_binary_classification(
            gold_test_bin, test_pred_json, y_test, y_pred,
            gold_test_txt, test_pred_txt, label_names,
            model_name=f"BERT POS Ablation ({pos_category}) - CLEANED"
        )

        # Calculate all required metrics
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix

        accuracy = accuracy_score(y_test, y_pred)
        precision_macro = precision_score(y_test, y_pred, average='macro')
        recall_macro = recall_score(y_test, y_pred, average='macro')
        f1_macro = f1_score(y_test, y_pred, average='macro')

        # Get classification report as dictionary
        class_report_dict = classification_report(y_test, y_pred,
                                                target_names=label_names,
                                                zero_division=0, digits=3,
                                                output_dict=True)

        # Calculate binary F1 score (MAMI evaluation metric)
        binary_f1 = evaluate_f1_scores(gold_test_txt, test_pred_txt, 2)

        # Create structured results dictionary
        pos_results = {
            'model_type': 'BERT_CLEANED',
            'pos_category': pos_category,
            'binary_f1': binary_f1,
            'macro_f1': f1_macro,
            'accuracy': accuracy,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'per_label_metrics': class_report_dict,
            'confusion_matrix': confusion_matrix(y_test, y_pred).tolist(),
            'label_names': label_names,
            'train_stats': train_stats,
            'test_stats': test_stats,
            'text_cleaned': True
        }

        # Save results to JSON file
        import os, json
        os.makedirs("evaluation/results/POS/BERT", exist_ok=True)
        results_file = f"evaluation/results/POS/BERT/{model_name}_results.json"

        with open(results_file, 'w') as f:
            json.dump(pos_results, f, indent=2, cls=NumpyEncoder)

        print(f"✅ BERT POS results (CLEANED) saved to: {results_file}")

        results = {
            'pos_category': pos_category,
            'model_name': model_name,
            'model_type': 'BERT_CLEANED',
            'accuracy': accuracy,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'binary_f1': binary_f1,
            'train_stats': train_stats,
            'test_stats': test_stats,
            'text_cleaned': True,
            'prediction_files': {
                'json': test_pred_json,
                'txt': test_pred_txt
            }
        }

        return results

    except Exception as e:
        print(f"Error in BERT POS ablation for {pos_category}: {e}")
        import traceback
        traceback.print_exc()
        return None


def run_complete_bert_pos_ablation_study(dataset_df_train, dataset_df_test, dataset_name,
                                         binary_label='misogynous', pos_tags=None):
    """
    Run complete BERT POS ablation study for all POS categories.
    NOW WITH PROPER TEXT CLEANING!
    """

    if pos_tags is None:
        pos_tags = UNIVERSAL_POS_TAGS.copy()

    print(f"\n{'='*80}")
    print(f"BERT POS ABLATION STUDY FOR {dataset_name} (WITH TEXT CLEANING)")
    print(f"Testing {len(pos_tags)} POS categories: {', '.join(pos_tags)}")
    print(f"{'='*80}")

    # Get original texts and labels
    X_train_raw = dataset_df_train["bert representation"].tolist()
    X_test_raw = dataset_df_test["bert representation"].tolist()
    y_train = dataset_df_train[binary_label].tolist()
    y_test = dataset_df_test[binary_label].tolist()

    # 🔥 KEY CHANGE: Clean texts for consistent POS analysis
    print("🧹 Cleaning texts for POS analysis...")
    X_train = clean_text_sep_token(X_train_raw)
    X_test = clean_text_sep_token(X_test_raw)

    print(f"Text cleaning summary:")
    print(f"  Original sample: '{X_train_raw[0][:60]}...'")
    print(f"  Cleaned sample:  '{X_train[0][:60]}...'")

    # OPTIMIZATION: Precompute POS tags for all CLEANED texts once
    print(f"\n{'='*60}")
    print("PRECOMPUTING POS TAGS FOR CLEANED TEXTS")
    print(f"{'='*60}")

    print("Computing POS tags for cleaned training data...")
    train_pos_tags = extract_pos_tags_batch(X_train, batch_size=1000)

    print("Computing POS tags for cleaned test data...")
    test_pos_tags = extract_pos_tags_batch(X_test, batch_size=1000)

    print("POS tag computation complete! Now running experiments...")

    # First run baseline (no ablation) with BERT on CLEANED texts
    print(f"\n{'='*60}")
    print("BERT BASELINE EXPERIMENT (NO ABLATION, CLEANED TEXTS)")
    print(f"{'='*60}")

    # Initialize BERT components for baseline
    from transformers import BertTokenizer, get_linear_schedule_with_warmup
    import torch
    from torch.optim import AdamW
    from torch.utils.data import DataLoader
    from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create baseline BERT datasets with CLEANED texts
    baseline_train_dataset = MemeDataset(
        texts=X_train,  # Using cleaned texts
        labels=y_train,
        tokenizer=tokenizer,
        max_len=128
    )

    baseline_test_dataset = MemeDataset(
        texts=X_test,   # Using cleaned texts
        labels=y_test,
        tokenizer=tokenizer,
        max_len=128
    )

    # Create baseline data loaders
    batch_size = 16
    baseline_train_loader = DataLoader(
        baseline_train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    baseline_test_loader = DataLoader(
        baseline_test_dataset,
        batch_size=batch_size
    )

    # Train baseline BERT model
    baseline_bert_model = BertClassifier(n_classes=2).to(device)
    baseline_optimizer = AdamW(baseline_bert_model.parameters(), lr=2e-5)

    epochs = 3
    total_steps = len(baseline_train_loader) * epochs
    baseline_scheduler = get_linear_schedule_with_warmup(
        baseline_optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Train baseline model
    best_baseline_accuracy = 0
    for epoch in range(epochs):
        print(f'Baseline Epoch {epoch + 1}/{epochs}')

        train_acc, train_loss = train_epoch_bin(
            baseline_bert_model,
            baseline_train_loader,
            baseline_optimizer,
            baseline_scheduler,
            device
        )

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

        val_acc, val_loss, _, _ = eval_model_bin(
            baseline_bert_model,
            baseline_test_loader,
            device
        )

        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        if val_acc > best_baseline_accuracy:
            best_baseline_accuracy = val_acc
            torch.save(baseline_bert_model.state_dict(), f'best_baseline_bert_cleaned_{dataset_name.lower()}.pt')

    # Load best baseline model and get predictions
    baseline_bert_model.load_state_dict(torch.load(f'best_baseline_bert_cleaned_{dataset_name.lower()}.pt'))
    _, _, baseline_pred, _ = eval_model_bin(baseline_bert_model, baseline_test_loader, device)
    baseline_pred = np.array(baseline_pred)

    # Calculate baseline metrics
    baseline_accuracy = accuracy_score(y_test, baseline_pred)
    baseline_f1 = f1_score(y_test, baseline_pred, average='macro')
    baseline_precision = precision_score(y_test, baseline_pred, average='macro')
    baseline_recall = recall_score(y_test, baseline_pred, average='macro')

    print(f"BERT Baseline Results (CLEANED TEXTS):")
    print(f"  - Accuracy: {baseline_accuracy:.3f}")
    print(f"  - F1-macro: {baseline_f1:.3f}")
    print(f"  - Precision-macro: {baseline_precision:.3f}")
    print(f"  - Recall-macro: {baseline_recall:.3f}")

    # Store all results
    all_results = {
        'model_type': 'BERT_CLEANED',
        'dataset': dataset_name,
        'text_cleaned': True,
        'baseline': {
            'accuracy': baseline_accuracy,
            'f1_macro': baseline_f1,
            'precision_macro': baseline_precision,
            'recall_macro': baseline_recall
        },
        'ablation_results': {}
    }

    # Run ablation for each POS category using precomputed tags for CLEANED texts
    for pos_category in pos_tags:
        print(f"\n{'='*40}")
        print(f"Running BERT ablation for {pos_category} (CLEANED)")
        print(f"{'='*40}")

        results = run_bert_pos_ablation_experiment(
            dataset_df_train, dataset_df_test, pos_category, dataset_name, binary_label,
            train_pos_tags, test_pos_tags
        )

        if results is not None:
            all_results['ablation_results'][pos_category] = results

            # Print performance drop
            accuracy_drop = baseline_accuracy - results['accuracy']
            f1_drop = baseline_f1 - results['f1_macro']

            print(f"\n📊 Performance Impact for {pos_category} (CLEANED):")
            print(f"   Accuracy drop: {accuracy_drop:.3f} ({accuracy_drop/baseline_accuracy*100:.1f}%)")
            print(f"   F1-macro drop: {f1_drop:.3f} ({f1_drop/baseline_f1*100:.1f}%)")
        else:
            print(f"❌ Failed to process {pos_category}")

    # Save complete results
    import os
    import json
    os.makedirs("evaluation/results/POS/BERT", exist_ok=True)
    summary_file = f"evaluation/results/POS/BERT/bert_pos_ablation_summary_cleaned_{dataset_name.lower()}.json"

    with open(summary_file, 'w') as f:
        json.dump(all_results, f, indent=2, cls=NumpyEncoder)

    print(f"\n✅ Complete BERT POS ablation results (CLEANED) saved to: {summary_file}")

    # Print summary
    print(f"\n{'='*60}")
    print("BERT POS ABLATION SUMMARY (CLEANED TEXTS)")
    print(f"{'='*60}")
    print(f"Baseline BERT (CLEANED) - Accuracy: {baseline_accuracy:.3f}, F1: {baseline_f1:.3f}")
    print("\nPOS Category Performance Drops:")

    for pos_category, result in all_results['ablation_results'].items():
        if result is not None:
            acc_drop = baseline_accuracy - result['accuracy']
            f1_drop = baseline_f1 - result['f1_macro']
            print(f"  {pos_category:6s}: Acc drop {acc_drop:.3f} ({acc_drop/baseline_accuracy*100:+5.1f}%), F1 drop {f1_drop:.3f} ({f1_drop/baseline_f1*100:+5.1f}%)")

    return all_results


# Universal POS tags definition for reference
UNIVERSAL_POS_TAGS = [
    'ADJ',      # adjective
    'ADV',      # adverb
    'INTJ',     # interjection
    'NOUN',     # noun
    'PROPN',    # proper noun
    'VERB'      # verb
]


In [None]:
from torch.optim import AdamW

def run_mami_bert_pos_ablation():
    """Run BERT POS ablation study on MAMI dataset."""
    print("Starting MAMI BERT POS Ablation Study...")

    results = run_complete_bert_pos_ablation_study(
        mami_training_df, mami_test_df, "MAMI", binary_label='misogynous'
    )

    return results


def run_exist2024_bert_pos_ablation():
    """Run BERT POS ablation study on EXIST2024 dataset."""
    print("Starting EXIST2024 BERT POS Ablation Study...")

    results = run_complete_bert_pos_ablation_study(
        exist_training_df, exist_test_df, "EXIST2024", binary_label='sexist'
    )

    return results


def run_all_bert_pos_ablation_studies():
    """
    Run BERT POS ablation studies on both MAMI and EXIST2024 datasets.
    """
    print("=" * 100)
    print("RUNNING ALL BERT POS ABLATION STUDIES")
    print("=" * 100)

    all_results = {}

    # Run MAMI BERT POS ablation
    print("\n" + "=" * 50)
    print("MAMI DATASET BERT POS ABLATION")
    print("=" * 50)
    mami_results = run_mami_bert_pos_ablation()
    all_results['MAMI'] = mami_results

    # Run EXIST2024 BERT POS ablation
    print("\n" + "=" * 50)
    print("EXIST2024 DATASET BERT POS ABLATION")
    print("=" * 50)
    exist_results = run_exist2024_bert_pos_ablation()
    all_results['EXIST2024'] = exist_results

    # Save combined results
    import os
    import json
    os.makedirs("evaluation/results/bert_pos", exist_ok=True)
    combined_file = "evaluation/results/bert_pos/bert_pos_ablation_all_datasets.json"

    with open(combined_file, 'w') as f:
        json.dump(all_results, f, indent=2, cls=NumpyEncoder)

    print(f"\n✅ Combined BERT POS ablation results saved to: {combined_file}")

    # Print comparison summary
    print(f"\n{'='*80}")
    print("CROSS-DATASET BERT POS ABLATION COMPARISON")
    print(f"{'='*80}")

    for pos_tag in UNIVERSAL_POS_TAGS:
        print(f"\n{pos_tag} Impact:")
        for dataset in ['MAMI', 'EXIST2024']:
            if pos_tag in all_results[dataset]['ablation_results']:
                result = all_results[dataset]['ablation_results'][pos_tag]
                baseline = all_results[dataset]['baseline']
                acc_drop = baseline['accuracy'] - result['accuracy']
                f1_drop = baseline['f1_macro'] - result['f1_macro']
                print(f"  {dataset:8s}: Acc drop {acc_drop:.3f} ({acc_drop/baseline['accuracy']*100:+5.1f}%), F1 drop {f1_drop:.3f} ({f1_drop/baseline['f1_macro']*100:+5.1f}%)")

    return all_results


# Run all BERT POS ablation experiments
all_results = run_all_bert_pos_ablation_studies()