<a href="https://colab.research.google.com/github/Stevebankz/Hate_Speech_Detection/blob/main/hate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- Cell 1: Install Dependencies ---

print("Installing required libraries...")
!pip install transformers datasets scikit-learn -q

print("--- Cell 1 Complete ---")


# --- Cell 2: Import Libraries & Mount Google Drive ---
import os
import pandas as pd
import torch
import numpy as np
from google.colab import drive
from datasets import load_dataset, DatasetDict

# Import individual metric functions
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.utils import resample
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
import gc # Garbage collector

# Mount your Google Drive
# This will prompt you for authorization.
#print("Mounting Google Drive...")
#drive.mount('/content/drive')

print("--- Cell 2 Complete ---")


# --- Cell 3: Define Configuration and Paths ---
# We'll set all our paths and hyperparameters here.
# This makes it super easy to change things later.

class Config:
    # --- Paths ---
    # This is the base path in your Google Drive
    DRIVE_PATH = '/content/drive/MyDrive/hate'

    # --- IMPORTANT ---
    # Change 'csv' to 'json' or 'parquet' if your files are not CSVs
    DATA_FILE_TYPE = 'csv'

    # Paths to your data files
    TRAIN_FILE = os.path.join(DRIVE_PATH, 'train.csv')
    VAL_FILE = os.path.join(DRIVE_PATH, 'val.csv')
    TEST_FILE = os.path.join(DRIVE_PATH, 'test.csv')

    # Where we will save the trained model
    MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, 'models/step1_bert_baseline')

    # --- Model Configuration ---
    # Since your data is multilingual, we CANNOT use 'bert-base-uncased'.
    # We MUST use a multilingual model. 'bert-base-multilingual-cased' (mBERT)
    # is the standard and perfect for this baseline.
    MODEL_NAME = 'bert-base-multilingual-cased'

    # --- Training Hyperparameters ---
    MAX_LENGTH = 128  # Max token length for sentences
    BATCH_SIZE = 16   # Batch size for training and eval
    EPOCHS = 5        # Number of training epochs (5 is a good start)
    LEARNING_RATE = 2e-5 # Standard learning rate for fine-tuning BERT

    # --- Labels ---
    NUM_LABELS = 2 # 0 (Non-hate) and 1 (Hate)

    # --- Evaluation ---
    N_BOOTSTRAPS = 1000 # Number of bootstrap samples for CIs

print("Configuration defined.")
print(f"Model to be trained: {Config.MODEL_NAME}")
print(f"Model will be saved to: {Config.MODEL_SAVE_PATH}")
print(f"Will run {Config.N_BOOTSTRAPS} bootstrap iterations for CI.")
print("--- Cell 3 Complete ---")


# --- Cell 4: Check for GPU ---
# Let's make sure we're using a GPU. Colab notebooks should have one.
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Awesome! We are using the GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("No GPU found. We are using the CPU (this will be SLOW).")

print("--- Cell 4 Complete ---")


# --- Cell 5: Load Dataset ---
# We use the 'datasets' library to load our split files directly
# into a DatasetDict object.

print(f"Loading data from {Config.DRIVE_PATH}...")
try:
    data_files = {
        'train': Config.TRAIN_FILE,
        'validation': Config.VAL_FILE,
        'test': Config.TEST_FILE
    }


    raw_datasets = load_dataset(Config.DATA_FILE_TYPE, data_files=data_files)

    print("Data loaded successfully!")
    print(raw_datasets)

    # Let's see an example
    print("\nExample from training set:")
    print(raw_datasets['train'][0])

except Exception as e:
    print(f"--- ERROR LOADING DATA ---")
    print(f"Could not load data. Check your paths and file type ('{Config.DATA_FILE_TYPE}').")
    print(f"Error: {e}")

    raise

print("--- Cell 5 Complete ---")


# --- Cell 6: Preprocessing (Tokenization) ---
# convert our 'text' into numbers (tokens) .

print(f"Loading tokenizer for {Config.MODEL_NAME}...")
# We use AutoTokenizer to automatically load the correct one for mBERT
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)

# This function will be applied to our entire dataset
def tokenize_function(batch):
    # 'text' is your text column.
    # 'padding="max_length"' pads all sentences to 128 tokens.
    # 'truncation=True' cuts off sentences longer than 128 tokens.
    return tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=Config.MAX_LENGTH
    )

print("Tokenizing datasets... (this may take a minute)")

# use .map() to apply the tokenization function to all splits
# batched=True makes it much faster.
# remove the original columns we don't need for training.
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=['post_id', 'text', 'label_name', 'label_3class', 'targets'] # Remove all non-essential columns
)

# The Trainer expects the label column to be named 'labels'
# rename our 'label' column
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

# Set the format to 'torch' so it returns PyTorch tensors
tokenized_datasets.set_format('torch')

print("Tokenization complete.")
print(tokenized_datasets)
print("\nExample of processed data:")
print(tokenized_datasets['train'][0])

print("--- Cell 6 Complete ---")


# --- Cell 7: Load Baseline Model ---
# load the mBERT model, configured for sequence classification.

print(f"Loading pre-trained model: {Config.MODEL_NAME}...")
model = AutoModelForSequenceClassification.from_pretrained(
    Config.MODEL_NAME,
    num_labels=Config.NUM_LABELS # 2 labels: 0 (Non-hate) and 1 (Hate)
)

# Move the model to the GPU
model.to(device)
print("Model loaded and moved to GPU.")
print("--- Cell 7 Complete ---")


# --- Cell 8: Define Evaluation Metrics ---
# This function is passed to the Trainer.
# It calculates the F1, Precision, and Recall .

def compute_metrics(eval_pred):
    """
    Called by the Trainer at evaluation time.
    """
    # eval_pred is a tuple of (logits, labels)
    logits, labels = eval_pred

    # Get the most likely prediction (index with the highest logit)
    predictions = np.argmax(logits, axis=-1)

    # Calculate metrics using individual functions
    # We use 'macro' averaging as it's good for potentially imbalanced datasets
    # and standard for classification tasks.
    precision = precision_score(labels, predictions, average='macro')
    recall = recall_score(labels, predictions, average='macro')
    f1 = f1_score(labels, predictions, average='macro')


    acc = accuracy_score(labels, predictions)

    # Return as a dictionary
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

print("Metrics function 'compute_metrics' defined.")
print("--- Cell 8 Complete ---")


# --- Cell 9: Configure Training Arguments ---
# This object holds all the training settings.

print("Configuring training arguments...")

training_args = TrainingArguments(
    output_dir=Config.MODEL_SAVE_PATH,

    # --- Training Hyperparameters ---
    num_train_epochs=Config.EPOCHS,
    learning_rate=Config.LEARNING_RATE,
    per_device_train_batch_size=Config.BATCH_SIZE,
    per_device_eval_batch_size=Config.BATCH_SIZE * 2,
    warmup_steps=500,
    weight_decay=0.01,

    # --- Evaluation and Saving ---
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    # --- Logging (silence step logs, keep only tqdm bars) ---
    report_to="none",
    logging_strategy="no",     # <- no step/epoch log lines
    disable_tqdm=False         # <- keep progress bars
)


print("--- Cell 9 Complete ---")


# --- Cell 10: Initialize Trainer ---
# The Trainer class handles all the complexity of training and evaluation.

print("Initializing Trainer...")

trainer = Trainer(
    model=model,                         # The model we just loaded
    args=training_args,                  # The training arguments we just defined
    train_dataset=tokenized_datasets["train"], # Our tokenized training data
    eval_dataset=tokenized_datasets["validation"],  # Our tokenized validation data
    tokenizer=tokenizer,                 # The tokenizer (so it can be saved with the model)
    compute_metrics=compute_metrics      # The function to calculate our metrics
)

print("Trainer initialized.")
print("--- Cell 10 Complete ---")


# --- Cell 11: Train the Model ---
#  We call .train() to start fine-tuning.

print("--- STARTING BASELINE MODEL TRAINING ---")
print(f"Training for {Config.EPOCHS} epochs...")

training_results = trainer.train()

print("--- TRAINING COMPLETE ---")
print("--- Cell 11 Complete ---")


# --- Cell 12: Save the Best Model and Results ---


print(f"Saving the best model to {Config.MODEL_SAVE_PATH}...")

# This saves the model, tokenizer, and config files
trainer.save_model(Config.MODEL_SAVE_PATH)

# We'll also save the training results
trainer.save_state()

print(f"Model successfully saved to {Config.MODEL_SAVE_PATH}")
print("--- Cell 12 Complete ---")


# --- Cell 13: Evaluate on the TEST Set (with Bootstrap CIs) ---


print("--- EVALUATING ON THE TEST SET (SINGLE PASS) ---")

# We run one clean pass to get the point-estimate
clean_test_results = trainer.evaluate(tokenized_datasets["test"])

print("\n\n--- FINAL BASELINE MODEL TEST RESULTS (CLEAN) ---")
print(f"Model: {Config.MODEL_NAME}")
print(f"Test F1-Score:   {clean_test_results['eval_f1']:.4f}")
print(f"Test Accuracy:   {clean_test_results['eval_accuracy']:.4f}")
print(f"Test Precision:  {clean_test_results['eval_precision']:.4f}")
print(f"Test Recall:     {clean_test_results['eval_recall']:.4f}")
print("---------------------------------------------------\n")

from tqdm.auto import tqdm

print(f"--- STARTING BOOTSTRAP EVALUATION ({Config.N_BOOTSTRAPS} iterations) ---")

test_dataset = tokenized_datasets["test"]
n_samples = len(test_dataset)
boot_f1_scores = []
boot_accuracy_scores = []
boot_precision_scores = []
boot_recall_scores = []

for _ in tqdm(range(Config.N_BOOTSTRAPS), desc="Bootstrapping", leave=False):
    boot_indices = resample(range(n_samples), replace=True, n_samples=n_samples)
    boot_sample = test_dataset.select(boot_indices)
    boot_results = trainer.evaluate(boot_sample, metric_key_prefix="boot")
    boot_f1_scores.append(boot_results['boot_f1'])
    boot_accuracy_scores.append(boot_results['boot_accuracy'])
    boot_precision_scores.append(boot_results['boot_precision'])
    boot_recall_scores.append(boot_results['boot_recall'])




print("--- BOOTSTRAP EVALUATION COMPLETE ---")

# Convert lists to numpy arrays for percentile calculation
boot_f1_scores = np.array(boot_f1_scores)
boot_accuracy_scores = np.array(boot_accuracy_scores)
boot_precision_scores = np.array(boot_precision_scores)
boot_recall_scores = np.array(boot_recall_scores)

# Calculate 95% confidence intervals (from 2.5th to 97.5th percentile)
f1_ci = np.percentile(boot_f1_scores, [2.5, 97.5])
acc_ci = np.percentile(boot_accuracy_scores, [2.5, 97.5])
prec_ci = np.percentile(boot_precision_scores, [2.5, 97.5])
rec_ci = np.percentile(boot_recall_scores, [2.5, 97.5])

# Calculate means
f1_mean = np.mean(boot_f1_scores)
acc_mean = np.mean(boot_accuracy_scores)
prec_mean = np.mean(boot_precision_scores)
rec_mean = np.mean(boot_recall_scores)

print("\n\n--- FINAL BASELINE MODEL TEST RESULTS (BOOTSTRAPPED) ---")
print(f"Metrics based on {Config.N_BOOTSTRAPS} bootstrap samples.")
print(f"Format: Mean (95% CI)")
print("----------------------------------------------------------")
print(f"Test F1-Score:   {f1_mean:.4f} (95% CI: [{f1_ci[0]:.4f}, {f1_ci[1]:.4f}])")
print(f"Test Accuracy:   {acc_mean:.4f} (95% CI: [{acc_ci[0]:.4f}, {acc_ci[1]:.4f}])")
print(f"Test Precision:  {prec_mean:.4f} (95% CI: [{prec_ci[0]:.4f}, {prec_ci[1]:.4f}])")
print(f"Test Recall:     {rec_mean:.4f} (95% CI: [{rec_ci[0]:.4f}, {rec_ci[1]:.4f}])")
print("----------------------------------------------------------\n")


# save these results to a file for our records
results_file = os.path.join(Config.DRIVE_PATH, 'models', 'step1_baseline_results.txt')
with open(results_file, 'w') as f:
    f.write("--- FINAL BASELINE MODEL TEST RESULTS ---\n\n")
    f.write(f"Model: {Config.MODEL_NAME}\n\n")

    f.write("--- SINGLE PASS (CLEAN) RESULTS ---\n")
    f.write(f"Test F1-Score:   {clean_test_results['eval_f1']:.4f}\n")
    f.write(f"Test Accuracy:   {clean_test_results['eval_accuracy']:.4f}\n")
    f.write(f"Test Precision:  {clean_test_results['eval_precision']:.4f}\n")
    f.write(f"Test Recall:     {clean_test_results['eval_recall']:.4f}\n\n")

    f.write(f"--- BOOTSTRAPPED RESULTS ({Config.N_BOOTSTRAPS} samples) ---\n")
    f.write(f"Format: Mean (95% CI)\n")
    f.write(f"Test F1-Score:   {f1_mean:.4f} (95% CI: [{f1_ci[0]:.4f}, {f1_ci[1]:.4f}])\n")
    f.write(f"Test Accuracy:   {acc_mean:.4f} (95% CI: [{acc_ci[0]:.4f}, {acc_ci[1]:.4f}])\n")
    f.write(f"Test Precision:  {prec_mean:.4f} (95% CI: [{prec_ci[0]:.4f}, {prec_ci[1]:.4f}])\n")
    f.write(f"Test Recall:     {rec_mean:.4f} (95% CI: [{rec_ci[0]:.4f}, {rec_ci[1]:.4f}])\n")


print(f"Test results saved to {results_file}")
print("--- Cell 13 Complete ---")


# --- Cell 14: Clean Up Memory ---


print("Cleaning up memory...")
del model
del trainer
del tokenized_datasets
del raw_datasets
gc.collect()
torch.cuda.empty_cache()

print("--- STEP 1 COMPLETE ---")
print("You now have a trained, saved, and evaluated baseline model.")

In [None]:
# --- Cell 1: Install Dependencies ---
print("Installing required libraries...")
!pip install transformers datasets scikit-learn -q

print("--- Cell 1 Complete ---")


# --- Cell 2: Import Libraries & Mount Google Drive ---
import os
import pandas as pd
import torch
import numpy as np
from google.colab import drive
from datasets import load_dataset, DatasetDict
# Import individual metric functions
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.utils import resample
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModel,
    AutoConfig,
    PreTrainedModel,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
import torch.nn as nn
import gc  # Garbage collector

# (Leave Drive mount commented if you don't need it right now)
# print("Mounting Google Drive...")
# drive.mount('/content/drive')

print("--- Cell 2 Complete ---")


# --- Cell 3: Define Configuration and Paths ---
class Config:
    # --- Paths ---
    DRIVE_PATH = '/content/drive/MyDrive/hate'

    # --- File type ---
    DATA_FILE_TYPE = 'csv'

    # Paths to your data files
    TRAIN_FILE = os.path.join(DRIVE_PATH, 'train.csv')
    VAL_FILE = os.path.join(DRIVE_PATH, 'val.csv')
    TEST_FILE = os.path.join(DRIVE_PATH, 'test.csv')

    # Save path: use a distinct folder for the gated model
    MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, 'models/step1_gated_fusion')

    # --- Model Configuration ---
    # BASE model for tokenizer & encoder
    BASE_MODEL_NAME = 'bert-base-multilingual-cased'
    # Public-facing name for this experiment
    MODEL_NAME = 'gated'

    # --- Training Hyperparameters ---
    MAX_LENGTH = 128
    BATCH_SIZE = 16
    EPOCHS = 5
    LEARNING_RATE = 2e-5

    # --- Labels ---
    NUM_LABELS = 2  # 0 (Non-hate), 1 (Hate)

    # --- Evaluation ---
    N_BOOTSTRAPS = 1000

print("Configuration defined.")
print(f"Model to be trained: {Config.MODEL_NAME}")
print(f"Base encoder: {Config.BASE_MODEL_NAME}")
print(f"Model will be saved to: {Config.MODEL_SAVE_PATH}")
print(f"Will run {Config.N_BOOTSTRAPS} bootstrap iterations for CI.")
print("--- Cell 3 Complete ---")


# --- Cell 4: Check for GPU ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Awesome! We are using the GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("No GPU found. We are using the CPU (this will be SLOW).")

print("--- Cell 4 Complete ---")


# --- Cell 5: Load Dataset ---
print(f"Loading data from {Config.DRIVE_PATH}...")
try:
    data_files = {
        'train': Config.TRAIN_FILE,
        'validation': Config.VAL_FILE,
        'test': Config.TEST_FILE
    }
    raw_datasets = load_dataset(Config.DATA_FILE_TYPE, data_files=data_files)

    print("Data loaded successfully!")
    print(raw_datasets)

    print("\nExample from training set:")
    print(raw_datasets['train'][0])

except Exception as e:
    print(f"--- ERROR LOADING DATA ---")
    print(f"Could not load data. Check your paths and file type ('{Config.DATA_FILE_TYPE}').")
    print(f"Error: {e}")
    raise

print("--- Cell 5 Complete ---")


# --- Cell 6: Preprocessing (Tokenization) ---
print(f"Loading tokenizer for {Config.BASE_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL_NAME)

def tokenize_function(batch):
    return tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=Config.MAX_LENGTH
    )

print("Tokenizing datasets... (this may take a minute)")
tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=['post_id', 'text', 'label_name', 'label_3class', 'targets']
)

tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format('torch')

print("Tokenization complete.")
print(tokenized_datasets)
print("\nExample of processed data:")
print(tokenized_datasets['train'][0])

print("--- Cell 6 Complete ---")


# --- Cell 7: Gated-Fusion Model (fixed & drop-in) ---
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig, PreTrainedModel, PretrainedConfig

class GFConfig(PretrainedConfig):
    """Config that can be safely constructed with no args by HF internals."""
    model_type = "gated_fusion_wrapper"
    def __init__(
        self,
        base_model_name: str = "bert-base-multilingual-cased",
        num_labels: int = 2,
        gate_hidden: int = 256,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.gate_hidden = gate_hidden

class GatedFusionForSequenceClassification(PreTrainedModel):
    config_class = GFConfig

    def __init__(self, config: GFConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size

        # Gate over hidden dims using [CLS] and masked-mean pooled token reps
        self.gate_mlp = nn.Sequential(
            nn.Linear(2 * hidden, config.gate_hidden),
            nn.ReLU(),
            nn.Linear(config.gate_hidden, hidden),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)

        self.post_init()

    @staticmethod
    def masked_mean(last_hidden_state, attention_mask):
        # attention_mask: [B, L], last_hidden_state: [B, L, H]
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # [B, L, 1]
        summed = (last_hidden_state * mask).sum(dim=1)                  # [B, H]
        denom = mask.sum(dim=1).clamp(min=1e-6)                         # [B, 1]
        return summed / denom

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        # Drop Trainer-injected / unknown kwargs that the base model won't accept
        allowed = {
            "position_ids", "head_mask", "inputs_embeds",
            "output_attentions", "output_hidden_states", "return_dict",
            "past_key_values", "encoder_hidden_states", "encoder_attention_mask"
        }
        safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
        safe_kwargs.pop("num_items_in_batch", None)

        enc = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **safe_kwargs
        )

        # CLS and masked mean pooling
        h_cls = enc.last_hidden_state[:, 0, :]                           # [B, H]
        h_mean = self.masked_mean(enc.last_hidden_state, attention_mask) # [B, H]

        # Gated fusion
        gate_inp = torch.cat([h_cls, h_mean], dim=-1)                    # [B, 2H]
        g = self.gate_mlp(gate_inp)                                      # [B, H] in (0,1)
        fused = g * h_cls + (1.0 - g) * h_mean
        fused = self.dropout(fused)
        logits = self.classifier(fused)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

        return {"loss": loss, "logits": logits}

print(f"Loading gated-fusion head on base encoder: {Config.BASE_MODEL_NAME}")
gf_config = GFConfig(
    base_model_name=Config.BASE_MODEL_NAME,
    num_labels=Config.NUM_LABELS,
    gate_hidden=256
)
model = GatedFusionForSequenceClassification(gf_config).to(device)
print("Gated-fusion model loaded and moved to device.")
print("--- Cell 7 Complete ---")




# --- Cell 8: Define Evaluation Metrics ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision = precision_score(labels, predictions, average='macro')
    recall = recall_score(labels, predictions, average='macro')
    f1 = f1_score(labels, predictions, average='macro')
    acc = accuracy_score(labels, predictions)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

print("Metrics function 'compute_metrics' defined.")
print("--- Cell 8 Complete ---")


# --- Cell 9: Configure Training Arguments ---
print("Configuring training arguments...")

training_args = TrainingArguments(
    output_dir=Config.MODEL_SAVE_PATH,

    # --- Training Hyperparameters ---
    num_train_epochs=Config.EPOCHS,
    learning_rate=Config.LEARNING_RATE,
    per_device_train_batch_size=Config.BATCH_SIZE,
    per_device_eval_batch_size=Config.BATCH_SIZE * 2,
    warmup_steps=500,
    weight_decay=0.01,

    # --- Evaluation and Saving ---
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    # --- Logging (progress bars only) ---
    report_to="none",
    logging_strategy="no",
    disable_tqdm=False
)

print("--- Cell 9 Complete ---")


# --- Cell 10: Initialize Trainer ---
print("Initializing Trainer...")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

print("Trainer initialized.")
print("--- Cell 10 Complete ---")


# --- Cell 11: Train the Model ---
print("--- STARTING GATED MODEL TRAINING ---")
print(f"Training for {Config.EPOCHS} epochs...")

training_results = trainer.train()

print("--- TRAINING COMPLETE ---")
print("--- Cell 11 Complete ---")


# --- Cell 12: Save the Best Model and Results ---
print(f"Saving the best model to {Config.MODEL_SAVE_PATH}...")
trainer.save_model(Config.MODEL_SAVE_PATH)
trainer.save_state()
print(f"Model successfully saved to {Config.MODEL_SAVE_PATH}")
print("--- Cell 12 Complete ---")


# --- Cell 13: Evaluate on the TEST Set (with Bootstrap CIs) ---
print("--- EVALUATING ON THE TEST SET (SINGLE PASS) ---")

clean_test_results = trainer.evaluate(tokenized_datasets["test"])

print("\n\n--- FINAL GATED MODEL TEST RESULTS (CLEAN) ---")
print(f"Model: {Config.MODEL_NAME}")
print(f"Test F1-Score:   {clean_test_results['eval_f1']:.4f}")
print(f"Test Accuracy:   {clean_test_results['eval_accuracy']:.4f}")
print(f"Test Precision:  {clean_test_results['eval_precision']:.4f}")
print(f"Test Recall:     {clean_test_results['eval_recall']:.4f}")
print("---------------------------------------------------\n")

from tqdm.auto import tqdm

print(f"--- STARTING BOOTSTRAP EVALUATION ({Config.N_BOOTSTRAPS} iterations) ---")

test_dataset = tokenized_datasets["test"]
n_samples = len(test_dataset)
boot_f1_scores = []
boot_accuracy_scores = []
boot_precision_scores = []
boot_recall_scores = []

for _ in tqdm(range(Config.N_BOOTSTRAPS), desc="Bootstrapping", leave=False):
    boot_indices = resample(range(n_samples), replace=True, n_samples=n_samples)
    boot_sample = test_dataset.select(boot_indices)
    boot_results = trainer.evaluate(boot_sample, metric_key_prefix="boot")
    boot_f1_scores.append(boot_results['boot_f1'])
    boot_accuracy_scores.append(boot_results['boot_accuracy'])
    boot_precision_scores.append(boot_results['boot_precision'])
    boot_recall_scores.append(boot_results['boot_recall'])

print("--- BOOTSTRAP EVALUATION COMPLETE ---")

boot_f1_scores = np.array(boot_f1_scores)
boot_accuracy_scores = np.array(boot_accuracy_scores)
boot_precision_scores = np.array(boot_precision_scores)
boot_recall_scores = np.array(boot_recall_scores)

f1_ci = np.percentile(boot_f1_scores, [2.5, 97.5])
acc_ci = np.percentile(boot_accuracy_scores, [2.5, 97.5])
prec_ci = np.percentile(boot_precision_scores, [2.5, 97.5])
rec_ci = np.percentile(boot_recall_scores, [2.5, 97.5])

f1_mean = np.mean(boot_f1_scores)
acc_mean = np.mean(boot_accuracy_scores)
prec_mean = np.mean(boot_precision_scores)
rec_mean = np.mean(boot_recall_scores)

print("\n\n--- FINAL GATED MODEL TEST RESULTS (BOOTSTRAPPED) ---")
print(f"Metrics based on {Config.N_BOOTSTRAPS} bootstrap samples.")
print(f"Format: Mean (95% CI)")
print("----------------------------------------------------------")
print(f"Test F1-Score:   {f1_mean:.4f} (95% CI: [{f1_ci[0]:.4f}, {f1_ci[1]:.4f}])")
print(f"Test Accuracy:   {acc_mean:.4f} (95% CI: [{acc_ci[0]:.4f}, {acc_ci[1]:.4f}])")
print(f"Test Precision:  {prec_mean:.4f} (95% CI: [{prec_ci[0]:.4f}, {prec_ci[1]:.4f}])")
print(f"Test Recall:     {rec_mean:.4f} (95% CI: [{rec_ci[0]:.4f}, {rec_ci[1]:.4f}])")
print("----------------------------------------------------------\n")

results_file = os.path.join(Config.DRIVE_PATH, 'models', 'step1_gated_baseline_results.txt')
with open(results_file, 'w') as f:
    f.write("--- FINAL GATED MODEL TEST RESULTS ---\n\n")
    f.write(f"Model: {Config.MODEL_NAME}\n\n")

    f.write("--- SINGLE PASS (CLEAN) RESULTS ---\n")
    f.write(f"Test F1-Score:   {clean_test_results['eval_f1']:.4f}\n")
    f.write(f"Test Accuracy:   {clean_test_results['eval_accuracy']:.4f}\n")
    f.write(f"Test Precision:  {clean_test_results['eval_precision']:.4f}\n")
    f.write(f"Test Recall:     {clean_test_results['eval_recall']:.4f}\n\n")

    f.write(f"--- BOOTSTRAPPED RESULTS ({Config.N_BOOTSTRAPS} samples) ---\n")
    f.write(f"Format: Mean (95% CI)\n")
    f.write(f"Test F1-Score:   {f1_mean:.4f} (95% CI: [{f1_ci[0]:.4f}, {f1_ci[1]:.4f}])\n")
    f.write(f"Test Accuracy:   {acc_mean:.4f} (95% CI: [{acc_ci[0]:.4f}, {acc_ci[1]:.4f}])\n")
    f.write(f"Test Precision:  {prec_mean:.4f} (95% CI: [{prec_ci[0]:.4f}, {prec_ci[1]:.4f}])\n")
    f.write(f"Test Recall:     {rec_mean:.4f} (95% CI: [{rec_ci[0]:.4f}, {rec_ci[1]:.4f}])\n")

print(f"Test results saved to {results_file}")
print("--- Cell 13 Complete ---")


# --- Cell 14: Clean Up Memory ---
print("Cleaning up memory...")
del model
del trainer
del tokenized_datasets
del raw_datasets
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("--- STEP 1 COMPLETE ---")
print("You now have a trained, saved, and evaluated GATED model.")


In [None]:
# SINGLE CELL: self-contained adversarial suite runner
# 1) Set your paths here:
DRIVE_BASE          = r"/content/drive/MyDrive/hate"
TRAIN_FILE          = f"{DRIVE_BASE}/train.csv"
VAL_FILE            = f"{DRIVE_BASE}/val.csv"
TEST_FILE           = f"{DRIVE_BASE}/test.csv"
BASELINE_CHECKPOINT = f"{DRIVE_BASE}/models/step1_bert_baseline"     # saved baseline model
GATED_CHECKPOINT    = f"{DRIVE_BASE}/models/step1_gated_fusion"      # saved gated model (folder)

# 2) Install deps
!pip -q install nlpaug nltk transformers datasets -q

# 3) Imports
import os, re, json, random
from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np
from datasets import Dataset, load_dataset
import nlpaug.augmenter.char as nac
from nlpaug.augmenter.word import SynonymAug
import nltk
nltk.download('wordnet', quiet=True)
from nltk.corpus import wordnet as wn

import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    AutoConfig,
    AutoModel,
    PreTrainedModel,
    PretrainedConfig
)
from sklearn.metrics import precision_recall_fscore_support

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------
# Gated model class + config
# -----------------------
class GFConfig(PretrainedConfig):
    model_type = "gated_fusion_wrapper"
    def __init__(self, base_model_name: str = "bert-base-multilingual-cased", num_labels: int = 2, gate_hidden: int = 256, **kwargs):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.gate_hidden = gate_hidden

class GatedFusionForSequenceClassification(PreTrainedModel):
    config_class = GFConfig
    def __init__(self, config: GFConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size
        self.gate_mlp = nn.Sequential(
            nn.Linear(2 * hidden, config.gate_hidden),
            nn.ReLU(),
            nn.Linear(config.gate_hidden, hidden),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()
    @staticmethod
    def masked_mean(last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        allowed = {"position_ids", "head_mask", "inputs_embeds", "output_attentions", "output_hidden_states", "return_dict", "past_key_values", "encoder_hidden_states", "encoder_attention_mask"}
        safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
        safe_kwargs.pop("num_items_in_batch", None)
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **safe_kwargs)
        h_cls = enc.last_hidden_state[:, 0, :]
        h_mean = self.masked_mean(enc.last_hidden_state, attention_mask)
        gate_inp = torch.cat([h_cls, h_mean], dim=-1)
        g = self.gate_mlp(gate_inp)
        fused = g * h_cls + (1.0 - g) * h_mean
        fused = self.dropout(fused)
        logits = self.classifier(fused)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        return {"loss": loss, "logits": logits}

# -----------------------
# Attack tools (working set)
# -----------------------
structural_typo   = nac.KeyboardAug()
structural_insert = nac.RandomCharAug(action="insert")

def attack_structural_case(text, p_char=0.15):
    if not isinstance(text, str):
        return text
    out = []
    for ch in text:
        if ch.isalpha() and random.random() < p_char:
            out.append(ch.upper() if ch.islower() else ch.lower())
        else:
            out.append(ch)
    return "".join(out)

synonym_aug = SynonymAug(aug_src='wordnet', aug_p=0.15)

def attack_structural_typo(text):
    try:
        return structural_typo.augment(text)[0]
    except Exception:
        return text

def attack_structural_insert(text):
    try:
        return structural_insert.augment(text)[0]
    except Exception:
        return text

def attack_semantic_synonym(text):
    try:
        out = synonym_aug.augment(text)
        if isinstance(out, list) and len(out) > 0:
            return out[0]
        return out
    except Exception:
        if not isinstance(text, str): return text
        toks = text.split()
        idxs = list(range(len(toks))); random.shuffle(idxs)
        for i in idxs:
            tok = toks[i]
            synsets = wn.synsets(tok)
            if not synsets: continue
            lemmas = []
            for syn in synsets[:3]:
                for l in syn.lemmas():
                    cand = l.name().replace('_', ' ')
                    if cand.lower() != tok.lower(): lemmas.append(cand)
            if lemmas:
                toks[i] = random.choice(lemmas); break
        return " ".join(toks)

CODED_LEXICON = {
    r'\bmuslims?\b': 'skittles',
    r'\bblack( people)?\b': 'googles',
    r'\bjews?\b': 'skypes',
    r'\bmexicans?\b': 'bings',
}
CODED_PATTERNS = [(re.compile(k, flags=re.IGNORECASE), v) for k, v in CODED_LEXICON.items()]
def attack_semantic_coded(text):
    if not isinstance(text, str): return text
    out = text
    for pat, repl in CODED_PATTERNS:
        out = pat.sub(repl, out)
    return out

SLUR_LEXICON = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
SLUR_PATTERNS = [re.compile(p, flags=re.IGNORECASE) for p in SLUR_LEXICON]
def attack_feature_slur_removal(text, replacement='[removed]'):
    if not isinstance(text, str): return text
    out = text
    for pat in SLUR_PATTERNS:
        out = pat.sub(replacement, out)
    return out

def apply_attack_texts(texts, attack_fn, desc="Attacking"):
    return [attack_fn(t) for t in tqdm(texts, desc=desc, leave=False)]

# -----------------------
# Per-class metrics helper
# -----------------------
def _per_class_metrics(y_true, y_pred, labels=(0, 1)):
    p, r, f1, sup = precision_recall_fscore_support(y_true, y_pred, labels=list(labels), zero_division=0)
    return [
        {"label": int(lbl), "precision": float(p[i]), "recall": float(r[i]), "f1": float(f1[i]), "support": int(sup[i])}
        for i, lbl in enumerate(labels)
    ]

# -----------------------
# Data bootstrap (auto-load if absent)
# -----------------------
def _ensure_data_loaded():
    global raw_datasets, tokenized_datasets
    need_raw = 'raw_datasets' not in globals()
    need_tok = 'tokenized_datasets' not in globals()
    if not (need_raw or need_tok):
        return  # already present

    # Load CSVs into a datasets.DatasetDict
    data_files = {}
    if os.path.isfile(TRAIN_FILE): data_files['train'] = TRAIN_FILE
    if os.path.isfile(VAL_FILE):   data_files['validation'] = VAL_FILE
    if os.path.isfile(TEST_FILE):  data_files['test'] = TEST_FILE
    if not data_files:
        raise RuntimeError("No data files found. Set TRAIN_FILE/VAL_FILE/TEST_FILE correctly.")

    raw_datasets = load_dataset("csv", data_files=data_files)

    # Choose a tokenizer: prefer baseline checkpoint tokenizer if present, else default multilingual BERT
    tok_path = BASELINE_CHECKPOINT if os.path.isdir(BASELINE_CHECKPOINT) else "bert-base-multilingual-cased"
    _tokenizer = AutoTokenizer.from_pretrained(tok_path)

    def tokenize_batch(batch):
        return _tokenizer(batch['text'], padding='max_length', truncation=True, max_length=128)

    tokenized_datasets = raw_datasets.map(tokenize_batch, batched=True)
    # rename label->labels for Trainer compatibility & set torch format
    if 'label' in tokenized_datasets['test'].column_names:
        tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")

_ensure_data_loaded()

# -----------------------
# Eval helpers & suite
# -----------------------
def evaluate_on_texts(trainer, tokenizer, texts, labels, max_length=128):
    encodings = tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    ds = Dataset.from_dict({k: encodings[k].tolist() for k in encodings})
    ds = ds.add_column("label", labels)
    ds.set_format(type='torch', columns=list(encodings.keys()) + ["label"])
    preds_out = trainer.predict(ds)
    logits = preds_out.predictions
    preds = np.argmax(logits, axis=-1)
    metrics = {}
    return preds, logits, metrics

def run_attack_and_report(trainer, tokenizer, raw_test_dataset, tokenized_test_dataset, attack_fn, attack_name, attack_mode='tp', save_dir=None):
    preds_out = trainer.predict(tokenized_test_dataset)
    clean_logits = preds_out.predictions
    clean_preds = np.argmax(clean_logits, axis=-1)
    clean_labels = preds_out.label_ids

    # Per-class metrics (clean)
    clean_per_class = _per_class_metrics(clean_labels, clean_preds)

    # Extract raw texts & labels
    raw_texts = list(raw_test_dataset['text'])
    raw_labels = list(raw_test_dataset['label'])
    n = len(raw_texts)
    assert n == len(clean_preds) == len(clean_labels)

    # choose indices
    if attack_mode == 'tp':
        target_indices = [i for i, (lab, pred) in enumerate(zip(clean_labels, clean_preds)) if lab == 1 and pred == 1]
    elif attack_mode == 'full':
        target_indices = list(range(n))
    else:
        raise ValueError("attack_mode must be 'tp' or 'full'")

    if len(target_indices) == 0:
        report = {
            'attack_name': attack_name, 'attack_mode': attack_mode,
            'clean_metrics': {}, 'clean_per_class': clean_per_class,
            'attacked_metrics': None, 'attacked_per_class': None,
            'attack_changed': 0, 'tp_count_targeted': 0, 'note': 'no targets'
        }
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f: json.dump(report, f, indent=2)
        return report

    attacked_texts = raw_texts.copy()
    to_attack = [raw_texts[i] for i in target_indices]
    attacked_outs = apply_attack_texts(to_attack, attack_fn, desc=attack_name)
    attack_count = 0
    for idx, new_text in zip(target_indices, attacked_outs):
        if new_text != raw_texts[idx]:
            attacked_texts[idx] = new_text
            attack_count += 1

    if attack_count == 0:
        report = {
            'attack_name': attack_name, 'attack_mode': attack_mode,
            'clean_metrics': {}, 'clean_per_class': clean_per_class,
            'attacked_metrics': None, 'attacked_per_class': None,
            'attack_changed': 0, 'tp_count_targeted': len(target_indices),
            'note': 'no modification made by attack_fn'
        }
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f: json.dump(report, f, indent=2)
        return report

    # evaluate attacked set
    attacked_preds, attacked_logits, attacked_metrics = evaluate_on_texts(trainer, tokenizer, attacked_texts, raw_labels)
    attacked_per_class = _per_class_metrics(raw_labels, attacked_preds)

    # For TP mode: ASR = % attacked TPs flipped to class 0
    asr = None; samples_flipped = None
    if attack_mode == 'tp':
        attacked_for_targets_preds = [attacked_preds[i] for i in target_indices]
        samples_flipped = sum(1 for p in attacked_for_targets_preds if p == 0)
        asr = samples_flipped / len(target_indices)

    # (Optional) delta F1 omitted since we're not recomputing macro-F1 here
    report = {
        'attack_name': attack_name,
        'attack_mode': attack_mode,
        'clean_per_class': clean_per_class,
        'attacked_per_class': attacked_per_class,
        'attack_changed': attack_count,
        'tp_count_targeted': len(target_indices),
        'samples_flipped': int(samples_flipped) if samples_flipped is not None else None,
        'attack_success_rate': float(asr) if asr is not None else None
    }
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f: json.dump(report, f, indent=2)
    return report

def run_all_attacks(trainer, tokenizer, raw_test_dataset, tokenized_test_dataset, save_dir="attack_reports", attack_mode='tp', attacks_to_run=None):
    os.makedirs(save_dir, exist_ok=True)
    attacks = attacks_to_run or [
        (attack_structural_typo, 'structural_typo'),
        (attack_structural_insert, 'structural_insert'),
        (attack_structural_case, 'structural_case'),
        (attack_semantic_synonym, 'semantic_synonym'),
        (attack_semantic_coded, 'semantic_coded'),
        (lambda t: attack_feature_slur_removal(t, replacement='[removed]'), 'feature_slur_removal'),
    ]
    suite_report = {'model': str(trainer.model.__class__), 'attack_mode': attack_mode, 'attacks': []}
    for fn, name in attacks:
        print(f"Running attack: {name} (mode={attack_mode})")
        rpt = run_attack_and_report(trainer, tokenizer, raw_test_dataset, tokenized_test_dataset, fn, name, attack_mode, save_dir=save_dir)
        if rpt is not None:
            suite_report['attacks'].append(rpt)
    with open(os.path.join(save_dir, f"suite_report_{trainer.model.__class__.__name__}_{attack_mode}.json"), 'w') as f:
        json.dump(suite_report, f, indent=2)
    return suite_report

# -----------------------
# Build eval-only trainers (auto; gated falls back to custom if needed)
# -----------------------
eval_args = TrainingArguments(
    output_dir="./tmp_eval",
    per_device_eval_batch_size=32,
    do_train=False, do_eval=True,
    report_to="none", logging_strategy="no",
    disable_tqdm=True
)

def build_eval_trainer(checkpoint_path, tokenizer=None):
    if not os.path.isdir(checkpoint_path):
        raise FileNotFoundError(checkpoint_path)
    tok = tokenizer or AutoTokenizer.from_pretrained(checkpoint_path)
    try:
        model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path)
        model.to(device)
        return Trainer(model=model, args=eval_args, tokenizer=tok, compute_metrics=None)
    except Exception:
        # Try gated fallback
        cfg = GFConfig.from_pretrained(checkpoint_path)
        model = GatedFusionForSequenceClassification(cfg)
        state_path = os.path.join(checkpoint_path, "pytorch_model.bin")
        if os.path.isfile(state_path):
            state_dict = torch.load(state_path, map_location=device)
            model.load_state_dict(state_dict, strict=False)
        model.to(device)
        return Trainer(model=model, args=eval_args, tokenizer=tok, compute_metrics=None)

# -----------------------
# Run: baseline then gated
# -----------------------
# Tokenizers (one per checkpoint, so embeddings/normalization match)
baseline_tok = AutoTokenizer.from_pretrained(BASELINE_CHECKPOINT if os.path.isdir(BASELINE_CHECKPOINT) else "bert-base-multilingual-cased")
gated_tok    = AutoTokenizer.from_pretrained(GATED_CHECKPOINT    if os.path.isdir(GATED_CHECKPOINT)    else "bert-base-multilingual-cased")

print("Building baseline trainer...")
trainer_baseline = build_eval_trainer(BASELINE_CHECKPOINT, tokenizer=baseline_tok)
print("Baseline trainer ready. Running attacks...")
baseline_report = run_all_attacks(
    trainer_baseline,
    tokenizer=baseline_tok,
    raw_test_dataset=raw_datasets['test'],
    tokenized_test_dataset=tokenized_datasets['test'],
    save_dir="attack_reports/baseline",
    attack_mode='tp'
)
print("Baseline attacks finished. Reports -> attack_reports/baseline")

print("\nBuilding gated trainer...")
trainer_gated = build_eval_trainer(GATED_CHECKPOINT, tokenizer=gated_tok)
print("Gated trainer ready. Running attacks...")
gated_report = run_all_attacks(
    trainer_gated,
    tokenizer=gated_tok,
    raw_test_dataset=raw_datasets['test'],
    tokenized_test_dataset=tokenized_datasets['test'],
    save_dir="attack_reports/gated",
    attack_mode='tp'
)
print("Gated attacks finished. Reports -> attack_reports/gated")

print("\nALL DONE ✓  Check JSONs under attack_reports/ (each includes per-class metrics).")


In [None]:
# ===========================
# Sentinel Architecture + Robust Training (FGM + Consistency)
# ===========================
!pip -q install transformers datasets scikit-learn

import os, math, re, json, random
import numpy as np
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    PretrainedConfig, PreTrainedModel,
    Trainer, TrainingArguments
)
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# Config (EDIT THESE PATHS)
# -------------------------
class Config:
    DRIVE = "/content/drive/MyDrive/hate"
    TRAIN = os.path.join(DRIVE, "train.csv")
    VAL   = os.path.join(DRIVE, "val.csv")
    TEST  = os.path.join(DRIVE, "test.csv")

    BASE_MODEL = "xlm-roberta-base"  # multilingual & strong; swap to roberta-large for EN-only
    MAX_LEN    = 128
    BATCH      = 16
    EPOCHS     = 5
    LR         = 2e-5
    NUM_LABELS = 2

    # Sentinel toggles
    USE_SENTINEL       = True     # set False to run plain encoder classifier (baseline)
    HEURISTIC_DIM      = 32
    HEURISTIC_HIDDEN   = 256
    CAUSAL_HIDDEN      = 256
    ATTENTION_HEADS    = 8
    AUX_CASUAL_LOSS_W  = 0.2

    # Robustness additions
    ALPHA_ADV       = 0.5     # weight for adversarial loss (FGM)
    BETA_CONS       = 0.2     # weight for consistency loss
    FGM_EPS         = 1e-3    # magnitude of FGM perturbation on embeddings
    CONS_TOK_MASK_P = 0.08    # probability to mask tokens for consistency

    SAVE_DIR = os.path.join(DRIVE, "models", "sentinel_xlmr")

# -------------------------
# Heuristic feature builder
# -------------------------
SLUR_REGEXES = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
SLUR_PATTERNS = [re.compile(p, re.IGNORECASE) for p in SLUR_REGEXES]

def build_heuristic_features(text: str) -> np.ndarray:
    """
    Minimal, fast features; replace/extend with LIWC, sentiment, dependency, etc.
    Size must equal Config.HEURISTIC_DIM (we'll pad/truncate).
    """
    if not isinstance(text, str):
        text = "" if text is None else str(text)

    length = len(text)
    words  = text.split()
    n_words = max(1, len(words))

    upper = sum(1 for c in text if c.isalpha() and c.isupper())
    digits = sum(1 for c in text if c.isdigit())
    punct = sum(1 for c in text if c in ".,;:!?")

    # crude ratios
    upper_ratio = upper / max(1, sum(c.isalpha() for c in text))
    digit_ratio = digits / max(1, len(text))
    punct_ratio = punct / max(1, len(text))

    # slur density
    slur_hits = 0
    for pat in SLUR_PATTERNS:
        slur_hits += len(pat.findall(text))
    slur_density = slur_hits / n_words

    # simplistic aggression cue count
    cues = sum(text.lower().count(k) for k in ["kill", "die", "trash", "dirty", "dog", "pig", "scum", "hate"])
    cue_density = cues / n_words

    base_feats = np.array([
        length, n_words, upper, digits, punct,
        upper_ratio, digit_ratio, punct_ratio,
        slur_hits, slur_density, cues, cue_density
    ], dtype=np.float32)

    # Normalize some scale-sensitive feats (very rough)
    base_feats[0] = math.log1p(base_feats[0])   # length
    base_feats[1] = math.log1p(base_feats[1])   # n_words

    # Pad/truncate to HEURISTIC_DIM
    H = Config.HEURISTIC_DIM
    if base_feats.shape[0] < H:
        pad = np.zeros(H - base_feats.shape[0], dtype=np.float32)
        feats = np.concatenate([base_feats, pad])
    else:
        feats = base_feats[:H]
    return feats

# -------------------------
# Dataset + Tokenization
# -------------------------
assert os.path.isfile(Config.TRAIN) and os.path.isfile(Config.VAL) and os.path.isfile(Config.TEST), \
    "Train/Val/Test CSVs not found. Update Config paths."

data_files = {"train": Config.TRAIN, "validation": Config.VAL, "test": Config.TEST}
raw = load_dataset("csv", data_files=data_files)
tok = AutoTokenizer.from_pretrained(Config.BASE_MODEL)

def tok_map(batch):
    enc = tok(batch["text"], padding="max_length", truncation=True, max_length=Config.MAX_LEN)
    # Heuristic features per example
    feats = [build_heuristic_features(t) for t in batch["text"]]
    enc["heuristic_feats"] = feats
    # Optional auxiliary causal targets:
    if "causal_target" in batch:
        enc["causal_target"] = batch["causal_target"]
    return enc

tokenized = raw.map(tok_map, batched=True, remove_columns=[c for c in raw["train"].column_names if c not in ("text","label","causal_target")])
tokenized = tokenized.rename_column("label", "labels")
tokenized.set_format(type="torch", columns=["input_ids","attention_mask","labels","heuristic_feats"] + (["causal_target"] if "causal_target" in tokenized["train"].column_names else []))

# -------------------------
# Data collator (to tensorize heuristic feats)
# -------------------------
@dataclass
class SentinelCollator:
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        out = {}
        for k in ("input_ids", "attention_mask", "labels"):
            out[k] = torch.stack([b[k] for b in batch])
        feats = [torch.tensor(b["heuristic_feats"], dtype=torch.float32) if not isinstance(b["heuristic_feats"], torch.Tensor)
                 else b["heuristic_feats"].to(torch.float32)
                 for b in batch]
        out["heuristic_feats"] = torch.stack(feats)
        if "causal_target" in batch[0]:
            out["causal_target"] = torch.stack([b["causal_target"] for b in batch]).long()
        return out

collator = SentinelCollator()

# -------------------------
# Sentinel Config + Model
# -------------------------
class SentinelConfig(PretrainedConfig):
    model_type = "sentinel_fusion"
    def __init__(
        self,
        base_model_name: str = "xlm-roberta-base",
        num_labels: int = 2,
        heuristic_dim: int = 32,
        heuristic_hidden: int = 256,
        causal_hidden: int = 256,
        attn_heads: int = 8,
        aux_causal_loss_weight: float = 0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.heuristic_dim = heuristic_dim
        self.heuristic_hidden = heuristic_hidden
        self.causal_hidden = causal_hidden
        self.attn_heads = attn_heads
        self.aux_causal_loss_weight = aux_causal_loss_weight

class SentinelModel(PreTrainedModel):
    config_class = SentinelConfig
    def __init__(self, config: SentinelConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size

        # Heuristic projector
        self.heuristic_proj = nn.Sequential(
            nn.Linear(config.heuristic_dim, config.heuristic_hidden),
            nn.ReLU(),
            nn.Linear(config.heuristic_hidden, hidden),
            nn.LayerNorm(hidden)
        )

        # Causal pathway
        self.causal_mlp = nn.Sequential(
            nn.Linear(hidden, config.causal_hidden),
            nn.ReLU(),
            nn.Linear(config.causal_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.causal_aux_head = nn.Linear(hidden, 2)

        # Cross-attention: Query = [CLS], KV = [heuristic, causal]
        self.xattn = nn.MultiheadAttention(embed_dim=hidden, num_heads=config.attn_heads, batch_first=True)

        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    @staticmethod
    def masked_mean(last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        heuristic_feats=None,
        labels=None,
        causal_target: Optional[torch.Tensor] = None,
        **kwargs
    ):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden = enc.last_hidden_state                  # [B, L, H]
        h_cls = last_hidden[:, 0, :]                         # [B, H] (semantic query)

        h_heu = self.heuristic_proj(heuristic_feats)         # [B, H]
        h_cau = self.causal_mlp(h_cls)                       # [B, H]

        Q = h_cls.unsqueeze(1)                               # [B, 1, H]
        KV = torch.stack([h_heu, h_cau], dim=1)              # [B, 2, H]
        fused, _ = self.xattn(Q, KV, KV)                     # [B, 1, H]
        fused = fused.squeeze(1)                             # [B, H]
        fused = self.dropout(fused)
        logits = self.classifier(fused)                      # [B, C]

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        if causal_target is not None:
            cau_logits = self.causal_aux_head(h_cau)         # [B, 2]
            aux_loss = nn.CrossEntropyLoss()(cau_logits, causal_target)
            if loss is None:
                loss = self.config.aux_causal_loss_weight * aux_loss
            else:
                loss = loss + self.config.aux_causal_loss_weight * aux_loss

        return {"loss": loss, "logits": logits}

# -------------------------
# Baseline Model (no fusion)
# -------------------------
class BaselineClassifier(PreTrainedModel):
    config_class = SentinelConfig
    def __init__(self, config: SentinelConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        h_cls = enc.last_hidden_state[:, 0, :]
        logits = self.classifier(self.dropout(h_cls))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

# -------------------------
# Metrics (macro)
# -------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy":  accuracy_score(labels, preds),
        "precision": precision_score(labels, preds, average="macro", zero_division=0),
        "recall":    recall_score(labels, preds, average="macro", zero_division=0),
        "f1":        f1_score(labels, preds, average="macro", zero_division=0),
    }

# -------------------------
# Robustness additions
# -------------------------
import torch.nn.functional as F

def symmetric_kl(logits_p, logits_q, temperature=1.0):
    p = F.log_softmax(logits_p/temperature, dim=-1)
    q = F.log_softmax(logits_q/temperature, dim=-1)
    p_soft = p.exp()
    q_soft = q.exp()
    return 0.5 * (F.kl_div(p, q_soft, reduction='batchmean') +
                  F.kl_div(q, p_soft, reduction='batchmean'))

@torch.no_grad()
def corrupt_inputs_for_consistency(input_ids, attention_mask, mask_token_id, p=0.08):
    x = input_ids.clone()
    B, L = x.size()
    rand = torch.rand_like(x.float())
    corrupt_mask = (attention_mask == 1) & (rand < p)
    corrupt_mask[:, 0] = False  # keep CLS intact
    x[corrupt_mask] = mask_token_id
    return x

class FGM:
    def __init__(self, model, epsilon=1e-3):
        self.model = model
        self.epsilon = epsilon
        self.backup = None
    def _emb(self):
        return self.model.encoder.get_input_embeddings().weight
    def attack(self):
        emb = self._emb()
        if emb.grad is None:
            return False
        grad = emb.grad
        norm = torch.norm(grad)
        if torch.isnan(norm) or torch.isinf(norm) or norm.item() == 0:
            return False
        self.backup = emb.data.clone()
        r_adv = self.epsilon * grad / (norm + 1e-12)
        emb.data.add_(r_adv)
        return True
    def restore(self):
        if self.backup is not None:
            self._emb().data = self.backup
            self.backup = None

from transformers import Trainer
import torch.nn.functional as F

class RobustTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fgm = FGM(self.model, epsilon=getattr(Config, "FGM_EPS", 1e-3))
        self.alpha = getattr(Config, "ALPHA_ADV", 0.0)
        self.beta  = getattr(Config, "BETA_CONS", 0.0)
        # resolve mask token id
        self.mask_token_id = None
        if getattr(self, "processing_class", None) is not None and getattr(self.processing_class, "mask_token_id", None) is not None:
            self.mask_token_id = self.processing_class.mask_token_id
        elif getattr(self, "tokenizer", None) is not None and getattr(self.tokenizer, "mask_token_id", None) is not None:
            self.mask_token_id = self.tokenizer.mask_token_id
        else:
            try:
                self.mask_token_id = self.model.encoder.config.mask_token_id
            except Exception:
                self.mask_token_id = None

    # NOTE: accept the new arg `num_items_in_batch`
    def training_step(self, model, inputs, num_items_in_batch=None):
        model.train()
        inputs = self._prepare_inputs(inputs)

        # ---- main forward (CE ± aux) ----
        outputs = model(**inputs)
        loss_main = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        total_loss = loss_main

        # ---- consistency loss (noisy masked inputs) ----
        if self.beta > 0 and self.mask_token_id is not None and "input_ids" in inputs and "attention_mask" in inputs:
            with torch.no_grad():
                noisy_inputs = {k: v for k, v in inputs.items()}
                noisy_inputs["input_ids"] = corrupt_inputs_for_consistency(
                    inputs["input_ids"], inputs["attention_mask"],
                    mask_token_id=self.mask_token_id,
                    p=getattr(Config, "CONS_TOK_MASK_P", 0.08)
                )
            logits_clean = outputs["logits"] if isinstance(outputs, dict) else outputs[1]
            outputs_noisy = model(**noisy_inputs)
            logits_noisy = outputs_noisy["logits"] if isinstance(outputs_noisy, dict) else outputs_noisy[1]
            loss_cons = symmetric_kl(logits_clean.detach(), logits_noisy)
            total_loss = total_loss + self.beta * loss_cons

        # backprop main+consistency
        total_loss.backward()

        # ---- FGM adversarial step ----
        if self.alpha > 0:
            if self.fgm.attack():
                adv_outputs = model(**inputs)
                loss_adv = adv_outputs["loss"] if isinstance(adv_outputs, dict) else adv_outputs[0]
                (self.alpha * loss_adv).backward()
                self.fgm.restore()

        self.optimizer.step()
        self.lr_scheduler.step()
        self.optimizer.zero_grad()
        return total_loss.detach()


# -------------------------
# Build model + trainer
# -------------------------
cfg = SentinelConfig(
    base_model_name=Config.BASE_MODEL,
    num_labels=Config.NUM_LABELS,
    heuristic_dim=Config.HEURISTIC_DIM,
    heuristic_hidden=Config.HEURISTIC_HIDDEN,
    causal_hidden=Config.CAUSAL_HIDDEN,
    attn_heads=Config.ATTENTION_HEADS,
    aux_causal_loss_weight=Config.AUX_CASUAL_LOSS_W
)

model = (SentinelModel(cfg) if Config.USE_SENTINEL else BaselineClassifier(cfg)).to(device)

args = TrainingArguments(
    output_dir=Config.SAVE_DIR,
    num_train_epochs=Config.EPOCHS,
    learning_rate=Config.LR,
    per_device_train_batch_size=Config.BATCH,
    per_device_eval_batch_size=Config.BATCH * 2,
    weight_decay=0.01,
    warmup_ratio=0.06,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    report_to="none",
    logging_strategy="no",
    disable_tqdm=False
)

trainer = RobustTrainer(
    model=model,
    args=args,
    data_collator=collator,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    processing_class=tok,
    compute_metrics=compute_metrics
)

# -------------------------
# Train & Evaluate
# -------------------------
print(f"Training {'Sentinel' if Config.USE_SENTINEL else 'Baseline'} model with robustness losses...")
trainer.train()
print("Evaluating on test...")
test_metrics = trainer.evaluate(tokenized["test"])
print({k: round(v, 4) for k, v in test_metrics.items()})

# Save final model
trainer.save_model(Config.SAVE_DIR)
print(f"Saved to: {Config.SAVE_DIR}")


In [None]:
# === SINGLE CELL: Sentinel-only Adversarial Suite (with per-class metrics) ===
# 1) Paths (edit as needed)
DRIVE_BASE           = "/content/drive/MyDrive/hate"
TRAIN_FILE           = f"{DRIVE_BASE}/train.csv"
VAL_FILE             = f"{DRIVE_BASE}/val.csv"
TEST_FILE            = f"{DRIVE_BASE}/test.csv"
SENTINEL_CHECKPOINT  = f"{DRIVE_BASE}/models/sentinel_xlmr"  # <--- set to your saved Sentinel folder

# 2) Deps
!pip -q install nlpaug nltk transformers datasets scikit-learn

# 3) Imports
import os, re, json, math, random
import numpy as np
from tqdm.auto import tqdm
import nlpaug.augmenter.char as nac
from nlpaug.augmenter.word import SynonymAug
import nltk
nltk.download('wordnet', quiet=True)
from nltk.corpus import wordnet as wn

import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer, AutoConfig, AutoModel,
    PretrainedConfig, PreTrainedModel,
    TrainingArguments, Trainer
)
from sklearn.metrics import precision_recall_fscore_support

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# Sentinel config (must match training)
# -------------------------
class SentinelConfig(PretrainedConfig):
    model_type = "sentinel_fusion"
    def __init__(
        self,
        base_model_name: str = "xlm-roberta-base",
        num_labels: int = 2,
        heuristic_dim: int = 32,         # <--- MUST match training
        heuristic_hidden: int = 256,
        causal_hidden: int = 256,
        attn_heads: int = 8,
        aux_causal_loss_weight: float = 0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.heuristic_dim = heuristic_dim
        self.heuristic_hidden = heuristic_hidden
        self.causal_hidden = causal_hidden
        self.attn_heads = attn_heads
        self.aux_causal_loss_weight = aux_causal_loss_weight

# -------------------------
# Sentinel model (same as training-time)
# -------------------------
class SentinelModel(PreTrainedModel):
    config_class = SentinelConfig
    def __init__(self, config: SentinelConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size

        # Heuristic projector
        self.heuristic_proj = nn.Sequential(
            nn.Linear(config.heuristic_dim, config.heuristic_hidden),
            nn.ReLU(),
            nn.Linear(config.heuristic_hidden, hidden),
            nn.LayerNorm(hidden)
        )

        # Causal pathway (from CLS)
        self.causal_mlp = nn.Sequential(
            nn.Linear(hidden, config.causal_hidden),
            nn.ReLU(),
            nn.Linear(config.causal_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.causal_aux_head = nn.Linear(hidden, 2)  # (unused in eval)

        # Cross-attention: Q = CLS; K,V = [heuristic, causal]
        self.xattn = nn.MultiheadAttention(embed_dim=hidden, num_heads=config.attn_heads, batch_first=True)

        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        heuristic_feats=None,
        labels=None,
        causal_target: Optional[torch.Tensor] = None,
        **kwargs
    ):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden = enc.last_hidden_state              # [B, L, H]
        h_cls = last_hidden[:, 0, :]                     # [B, H]
        h_heu = self.heuristic_proj(heuristic_feats)     # [B, H]
        h_cau = self.causal_mlp(h_cls)                   # [B, H]

        Q = h_cls.unsqueeze(1)                           # [B, 1, H]
        KV = torch.stack([h_heu, h_cau], dim=1)          # [B, 2, H]
        fused, _ = self.xattn(Q, KV, KV)                 # [B, 1, H]
        fused = fused.squeeze(1)
        logits = self.classifier(self.dropout(fused))

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

# -------------------------
# Heuristic features (must match training logic & HEURISTIC_DIM)
# -------------------------
SLUR_REGEXES = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
SLUR_PATTERNS = [re.compile(p, re.IGNORECASE) for p in SLUR_REGEXES]

HEURISTIC_DIM = 32  # <--- set to the same value you trained with

def build_heuristic_features(text: str) -> np.ndarray:
    if not isinstance(text, str):
        text = "" if text is None else str(text)

    length = len(text)
    words  = text.split()
    n_words = max(1, len(words))

    upper  = sum(1 for c in text if c.isalpha() and c.isupper())
    digits = sum(1 for c in text if c.isdigit())
    punct  = sum(1 for c in text if c in ".,;:!?")

    alpha_count = max(1, sum(c.isalpha() for c in text))
    upper_ratio = upper / alpha_count
    digit_ratio = digits / max(1, len(text))
    punct_ratio = punct / max(1, len(text))

    slur_hits = 0
    for pat in SLUR_PATTERNS:
        slur_hits += len(pat.findall(text))
    slur_density = slur_hits / n_words

    cues = sum(text.lower().count(k) for k in ["kill", "die", "trash", "dirty", "dog", "pig", "scum", "hate"])
    cue_density = cues / n_words

    base_feats = np.array([
        math.log1p(length), math.log1p(n_words), upper, digits, punct,
        upper_ratio, digit_ratio, punct_ratio,
        slur_hits, slur_density, cues, cue_density
    ], dtype=np.float32)

    if base_feats.shape[0] < HEURISTIC_DIM:
        pad = np.zeros(HEURISTIC_DIM - base_feats.shape[0], dtype=np.float32)
        feats = np.concatenate([base_feats, pad])
    else:
        feats = base_feats[:HEURISTIC_DIM]
    return feats

# -------------------------
# Data collator (stacks heuristic feats)
# -------------------------
@dataclass
class SentinelCollator:
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        out = {}
        for k in ("input_ids", "attention_mask", "labels"):
            out[k] = torch.stack([b[k] for b in batch])
        feats = [
            b["heuristic_feats"] if isinstance(b["heuristic_feats"], torch.Tensor)
            else torch.tensor(b["heuristic_feats"], dtype=torch.float32)
            for b in batch
        ]
        out["heuristic_feats"] = torch.stack(feats).to(torch.float32)
        return out

collator = SentinelCollator()

# -------------------------
# Load data (raw + tokenized with heuristic feats)
# -------------------------
def _load_data():
    assert os.path.isfile(TEST_FILE), "Test CSV not found."
    data_files = {}
    if os.path.isfile(TRAIN_FILE): data_files['train'] = TRAIN_FILE
    if os.path.isfile(VAL_FILE):   data_files['validation'] = VAL_FILE
    data_files['test'] = TEST_FILE
    raw = load_dataset("csv", data_files=data_files)

    # try to read base model name from sentinel checkpoint config
    try:
        cfg_on_disk = SentinelConfig.from_pretrained(SENTINEL_CHECKPOINT)
        base_model = cfg_on_disk.base_model_name
        num_labels = cfg_on_disk.num_labels
        h_dim = cfg_on_disk.heuristic_dim
        global HEURISTIC_DIM
        HEURISTIC_DIM = h_dim  # sync to saved config
    except Exception:
        base_model = "xlm-roberta-base"
        num_labels = 2

    tok = AutoTokenizer.from_pretrained(base_model)
    def tok_map(batch):
        enc = tok(batch["text"], padding="max_length", truncation=True, max_length=128)
        enc["heuristic_feats"] = [build_heuristic_features(t) for t in batch["text"]]
        return enc

    tokenized = raw.map(tok_map, batched=True)
    if "label" in tokenized["test"].column_names:
        tokenized = tokenized.rename_column("label", "labels")
    tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "heuristic_feats"])
    return raw, tokenized, tok, base_model, num_labels

raw_datasets, tokenized_datasets, tokenizer, BASE_MODEL_NAME, NUM_LABELS = _load_data()

# -------------------------
# Load Sentinel model from checkpoint
# -------------------------
def load_sentinel_for_eval(checkpoint_dir: str):
    cfg = SentinelConfig.from_pretrained(checkpoint_dir)
    model = SentinelModel.from_pretrained(checkpoint_dir, config=cfg)  # will load weights if saved with Trainer
    model.to(device)
    return model, cfg

model, cfg_loaded = load_sentinel_for_eval(SENTINEL_CHECKPOINT)

eval_args = TrainingArguments(
    output_dir="./tmp_eval_sentinel",
    per_device_eval_batch_size=32,
    do_train=False, do_eval=True,
    report_to="none",
    logging_strategy="no",
    disable_tqdm=True
)

trainer = Trainer(
    model=model,
    args=eval_args,
    data_collator=collator,
    tokenizer=tokenizer,  # fine for eval
    compute_metrics=None
)

# -------------------------
# Attack tools (structural / semantic / feature-targeted)
# -------------------------
structural_typo   = nac.KeyboardAug()
structural_insert = nac.RandomCharAug(action="insert")

def attack_structural_typo(text):
    try: return structural_typo.augment(text)[0]
    except Exception: return text

def attack_structural_insert(text):
    try: return structural_insert.augment(text)[0]
    except Exception: return text

def attack_structural_case(text, p_char=0.15):
    if not isinstance(text, str): return text
    out = []
    for ch in text:
        if ch.isalpha() and random.random() < p_char:
            out.append(ch.upper() if ch.islower() else ch.lower())
        else:
            out.append(ch)
    return "".join(out)

synonym_aug = SynonymAug(aug_src="wordnet", aug_p=0.15)
def attack_semantic_synonym(text):
    try:
        out = synonym_aug.augment(text)
        return out[0] if isinstance(out, list) and len(out) > 0 else out
    except Exception:
        if not isinstance(text, str): return text
        toks = text.split()
        idxs = list(range(len(toks))); random.shuffle(idxs)
        for i in idxs:
            tok = toks[i]
            synsets = wn.synsets(tok)
            if not synsets: continue
            lemmas = []
            for syn in synsets[:3]:
                for l in syn.lemmas():
                    cand = l.name().replace('_', ' ')
                    if cand.lower() != tok.lower():
                        lemmas.append(cand)
            if lemmas:
                toks[i] = random.choice(lemmas); break
        return " ".join(toks)

CODED_LEXICON = {
    r'\bmuslims?\b': 'skittles',
    r'\bblack( people)?\b': 'googles',
    r'\bjews?\b': 'skypes',
    r'\bmexicans?\b': 'bings',
}
CODED_PATTERNS = [(re.compile(k, flags=re.IGNORECASE), v) for k, v in CODED_LEXICON.items()]
def attack_semantic_coded(text):
    if not isinstance(text, str): return text
    out = text
    for pat, repl in CODED_PATTERNS:
        out = pat.sub(repl, out)
    return out

SLUR_LEXICON = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
SLUR_PATTERNS = [re.compile(p, flags=re.IGNORECASE) for p in SLUR_LEXICON]
def attack_feature_slur_removal(text, replacement='[removed]'):
    if not isinstance(text, str): return text
    out = text
    for pat in SLUR_PATTERNS:
        out = pat.sub(replacement, out)
    return out

def apply_attack_texts(texts, attack_fn, desc="Attacking"):
    return [attack_fn(t) for t in tqdm(texts, desc=desc, leave=False)]

# -------------------------
# Per-class metrics helper
# -------------------------
def _per_class_metrics(y_true, y_pred, labels=(0, 1)):
    p, r, f1, sup = precision_recall_fscore_support(y_true, y_pred, labels=list(labels), zero_division=0)
    return [
        {"label": int(lbl), "precision": float(p[i]), "recall": float(r[i]), "f1": float(f1[i]), "support": int(sup[i])}
        for i, lbl in enumerate(labels)
    ]

# -------------------------
# Eval helpers
# -------------------------
def evaluate_on_texts(trainer, tokenizer, texts, labels, max_length=128):
    # Re-tokenize and rebuild heuristic feats for the (possibly) attacked texts
    enc = tokenizer(texts, padding='max_length', truncation=True, max_length=max_length)
    feats = [build_heuristic_features(t) for t in texts]
    ds = Dataset.from_dict({
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "heuristic_feats": feats,
        "label": labels
    })
    ds = ds.rename_column("label", "labels")
    ds.set_format(type="torch", columns=["input_ids","attention_mask","labels","heuristic_feats"])
    preds_out = trainer.predict(ds)
    logits = preds_out.predictions
    preds = np.argmax(logits, axis=-1)
    return preds, logits

# -------------------------
# Attack runner (Sentinel only)
# -------------------------
def run_attack_and_report(
    trainer,
    tokenizer,
    raw_test_dataset,
    tokenized_test_dataset,
    attack_fn,
    attack_name,
    attack_mode='tp',           # 'tp' (True positives only) or 'full'
    save_dir="attack_reports/sentinel"
):
    # Clean predictions (fast) on pre-tokenized test set
    clean_out = trainer.predict(tokenized_test_dataset)
    clean_logits = clean_out.predictions
    clean_preds  = np.argmax(clean_logits, axis=-1)
    clean_labels = clean_out.label_ids

    # Per-class on clean
    clean_per_class = _per_class_metrics(clean_labels, clean_preds)

    # Raw texts/labels
    raw_texts  = list(raw_test_dataset['text'])
    raw_labels = list(raw_test_dataset['label'])
    n = len(raw_texts)
    assert n == len(clean_preds) == len(clean_labels)

    # Target indices
    if attack_mode == 'tp':
        target_indices = [i for i, (lab, pred) in enumerate(zip(clean_labels, clean_preds)) if lab == 1 and pred == 1]
    elif attack_mode == 'full':
        target_indices = list(range(n))
    else:
        raise ValueError("attack_mode must be 'tp' or 'full'")

    if len(target_indices) == 0:
        os.makedirs(save_dir, exist_ok=True)
        report = {
            "attack_name": attack_name, "attack_mode": attack_mode,
            "clean_per_class": clean_per_class,
            "attacked_per_class": None,
            "attack_changed": 0,
            "tp_count_targeted": 0,
            "samples_flipped": None,
            "attack_success_rate": None,
            "note": "no targets"
        }
        with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f:
            json.dump(report, f, indent=2)
        return report

    # Apply attack to targeted subset
    attacked_texts = raw_texts.copy()
    to_attack = [raw_texts[i] for i in target_indices]
    attacked_outs = apply_attack_texts(to_attack, attack_fn, desc=attack_name)
    attack_count = 0
    for idx, new_text in zip(target_indices, attacked_outs):
        if new_text != raw_texts[idx]:
            attacked_texts[idx] = new_text
            attack_count += 1

    if attack_count == 0:
        os.makedirs(save_dir, exist_ok=True)
        report = {
            "attack_name": attack_name, "attack_mode": attack_mode,
            "clean_per_class": clean_per_class,
            "attacked_per_class": None,
            "attack_changed": 0,
            "tp_count_targeted": len(target_indices),
            "samples_flipped": None,
            "attack_success_rate": None,
            "note": "no modification made by attack_fn"
        }
        with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f:
            json.dump(report, f, indent=2)
        return report

    # Evaluate attacked set (re-tokenize + rebuild heuristic feats)
    attacked_preds, attacked_logits = evaluate_on_texts(trainer, tokenizer, attacked_texts, raw_labels)
    attacked_per_class = _per_class_metrics(raw_labels, attacked_preds)

    # ASR for TP mode
    asr = None; samples_flipped = None
    if attack_mode == 'tp':
        attacked_for_targets_preds = [attacked_preds[i] for i in target_indices]
        samples_flipped = sum(1 for p in attacked_for_targets_preds if p == 0)
        asr = samples_flipped / len(target_indices)

    report = {
        "attack_name": attack_name,
        "attack_mode": attack_mode,
        "clean_per_class": clean_per_class,
        "attacked_per_class": attacked_per_class,
        "attack_changed": attack_count,
        "tp_count_targeted": len(target_indices),
        "samples_flipped": int(samples_flipped) if samples_flipped is not None else None,
        "attack_success_rate": float(asr) if asr is not None else None
    }
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, f"{attack_name}_{attack_mode}.json"), "w") as f:
        json.dump(report, f, indent=2)
    return report

def run_all_attacks_sentinel(
    trainer,
    tokenizer,
    raw_test_dataset,
    tokenized_test_dataset,
    save_dir="attack_reports/sentinel",
    attack_mode='tp',
    attacks_to_run=None
):
    os.makedirs(save_dir, exist_ok=True)
    attacks = attacks_to_run or [
        (attack_structural_typo,          "structural_typo"),
        (attack_structural_insert,        "structural_insert"),
        (attack_structural_case,          "structural_case"),
        (attack_semantic_synonym,         "semantic_synonym"),
        (attack_semantic_coded,           "semantic_coded"),
        (lambda t: attack_feature_slur_removal(t, "[removed]"), "feature_slur_removal"),
    ]
    suite_report = {
        "model": "SentinelModel",
        "checkpoint": SENTINEL_CHECKPOINT,
        "attack_mode": attack_mode,
        "attacks": []
    }
    for fn, name in attacks:
        print(f"Running attack: {name} (mode={attack_mode})")
        rpt = run_attack_and_report(trainer, tokenizer, raw_test_dataset, tokenized_test_dataset, fn, name, attack_mode, save_dir)
        if rpt is not None:
            suite_report["attacks"].append(rpt)
    with open(os.path.join(save_dir, f"suite_report_sentinel_{attack_mode}.json"), "w") as f:
        json.dump(suite_report, f, indent=2)
    return suite_report

# -------------------------
# RUN (Sentinel only)
# -------------------------
print("Sentinel eval trainer ready. Running attacks...")
sentinel_report = run_all_attacks_sentinel(
    trainer=trainer,
    tokenizer=tokenizer,
    raw_test_dataset=raw_datasets["test"],
    tokenized_test_dataset=tokenized_datasets["test"],
    save_dir="attack_reports/sentinel",
    attack_mode="tp"   # change to "full" to attack all test texts
)
print("✓ Sentinel attacks finished. Reports -> attack_reports/sentinel")


In [None]:
#from google.colab import drive
#drive.mount('/content/drive', force_remount=False)

import os, shutil

src = "/content/attack_reports"                     # where your reports live in Colab
dst = "/content/drive/MyDrive/hate/attack_reports"  # destination in Drive

if not os.path.exists(src):
    raise FileNotFoundError(f"Source folder not found: {src}")

os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copytree(src, dst, dirs_exist_ok=True)  # dirs_exist_ok needs Python 3.8+

print(f"✓ Copied '{src}' → '{dst}'")


In [None]:
# === 3-Model ROC & PR Curves (600 dpi, saves PNG/PDF) ===
# Paths
DRIVE_BASE           = "/content/drive/MyDrive/hate"
TRAIN_FILE           = f"{DRIVE_BASE}/train.csv"
VAL_FILE             = f"{DRIVE_BASE}/val.csv"
TEST_FILE            = f"{DRIVE_BASE}/test.csv"

BASELINE_CKPT        = f"{DRIVE_BASE}/models/step1_bert_baseline"
GATED_CKPT           = f"{DRIVE_BASE}/models/step1_gated_fusion"
SENTINEL_CKPT        = f"{DRIVE_BASE}/models/sentinel_xlmr"

FIG_DIR              = f"{DRIVE_BASE}/fig"

# Deps
!pip -q install transformers datasets scikit-learn matplotlib

import os, re, math, json, numpy as np, torch, torch.nn as nn
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, average_precision_score

from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    AutoModelForSequenceClassification,
    PretrainedConfig, PreTrainedModel,
    TrainingArguments, Trainer
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(FIG_DIR, exist_ok=True)

# -------------------------
# Load test data
# -------------------------
assert os.path.isfile(TEST_FILE), "TEST_FILE not found."
raw = load_dataset("csv", data_files={"test": TEST_FILE})
assert "text" in raw["test"].column_names and "label" in raw["test"].column_names, "CSV must have 'text' and 'label' columns."

texts  = list(raw["test"]["text"])
y_true = np.array(list(raw["test"]["label"]), dtype=int)

# -------------------------
# Sentinel model/types (must match training)
# -------------------------
class SentinelConfig(PretrainedConfig):
    model_type = "sentinel_fusion"
    def __init__(
        self,
        base_model_name: str = "xlm-roberta-base",
        num_labels: int = 2,
        heuristic_dim: int = 32,
        heuristic_hidden: int = 256,
        causal_hidden: int = 256,
        attn_heads: int = 8,
        aux_causal_loss_weight: float = 0.2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.heuristic_dim = heuristic_dim
        self.heuristic_hidden = heuristic_hidden
        self.causal_hidden = causal_hidden
        self.attn_heads = attn_heads
        self.aux_causal_loss_weight = aux_causal_loss_weight

class SentinelModel(PreTrainedModel):
    config_class = SentinelConfig
    def __init__(self, config: SentinelConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size

        self.heuristic_proj = nn.Sequential(
            nn.Linear(config.heuristic_dim, config.heuristic_hidden),
            nn.ReLU(),
            nn.Linear(config.heuristic_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.causal_mlp = nn.Sequential(
            nn.Linear(hidden, config.causal_hidden),
            nn.ReLU(),
            nn.Linear(config.causal_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.causal_aux_head = nn.Linear(hidden, 2)
        self.xattn = nn.MultiheadAttention(embed_dim=hidden, num_heads=config.attn_heads, batch_first=True)
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        heuristic_feats=None,
        labels=None,
        **kwargs
    ):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden = enc.last_hidden_state
        h_cls = last_hidden[:, 0, :]
        h_heu = self.heuristic_proj(heuristic_feats)
        h_cau = self.causal_mlp(h_cls)
        Q  = h_cls.unsqueeze(1)
        KV = torch.stack([h_heu, h_cau], dim=1)
        fused, _ = self.xattn(Q, KV, KV)
        fused = fused.squeeze(1)
        logits = self.classifier(self.dropout(fused))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

# -------------------------
# Gated fusion model/types (from earlier)
# -------------------------
class GFConfig(PretrainedConfig):
    model_type = "gated_fusion_wrapper"
    def __init__(self, base_model_name: str = "bert-base-multilingual-cased", num_labels: int = 2, gate_hidden: int = 256, **kwargs):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.gate_hidden = gate_hidden

class GatedFusionForSequenceClassification(PreTrainedModel):
    config_class = GFConfig
    def __init__(self, config: GFConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size
        self.gate_mlp = nn.Sequential(
            nn.Linear(2 * hidden, config.gate_hidden),
            nn.ReLU(),
            nn.Linear(config.gate_hidden, hidden),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    @staticmethod
    def masked_mean(last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        allowed = {"position_ids","head_mask","inputs_embeds","output_attentions","output_hidden_states","return_dict","past_key_values","encoder_hidden_states","encoder_attention_mask"}
        safe_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
        safe_kwargs.pop("num_items_in_batch", None)

        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **safe_kwargs)
        h_cls = enc.last_hidden_state[:, 0, :]
        h_mean = self.masked_mean(enc.last_hidden_state, attention_mask)
        gate_inp = torch.cat([h_cls, h_mean], dim=-1)
        g = self.gate_mlp(gate_inp)
        fused = g * h_cls + (1.0 - g) * h_mean
        logits = self.classifier(self.dropout(fused))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

# -------------------------
# Heuristic features (Sentinel) — keep in sync with training
# -------------------------
_SLUR_REGEXES = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
_SLUR_PATTERNS = [re.compile(p, re.IGNORECASE) for p in _SLUR_REGEXES]

def _build_heuristic_features(text: str, H: int) -> np.ndarray:
    if not isinstance(text, str):
        text = "" if text is None else str(text)
    length = len(text)
    words  = text.split()
    n_words = max(1, len(words))
    upper  = sum(1 for c in text if c.isalpha() and c.isupper())
    digits = sum(1 for c in text if c.isdigit())
    punct  = sum(1 for c in text if c in ".,;:!?")
    alpha_count = max(1, sum(c.isalpha() for c in text))
    upper_ratio = upper / alpha_count
    digit_ratio = digits / max(1, len(text))
    punct_ratio = punct / max(1, len(text))
    slur_hits = 0
    for pat in _SLUR_PATTERNS:
        slur_hits += len(pat.findall(text))
    slur_density = slur_hits / n_words
    cues = sum(text.lower().count(k) for k in ["kill", "die", "trash", "dirty", "dog", "pig", "scum", "hate"])
    cue_density = cues / n_words
    base_feats = np.array([
        math.log1p(length), math.log1p(n_words), upper, digits, punct,
        upper_ratio, digit_ratio, punct_ratio,
        slur_hits, slur_density, cues, cue_density
    ], dtype=np.float32)
    if base_feats.shape[0] < H:
        pad = np.zeros(H - base_feats.shape[0], dtype=np.float32)
        feats = np.concatenate([base_feats, pad])
    else:
        feats = base_feats[:H]
    return feats

@dataclass
class SentinelCollator:
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        out = {
            "input_ids": torch.stack([b["input_ids"] for b in batch]),
            "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
            "labels": torch.stack([b["labels"] for b in batch])
        }
        feats = [
            (b["heuristic_feats"] if isinstance(b["heuristic_feats"], torch.Tensor)
             else torch.tensor(b["heuristic_feats"], dtype=torch.float32))
            for b in batch
        ]
        out["heuristic_feats"] = torch.stack(feats).to(torch.float32)
        return out

# -------------------------
# Utility: get probabilities for positive class (label=1)
# -------------------------
def _softmax_np(logits: np.ndarray) -> np.ndarray:
    z = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(z)
    p = e / e.sum(axis=1, keepdims=True)
    return p

def predict_probs_baseline_or_gated(ckpt: str, texts: list) -> np.ndarray:
    tokenizer = AutoTokenizer.from_pretrained(ckpt)
    # Try plain classifier first; fall back to custom gated head if needed
    try:
        model = AutoModelForSequenceClassification.from_pretrained(ckpt)
    except Exception:
        cfg = GFConfig.from_pretrained(ckpt)
        model = GatedFusionForSequenceClassification.from_pretrained(ckpt, config=cfg)
    model.to(device).eval()

    # Return lists (not tensors) for Dataset.from_dict
    enc = tokenizer(texts, padding="max_length", truncation=True, max_length=128)
    ds = Dataset.from_dict({
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "labels": y_true.tolist(),   # already named 'labels' -> no rename
    })
    ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    args = TrainingArguments(
        output_dir="./tmp_eval_generic",
        per_device_eval_batch_size=64,
        report_to="none", logging_strategy="no", disable_tqdm=True
    )
    tr = Trainer(model=model, args=args, compute_metrics=None, tokenizer=tokenizer)
    out = tr.predict(ds)
    probs = _softmax_np(out.predictions)[:, 1]
    return probs


def predict_probs_sentinel(ckpt: str, texts: list) -> np.ndarray:
    cfg = SentinelConfig.from_pretrained(ckpt)
    tokenizer = AutoTokenizer.from_pretrained(cfg.base_model_name)
    model = SentinelModel.from_pretrained(ckpt, config=cfg).to(device).eval()

    enc = tokenizer(texts, padding="max_length", truncation=True, max_length=128)
    feats = [_build_heuristic_features(t, cfg.heuristic_dim) for t in texts]
    ds = Dataset.from_dict({
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "heuristic_feats": feats,
        "labels": y_true.tolist(),   # already 'labels'
    })
    ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "heuristic_feats"])

    args = TrainingArguments(
        output_dir="./tmp_eval_sentinel",
        per_device_eval_batch_size=64,
        report_to="none", logging_strategy="no", disable_tqdm=True
    )
    tr = Trainer(
        model=model, args=args, compute_metrics=None,
        data_collator=SentinelCollator(), tokenizer=tokenizer
    )
    out = tr.predict(ds)
    probs = _softmax_np(out.predictions)[:, 1]
    return probs


# -------------------------
# Get probabilities for each model
# -------------------------
print("Scoring test set with Baseline...")
probs_baseline = predict_probs_baseline_or_gated(BASELINE_CKPT, texts)

print("Scoring test set with Gated Fusion...")
probs_gated = predict_probs_baseline_or_gated(GATED_CKPT, texts)

print("Scoring test set with Sentinel...")
probs_sentinel = predict_probs_sentinel(SENTINEL_CKPT, texts)

# -------------------------
# Compute curves & metrics
# -------------------------
def compute_roc_pr(y_true, y_score):
    fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=1)
    roc_auc = roc_auc_score(y_true, y_score)
    prec, rec, _ = precision_recall_curve(y_true, y_score, pos_label=1)
    ap = average_precision_score(y_true, y_score)  # area under PR (AP)
    return (fpr, tpr, roc_auc), (rec, prec, ap)

roc_pr = {}
roc_pr["Baseline"]  = compute_roc_pr(y_true, probs_baseline)
roc_pr["Gated"]     = compute_roc_pr(y_true, probs_gated)
roc_pr["Sentinel"]  = compute_roc_pr(y_true, probs_sentinel)

# -------------------------
# Plot ROC (all 3) — 600 dpi
# -------------------------
plt.figure(figsize=(6, 6), dpi=600)
for name, ((fpr, tpr, aucv), _) in roc_pr.items():
    plt.plot(fpr, tpr, label=f"{name} (AUC={aucv:.3f})", linewidth=1.5)
plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1.0)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves (Test)")
plt.legend(loc="lower right", frameon=False)
plt.tight_layout()
roc_png = os.path.join(FIG_DIR, "roc_three_models.png")
roc_pdf = os.path.join(FIG_DIR, "roc_three_models.pdf")
plt.savefig(roc_png, dpi=600)
plt.savefig(roc_pdf, dpi=600)
plt.close()

# -------------------------
# Plot Precision–Recall (all 3) — 600 dpi
# -------------------------
plt.figure(figsize=(6, 6), dpi=600)
for name, (_, (rec, prec, ap)) in roc_pr.items():
    plt.plot(rec, prec, label=f"{name} (AP={ap:.3f})", linewidth=1.5)
# Baseline of PR depends on class prevalence; no trivial diagonal
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall Curves (Test)")
plt.legend(loc="lower left", frameon=False)
plt.tight_layout()
pr_png = os.path.join(FIG_DIR, "pr_three_models.png")
pr_pdf = os.path.join(FIG_DIR, "pr_three_models.pdf")
plt.savefig(pr_png, dpi=600)
plt.savefig(pr_pdf, dpi=600)
plt.close()

print("Saved figures:")
print(" -", roc_png)
print(" -", roc_pdf)
print(" -", pr_png)
print(" -", pr_pdf)


In [None]:
# ================================
# BAR CHARTS: Clean vs Adv + Attack-wise (ASR)
# Saves 600dpi figs to /content/drive/MyDrive/hate/figs
# ================================

import os, json, math, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 0) Mount Drive (skip if already mounted)
#from google.colab import drive
#drive.mount('/content/drive')

BASE = "/content/drive/MyDrive/hate"
REPORT_ROOT = os.path.join(BASE, "attack_reports")
FIG_DIR = os.path.join(BASE, "figs")
os.makedirs(FIG_DIR, exist_ok=True)

# 1) Discover model report folders (e.g., baseline, gated, sentinel)
if not os.path.isdir(REPORT_ROOT):
    raise FileNotFoundError(f"Report root not found: {REPORT_ROOT}")

model_dirs = [d for d in glob.glob(os.path.join(REPORT_ROOT, "*")) if os.path.isdir(d)]
if not model_dirs:
    raise RuntimeError(f"No model subfolders under {REPORT_ROOT}")

# Friendly labels (optional mapping)
label_map = {
    "baseline": "Baseline",
    "gated": "Gated Fusion",
    "sentinel": "Sentinel",
}
def pretty_model_name(path):
    name = os.path.basename(path)
    return label_map.get(name.lower(), name)

# 2) Parse attack JSONs into a tidy DataFrame
rows = []
for mdir in model_dirs:
    model_name = pretty_model_name(mdir)
    for f in glob.glob(os.path.join(mdir, "*.json")):
        # Skip suite report jsons
        base = os.path.basename(f)
        if base.startswith("suite_report"):
            continue
        try:
            with open(f, "r") as fp:
                data = json.load(fp)
        except Exception:
            continue

        attack_name = data.get("attack_name", os.path.splitext(base)[0])

        # Compute macro-F1 from per-class metrics arrays
        clean_pc = data.get("clean_per_class", [])
        attacked_pc = data.get("attacked_per_class", [])

        def macro_f1(pc):
            if isinstance(pc, list) and pc and isinstance(pc[0], dict) and "f1" in pc[0]:
                return float(np.mean([float(x.get("f1", 0.0)) for x in pc]))
            return np.nan

        clean_f1 = macro_f1(clean_pc)
        attacked_f1 = macro_f1(attacked_pc)
        delta_f1 = np.nan
        if not math.isnan(clean_f1) and not math.isnan(attacked_f1):
            delta_f1 = attacked_f1 - clean_f1

        asr = data.get("attack_success_rate", None)
        if asr is None:
            asr = np.nan

        rows.append({
            "model": model_name,
            "attack": attack_name,
            "clean_macro_f1": clean_f1,
            "attacked_macro_f1": attacked_f1,
            "delta_macro_f1": delta_f1,
            "attack_success_rate": asr
        })

df = pd.DataFrame(rows)
if df.empty:
    raise RuntimeError("No usable attack JSON files found under attack_reports/*/*.json.")

# Order models nicely if present
model_order = [lbl for key, lbl in label_map.items() if lbl in df["model"].unique()]
if not model_order:
    model_order = sorted(df["model"].unique())

# -------------------------------
# Figure 1: Mean clean vs attacked F1 per model
# -------------------------------
agg = df.groupby("model", as_index=False)[["clean_macro_f1", "attacked_macro_f1"]].mean(numeric_only=True)
agg["model"] = pd.Categorical(agg["model"], categories=model_order, ordered=True)
agg = agg.sort_values("model")

plt.figure(figsize=(7, 5))
x = np.arange(len(agg))
w = 0.35
plt.bar(x - w/2, agg["clean_macro_f1"], width=w, label="Clean macro-F1")
plt.bar(x + w/2, agg["attacked_macro_f1"], width=w, label="Attacked macro-F1")
plt.xticks(x, agg["model"], rotation=0)
plt.ylabel("Macro-F1")
plt.title("Clean vs. Adversarial (Mean Macro-F1) by Model")
plt.legend()
plt.tight_layout()
fig1_path = os.path.join(FIG_DIR, "clean_vs_adv_f1_by_model.png")
plt.savefig(fig1_path, dpi=600)
plt.close()

# -------------------------------
# Figure 2: Attack-wise comparison across models
# Prefer ASR; fallback to normalized -ΔF1 if ASR missing
# -------------------------------
df_asr = df.copy()

# Fallback: if ASR missing, use normalized -delta_f1 (clip to [0,1]) as a rough proxy
mask_missing_asr = df_asr["attack_success_rate"].isna()
if mask_missing_asr.any():
    approx_asr = (-df_asr.loc[mask_missing_asr, "delta_macro_f1"]).clip(lower=0.0, upper=1.0)
    df_asr.loc[mask_missing_asr, "attack_success_rate"] = approx_asr

# Keep common attacks across models for fair plotting (or all if none common)
common_attacks = sorted(
    set.intersection(*[set(df_asr[df_asr["model"] == m]["attack"].unique()) for m in df_asr["model"].unique()])
) or sorted(df_asr["attack"].unique())

plot_df = df_asr[df_asr["attack"].isin(common_attacks)].copy()
plot_df["model"] = pd.Categorical(plot_df["model"], categories=model_order, ordered=True)
plot_df = plot_df.sort_values(["attack", "model"])

attacks = common_attacks
M = len(model_order)
N = len(attacks)
bar_w = 0.8 / max(1, M)
indices = np.arange(N)

plt.figure(figsize=(max(8, N * 0.6), 5))
for i, m in enumerate(model_order):
    vals = []
    for a in attacks:
        row = plot_df[(plot_df["model"] == m) & (plot_df["attack"] == a)]
        vals.append(float(row["attack_success_rate"].iloc[0]) if not row.empty else np.nan)
    vals = np.array(vals, dtype=float)
    plt.bar(indices + i * bar_w - (M-1)*bar_w/2, vals, width=bar_w, label=m)

plt.xticks(indices, attacks, rotation=35, ha="right")
plt.ylabel("Attack Success Rate (ASR)")
plt.title("Attack-wise Comparison by Model")
plt.legend()
plt.tight_layout()
fig2_path = os.path.join(FIG_DIR, "asr_by_attack_and_model.png")
plt.savefig(fig2_path, dpi=600)
plt.close()

print("Saved 600dpi figures to:")
print(" -", fig1_path)
print(" -", fig2_path)


In [None]:
# ================================
# TABULAR VIEW: per-attack + grouped (3 categories)
# ================================
import os, json, glob, math
import numpy as np
import pandas as pd

# Mount Drive (skip if already mounted)
#from google.colab import drive
#drive.mount('/content/drive')

BASE = "/content/drive/MyDrive/hate"
REPORT_ROOT = os.path.join(BASE, "attack_reports")

# --- discover model subfolders
if not os.path.isdir(REPORT_ROOT):
    raise FileNotFoundError(f"Report root not found: {REPORT_ROOT}")

model_dirs = [d for d in glob.glob(os.path.join(REPORT_ROOT, "*")) if os.path.isdir(d)]
if not model_dirs:
    raise RuntimeError(f"No model subfolders under {REPORT_ROOT}")

# Pretty names (optional)
label_map = {
    "baseline": "Baseline",
    "gated": "Gated Fusion",
    "sentinel": "Sentinel",
}
def pretty_model_name(path):
    name = os.path.basename(path)
    return label_map.get(name.lower(), name)

# --- attack -> category mapping
def attack_category(name: str) -> str:
    n = name.lower()
    if n in {"structural_typo", "structural_insert", "structural_case"}:
        return "Structural"
    if n in {"semantic_synonym", "semantic_coded"}:
        return "Semantic/Cue"
    if n in {"feature_slur_removal"}:
        return "Feature-Targeted"
    # fallback: guess by keywords
    if "structural" in n or "typo" in n or "insert" in n or "case" in n:
        return "Structural"
    if "synonym" in n or "coded" in n or "semantic" in n:
        return "Semantic/Cue"
    if "slur" in n or "feature" in n:
        return "Feature-Targeted"
    return "Uncategorized"

# --- helpers
def macro_f1_from_perclass(per_class_list):
    if isinstance(per_class_list, list) and per_class_list and isinstance(per_class_list[0], dict):
        vals = [float(x.get("f1", 0.0)) for x in per_class_list]
        return float(np.mean(vals)) if len(vals) else np.nan
    return np.nan

# --- parse all JSONs
rows = []
for mdir in model_dirs:
    model_name = pretty_model_name(mdir)
    for f in glob.glob(os.path.join(mdir, "*.json")):
        base = os.path.basename(f)
        if base.startswith("suite_report"):
            continue
        try:
            with open(f, "r") as fp:
                data = json.load(fp)
        except Exception:
            continue

        attack_name = data.get("attack_name", os.path.splitext(base)[0])

        clean_f1 = macro_f1_from_perclass(data.get("clean_per_class", []))
        attacked_f1 = macro_f1_from_perclass(data.get("attacked_per_class", []))
        dF1 = np.nan
        if not math.isnan(clean_f1) and not math.isnan(attacked_f1):
            dF1 = attacked_f1 - clean_f1

        asr = data.get("attack_success_rate", np.nan)

        rows.append({
            "model": model_name,
            "attack": attack_name,
            "category": attack_category(attack_name),
            "clean_macro_f1": clean_f1,
            "attacked_macro_f1": attacked_f1,
            "delta_macro_f1": dF1,
            "ASR": asr
        })

df = pd.DataFrame(rows)
if df.empty:
    raise RuntimeError("No usable attack JSON files found.")

# --- tidy ordering
model_order = [lbl for lbl in ["Baseline", "Gated Fusion", "Sentinel"] if lbl in df["model"].unique()]
if not model_order:
    model_order = sorted(df["model"].unique())
cat_order = ["Structural", "Semantic/Cue", "Feature-Targeted", "Uncategorized"]

df["model"] = pd.Categorical(df["model"], categories=model_order, ordered=True)
df["category"] = pd.Categorical(df["category"], categories=cat_order, ordered=True)
df = df.sort_values(["model", "category", "attack"]).reset_index(drop=True)

# --- show per-attack table (rounded)
per_attack_cols = ["model", "category", "attack", "clean_macro_f1", "attacked_macro_f1", "delta_macro_f1", "ASR"]
per_attack_df = df[per_attack_cols].copy()
per_attack_df[["clean_macro_f1","attacked_macro_f1","delta_macro_f1","ASR"]] = \
    per_attack_df[["clean_macro_f1","attacked_macro_f1","delta_macro_f1","ASR"]].round(4)

print("=== Per-attack metrics (Clean vs Attacked, ΔF1, ASR) ===")
display(per_attack_df)

# --- grouped summary: mean across attacks per (model, category)
grouped = (
    df.groupby(["model","category"], as_index=False)
      .agg(
          mean_clean_macro_f1 = ("clean_macro_f1", "mean"),
          mean_attacked_macro_f1 = ("attacked_macro_f1", "mean"),
          mean_delta_macro_f1 = ("delta_macro_f1", "mean"),
          mean_ASR = ("ASR", "mean"),
          n_attacks = ("attack", "nunique")
      )
)
for c in ["mean_clean_macro_f1","mean_attacked_macro_f1","mean_delta_macro_f1","mean_ASR"]:
    grouped[c] = grouped[c].round(4)

grouped = grouped.sort_values(["model","category"])
print("\n=== Grouped summary (means over attacks) by Model × Category ===")
display(grouped)


In [None]:
# --- FIXED grouped summary (robust to pandas versions) ---
# Recompute from `df` you already built

import pandas as pd
import numpy as np

# Ensure numeric dtypes
for c in ["clean_macro_f1","attacked_macro_f1","delta_macro_f1","ASR"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

grouped = (
    df.groupby(["model","category"], observed=False)
      .agg({
          "clean_macro_f1": "mean",
          "attacked_macro_f1": "mean",
          "delta_macro_f1": "mean",
          "ASR": "mean",
          "attack": "nunique"
      })
      .rename(columns={
          "clean_macro_f1": "mean_clean_macro_f1",
          "attacked_macro_f1": "mean_attacked_macro_f1",
          "delta_macro_f1": "mean_delta_macro_f1",
          "ASR": "mean_ASR",
          "attack": "n_attacks"
      })
      .reset_index()
      .sort_values(["model","category"])
)

# Round for display
for c in ["mean_clean_macro_f1","mean_attacked_macro_f1","mean_delta_macro_f1","mean_ASR"]:
    grouped[c] = grouped[c].round(4)

print("\n=== Grouped summary (means over attacks) by Model × Category ===")
display(grouped)


In [None]:
# ================================
# HateEval labeled evaluation (3 models): Baseline, Gated, Sentinel
# Saves predictions + metrics; prints tidy DataFrames
# ================================
!pip -q install transformers datasets pandas scikit-learn

import os, re, math, json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from typing import List, Dict, Any
from datasets import Dataset
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, roc_auc_score, confusion_matrix
)
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig, AutoModelForSequenceClassification,
    PretrainedConfig, PreTrainedModel, TrainingArguments, Trainer
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------- Paths (edit if different) --------
BASE = "/content/drive/MyDrive/hate"
TEST_CSV = f"{BASE}/hateEval_test.csv"     # must have: id,text,label (label in {0,1})
CKPT_BASELINE = f"{BASE}/models/step1_bert_baseline"
CKPT_GATED    = f"{BASE}/models/step1_gated_fusion"
CKPT_SENTINEL = f"{BASE}/models/sentinel_xlmr"
OUT_DIR       = f"{BASE}/predictions"
os.makedirs(OUT_DIR, exist_ok=True)

# -------- Load HateEval test ----------
df = pd.read_csv(TEST_CSV)
assert {"id","text","label"}.issubset(df.columns), "hateEval_test.csv must have columns: id,text,label"
# Coerce labels to {0,1}
if df["label"].dtype != int and df["label"].dtype != np.int64:
    df["label"] = df["label"].astype(str).str.strip().map({"0":0,"1":1,"non-hate":0,"hate":1}).astype(int)

texts = df["text"].astype(str).tolist()
ids   = df["id"].tolist()
y_true = df["label"].astype(int).values

# =======================
# Gated-fusion definition
# =======================
class GFConfig(PretrainedConfig):
    model_type = "gated_fusion_wrapper"
    def __init__(self, base_model_name="bert-base-multilingual-cased", num_labels=2, gate_hidden=256, **kwargs):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.gate_hidden = gate_hidden

class GatedFusionForSequenceClassification(PreTrainedModel):
    config_class = GFConfig
    def __init__(self, config: GFConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size
        self.gate_mlp = nn.Sequential(
            nn.Linear(2 * hidden, config.gate_hidden),
            nn.ReLU(),
            nn.Linear(config.gate_hidden, hidden),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    @staticmethod
    def masked_mean(last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        allowed = {"position_ids","head_mask","inputs_embeds","output_attentions",
                   "output_hidden_states","return_dict","past_key_values",
                   "encoder_hidden_states","encoder_attention_mask"}
        safe_kwargs = {k:v for k,v in kwargs.items() if k in allowed}
        safe_kwargs.pop("num_items_in_batch", None)

        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **safe_kwargs)
        h_cls = enc.last_hidden_state[:, 0, :]
        h_mean = self.masked_mean(enc.last_hidden_state, attention_mask)
        g = self.gate_mlp(torch.cat([h_cls, h_mean], dim=-1))
        fused = g * h_cls + (1.0 - g) * h_mean
        logits = self.classifier(self.dropout(fused))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits.view(-1, self.config.num_labels), labels.view(-1))
        return {"loss": loss, "logits": logits}

# ===================
# Sentinel definition
# ===================
SLUR_REGEXES = [
    r"\bmongol(s)?\b", r"\bretard(s|ed)?\b", r"\btolol\b", r"\bkontol\b",
    r"\bbajingan\b", r"\bbabi\b", r"\bbhen ?chod\b", r"\bmadar ?chod\b",
    r"\brandi\b", r"\bperra\b", r"\bzorra\b", r"\bputa\b"
]
SLUR_PATTERNS = [re.compile(p, re.IGNORECASE) for p in SLUR_REGEXES]

def build_heuristic_features(text: str, H=32) -> np.ndarray:
    length = len(text)
    words  = text.split()
    n_words = max(1, len(words))
    upper = sum(1 for c in text if c.isalpha() and c.isupper())
    digits = sum(1 for c in text if c.isdigit())
    punct = sum(1 for c in text if c in ".,;:!?")
    upper_ratio = upper / max(1, sum(c.isalpha() for c in text))
    digit_ratio = digits / max(1, len(text))
    punct_ratio = punct / max(1, len(text))
    slur_hits = sum(len(pat.findall(text)) for pat in SLUR_PATTERNS)
    slur_density = slur_hits / n_words
    cues = sum(text.lower().count(k) for k in ["kill","die","trash","dirty","dog","pig","scum","hate"])
    cue_density = cues / n_words
    base_feats = np.array([
        length, n_words, upper, digits, punct,
        upper_ratio, digit_ratio, punct_ratio,
        slur_hits, slur_density, cues, cue_density
    ], dtype=np.float32)
    base_feats[0] = math.log1p(base_feats[0])
    base_feats[1] = math.log1p(base_feats[1])
    if base_feats.shape[0] < H:
        pad = np.zeros(H - base_feats.shape[0], dtype=np.float32)
        feats = np.concatenate([base_feats, pad])
    else:
        feats = base_feats[:H]
    return feats

class SentinelConfig(PretrainedConfig):
    model_type = "sentinel_fusion"
    def __init__(self,
        base_model_name="xlm-roberta-base",
        num_labels=2,
        heuristic_dim=32,
        heuristic_hidden=256,
        causal_hidden=256,
        attn_heads=8,
        aux_causal_loss_weight=0.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.num_labels = num_labels
        self.heuristic_dim = heuristic_dim
        self.heuristic_hidden = heuristic_hidden
        self.causal_hidden = causal_hidden
        self.attn_heads = attn_heads
        self.aux_causal_loss_weight = aux_causal_loss_weight

class SentinelModel(PreTrainedModel):
    config_class = SentinelConfig
    def __init__(self, config: SentinelConfig):
        super().__init__(config)
        self.base_cfg = AutoConfig.from_pretrained(config.base_model_name)
        self.encoder = AutoModel.from_pretrained(config.base_model_name, config=self.base_cfg)
        hidden = self.base_cfg.hidden_size

        self.heuristic_proj = nn.Sequential(
            nn.Linear(config.heuristic_dim, config.heuristic_hidden),
            nn.ReLU(),
            nn.Linear(config.heuristic_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.causal_mlp = nn.Sequential(
            nn.Linear(hidden, config.causal_hidden),
            nn.ReLU(),
            nn.Linear(config.causal_hidden, hidden),
            nn.LayerNorm(hidden)
        )
        self.xattn = nn.MultiheadAttention(embed_dim=hidden, num_heads=config.attn_heads, batch_first=True)
        self.dropout = nn.Dropout(getattr(self.base_cfg, "hidden_dropout_prob", 0.1))
        self.classifier = nn.Linear(hidden, config.num_labels)
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, heuristic_feats=None, labels=None, **kwargs):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = enc.last_hidden_state
        h_cls = last_hidden[:, 0, :]
        h_heu = self.heuristic_proj(heuristic_feats)
        h_cau = self.causal_mlp(h_cls)
        Q = h_cls.unsqueeze(1)
        KV = torch.stack([h_heu, h_cau], dim=1)
        fused, _ = self.xattn(Q, KV, KV)
        fused = fused.squeeze(1)
        logits = self.classifier(self.dropout(fused))
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

# -------------------------
# Helper: softmax to probs
# -------------------------
def to_probs(logits: np.ndarray) -> np.ndarray:
    e = np.exp(logits - logits.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

# ----------------------------------
# Predictors for each model family
# ----------------------------------
def predict_probs_baseline(checkpoint_dir: str, texts: List[str], max_length=128) -> np.ndarray:
    tok = AutoTokenizer.from_pretrained(checkpoint_dir if os.path.isdir(checkpoint_dir) else "bert-base-multilingual-cased")
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir).to(device).eval()
    enc = tok(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
    ds = Dataset.from_dict({k: enc[k].tolist() for k in enc})
    ds.set_format(type="torch", columns=["input_ids","attention_mask"])
    args = TrainingArguments(output_dir="./tmp_eval", per_device_eval_batch_size=64, report_to="none", logging_strategy="no", disable_tqdm=False)
    trainer = Trainer(model=model, args=args, tokenizer=tok)
    with torch.no_grad():
        out = trainer.predict(ds)
    return to_probs(out.predictions)

import math
import torch
import torch.nn as nn
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoConfig, AutoModel
from transformers import PretrainedConfig, PreTrainedModel

# --- keep your GFConfig and GatedFusionForSequenceClassification definitions as-is ---

def _softmax_np(logits: np.ndarray) -> np.ndarray:
    e = np.exp(logits - logits.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

@torch.no_grad()
def predict_probs_gated(checkpoint_dir: str, texts, max_length: int = 128, batch_size: int = 64) -> np.ndarray:
    """
    Manual batched inference for the gated model to avoid Trainer/accelerate
    touching a dict with loss=None.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_name = "bert-base-multilingual-cased"

    # tokenizer
    tok = AutoTokenizer.from_pretrained(checkpoint_dir if os.path.isdir(checkpoint_dir) else base_name)

    # config + model
    try:
        cfg = GFConfig.from_pretrained(checkpoint_dir)
    except Exception:
        cfg = GFConfig(base_model_name=base_name, num_labels=2, gate_hidden=256)
    model = GatedFusionForSequenceClassification(cfg)
    state_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
    if os.path.isfile(state_path):
        state = torch.load(state_path, map_location=device)
        model.load_state_dict(state, strict=False)
    model.to(device).eval()

    # batched forward
    probs_list = []
    N = len(texts)
    for i in range(0, N, batch_size):
        batch_texts = texts[i:i+batch_size]
        enc = tok(batch_texts, padding="max_length", truncation=True,
                  max_length=max_length, return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}
        out = model(**enc)
        # out is a dict {"loss": None/..., "logits": tensor}
        logits = out["logits"].detach().cpu().numpy()
        probs_list.append(_softmax_np(logits))
    return np.vstack(probs_list)

# ===== Optional: make Sentinel manual too (more robust and symmetric) =====

# keep your SentinelConfig and SentinelModel definitions as-is
def _build_heuristic_batch(texts, H=32):
    arr = np.stack([build_heuristic_features(t, H=H) for t in texts]).astype(np.float32)
    return torch.tensor(arr)

@torch.no_grad()
def predict_probs_sentinel(checkpoint_dir: str, texts, max_length: int = 128, batch_size: int = 64) -> np.ndarray:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_name = "xlm-roberta-base"

    tok = AutoTokenizer.from_pretrained(checkpoint_dir if os.path.isdir(checkpoint_dir) else base_name)
    try:
        scfg = SentinelConfig.from_pretrained(checkpoint_dir)
    except Exception:
        scfg = SentinelConfig(base_model_name=base_name, num_labels=2, heuristic_dim=32)

    model = SentinelModel(scfg)
    state_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
    if os.path.isfile(state_path):
        state = torch.load(state_path, map_location=device)
        model.load_state_dict(state, strict=False)
    model.to(device).eval()

    probs_list = []
    N = len(texts)
    for i in range(0, N, batch_size):
        batch_texts = texts[i:i+batch_size]
        enc = tok(batch_texts, padding="max_length", truncation=True,
                  max_length=max_length, return_tensors="pt")
        heur = _build_heuristic_batch(batch_texts, H=scfg.heuristic_dim)

        enc = {k: v.to(device) for k, v in enc.items()}
        enc["heuristic_feats"] = heur.to(device)

        out = model(**enc)
        logits = out["logits"].detach().cpu().numpy()
        probs_list.append(_softmax_np(logits))
    return np.vstack(probs_list)


# -------------------------
# Run predictions
# -------------------------
print("Scoring HateEval with Baseline...")
probs_base = predict_probs_baseline(CKPT_BASELINE, texts)
print("Scoring HateEval with Gated Fusion...")
probs_gated = predict_probs_gated(CKPT_GATED, texts)
print("Scoring HateEval with Sentinel...")
probs_sent = predict_probs_sentinel(CKPT_SENTINEL, texts)

def save_preds(name, probs, y_true):
    preds = probs.argmax(axis=1)
    out = pd.DataFrame({
        "id": ids,
        "text": texts,
        "label": y_true,
        "prob_nonhate": probs[:,0],
        "prob_hate": probs[:,1],
        "pred_label": preds
    })
    path = os.path.join(OUT_DIR, f"hateEval_{name}.csv")
    out.to_csv(path, index=False)
    print(f"Saved predictions: {path}")
    return preds, out

pred_base , df_base  = save_preds("baseline", probs_base, y_true)
pred_gated, df_gated = save_preds("gated",    probs_gated, y_true)
pred_sent , df_sent  = save_preds("sentinel", probs_sent, y_true)

# -------------------------
# Metrics + Reports
# -------------------------
def compute_all_metrics(name: str, y_true, probs):
    y_pred = probs.argmax(axis=1)

    acc = accuracy_score(y_true, y_pred)
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )

    # per-class (labels assumed 0=non-hate, 1=hate)
    p_c, r_c, f1_c, sup_c = precision_recall_fscore_support(
        y_true, y_pred, labels=[0, 1], zero_division=0
    )

    # ROC-AUC for positive class (handle degenerate cases)
    try:
        auc = roc_auc_score(y_true, probs[:, 1])  # hate=1
    except ValueError:
        auc = float("nan")

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1]).tolist()

    rep = {
        "model": name,
        "accuracy": float(acc),
        "precision_macro": float(p_macro),
        "recall_macro": float(r_macro),
        "f1_macro": float(f1_macro),
        "auc_roc_hate1": float(auc) if auc == auc else None,  # NaN -> None for JSON
        "per_class": [
            {
                "label": 0,
                "precision": float(p_c[0]),
                "recall": float(r_c[0]),
                "f1": float(f1_c[0]),
                "support": int(sup_c[0]),
            },
            {
                "label": 1,
                "precision": float(p_c[1]),
                "recall": float(r_c[1]),
                "f1": float(f1_c[1]),
                "support": int(sup_c[1]),
            },
        ],
        "confusion_matrix": {"labels": [0, 1], "matrix": cm},
    }
    return rep

rep_base  = compute_all_metrics("Baseline",     y_true, probs_base)
rep_gated = compute_all_metrics("Gated Fusion", y_true, probs_gated)
rep_sent  = compute_all_metrics("Sentinel",     y_true, probs_sent)

# Save JSON reports
for rep, nm in [(rep_base, "baseline"), (rep_gated, "gated"), (rep_sent, "sentinel")]:
    with open(os.path.join(OUT_DIR, f"hateEval_metrics_{nm}.json"), "w") as f:
        json.dump(rep, f, indent=2)

# Tidy summary table (fix: build rows in a simple loop)
rows = []
for r in (rep_base, rep_gated, rep_sent):
    rows.append({k: v for k, v in r.items() if k not in ("per_class", "confusion_matrix")})
summary = pd.DataFrame(rows)

# Select and order common columns if present
ordered_cols = ["model", "accuracy", "precision_macro", "recall_macro", "f1_macro", "auc_roc_hate1"]
summary = summary[ordered_cols]

# Round numeric columns
summary_rounded = summary.copy()
for col in ordered_cols:
    if col != "model":
        summary_rounded[col] = summary_rounded[col].apply(lambda x: None if x is None else round(float(x), 4))

# Per-class tidy
def per_class_rows(rep):
    return [
        {
            "model": rep["model"],
            "label": d["label"],
            "precision": float(d["precision"]),
            "recall": float(d["recall"]),
            "f1": float(d["f1"]),
            "support": int(d["support"]),
        }
        for d in rep["per_class"]
    ]

per_class_df = pd.DataFrame(per_class_rows(rep_base) + per_class_rows(rep_gated) + per_class_rows(rep_sent))
per_class_df_rounded = per_class_df.copy()
for c in ["precision", "recall", "f1"]:
    per_class_df_rounded[c] = per_class_df_rounded[c].map(lambda x: round(float(x), 4))

# Save CSVs
summary_rounded.to_csv(os.path.join(OUT_DIR, "hateEval_summary_metrics.csv"), index=False)
per_class_df_rounded.to_csv(os.path.join(OUT_DIR, "hateEval_per_class_metrics.csv"), index=False)

# -------- Pretty print to notebook --------
print("\n=== HateEval: Summary Metrics ===")
print(summary_rounded.to_string(index=False))

print("\n=== HateEval: Per-Class Metrics ===")
print(per_class_df_rounded.sort_values(["label", "model"]).to_string(index=False))

print(f"\nArtifacts saved under: {OUT_DIR}")



In [None]:
# %% Dataset composition summary for HateXplain (train/val/test) and HateEval
import os
import pandas as pd
from collections import OrderedDict

# ----- paths (edit if yours differ) -----
DRIVE = "/content/drive/MyDrive/hate"
PATHS = OrderedDict({
    "HateXplain": {
        "train": os.path.join(DRIVE, "train.csv"),
        "validation": os.path.join(DRIVE, "val.csv"),
        "test": os.path.join(DRIVE, "test.csv"),
    },
    # We’ll try hateEval_test.csv first; fallback to hateEval.csv if that’s what you saved.
    "HateEval": {
        "test": os.path.join(DRIVE, "hateEval_test.csv") if os.path.exists(os.path.join(DRIVE, "hateEval_test.csv"))
                else os.path.join(DRIVE, "hateEval.csv")
    },
})

# ----- helpers -----
POSSIBLE_LABEL_COLS = ["label", "labels", "target", "class", "y"]
POSSIBLE_TEXT_COLS  = ["text", "tweet", "content", "document", "sentence"]

def load_with_autocols(path):
    """Load CSV and auto-detect label/text columns; return (df, label_col, text_col)."""
    df = pd.read_csv(path)
    label_col = next((c for c in POSSIBLE_LABEL_COLS if c in df.columns), None)
    text_col  = next((c for c in POSSIBLE_TEXT_COLS  if c in df.columns), None)
    if label_col is None:
        raise ValueError(f"No label column found in {path}. "
                         f"Expected one of {POSSIBLE_LABEL_COLS}, got {list(df.columns)}")
    if text_col is None:
        raise ValueError(f"No text column found in {path}. "
                         f"Expected one of {POSSIBLE_TEXT_COLS}, got {list(df.columns)}")
    return df, label_col, text_col

def summarize_split(df, label_col, dataset_name, split_name):
    """Return long-form and wide-form summaries for one split."""
    # Ensure labels are simple (0/1 or strings); don’t coerce to int to avoid crashing on strings.
    counts = (
        df.groupby(label_col, dropna=False)
          .size()
          .reset_index(name="count")
          .rename(columns={label_col: "label"})
    )
    counts["dataset"] = dataset_name
    counts["split"]   = split_name
    # Percent within split
    total = counts["count"].sum()
    counts["percent"] = counts["count"] / max(total, 1) * 100.0

    # Wide view: one row per split with columns for each label’s count
    wide = counts.pivot_table(index=["dataset", "split"], columns="label", values="count", fill_value=0)
    wide = wide.reset_index()
    wide.columns.name = None
    wide["total"] = wide.drop(columns=["dataset", "split"]).sum(axis=1)

    # If binary {0,1}, add class balance stats
    if 0 in counts["label"].unique().tolist() and 1 in counts["label"].unique().tolist():
        # Retrieve counts safely (may be missing in some splits)
        def _get(w, col):
            return w[col] if col in w else 0
        wide["pos_frac_%"] = wide.apply(lambda r: ( _get(r, 1) / r["total"] * 100.0 ) if r["total"] > 0 else 0.0, axis=1)
        wide["neg_frac_%"] = wide.apply(lambda r: ( _get(r, 0) / r["total"] * 100.0 ) if r["total"] > 0 else 0.0, axis=1)

    # Sort label columns (nice ordering)
    non_label_cols = ["dataset", "split", "total", "pos_frac_%", "neg_frac_%"]
    non_label_cols = [c for c in non_label_cols if c in wide.columns]
    label_cols = [c for c in wide.columns if c not in non_label_cols]
    # Keep dataset/split first
    wide = wide[["dataset", "split"] + label_cols + [c for c in non_label_cols if c not in ["dataset","split"]]]

    return counts, wide

# ----- build summaries -----
all_long = []
all_wide = []

for ds_name, splits in PATHS.items():
    for split_name, p in splits.items():
        if not os.path.exists(p):
            print(f"[WARN] Missing file for {ds_name}/{split_name}: {p}")
            continue
        df, label_col, text_col = load_with_autocols(p)
        long_df, wide_df = summarize_split(df, label_col, ds_name, split_name)
        all_long.append(long_df)
        all_wide.append(wide_df)

if not all_long:
    raise RuntimeError("No datasets were found. Check PATHS.")

long_summary = pd.concat(all_long, ignore_index=True)
wide_summary = pd.concat(all_wide, ignore_index=True)

# Pretty sort
long_summary = long_summary.sort_values(["dataset", "split", "label"]).reset_index(drop=True)
wide_summary = wide_summary.sort_values(["dataset", "split"]).reset_index(drop=True)

# Round percents
if "percent" in long_summary.columns:
    long_summary["percent"] = long_summary["percent"].map(lambda x: round(float(x), 2))
for col in ["pos_frac_%", "neg_frac_%"]:
    if col in wide_summary.columns:
        wide_summary[col] = wide_summary[col].map(lambda x: round(float(x), 2))

# ----- display -----
print("\n=== CLASS COMPOSITION (long-form: one row per label) ===")
display(long_summary)

print("\n=== SPLIT SUMMARY (wide-form: counts by label + totals) ===")
display(wide_summary)

# ----- save to Drive for record -----
out_dir = os.path.join(DRIVE, "analysis")
os.makedirs(out_dir, exist_ok=True)
long_summary.to_csv(os.path.join(out_dir, "dataset_composition_long.csv"), index=False)
wide_summary.to_csv(os.path.join(out_dir, "dataset_composition_wide.csv"), index=False)
print(f"\nSaved CSVs to: {out_dir}")
