In [52]:
import os
import json
import numpy as np
import pandas as pd
import random
import re
import math
from typing import List, Dict, Tuple, Any, Optional
from collections import Counter, defaultdict

# ML libs
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import GridSearchCV

# Transformer libs
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW

# Explainability
try:
    from captum.attr import IntegratedGradients
    CAPTUM_AVAILABLE = True
except Exception:
    CAPTUM_AVAILABLE = False

# NLP
import spacy
nlp = spacy.load("en_core_web_sm")

# Utilities
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [53]:
# -----------------------
# Configuration
# -----------------------
CONFIG = {
    "DATA_DIR": "./",  # path to dataset files
    "OUTPUT_DIR": "./results",
    "RANDOM_SEED": 42,
    "LABEL_FILE": "classes.npy",          # 3-class mapping file
    "LABEL_FILE_2": "classes_two.npy",    # 2-class mapping file (if required)
    "DATA_JSON": "dataset.json",
    "SPLIT_JSON": "post_id_divisions.json",
    "TARGET_LABEL_TYPE": "majority",      # 'majority' or 'first' or 'consensus' strategy to pick ground truth
    "USE_2CLASS": False,                  # if True, map labels into toxic/non-toxic using classes_two.npy
    "TRANSFORMER_MODEL": "distilbert-base-uncased",
    "MAX_SEQ_LENGTH": 128,
    "BATCH_SIZE": 16,
    "EPOCHS": 3,
    "LR": 2e-5,
    "TFIDF_MAX_FEATURES": 20000,
    "SAVE_MODELS": True,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu"
}
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)

# Reproducibility
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG["RANDOM_SEED"])

In [54]:
# -----------------------
# Data Loading & Preprocessing
# -----------------------
def load_label_encoder(file_path: str) -> LabelEncoder:
    encoder = LabelEncoder()
    encoder.classes_ = np.load(file_path, allow_pickle=True)
    return encoder

def load_dataset_json(data_json_path: str) -> Dict[str, Any]:
    with open(data_json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def aggregate_annotations(example: Dict[str, Any], label_strategy="majority") -> Tuple[str, List[str]]:
    """
    Aggregates annotator labels into a single ground truth label.
    Returns (label, list_of_targets)
    label_strategy: 'majority' (default), 'first' (first annotator label), 'consensus' (only if unanimous else 'normal')
    """
    annotators = example.get("annotators", [])
    labels = [ann.get("label").lower() for ann in annotators if ann.get("label")]
    # normalize labels to lower-case known labels
    if not labels:
        return "normal", []
    if label_strategy == "first":
        chosen = labels[0]
    elif label_strategy == "consensus":
        if all(l == labels[0] for l in labels):
            chosen = labels[0]
        else:
            chosen = "normal"
    else:
        # majority
        chosen = Counter(labels).most_common(1)[0][0]
    targets = []
    for ann in annotators:
        t = ann.get("target", [])
        if t:
            for v in t:
                targets.append(v)
    # de-duplicate targets and normalize
    targets = list({t.strip().lower() for t in targets if isinstance(t, str) and t.strip()})
    return chosen, targets

def tokens_to_text(tokens: List[str]) -> str:
    # Join tokens into cleaned string. The dataset tokens may be raw tokens; rejoin and clean.
    text = " ".join(tokens)
    # basic normalization
    text = text.replace(" n't", "n't")
    text = re.sub(r"\s+", " ", text).strip()
    return text

def clean_text(text: str) -> str:
    # lowercase, remove excessive whitespace, basic punctuation spacing
    text = text.lower()
    text = re.sub(r"http\S+", " ", text)  # remove URLs
    text = re.sub(r"[^a-z0-9\s'\-]", " ", text)  # allow apostrophes and hyphens
    text = re.sub(r"\s+", " ", text).strip()
    return text

def build_dataframe(data_json: Dict[str, Any],
                    split_ids: Dict[str, List[str]],
                    label_strategy="majority") -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Build train/val/test pandas DataFrames with columns:
      post_id, text, label, targets(list), rationales(list of lists), tokens(list)
    """
    rows = []
    for pid, example in data_json.items():
        post_tokens = example.get("post_tokens", [])
        text = tokens_to_text(post_tokens)
        label, targets = aggregate_annotations(example, label_strategy=label_strategy)
        rationales = example.get("rationales", [])
        rows.append({
            "post_id": pid,
            "text": text,
            "clean_text": clean_text(text),
            "label": label,
            "targets": targets,
            "rationales": rationales,
            "tokens": post_tokens
        })
    df = pd.DataFrame(rows).set_index("post_id", drop=False)
    # create splits
    train_ids = set(split_ids.get("train", []))
    val_ids = set(split_ids.get("valid", []))
    test_ids = set(split_ids.get("test", []))
    train_df = df[df.index.isin(train_ids)].copy()
    val_df = df[df.index.isin(val_ids)].copy()
    test_df = df[df.index.isin(test_ids)].copy()
    return train_df, val_df, test_df

In [55]:
# -----------------------
# Label mapping utilities
# -----------------------
def maybe_map_2class(df: pd.DataFrame, classes_two_encoder: Optional[LabelEncoder], classes_enc: LabelEncoder) -> Tuple[pd.DataFrame, LabelEncoder]:
    """
    If user wants 2-class mapping (toxic/non-toxic), map the labels accordingly.
    classes_two_encoder contains classes_two.npy content; classes_enc contains classes.npy content (3-class names).
    We'll map 3-class labels into 2-class labels by mapping known 'hatespeech'/'offensive' => 'toxic', 'normal' => 'non-toxic'.
    If USE_2CLASS is false, return original df and classes_enc.
    """
    if not CONFIG["USE_2CLASS"]:
        return df, classes_enc
    # Build map: any label in classes_two_encoder.classes_ => keep as is; else map
    # Typically classes_two contains ["toxic","non-toxic"] but we will implement direct mapping
    def map_label_to_2(label):
        if label in ("hatespeech", "offensive", "hate", "abusive", "toxic"):
            return "toxic"
        return "non-toxic"
    df = df.copy()
    df["label_2"] = df["label"].apply(map_label_to_2)
    encoder2 = classes_two_encoder
    return df, encoder2

In [56]:
# -----------------------
# TF-IDF + Logistic Regression Baseline
# -----------------------
def train_tfidf_lr(train_texts: List[str], train_labels: List[str],
                   val_texts: List[str], val_labels: List[str],
                   max_features: int = 20000) -> Dict[str, Any]:
    tfidf = TfidfVectorizer(max_features=max_features, ngram_range=(1,2))
    X_train = tfidf.fit_transform(train_texts)
    X_val = tfidf.transform(val_texts)
    # Basic LR with class_weight balanced
    lr = LogisticRegression(max_iter=2000, class_weight="balanced")
    lr.fit(X_train, train_labels)
    preds_val = lr.predict(X_val)
    return {"vectorizer": tfidf, "model": lr, "val_preds": preds_val}

In [57]:

# -----------------------
# Transformer Dataset & Training
# -----------------------
class TextDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        enc = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        item = {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }
        return item

def train_transformer(train_texts, train_labels_idx, val_texts, val_labels_idx,
                      label_list, model_name="distilbert-base-uncased",
                      epochs=3, batch_size=16, lr=2e-5, max_len=128, device="cpu"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(label_list))
    model.to(device)

    train_dataset = TextDataset(train_texts, train_labels_idx, tokenizer, max_length=max_len)
    val_dataset = TextDataset(val_texts, val_labels_idx, tokenizer, max_length=max_len)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1*total_steps), num_training_steps=total_steps)

    best_val_f1 = -1.0
    best_model_state = None

    for epoch in range(1, epochs + 1):
        model.train()
        train_losses = []
        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_losses.append(loss.item())
            pbar.set_postfix(loss=sum(train_losses)/len(train_losses))
        # Validation
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Val"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=1).cpu().numpy()
                all_preds.extend(preds.tolist())
                all_labels.extend(labels.cpu().numpy().tolist())

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro", zero_division=0)
        val_f1 = f1
        print(f"Epoch {epoch} - Val macro F1: {val_f1:.4f}")
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    if best_model_state:
        model.load_state_dict(best_model_state)
    return {"model": model, "tokenizer": tokenizer, "label_list": label_list}

In [58]:
# -----------------------
# Evaluation & Fairness
# -----------------------
def eval_classification(y_true: List[str], y_pred: List[str], label_encoder: LabelEncoder) -> Dict[str, Any]:
    """
    Returns overall metrics, per-class metrics, confusion matrix.
    """
    overall_acc = accuracy_score(y_true, y_pred)
    precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=label_encoder.classes_, zero_division=0)
    per_class = {}
    for i, cls in enumerate(label_encoder.classes_):
        per_class[cls] = {"precision": precision[i], "recall": recall[i], "f1": f1[i], "support": int(support[i])}
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(y_true, y_pred, average="micro", zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=label_encoder.classes_)
    return {
        "accuracy": overall_acc,
        "per_class": per_class,
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "macro_f1": macro_f1,
        "micro_precision": micro_precision,
        "micro_recall": micro_recall,
        "micro_f1": micro_f1,
        "confusion_matrix": cm.tolist()
    }

def metrics_by_target_group(df: pd.DataFrame, y_true_col: str, y_pred_col: str, label_encoder: LabelEncoder,
                            min_samples: int = 20) -> Dict[str, Any]:
    """
    Computes metrics per target group (target strings found in 'targets' list column).
    Returns dict mapping target->metrics. Also returns disparity measures vs global.
    """
    # Flatten all target types found
    # We'll compute metrics for every target token (e.g., 'african', 'women', 'muslim', 'gay', etc.)
    global_metrics = eval_classification(df[y_true_col].tolist(), df[y_pred_col].tolist(), label_encoder)
    global_f1 = global_metrics["macro_f1"]
    results = {}
    target_counts = defaultdict(int)
    for idx, row in df.iterrows():
        for t in row["targets"]:
            target_counts[t] += 1
    # Evaluate for targets meeting min_samples
    for target, cnt in target_counts.items():
        if cnt < min_samples:
            continue
        # Collect rows where target present
        sub = df[df["targets"].apply(lambda ts: target in ts)]
        y_true = sub[y_true_col].tolist()
        y_pred = sub[y_pred_col].tolist()
        metrics = eval_classification(y_true, y_pred, label_encoder)
        # Disparity: difference between group's macro_f1 and global_f1
        disparity_f1 = metrics["macro_f1"] - global_f1
        # TPR (True Positive Rate) and FPR per group for 'toxic' vs 'non-toxic' not always defined; compute per class if possible
        results[target] = {
            "count": int(cnt),
            "metrics": metrics,
            "disparity_macro_f1": disparity_f1
        }
    return {"global": global_metrics, "by_target": results}

In [59]:
# -----------------------
# Rationale evaluation
# -----------------------
def aggregate_human_rationale_binary(rationales: List[List[int]]) -> List[int]:
    """
    Given list of binary rationales (one per annotator, length == number of tokens),
    aggregate them to a single token-level binary vector (e.g., majority vote).
    """
    if not rationales:
        return []
    arr = np.array(rationales)
    # majority vote: token included if >=50% annotators flagged it
    vote = (arr.sum(axis=0) >= (arr.shape[0] / 2)).astype(int)
    return vote.tolist()

def model_token_importance_attention(tokenizer, model, text: str, max_len=128, device="cpu") -> List[float]:
    """
    Approximate token importance using last-layer attention weights aggregated by token.
    Fallback if Captum not available. Returns importance per tokenized token (excluding special tokens).
    """
    # Tokenize
    enc = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_len, truncation=True, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    # Get outputs with attentions
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
    # outputs.attentions is a tuple: (layer_count, batch, heads, seq_len, seq_len) for BERT-like models
    attentions = outputs.attentions  # if model returns None, will raise
    # Aggregate attentions: mean over heads and layers, average attention from [CLS] to tokens or overall attention
    attn_tensor = torch.stack(attentions)  # shape (layers, batch, heads, seq_len, seq_len)
    # mean over heads and layers
    attn_mean = attn_tensor.mean(dim=0).mean(dim=1).squeeze(0)  # shape (seq_len, seq_len)
    # use attention from [CLS] token (index 0) to other tokens as importance
    cls_to_tokens = attn_mean[0].cpu().numpy()  # length seq_len
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))
    # map importance back to tokens excluding special tokens
    token_importances = []
    for tok, imp in zip(tokens, cls_to_tokens):
        token_importances.append((tok, float(imp)))
    # remove special tokens (like [CLS],[SEP],[PAD])
    filtered = [(t, v) for t, v in token_importances if re.match(r"^\[|\<|\#", t) is None and t not in ("[CLS]", "[SEP]", "[PAD]")]
    toks, imps = zip(*filtered) if filtered else ([], [])
    return list(imps)

def compare_rationale_and_model(tokens: List[str], human_rationale_binary: List[int], model_token_importances: List[float], top_k: Optional[int] = None) -> Dict[str, Any]:
    """
    Compare human rationale (binary vector aligned to tokens) vs model token importances.
    - Convert model importances to top-k tokens flagged by model (or threshold)
    - Compute token-level precision/recall/F1 (binary)
    """
    if not tokens or not human_rationale_binary:
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "iou": 0.0}
    L = len(human_rationale_binary)
    # Align model_token_importances length to token length if possible. If model returns truncated tokens, attempt to align.
    if len(model_token_importances) >= L:
        model_imp = np.array(model_token_importances[:L])
    else:
        # pad with zeros
        model_imp = np.array(list(model_token_importances) + [0.0] * (L - len(model_token_importances)))
    # choose top_k (default: number of positive human rationale tokens) for comparison
    if top_k is None:
        top_k = max(1, int(sum(human_rationale_binary)))
    idx_top = np.argsort(-model_imp)[:top_k]
    model_binary = np.zeros(L, dtype=int)
    model_binary[idx_top] = 1
    human = np.array(human_rationale_binary, dtype=int)
    tp = int(((model_binary == 1) & (human == 1)).sum())
    fp = int(((model_binary == 1) & (human == 0)).sum())
    fn = int(((model_binary == 0) & (human == 1)).sum())
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    intersection = int(((model_binary == 1) & (human == 1)).sum())
    union = int(((model_binary == 1) | (human == 1)).sum())
    iou = intersection / union if union > 0 else 0.0
    return {"precision": precision, "recall": recall, "f1": f1, "iou": iou}

In [60]:
# -----------------------
# Putting it all together: main pipeline
# -----------------------
def run_pipeline():
    import joblib

    # Create plot directory
    plot_dir = os.path.join(CONFIG["OUTPUT_DIR"], "plots")
    os.makedirs(plot_dir, exist_ok=True)

    # Load files
    data_dir = CONFIG["DATA_DIR"]
    data_json_path = os.path.join(data_dir, CONFIG["DATA_JSON"])
    split_json_path = os.path.join(data_dir, CONFIG["SPLIT_JSON"])
    classes_path = os.path.join(data_dir, CONFIG["LABEL_FILE"])
    classes2_path = os.path.join(data_dir, CONFIG["LABEL_FILE_2"])

    assert os.path.exists(data_json_path), f"Missing {data_json_path}"
    assert os.path.exists(split_json_path), f"Missing {split_json_path}"
    assert os.path.exists(classes_path), f"Missing {classes_path}"

    print("Loading dataset...")
    data_json = load_dataset_json(data_json_path)
    with open(split_json_path, 'r', encoding='utf-8') as f:
        split_ids = json.load(f)

    classes_encoder = load_label_encoder(classes_path)
    classes2_encoder = None
    if os.path.exists(classes2_path):
        classes2_encoder = load_label_encoder(classes2_path)

    print("Building dataframes and splits...")
    train_df, val_df, test_df = build_dataframe(
        data_json, split_ids,
        label_strategy=CONFIG["TARGET_LABEL_TYPE"]
    )
    print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")

    # If no validation data exists, create it from train (10%)
    if len(val_df) == 0:
        print("Validation split missing — creating 10% validation split from training set.")
        val_df = train_df.sample(frac=0.1, random_state=CONFIG["RANDOM_SEED"])
        train_df = train_df.drop(val_df.index)
        print(f"New Train size: {len(train_df)}, New Val size: {len(val_df)}")

    # If mapping to 2-class required
    if CONFIG["USE_2CLASS"]:
        train_df, classes2_encoder = maybe_map_2class(train_df, classes2_encoder, classes_encoder)
        val_df, classes2_encoder = maybe_map_2class(val_df, classes2_encoder, classes_encoder)
        test_df, classes2_encoder = maybe_map_2class(test_df, classes2_encoder, classes_encoder)
        label_encoder = classes2_encoder
        label_col = "label_2"
    else:
        label_encoder = classes_encoder
        label_col = "label"

    # EDA
    print("Label distribution (train):", train_df[label_col].value_counts().to_dict())

    # Prepare text & labels
    X_train = train_df["clean_text"].tolist()
    y_train = train_df[label_col].tolist()
    X_val = val_df["clean_text"].tolist()
    y_val = val_df[label_col].tolist()
    X_test = test_df["clean_text"].tolist()
    y_test = test_df[label_col].tolist()

    # ---------------- Baseline Model ----------------
    print("\nTraining TF-IDF + LogisticRegression baseline...")
    baseline = train_tfidf_lr(X_train, y_train, X_val, y_val, max_features=CONFIG["TFIDF_MAX_FEATURES"])
    tfidf = baseline["vectorizer"]
    lr = baseline["model"]
    val_preds_baseline = baseline["val_preds"]

    # Save TF-IDF model
    joblib.dump(tfidf, os.path.join(CONFIG["OUTPUT_DIR"], "tfidf_vectorizer.joblib"))
    joblib.dump(lr, os.path.join(CONFIG["OUTPUT_DIR"], "logreg_model.joblib"))

    val_metrics_baseline = eval_classification(y_val, val_preds_baseline, label_encoder)

    # Test baseline
    X_test_tfidf = tfidf.transform(X_test)
    test_preds_baseline = lr.predict(X_test_tfidf)
    test_metrics_baseline = eval_classification(y_test, test_preds_baseline, label_encoder)

    # Save metrics
    with open(os.path.join(CONFIG["OUTPUT_DIR"], "baseline_test_metrics.json"), "w") as f:
        json.dump(test_metrics_baseline, f, indent=2)

    print("Baseline test macro F1:", test_metrics_baseline["macro_f1"])

    # ---------------- Transformer Model ----------------
    print("\nPreparing Transformer model fine-tuning...")
    label_list = list(label_encoder.classes_)
    label2idx = {lbl: i for i, lbl in enumerate(label_list)}
    y_train_idx = [label2idx[v] for v in y_train]
    y_val_idx = [label2idx[v] for v in y_val]
    y_test_idx = [label2idx[v] for v in y_test]

    transformer_res = train_transformer(
        X_train, y_train_idx,
        X_val, y_val_idx,
        label_list=label_list,
        model_name=CONFIG["TRANSFORMER_MODEL"],
        epochs=CONFIG["EPOCHS"],
        batch_size=CONFIG["BATCH_SIZE"],
        lr=CONFIG["LR"],
        max_len=CONFIG["MAX_SEQ_LENGTH"],
        device=CONFIG["DEVICE"]
    )

    transformer_model = transformer_res["model"]
    tokenizer = transformer_res["tokenizer"]

    if CONFIG["SAVE_MODELS"]:
        out_dir = os.path.join(CONFIG["OUTPUT_DIR"], "transformer_model")
        os.makedirs(out_dir, exist_ok=True)
        transformer_model.save_pretrained(out_dir)
        tokenizer.save_pretrained(out_dir)

    # Transformer test evaluation
    print("Evaluating transformer on test set...")
    test_dataset = TextDataset(X_test, y_test_idx, tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)

    all_preds = []
    transformer_model.eval()
    with torch.no_grad():
        for batch in test_loader:
            logits = transformer_model(
                input_ids=batch["input_ids"].to(CONFIG["DEVICE"]),
                attention_mask=batch["attention_mask"].to(CONFIG["DEVICE"])
            ).logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds.tolist())

    test_preds_transformer = [label_list[p] for p in all_preds]
    test_metrics_transformer = eval_classification(y_test, test_preds_transformer, label_encoder)

    with open(os.path.join(CONFIG["OUTPUT_DIR"], "transformer_test_metrics.json"), "w") as f:
        json.dump(test_metrics_transformer, f, indent=2)

    print("Transformer test macro F1:", test_metrics_transformer["macro_f1"])

    # ---------------- Fairness Evaluation ----------------
    test_df_copy = test_df.copy()
    test_df_copy["y_true"] = y_test
    test_df_copy["y_pred_baseline"] = test_preds_baseline
    test_df_copy["y_pred_transformer"] = test_preds_transformer

    fairness_baseline = metrics_by_target_group(test_df_copy, "y_true", "y_pred_baseline", label_encoder)
    fairness_transformer = metrics_by_target_group(test_df_copy, "y_true", "y_pred_transformer", label_encoder)

    # ---------------- Rationale Evaluation ----------------
    print("\nRationale evaluation...")
    rationale_results = []
    test_with_rationales = test_df_copy[test_df_copy["rationales"].apply(lambda x: bool(x))]
    sample = test_with_rationales.sample(n=min(200, len(test_with_rationales)), random_state=42)

    for idx, row in sample.iterrows():
        human_bin = aggregate_human_rationale_binary(row["rationales"])
        try:
            model_imps = model_token_importance_attention(
                tokenizer, transformer_model, row["text"],
                max_len=CONFIG["MAX_SEQ_LENGTH"],
                device=CONFIG["DEVICE"]
            )
        except:
            model_imps = [0.0] * len(row["tokens"])

        comp = compare_rationale_and_model(
            row["tokens"], human_bin, model_imps
        )
        rationale_results.append(comp)

    rationale_df = pd.DataFrame(rationale_results)
    rationale_summary = rationale_df.mean().to_dict()

    # ---------------- SAVE FAIRNESS, RATIONALE, RESULTS ----------------
    final_summary = {
        "baseline_test": test_metrics_baseline,
        "transformer_test": test_metrics_transformer,
        "fairness_baseline": fairness_baseline,
        "fairness_transformer": fairness_transformer,
        "rationale_summary": rationale_summary
    }

    with open(os.path.join(CONFIG["OUTPUT_DIR"], "final_results_summary.json"), "w") as f:
        json.dump(final_summary, f, indent=2)

    # ================================
    #    VISUALISATIONS SECTION
    # ================================

    # 1. Confusion Matrices
    def plot_confusion(cm, labels, title, filename):
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                    xticklabels=labels, yticklabels=labels)
        plt.title(title)
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, filename))
        plt.close()

    plot_confusion(
        np.array(test_metrics_baseline["confusion_matrix"]),
        label_list, "Baseline Confusion Matrix", "baseline_confusion_matrix.png"
    )

    plot_confusion(
        np.array(test_metrics_transformer["confusion_matrix"]),
        label_list, "Transformer Confusion Matrix", "transformer_confusion_matrix.png"
    )

    # 2. Per-class F1 comparison
    df_f1 = pd.DataFrame({
        "Class": label_list,
        "Baseline F1": [test_metrics_baseline["per_class"][c]["f1"] for c in label_list],
        "Transformer F1": [test_metrics_transformer["per_class"][c]["f1"] for c in label_list]
    })
    df_f1.to_csv(os.path.join(CONFIG["OUTPUT_DIR"], "per_class_f1.csv"), index=False)

    df_f1.plot(x="Class", kind="bar", figsize=(8, 5))
    plt.title("Per-Class F1 Score Comparison")
    plt.ylabel("F1 Score")
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "per_class_f1_comparison.png"))
    plt.close()

    # 3. Fairness – Macro F1 by Target Group
    ft = fairness_transformer["by_target"]
    targets = list(ft.keys())
    macro_f1 = [ft[t]["metrics"]["macro_f1"] for t in targets]

    plt.figure(figsize=(10, 6))
    sns.barplot(x=targets, y=macro_f1)
    plt.title("Transformer Macro-F1 by Target Group")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "fairness_macro_f1_by_target.png"))
    plt.close()

    # 4. Fairness Disparity Plot
    disparities = [ft[t]["disparity_macro_f1"] for t in targets]

    plt.figure(figsize=(10, 6))
    sns.barplot(x=targets, y=disparities)
    plt.axhline(0, color="black", linewidth=1)
    plt.title("Fairness Disparity (Target Macro-F1 - Global Macro-F1)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "fairness_disparity_plot.png"))
    plt.close()

    # 5. Rationale Alignment Histogram
    plt.figure(figsize=(7, 5))
    sns.histplot(rationale_df["f1"], bins=15, kde=True)
    plt.title("Distribution of Rationale Alignment (Token-level F1)")
    plt.xlabel("Rationale Alignment F1")
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, "rationale_alignment_hist.png"))
    plt.close()

    print("\nAll visualisations saved to:", plot_dir)
    print("Pipeline completed successfully.")

In [61]:
if __name__ == "__main__":
    run_pipeline()

Loading dataset...
Building dataframes and splits...
Train size: 15383, Val size: 0, Test size: 1924
Validation split missing — creating 10% validation split from training set.
New Train size: 13845, New Val size: 1538
Label distribution (train): {'normal': 5600, 'hatespeech': 4310, 'offensive': 3935}

Training TF-IDF + LogisticRegression baseline...
Baseline test macro F1: 0.648954999568649

Preparing Transformer model fine-tuning...


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train Epoch 1:   0%|          | 0/866 [00:00<?, ?it/s]

Val:   0%|          | 0/97 [00:00<?, ?it/s]

Epoch 1 - Val macro F1: 0.6795


Train Epoch 2:   0%|          | 0/866 [00:00<?, ?it/s]

Val:   0%|          | 0/97 [00:00<?, ?it/s]

Epoch 2 - Val macro F1: 0.6944


Train Epoch 3:   0%|          | 0/866 [00:00<?, ?it/s]

Val:   0%|          | 0/97 [00:00<?, ?it/s]

Epoch 3 - Val macro F1: 0.6833
Evaluating transformer on test set...
Transformer test macro F1: 0.6807284822611056

Rationale evaluation...

All visualisations saved to: ./results/plots
Pipeline completed successfully.
