BELOW IS THE SECOND IMPLEMENTATION OF HACS-TL WITH REVIEWERS COMMENT'S IMPLEMENTED

In [None]:
# ============================================
# Google Colab - Mount Google Drive
# ============================================

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers datasets torch pandas scikit-learn numpy matplotlib seaborn accelerate



In [None]:
"""
ENHANCED HAUSA AJAMI HATE SPEECH DETECTION
Complete Implementation with:
- Conversion quality validation
- Orthographic variation stress tests
- Stronger baselines (char/byte-level, LoRA/PEFT)
- Zero-shot/few-shot LLM evaluation
- Comprehensive error analysis
- Qualitative example analysis

FIXED: PEFT model output handling
"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    get_cosine_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    accuracy_score,
    roc_auc_score,
    roc_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
import re
import warnings
from collections import Counter, defaultdict
from typing import Dict, List, Tuple, Optional
import json
from datetime import datetime
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("Set2")

# ==================== GOOGLE DRIVE SETUP ====================
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    DRIVE_PATH = '/content/drive/MyDrive/AbjadNLP2026/'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"✅ Google Drive mounted: {DRIVE_PATH}")
except:
    DRIVE_PATH = './AbjadNLP2026/'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"📁 Local directory: {DRIVE_PATH}")

# Create subdirectories for organization
for subdir in ['models', 'visualizations', 'error_analysis', 'conversion_validation']:
    os.makedirs(os.path.join(DRIVE_PATH, subdir), exist_ok=True)

# ==================== PROPER HAUSA-AJAMI CONVERSION WITH VALIDATION ====================
class HausaAjamiConverter:
    """
    Linguistically accurate Hausa Latin → Ajami converter with validation
    """
    def __init__(self):
        # Core mappings (longest patterns first for proper matching)
        self.mappings = [
            # Hausa-specific consonants
            ("ɓ", "ݒ"),    # Implosive bilabial
            ("'y", "ࢩ"),   # Glottalized y
            ("ƙ", "ࢼ"),    # Ejective velar
            ("ɗ", "ڎ"),    # Implosive alveolar

            # Digraphs (must come before single chars)
            ("ts", "تْسْ"),
            ("sh", "شْ"),
            ("ng", "نْگْ"),

            # Long vowels
            ("aa", "اَ"),
            ("ee", "اِ"),
            ("ii", "اِي"),
            ("oo", "اُ"),
            ("uu", "اُو"),

            # Standard consonants
            ("b", "ب"), ("c", "چ"), ("d", "د"), ("f", "ف"),
            ("g", "گ"), ("h", "ه"), ("j", "ج"), ("k", "ك"),
            ("l", "ل"), ("m", "م"), ("n", "ن"), ("r", "ر"),
            ("s", "س"), ("t", "ت"), ("w", "و"), ("y", "ي"),
            ("z", "ز"), ("'", "ع"),

            # Short vowels (diacritics)
            ("a", "َ"), ("e", "ِ"), ("i", "ِ"),
            ("o", "ُ"), ("u", "ُ"),
        ]

        # Alternative spellings for validation
        self.variant_mappings = {
            'k': ['ك', 'ق'],  # k can be k or q in some dialects
            'ts': ['تْسْ', 'ڞ'],  # ts alternative
            'ng': ['نْگْ', 'ڭ'],  # ng alternative
        }

        # Reverse mapping for validation
        self.reverse_mappings = {arab: lat for lat, arab in self.mappings}

    def convert(self, text: str) -> str:
        """Convert Latin Hausa to Ajami script"""
        if not isinstance(text, str) or not text.strip():
            return text

        text = text.lower().strip()
        result = []
        i = 0

        while i < len(text):
            matched = False
            # Try longest matches first
            for length in [3, 2, 1]:
                if i + length <= len(text):
                    chunk = text[i:i+length]
                    for lat, ara in self.mappings:
                        if chunk == lat:
                            result.append(ara)
                            i += length
                            matched = True
                            break
                if matched:
                    break

            if not matched:
                if text[i] in ' \n\t':
                    result.append(text[i])
                i += 1

        return ''.join(result) if result else text

    def reverse_convert(self, ajami_text: str) -> str:
        """Attempt to convert Ajami back to Latin (approximate)"""
        result = []
        i = 0
        while i < len(ajami_text):
            matched = False
            # Try multi-character sequences first
            for length in [3, 2, 1]:
                if i + length <= len(ajami_text):
                    chunk = ajami_text[i:i+length]
                    if chunk in self.reverse_mappings:
                        result.append(self.reverse_mappings[chunk])
                        i += length
                        matched = True
                        break
            if not matched:
                if ajami_text[i] in ' \n\t':
                    result.append(ajami_text[i])
                i += 1
        return ''.join(result)

    def validate_conversion(self, latin_text: str, ajami_text: str) -> Dict:
        """
        Validate conversion quality through round-trip and linguistic checks
        """
        # Round-trip conversion
        back_converted = self.reverse_convert(ajami_text)

        # Character-level similarity
        latin_clean = latin_text.lower().strip()
        back_clean = back_converted.lower().strip()

        # Calculate character overlap
        char_match = sum(1 for a, b in zip(latin_clean, back_clean) if a == b)
        char_similarity = char_match / max(len(latin_clean), len(back_clean)) if max(len(latin_clean), len(back_clean)) > 0 else 0

        # Check for Hausa-specific characters
        hausa_chars = ['ɓ', 'ƙ', 'ɗ']
        hausa_chars_preserved = sum(1 for char in hausa_chars if char in latin_text)

        # Length preservation (should be similar)
        length_ratio = len(ajami_text) / len(latin_text) if len(latin_text) > 0 else 0

        return {
            'char_similarity': char_similarity,
            'length_ratio': length_ratio,
            'hausa_chars_preserved': hausa_chars_preserved,
            'original': latin_text,
            'ajami': ajami_text,
            'back_converted': back_converted,
            'is_valid': char_similarity > 0.7  # Threshold for valid conversion
        }

    def generate_orthographic_variants(self, text: str, num_variants: int = 3) -> List[str]:
        """
        Generate orthographic variants for stress testing:
        - With/without diacritics
        - Alternative spellings
        - Dialect variations
        """
        variants = [text]

        # Variant 1: Remove some diacritics (common in informal writing)
        no_diacritics = re.sub(r'[َُِّْ]', '', text)
        if no_diacritics != text:
            variants.append(no_diacritics)

        # Variant 2: Alternative spellings using variant mappings
        for lat_char, arab_variants in self.variant_mappings.items():
            for arab_char in arab_variants[1:]:  # Skip first (it's the default)
                alt_text = text
                for lat, ara in self.mappings:
                    if lat == lat_char:
                        alt_text = alt_text.replace(ara, arab_char)
                        break
                if alt_text != text:
                    variants.append(alt_text)

        # Variant 3: Mix diacritics presence
        words = text.split()
        if len(words) > 2:
            mixed = []
            for i, word in enumerate(words):
                if i % 2 == 0:
                    mixed.append(re.sub(r'[َُِّْ]', '', word))
                else:
                    mixed.append(word)
            variants.append(' '.join(mixed))

        return variants[:num_variants + 1]

converter = HausaAjamiConverter()

# ==================== CONVERSION VALIDATION SUITE ====================
class ConversionValidator:
    """Comprehensive validation of Latin-to-Ajami conversion quality"""

    def __init__(self, converter: HausaAjamiConverter):
        self.converter = converter
        self.validation_results = []

    def validate_dataset(self, df: pd.DataFrame, sample_size: int = 100) -> Dict:
        """Validate conversion quality on dataset sample"""
        print("\n" + "="*80)
        print("🔍 CONVERSION QUALITY VALIDATION")
        print("="*80)

        # Sample texts for validation
        sample_indices = np.random.choice(len(df), min(sample_size, len(df)), replace=False)

        validation_scores = []
        failed_conversions = []

        for idx in sample_indices:
            latin_text = df.iloc[idx]['text']
            ajami_text = df.iloc[idx]['text_ajami']

            validation = self.converter.validate_conversion(latin_text, ajami_text)
            validation_scores.append(validation['char_similarity'])

            if not validation['is_valid']:
                failed_conversions.append({
                    'index': idx,
                    'similarity': validation['char_similarity'],
                    'original': latin_text[:50],
                    'ajami': ajami_text[:50]
                })

        avg_similarity = np.mean(validation_scores)
        min_similarity = np.min(validation_scores)
        max_similarity = np.max(validation_scores)

        print(f"\n📊 Conversion Quality Metrics:")
        print(f"  Average Character Similarity: {avg_similarity:.4f}")
        print(f"  Min Similarity: {min_similarity:.4f}")
        print(f"  Max Similarity: {max_similarity:.4f}")
        print(f"  Valid Conversions: {len([s for s in validation_scores if s > 0.7])}/{len(validation_scores)}")
        print(f"  Failed Conversions: {len(failed_conversions)}")

        if failed_conversions:
            print(f"\n⚠️  Sample Failed Conversions:")
            for fail in failed_conversions[:3]:
                print(f"    Similarity: {fail['similarity']:.4f}")
                print(f"    Original: {fail['original']}")
                print(f"    Ajami: {fail['ajami']}")
                print()

        return {
            'avg_similarity': avg_similarity,
            'min_similarity': min_similarity,
            'max_similarity': max_similarity,
            'validation_scores': validation_scores,
            'failed_conversions': failed_conversions,
            'pass_rate': len([s for s in validation_scores if s > 0.7]) / len(validation_scores)
        }

    def orthographic_stress_test(self, texts: List[str], labels: List[int],
                                 model, tokenizer, device, num_variants: int = 3) -> Dict:
        """
        Test model robustness to orthographic variations
        """
        print("\n" + "="*80)
        print("🧪 ORTHOGRAPHIC VARIATION STRESS TEST")
        print("="*80)

        stress_results = {
            'original_accuracy': [],
            'variant_accuracies': [[] for _ in range(num_variants)],
            'consistency_scores': []
        }

        sample_size = min(50, len(texts))
        sample_indices = np.random.choice(len(texts), sample_size, replace=False)

        for idx in sample_indices:
            text = texts[idx]
            label = labels[idx]

            # Generate variants
            variants = self.converter.generate_orthographic_variants(text, num_variants)

            # Get predictions for all variants
            predictions = []
            for variant in variants:
                pred = self._predict_single(variant, model, tokenizer, device)
                predictions.append(pred)

            # Original prediction
            original_pred = predictions[0]
            stress_results['original_accuracy'].append(int(original_pred == label))

            # Variant predictions
            for v_idx in range(1, min(len(predictions), num_variants + 1)):
                is_correct = int(predictions[v_idx] == label)
                stress_results['variant_accuracies'][v_idx - 1].append(is_correct)

            # Consistency: all predictions agree
            consistency = len(set(predictions)) == 1
            stress_results['consistency_scores'].append(int(consistency))

        # Calculate summary statistics
        original_acc = np.mean(stress_results['original_accuracy'])
        variant_accs = [np.mean(var_acc) for var_acc in stress_results['variant_accuracies'] if var_acc]
        consistency = np.mean(stress_results['consistency_scores'])

        print(f"\n📊 Stress Test Results:")
        print(f"  Original Accuracy: {original_acc:.4f}")
        for i, var_acc in enumerate(variant_accs, 1):
            print(f"  Variant {i} Accuracy: {var_acc:.4f} (Δ: {var_acc - original_acc:+.4f})")
        print(f"  Prediction Consistency: {consistency:.4f}")

        return {
            'original_accuracy': original_acc,
            'variant_accuracies': variant_accs,
            'consistency': consistency,
            'detailed_results': stress_results
        }

    def _predict_single(self, text: str, model, tokenizer, device) -> int:
        """Helper function for single prediction"""
        model.eval()
        with torch.no_grad():
            encoding = tokenizer(text, max_length=256, padding='max_length',
                               truncation=True, return_tensors='pt')
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)

            output = model(input_ids, attention_mask)
            # Handle both raw logits and model output objects
            if hasattr(output, 'logits'):
                logits = output.logits
            else:
                logits = output
            pred = torch.argmax(logits, dim=1).cpu().item()
        return pred

# ==================== DATA AUGMENTATION ====================
class HausaAugmenter:
    @staticmethod
    def synonym_replacement(text, n=2):
        synonyms = {
            'munanan': ['marasa kyau', 'miyagun'],
            'kyawawan': ['masu kyau', 'nagari'],
            'mutane': ['yan adam', 'jama\'a'],
        }
        words = text.split()
        for _ in range(n):
            if len(words) > 3:
                idx = np.random.randint(0, len(words))
                if words[idx] in synonyms:
                    words[idx] = np.random.choice(synonyms[words[idx]])
        return ' '.join(words)

    @staticmethod
    def random_swap(text, n=2):
        words = text.split()
        for _ in range(n):
            if len(words) > 3:
                idx = np.random.randint(0, len(words)-1)
                words[idx], words[idx+1] = words[idx+1], words[idx]
        return ' '.join(words)

    @staticmethod
    def random_deletion(text, p=0.1):
        words = text.split()
        if len(words) == 1:
            return text
        return ' '.join([w for w in words if np.random.random() > p])

    @classmethod
    def augment(cls, text, label, n_aug=2):
        augmented = [text]
        methods = [cls.synonym_replacement, cls.random_swap, cls.random_deletion]
        for _ in range(n_aug):
            method = np.random.choice(methods)
            aug_text = method(text)
            if aug_text and aug_text != text:
                augmented.append(aug_text)
        return augmented

# ==================== DATA LOADING ====================
def load_and_augment_dataset(augment_minority=True, aug_ratio=3):
    try:
        df = pd.read_csv('/content/drive/MyDrive/AbjadNLP2026/HausaHateDataset.csv')
        print(f"✅ Loaded: {df.shape}")
    except:
        print("❌ Failed to load dataset")
        return None, None, None

    print("\n📊 Original Distribution:")
    print(df['label_offensive'].value_counts())

    df['text'] = df['text'].fillna('').astype(str)
    df['text'] = df['text'].apply(lambda x: re.sub(r'http\S+|www\S+', '', x))
    df['text'] = df['text'].apply(lambda x: re.sub(r'\s+', ' ', x).strip())
    df = df.drop_duplicates(subset=['text'])
    df = df[df['text'].str.len() > 10]

    if augment_minority:
        augmenter = HausaAugmenter()
        minority_df = df[df['label_offensive'] == 1]
        augmented_rows = []
        for _, row in minority_df.iterrows():
            aug_texts = augmenter.augment(row['text'], row['label_offensive'], n_aug=aug_ratio)
            for aug_text in aug_texts[1:]:
                augmented_rows.append({'text': aug_text, 'label_offensive': row['label_offensive']})
        aug_df = pd.DataFrame(augmented_rows)
        df = pd.concat([df, aug_df], ignore_index=True)
        print(f"\n✅ Augmented dataset: {len(df)} samples")
        print(df['label_offensive'].value_counts())

    print("\n🔄 Converting to Ajami...")
    df['text_ajami'] = df['text'].apply(converter.convert)

    # Validate conversion quality
    validator = ConversionValidator(converter)
    validation_results = validator.validate_dataset(df, sample_size=100)

    # Save validation results
    validation_df = pd.DataFrame({
        'metric': ['avg_similarity', 'min_similarity', 'max_similarity', 'pass_rate'],
        'value': [
            validation_results['avg_similarity'],
            validation_results['min_similarity'],
            validation_results['max_similarity'],
            validation_results['pass_rate']
        ]
    })
    validation_df.to_csv(f'{DRIVE_PATH}conversion_validation/validation_metrics.csv', index=False)

    return df['text_ajami'].values, df['label_offensive'].values, df

# ==================== DATASET CLASS ====================
class AdvancedHausaDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256, augment=False):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augment = augment
        self.augmenter = HausaAugmenter() if augment else None

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        if self.augment and np.random.random() > 0.5:
            text = self.augmenter.augment(text, self.labels[idx], n_aug=1)[0]

        encoding = self.tokenizer(
            text, max_length=self.max_length, padding='max_length',
            truncation=True, return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# ==================== CHARACTER/BYTE-LEVEL BASELINE ====================
class CharLevelCNN(nn.Module):
    """Character-level CNN baseline for comparison"""
    def __init__(self, vocab_size=256, embedding_dim=128, num_labels=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        # Multiple kernel sizes for n-gram features
        self.convs = nn.ModuleList([
            nn.Conv1d(embedding_dim, 256, kernel_size=k, padding=k//2)
            for k in [3, 4, 5, 7]
        ])

        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 4, 128)
        self.fc2 = nn.Linear(128, num_labels)

    def forward(self, input_ids, attention_mask=None):
        # input_ids: [batch, seq_len]
        embedded = self.embedding(input_ids)  # [batch, seq_len, emb_dim]
        embedded = embedded.transpose(1, 2)  # [batch, emb_dim, seq_len]

        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(embedded))
            pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
            conv_outputs.append(pooled)

        concat = torch.cat(conv_outputs, dim=1)
        concat = self.dropout(concat)

        hidden = F.relu(self.fc1(concat))
        hidden = self.dropout(hidden)
        logits = self.fc2(hidden)

        return logits


class ByteLevelEncoder(nn.Module):
    """Byte-level model treating text as bytes"""
    def __init__(self, num_labels=2):
        super().__init__()
        self.embedding = nn.Embedding(256, 64, padding_idx=0)

        self.lstm = nn.LSTM(64, 128, num_layers=2, batch_first=True,
                           dropout=0.3, bidirectional=True)

        self.attention = nn.Sequential(
            nn.Linear(256, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_labels)
        )

    def forward(self, input_ids, attention_mask=None):
        embedded = self.embedding(input_ids)
        lstm_out, _ = self.lstm(embedded)

        # Attention pooling
        attn_weights = self.attention(lstm_out)
        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(
                attention_mask.unsqueeze(-1) == 0, float('-inf')
            )
        attn_weights = torch.softmax(attn_weights, dim=1)
        pooled = torch.sum(lstm_out * attn_weights, dim=1)

        logits = self.classifier(pooled)
        return logits

# Character/Byte tokenizer
class CharByteTokenizer:
    def __init__(self, max_length=512, level='char'):
        self.max_length = max_length
        self.level = level

    def __call__(self, text, max_length=None, padding='max_length',
                truncation=True, return_tensors='pt'):
        if max_length is None:
            max_length = self.max_length

        if self.level == 'char':
            # Character-level encoding
            indices = [ord(c) % 256 for c in text[:max_length]]
        else:  # byte-level
            # Byte-level encoding
            indices = list(text.encode('utf-8')[:max_length])

        # Padding
        if padding == 'max_length':
            attention_mask = [1] * len(indices) + [0] * (max_length - len(indices))
            indices = indices + [0] * (max_length - len(indices))
        else:
            attention_mask = [1] * len(indices)

        if return_tensors == 'pt':
            return {
                'input_ids': torch.tensor([indices]),
                'attention_mask': torch.tensor([attention_mask])
            }
        return {'input_ids': indices, 'attention_mask': attention_mask}

# ==================== PEFT/LoRA MODELS (FIXED) ====================
class LoRAWrapper(nn.Module):
    """Wrapper to make PEFT models return raw logits"""
    def __init__(self, peft_model):
        super().__init__()
        self.model = peft_model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits  # Extract logits from output object


class LoRAMBERT:
    """mBERT with LoRA for parameter-efficient fine-tuning"""
    def __init__(self, num_labels=2, lora_r=8, lora_alpha=16):
        base_model = AutoModelForSequenceClassification.from_pretrained(
            'bert-base-multilingual-cased',
            num_labels=num_labels
        )

        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=0.1,
            target_modules=["query", "value"],
            bias="none"
        )

        peft_model = get_peft_model(base_model, lora_config)
        peft_model.print_trainable_parameters()

        # Wrap to return logits
        self.model = LoRAWrapper(peft_model)

    def get_model(self):
        return self.model


class LoRAXLMR:
    """XLM-R with LoRA"""
    def __init__(self, num_labels=2, lora_r=8, lora_alpha=16):
        base_model = AutoModelForSequenceClassification.from_pretrained(
            'xlm-roberta-base',
            num_labels=num_labels
        )

        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=0.1,
            target_modules=["query", "value"],
            bias="none"
        )

        peft_model = get_peft_model(base_model, lora_config)
        peft_model.print_trainable_parameters()

        # Wrap to return logits
        self.model = LoRAWrapper(peft_model)

    def get_model(self):
        return self.model

# ==================== AFRICAN LANGUAGE BASELINE ====================
class AfroXLMR(nn.Module):
    """
    African language-focused baseline using AfroXLMR if available,
    otherwise African-adapted XLM-R
    """
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        try:
            # Try to load African language-specific model
            self.xlmr = AutoModel.from_pretrained('Davlan/afro-xlmr-base')
            print("✅ Loaded AfroXLMR model")
        except:
            # Fallback to standard XLM-R
            self.xlmr = AutoModel.from_pretrained('xlm-roberta-base')
            print("⚠️  Using standard XLM-R (AfroXLMR not available)")

        hidden_size = self.xlmr.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.xlmr(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled)
        return logits

# ==================== ORIGINAL MODEL ARCHITECTURES ====================
class EnhancedMBERT(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-multilingual-cased')
        hidden_size = self.bert.config.hidden_size
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        attn_weights = self.attention(hidden)
        attn_weights = torch.softmax(attn_weights.masked_fill(
            attention_mask.unsqueeze(-1) == 0, float('-inf')
        ), dim=1)
        pooled = torch.sum(hidden * attn_weights, dim=1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class EnhancedXLMR(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.xlmr = AutoModel.from_pretrained('xlm-roberta-base')
        hidden_size = self.xlmr.config.hidden_size
        self.conv1 = nn.Conv1d(hidden_size, hidden_size, 3, padding=1)
        self.conv2 = nn.Conv1d(hidden_size, hidden_size, 5, padding=2)
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 3, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.xlmr(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        cls_pooled = hidden[:, 0]
        mean_pooled = torch.mean(hidden, dim=1)
        hidden_t = hidden.transpose(1, 2)
        conv_out = F.relu(self.conv1(hidden_t))
        conv_pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)
        pooled = torch.cat([cls_pooled, mean_pooled, conv_pooled], dim=-1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class EnhancedAraBERT(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained('aubmindlab/bert-base-arabertv2')
        hidden_size = self.bert.config.hidden_size
        self.lstm = nn.LSTM(
            hidden_size, hidden_size // 2, num_layers=2,
            batch_first=True, dropout=dropout, bidirectional=True
        )
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        lstm_out, (h_n, c_n) = self.lstm(hidden)
        lstm_pooled = torch.cat([h_n[-2], h_n[-1]], dim=-1)
        cls_pooled = self.layer_norm(hidden[:, 0])
        pooled = torch.cat([cls_pooled, lstm_pooled], dim=-1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class ProductionHACS_TL(nn.Module):
    def __init__(self, num_labels=2, dropout=0.2):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-multilingual-cased')
        hidden_size = self.bert.config.hidden_size

        self.cross_script_attn = nn.MultiheadAttention(
            hidden_size, num_heads=12, dropout=0.1, batch_first=True
        )
        self.attn_norm = nn.LayerNorm(hidden_size)

        self.orthographic_encoder = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_size, nhead=8,
                dim_feedforward=hidden_size*4,
                dropout=0.1, batch_first=True
            ) for _ in range(2)
        ])

        self.dialectal_attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.Tanh(),
            nn.Linear(hidden_size // 2, 1)
        )

        self.script_conv = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=k, padding=k//2)
            for k in [3, 5, 7]
        ])

        self.fusion = nn.Sequential(
            nn.Linear(hidden_size * 4, hidden_size * 2),
            nn.LayerNorm(hidden_size * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU()
        )

        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(8)])

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 4, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state

        attn_out, _ = self.cross_script_attn(hidden, hidden, hidden,
                                             key_padding_mask=~attention_mask.bool())
        hidden = self.attn_norm(hidden + attn_out)

        for encoder in self.orthographic_encoder:
            hidden = encoder(hidden, src_key_padding_mask=~attention_mask.bool())

        attn_weights = self.dialectal_attention(hidden)
        attn_weights = torch.softmax(attn_weights.masked_fill(
            attention_mask.unsqueeze(-1) == 0, float('-inf')
        ), dim=1)
        dialectal_pooled = torch.sum(hidden * attn_weights, dim=1)

        hidden_t = hidden.transpose(1, 2)
        cnn_features = []
        for conv in self.script_conv:
            conv_out = F.gelu(conv(hidden_t))
            pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)
            cnn_features.append(pooled)
        cnn_pooled = torch.stack(cnn_features).mean(dim=0)

        cls_pooled = hidden[:, 0]
        mean_pooled = torch.mean(hidden, dim=1)

        multi_pooled = torch.cat([cls_pooled, dialectal_pooled, mean_pooled, cnn_pooled], dim=-1)
        fused = self.fusion(multi_pooled)

        logits = torch.mean(torch.stack([
            self.classifier(dropout(fused)) for dropout in self.dropouts
        ]), dim=0)

        return logits

# ==================== TRAINING (FIXED) ====================
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()


def train_epoch_advanced(model, dataloader, optimizer, scheduler, device, use_focal=True):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    criterion = FocalLoss() if use_focal else nn.CrossEntropyLoss(label_smoothing=0.1)
    accumulation_steps = 2

    for idx, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        output = model(input_ids, attention_mask)

        # Handle both raw logits and model output objects
        if hasattr(output, 'logits'):
            logits = output.logits
        else:
            logits = output

        loss = criterion(logits, labels) / accumulation_steps
        loss.backward()

        if (idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps
        all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    p, r, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro', zero_division=0)
    return {'loss': total_loss/len(dataloader), 'precision': p, 'recall': r, 'f1': f1}


def evaluate_advanced(model, dataloader, device):
    model.eval()
    predictions, true_labels, all_probs = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            output = model(input_ids, attention_mask)

            # Handle both raw logits and model output objects
            if hasattr(output, 'logits'):
                logits = output.logits
            else:
                logits = output

            probs = torch.softmax(logits, dim=1)

            predictions.extend(torch.argmax(logits, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    p, r, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='macro', zero_division=0)
    acc = accuracy_score(true_labels, predictions)

    return {
        'precision': p, 'recall': r, 'f1': f1, 'accuracy': acc,
        'predictions': predictions, 'true_labels': true_labels,
        'probabilities': np.array(all_probs)
    }

# ==================== ERROR ANALYSIS ====================
class ErrorAnalyzer:
    """Comprehensive error analysis with linguistic phenomena identification"""

    def __init__(self, texts, labels, predictions, probs, model_name, df_original=None):
        self.texts = texts
        self.labels = labels
        self.predictions = predictions
        self.probs = probs
        self.model_name = model_name
        self.df_original = df_original

    def analyze_errors(self) -> Dict:
        """Perform comprehensive error analysis"""
        print(f"\n{'='*80}")
        print(f"🔍 ERROR ANALYSIS: {self.model_name}")
        print(f"{'='*80}")

        # Identify error types
        fps = []  # False Positives
        fns = []  # False Negatives
        tps = []  # True Positives
        tns = []  # True Negatives

        for i, (true, pred, prob, text) in enumerate(zip(
            self.labels, self.predictions, self.probs, self.texts
        )):
            example = {
                'index': i,
                'text': text,
                'true_label': true,
                'pred_label': pred,
                'confidence': prob[pred],
                'true_prob': prob[true],
                'text_length': len(text.split())
            }

            if true == 1 and pred == 0:
                fns.append(example)
            elif true == 0 and pred == 1:
                fps.append(example)
            elif true == 1 and pred == 1:
                tps.append(example)
            else:
                tns.append(example)

        print(f"\n📊 Error Distribution:")
        print(f"  True Positives: {len(tps)}")
        print(f"  True Negatives: {len(tns)}")
        print(f"  False Positives: {len(fps)} (Non-offensive predicted as Offensive)")
        print(f"  False Negatives: {len(fns)} (Offensive predicted as Non-offensive)")

        # Analyze linguistic phenomena
        phenomena_analysis = self._analyze_linguistic_phenomena(fps, fns)

        # Confidence analysis
        confidence_analysis = self._analyze_confidence(fps, fns, tps, tns)

        # Generate qualitative examples
        qualitative_examples = self._generate_qualitative_examples(fps, fns, tps)

        return {
            'fps': fps,
            'fns': fns,
            'tps': tps,
            'tns': tns,
            'phenomena_analysis': phenomena_analysis,
            'confidence_analysis': confidence_analysis,
            'qualitative_examples': qualitative_examples
        }

    def _analyze_linguistic_phenomena(self, fps, fns) -> Dict:
        """Analyze specific linguistic phenomena causing errors"""
        phenomena = {
            'code_mixing': {'fp': 0, 'fn': 0},
            'dialect_variation': {'fp': 0, 'fn': 0},
            'implicit_hate': {'fp': 0, 'fn': 0},
            'sarcasm': {'fp': 0, 'fn': 0},
            'short_text': {'fp': 0, 'fn': 0},
            'long_text': {'fp': 0, 'fn': 0}
        }

        # Simple heuristics for phenomenon detection
        for fp in fps:
            text = fp['text']
            text_len = fp['text_length']

            # Code-mixing detection (presence of Latin characters in Ajami)
            if any(c.isalpha() and ord(c) < 128 for c in text):
                phenomena['code_mixing']['fp'] += 1

            # Short/long text
            if text_len < 5:
                phenomena['short_text']['fp'] += 1
            elif text_len > 30:
                phenomena['long_text']['fp'] += 1

        for fn in fns:
            text = fn['text']
            text_len = fn['text_length']

            if any(c.isalpha() and ord(c) < 128 for c in text):
                phenomena['code_mixing']['fn'] += 1

            if text_len < 5:
                phenomena['short_text']['fn'] += 1
            elif text_len > 30:
                phenomena['long_text']['fn'] += 1

            # Implicit hate (low confidence predictions that are FN)
            if fn['confidence'] < 0.6:
                phenomena['implicit_hate']['fn'] += 1

        print(f"\n📊 Linguistic Phenomena Analysis:")
        for phenom, counts in phenomena.items():
            if counts['fp'] > 0 or counts['fn'] > 0:
                print(f"  {phenom.replace('_', ' ').title()}:")
                print(f"    False Positives: {counts['fp']}")
                print(f"    False Negatives: {counts['fn']}")

        return phenomena

    def _analyze_confidence(self, fps, fns, tps, tns) -> Dict:
        """Analyze confidence distributions"""
        fp_confs = [ex['confidence'] for ex in fps]
        fn_confs = [ex['confidence'] for ex in fns]
        tp_confs = [ex['confidence'] for ex in tps]
        tn_confs = [ex['confidence'] for ex in tns]

        analysis = {
            'fp_conf_mean': np.mean(fp_confs) if fp_confs else 0,
            'fn_conf_mean': np.mean(fn_confs) if fn_confs else 0,
            'tp_conf_mean': np.mean(tp_confs) if tps else 0,
            'tn_conf_mean': np.mean(tn_confs) if tns else 0,
            'fp_conf_std': np.std(fp_confs) if fp_confs else 0,
            'fn_conf_std': np.std(fn_confs) if fn_confs else 0,
        }

        print(f"\n📊 Confidence Analysis:")
        print(f"  False Positives: {analysis['fp_conf_mean']:.4f} ± {analysis['fp_conf_std']:.4f}")
        print(f"  False Negatives: {analysis['fn_conf_mean']:.4f} ± {analysis['fn_conf_std']:.4f}")
        print(f"  True Positives: {analysis['tp_conf_mean']:.4f}")
        print(f"  True Negatives: {analysis['tn_conf_mean']:.4f}")

        return analysis

    def _generate_qualitative_examples(self, fps, fns, tps, n_examples=5) -> Dict:
        """Generate representative qualitative examples"""
        examples = {
            'false_positives': [],
            'false_negatives': [],
            'true_positives_high_conf': [],
            'borderline_cases': []
        }

        # Sort by confidence
        fps_sorted = sorted(fps, key=lambda x: x['confidence'], reverse=True)
        fns_sorted = sorted(fns, key=lambda x: x['confidence'], reverse=True)
        tps_sorted = sorted(tps, key=lambda x: x['confidence'], reverse=True)

        # Select representative examples
        for fp in fps_sorted[:n_examples]:
            examples['false_positives'].append({
                'text': fp['text'][:100],
                'confidence': fp['confidence'],
                'true_label': 'Non-Offensive',
                'pred_label': 'Offensive'
            })

        for fn in fns_sorted[:n_examples]:
            examples['false_negatives'].append({
                'text': fn['text'][:100],
                'confidence': fn['confidence'],
                'true_label': 'Offensive',
                'pred_label': 'Non-Offensive'
            })

        for tp in tps_sorted[:n_examples]:
            examples['true_positives_high_conf'].append({
                'text': tp['text'][:100],
                'confidence': tp['confidence'],
                'true_label': 'Offensive',
                'pred_label': 'Offensive'
            })

        # Borderline cases (low confidence correct predictions)
        all_correct = [ex for ex in fps_sorted + fns_sorted
                      if abs(ex['confidence'] - 0.5) < 0.1]
        for case in all_correct[:n_examples]:
            examples['borderline_cases'].append({
                'text': case['text'][:100],
                'confidence': case['confidence']
            })

        print(f"\n📋 Representative Examples Generated:")
        print(f"  False Positives: {len(examples['false_positives'])}")
        print(f"  False Negatives: {len(examples['false_negatives'])}")
        print(f"  High-Confidence TPs: {len(examples['true_positives_high_conf'])}")

        return examples

    def save_analysis(self, save_dir):
        """Save comprehensive error analysis"""
        analysis = self.analyze_errors()

        # Save qualitative examples
        qual_df = []
        for error_type in ['false_positives', 'false_negatives', 'true_positives_high_conf']:
            for ex in analysis['qualitative_examples'][error_type]:
                qual_df.append({
                    'model': self.model_name,
                    'error_type': error_type,
                    'text': ex['text'],
                    'confidence': float(ex.get('confidence', 0)),  # Convert to native Python float
                    'true_label': ex.get('true_label', ''),
                    'pred_label': ex.get('pred_label', '')
                })

        qual_df = pd.DataFrame(qual_df)
        qual_df.to_csv(f'{save_dir}/qualitative_examples_{self.model_name}.csv',
                      index=False, encoding='utf-8')

        # Save phenomena analysis
        phenom_df = pd.DataFrame([
            {
                'model': self.model_name,
                'phenomenon': phenom,
                'false_positives': int(counts['fp']),  # Convert to native Python int
                'false_negatives': int(counts['fn'])   # Convert to native Python int
            }
            for phenom, counts in analysis['phenomena_analysis'].items()
        ])
        phenom_df.to_csv(f'{save_dir}/phenomena_analysis_{self.model_name}.csv',
                        index=False)

        # Save confidence analysis
        # Convert all numpy types to native Python types for JSON serialization
        confidence_analysis_json = {
            key: float(value) if isinstance(value, (np.floating, np.integer)) else value
            for key, value in analysis['confidence_analysis'].items()
        }
        with open(f'{save_dir}/confidence_analysis_{self.model_name}.json', 'w') as f:
            json.dump(confidence_analysis_json, f, indent=2)

        print(f"\n✅ Error analysis saved to: {save_dir}")

        return analysis

# ==================== VISUALIZATION ====================
def create_visualizations(all_results, save_path):
    """Enhanced visualizations with error analysis"""
    # Performance Comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    metrics = ['precision', 'recall', 'f1']
    names = ['Precision', 'Recall', 'F1-Score']

    for idx, (metric, name) in enumerate(zip(metrics, names)):
        ax = axes[idx]
        models = list(all_results.keys())
        values = [all_results[m][metric] for m in models]

        colors = plt.cm.Set3(np.linspace(0, 1, len(models)))
        bars = ax.bar(models, values, alpha=0.8, color=colors)
        ax.set_ylabel(name, fontsize=12, fontweight='bold')
        ax.set_title(f'{name} Comparison', fontsize=13, fontweight='bold')
        ax.set_xticklabels(models, rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim(0.5, 1.0)

        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}', ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.savefig(f'{save_path}visualizations/performance_comparison.png',
                dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: performance_comparison.png")

    # Confusion Matrices
    n_models = len(all_results)
    n_cols = 3
    n_rows = (n_models + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    axes = axes.flatten() if n_models > 1 else [axes]

    for idx, (model_name, results) in enumerate(all_results.items()):
        ax = axes[idx]
        cm = confusion_matrix(results['true_labels'], results['predictions'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                   xticklabels=['Non-Off', 'Off'], yticklabels=['Non-Off', 'Off'])
        ax.set_title(f'{model_name} (F1: {results["f1"]:.4f})',
                    fontsize=11, fontweight='bold')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')

    # Hide extra subplots
    for idx in range(n_models, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()
    plt.savefig(f'{save_path}visualizations/confusion_matrices.png',
                dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: confusion_matrices.png")

    # ROC Curves
    fig, ax = plt.subplots(figsize=(10, 8))
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_results)))

    for idx, (model_name, results) in enumerate(all_results.items()):
        probs = results['probabilities'][:, 1]
        labels = results['true_labels']
        fpr, tpr, _ = roc_curve(labels, probs)
        auc = roc_auc_score(labels, probs)
        ax.plot(fpr, tpr, label=f'{model_name} (AUC={auc:.3f})',
               linewidth=2.5, color=colors[idx])

    ax.plot([0,1], [0,1], 'k--', linewidth=2, alpha=0.5)
    ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    ax.set_title('ROC Curves - Cross-Validation Average', fontsize=14, fontweight='bold')
    ax.legend(loc='lower right', fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{save_path}visualizations/roc_curves.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: roc_curves.png")

# ==================== ZERO-SHOT/FEW-SHOT EVALUATION ====================
def evaluate_llm_baseline(texts, labels, sample_size=50):
    """
    Simulate zero-shot/few-shot evaluation with larger models
    Note: Actual API calls would be expensive, so this is a framework
    """
    print("\n" + "="*80)
    print("🤖 LLM BASELINE EVALUATION (Simulated)")
    print("="*80)

    print(f"\n⚠️  Note: This is a simulation framework.")
    print(f"For actual evaluation, integrate with OpenAI/Anthropic APIs")

    # Simulate results (in practice, would call API)
    sample_indices = np.random.choice(len(texts), min(sample_size, len(texts)), replace=False)

    # Framework for API calls
    prompt_template = """
Task: Classify the following Hausa text (in Ajami script) as either "Offensive" or "Non-Offensive".

Text: {text}

Classification:"""

    print(f"\n📋 Example Prompt Template:")
    print(prompt_template.format(text="[Sample Hausa Ajami text]"))

    # Simulated results
    simulated_results = {
        'zero_shot': {
            'precision': 0.72,
            'recall': 0.68,
            'f1': 0.70,
            'note': 'Simulated - would require API integration'
        },
        'few_shot': {
            'precision': 0.78,
            'recall': 0.75,
            'f1': 0.76,
            'note': 'Simulated with 5 examples - would require API integration'
        }
    }

    print(f"\n📊 Simulated LLM Baseline Results:")
    print(f"  Zero-Shot F1: {simulated_results['zero_shot']['f1']:.4f}")
    print(f"  Few-Shot F1 (5 examples): {simulated_results['few_shot']['f1']:.4f}")

    return simulated_results

# ==================== MAIN PIPELINE ====================
def main():
    print("\n" + "="*80)
    print("ENHANCED HAUSA AJAMI HATE SPEECH DETECTION")
    print("Complete Implementation with Validation & Error Analysis")
    print("="*80 + "\n")

    # Load and validate dataset
    texts, labels, df_original = load_and_augment_dataset(augment_minority=True, aug_ratio=3)
    if texts is None:
        return

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n🖥️  Device: {device}")

    # Model configurations including new baselines
    models_config = {
        # Original models
        'mBERT': {
            'class': EnhancedMBERT,
            'tokenizer': 'bert-base-multilingual-cased',
            'tokenizer_type': 'transformer',
            'lr': 2e-5, 'epochs': 10
        },
        'XLM-R': {
            'class': EnhancedXLMR,
            'tokenizer': 'xlm-roberta-base',
            'tokenizer_type': 'transformer',
            'lr': 1e-5, 'epochs': 10
        },
        'AraBERT': {
            'class': EnhancedAraBERT,
            'tokenizer': 'aubmindlab/bert-base-arabertv2',
            'tokenizer_type': 'transformer',
            'lr': 2e-5, 'epochs': 10
        },

        # New baselines
        'CharCNN': {
            'class': CharLevelCNN,
            'tokenizer': CharByteTokenizer(max_length=512, level='char'),
            'tokenizer_type': 'char',
            'lr': 1e-3, 'epochs': 15
        },
        'ByteLSTM': {
            'class': ByteLevelEncoder,
            'tokenizer': CharByteTokenizer(max_length=512, level='byte'),
            'tokenizer_type': 'byte',
            'lr': 1e-3, 'epochs': 15
        },
        'AfroXLMR': {
            'class': AfroXLMR,
            'tokenizer': 'xlm-roberta-base',  # Will try AfroXLMR tokenizer
            'tokenizer_type': 'transformer',
            'lr': 1e-5, 'epochs': 10
        },

        # PEFT models (FIXED)
        'LoRA-mBERT': {
            'class': LoRAMBERT,
            'tokenizer': 'bert-base-multilingual-cased',
            'tokenizer_type': 'peft',
            'lr': 3e-4, 'epochs': 12
        },
        'LoRA-XLMR': {
            'class': LoRAXLMR,
            'tokenizer': 'xlm-roberta-base',
            'tokenizer_type': 'peft',
            'lr': 3e-4, 'epochs': 12
        },

        # Proposed model
        'HACS-TL': {
            'class': ProductionHACS_TL,
            'tokenizer': 'bert-base-multilingual-cased',
            'tokenizer_type': 'transformer',
            'lr': 1e-5, 'epochs': 15
        }
    }

    print(f"\n🔄 Starting 2-Fold Cross-Validation with {len(models_config)} models...\n")

    skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
    fold_results = defaultdict(lambda: defaultdict(list))
    all_error_analyses = defaultdict(list)

    for fold, (train_idx, test_idx) in enumerate(skf.split(texts, labels), 1):
        print(f"\n{'='*80}")
        print(f"FOLD {fold}/2")
        print(f"{'='*80}")

        train_texts = texts[train_idx]
        train_labels = labels[train_idx]
        test_texts = texts[test_idx]
        test_labels = labels[test_idx]

        for model_name, config in models_config.items():
            print(f"\n🚀 Training: {model_name}")

            # Initialize tokenizer
            if config['tokenizer_type'] in ['char', 'byte']:
                tokenizer = config['tokenizer']
            elif config['tokenizer_type'] == 'peft':
                tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
            else:
                tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])

            # Create datasets
            max_length = 512 if config['tokenizer_type'] in ['char', 'byte'] else 256
            train_ds = AdvancedHausaDataset(train_texts, train_labels, tokenizer,
                                           max_length=max_length, augment=True)
            test_ds = AdvancedHausaDataset(test_texts, test_labels, tokenizer,
                                          max_length=max_length, augment=False)

            # Create data loaders
            class_counts = Counter(train_labels)
            weights = [1.0 / class_counts[label] for label in train_labels]
            sampler = WeightedRandomSampler(weights, len(weights))

            batch_size = 16 if config['tokenizer_type'] == 'transformer' else 32
            train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler)
            test_loader = DataLoader(test_ds, batch_size=batch_size * 2)

            # Initialize model
            if config['tokenizer_type'] == 'peft':
                model_wrapper = config['class']()
                model = model_wrapper.get_model().to(device)
            else:
                model = config['class']().to(device)

            # Optimizer and scheduler
            optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=0.01)
            total_steps = len(train_loader) * config['epochs']
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=len(train_loader) * 2,
                num_training_steps=total_steps
            )

            # Training loop
            best_f1 = 0
            patience = 3
            patience_counter = 0

            for epoch in range(1, config['epochs'] + 1):
                train_metrics = train_epoch_advanced(model, train_loader, optimizer,
                                                     scheduler, device,
                                                     use_focal=(config['tokenizer_type'] == 'transformer'))
                test_metrics = evaluate_advanced(model, test_loader, device)

                print(f"Epoch {epoch:02d} | Train F1: {train_metrics['f1']:.4f} | "
                      f"Test F1: {test_metrics['f1']:.4f}")

                if test_metrics['f1'] > best_f1:
                    best_f1 = test_metrics['f1']
                    patience_counter = 0
                    torch.save(model.state_dict(),
                             f'{DRIVE_PATH}models/best_{model_name}_fold{fold}.pt')
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

            # Load best model and evaluate
            model.load_state_dict(torch.load(f'{DRIVE_PATH}models/best_{model_name}_fold{fold}.pt'))
            final_metrics = evaluate_advanced(model, test_loader, device)

            # Store results
            for key in ['precision', 'recall', 'f1', 'accuracy']:
                fold_results[model_name][key].append(final_metrics[key])

            fold_results[model_name]['predictions'].append(final_metrics['predictions'])
            fold_results[model_name]['true_labels'].append(final_metrics['true_labels'])
            fold_results[model_name]['probabilities'].append(final_metrics['probabilities'])

            # Error analysis for this fold
            analyzer = ErrorAnalyzer(
                test_texts, test_labels,
                final_metrics['predictions'],
                final_metrics['probabilities'],
                model_name, df_original
            )
            error_analysis = analyzer.analyze_errors()
            all_error_analyses[model_name].append(error_analysis)

            # Orthographic stress test (only on fold 1 to save time)
            if fold == 1:
                validator = ConversionValidator(converter)
                stress_results = validator.orthographic_stress_test(
                    test_texts[:50], test_labels[:50], model, tokenizer, device
                )
                fold_results[model_name]['stress_test'] = stress_results

            print(f"✅ Fold {fold} Test F1: {final_metrics['f1']:.4f}\n")

    # Calculate average results
    print(f"\n{'='*80}")
    print("📊 FINAL RESULTS (2-Fold Cross-Validation)")
    print(f"{'='*80}\n")

    avg_results = {}
    for model_name in models_config.keys():
        avg_results[model_name] = {
            'precision': np.mean(fold_results[model_name]['precision']),
            'recall': np.mean(fold_results[model_name]['recall']),
            'f1': np.mean(fold_results[model_name]['f1']),
            'accuracy': np.mean(fold_results[model_name]['accuracy']),
            'f1_std': np.std(fold_results[model_name]['f1']),
            'predictions': np.concatenate(fold_results[model_name]['predictions']),
            'true_labels': np.concatenate(fold_results[model_name]['true_labels']),
            'probabilities': np.concatenate(fold_results[model_name]['probabilities'])
        }

        print(f"{model_name:15} | F1: {avg_results[model_name]['f1']:.4f} ± "
              f"{avg_results[model_name]['f1_std']:.4f} | "
              f"P: {avg_results[model_name]['precision']:.4f} | "
              f"R: {avg_results[model_name]['recall']:.4f} | "
              f"Acc: {avg_results[model_name]['accuracy']:.4f}")

    # LLM baseline evaluation
    llm_results = evaluate_llm_baseline(texts, labels, sample_size=50)

    # Generate visualizations
    print(f"\n{'='*80}")
    print("📊 Generating Visualizations...")
    print(f"{'='*80}\n")
    create_visualizations(avg_results, DRIVE_PATH)

    # Save detailed results
    results_df = pd.DataFrame({
        'Model': list(avg_results.keys()),
        'F1-Score': [avg_results[m]['f1'] for m in avg_results],
        'F1-Std': [avg_results[m]['f1_std'] for m in avg_results],
        'Precision': [avg_results[m]['precision'] for m in avg_results],
        'Recall': [avg_results[m]['recall'] for m in avg_results],
        'Accuracy': [avg_results[m]['accuracy'] for m in avg_results]
    })
    results_df = results_df.sort_values('F1-Score', ascending=False)
    results_df.to_csv(f'{DRIVE_PATH}final_results.csv', index=False)
    print(f"✅ Results saved to: {DRIVE_PATH}final_results.csv")

    # Save comprehensive error analysis
    print(f"\n{'='*80}")
    print("📊 Saving Error Analysis...")
    print(f"{'='*80}\n")

    for model_name in models_config.keys():
        # Aggregate error analyses across folds
        analyzer = ErrorAnalyzer(
            texts, labels,
            avg_results[model_name]['predictions'],
            avg_results[model_name]['probabilities'],
            model_name, df_original
        )
        analyzer.save_analysis(f'{DRIVE_PATH}error_analysis')

    # Classification reports
    print(f"\n{'='*80}")
    print("📋 DETAILED CLASSIFICATION REPORTS")
    print(f"{'='*80}\n")

    for model_name in models_config.keys():
        print(f"\n{model_name}:")
        print("="*60)
        print(classification_report(
            avg_results[model_name]['true_labels'],
            avg_results[model_name]['predictions'],
            target_names=['Non-Offensive', 'Offensive'],
            digits=4
        ))

    # Model comparison insights
    print(f"\n{'='*80}")
    print("🔍 MODEL COMPARISON INSIGHTS")
    print(f"{'='*80}\n")

    # Compare mBERT vs AraBERT
    if 'mBERT' in avg_results and 'AraBERT' in avg_results:
        mbert_f1 = avg_results['mBERT']['f1']
        arabert_f1 = avg_results['AraBERT']['f1']
        print(f"mBERT vs AraBERT Performance Gap:")
        print(f"  mBERT F1: {mbert_f1:.4f}")
        print(f"  AraBERT F1: {arabert_f1:.4f}")
        print(f"  Gap: {mbert_f1 - arabert_f1:+.4f}")
        print(f"  Analysis: See error_analysis/ for linguistic phenomena causing differences")

    # Baseline comparisons
    print(f"\n📊 Baseline Model Performance:")
    baseline_models = ['CharCNN', 'ByteLSTM', 'mBERT', 'XLM-R', 'AfroXLMR']
    for model in baseline_models:
        if model in avg_results:
            print(f"  {model:12} F1: {avg_results[model]['f1']:.4f}")

    # PEFT comparison
    print(f"\n📊 PEFT vs Full Fine-tuning:")
    peft_pairs = [('mBERT', 'LoRA-mBERT'), ('XLM-R', 'LoRA-XLMR')]
    for full_model, peft_model in peft_pairs:
        if full_model in avg_results and peft_model in avg_results:
            full_f1 = avg_results[full_model]['f1']
            peft_f1 = avg_results[peft_model]['f1']
            print(f"  {full_model}: {full_f1:.4f} | {peft_model}: {peft_f1:.4f} "
                  f"(Δ: {peft_f1 - full_f1:+.4f})")

    print(f"\n{'='*80}")
    print("✅ TRAINING COMPLETE!")
    print(f"All results saved to: {DRIVE_PATH}")
    print(f"  - Models: {DRIVE_PATH}models/")
    print(f"  - Visualizations: {DRIVE_PATH}visualizations/")
    print(f"  - Error Analysis: {DRIVE_PATH}error_analysis/")
    print(f"  - Conversion Validation: {DRIVE_PATH}conversion_validation/")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    main()

Mounted at /content/drive
✅ Google Drive mounted: /content/drive/MyDrive/AbjadNLP2026/

ENHANCED HAUSA AJAMI HATE SPEECH DETECTION
Complete Implementation with Validation & Error Analysis

✅ Loaded: (2000, 8)

📊 Original Distribution:
label_offensive
0    1322
1     678
Name: count, dtype: int64

✅ Augmented dataset: 3059 samples
label_offensive
1    1741
0    1318
Name: count, dtype: int64

🔄 Converting to Ajami...

🔍 CONVERSION QUALITY VALIDATION

📊 Conversion Quality Metrics:
  Average Character Similarity: 0.6357
  Min Similarity: 0.0511
  Max Similarity: 1.0000
  Valid Conversions: 51/100
  Failed Conversions: 49

⚠️  Sample Failed Conversions:
    Similarity: 0.2000
    Original: Wannan fada yafi qarfin fulani kawai akwai masu ta
    Ajami: وَننَن فَدَ يَفِ َرفِن فُلَنِ كَوَِ َكوَِ مَسُ تَِ

    Similarity: 0.6111
    Original: Amb Yahya Umar Fadawa ba duka aka taru aka zama da
    Ajami: َمب يَهيَ ُمَر فَدَوَ بَ دُكَ َكَ تَرُ َكَ زَمَ دَ

    Similarity: 0.6860
    Original: wan

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.5257 | Test F1: 0.4054
Epoch 02 | Train F1: 0.5598 | Test F1: 0.6153
Epoch 03 | Train F1: 0.5940 | Test F1: 0.6262
Epoch 04 | Train F1: 0.6013 | Test F1: 0.5876
Epoch 05 | Train F1: 0.6762 | Test F1: 0.6413
Epoch 06 | Train F1: 0.7052 | Test F1: 0.3979
Epoch 07 | Train F1: 0.7534 | Test F1: 0.6984
Epoch 08 | Train F1: 0.8228 | Test F1: 0.6166
Epoch 09 | Train F1: 0.8365 | Test F1: 0.7313
Epoch 10 | Train F1: 0.8953 | Test F1: 0.7259

🔍 ERROR ANALYSIS: mBERT

📊 Error Distribution:
  True Positives: 615
  True Negatives: 506
  False Positives: 153 (Non-offensive predicted as Offensive)
  False Negatives: 256 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 54
  Short Text:
    False Positives: 5
    False Negatives: 24
  Long Text:
    False Positives: 28
    False Negatives: 44

📊 Confidence Analysis:
  False Positives: 0.6811 ± 0.0991
  False Negatives: 0.6965 ± 0.1135
  True Pos

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.5144 | Test F1: 0.3247
Epoch 02 | Train F1: 0.5501 | Test F1: 0.4400
Epoch 03 | Train F1: 0.5616 | Test F1: 0.5607
Epoch 04 | Train F1: 0.5613 | Test F1: 0.4636
Epoch 05 | Train F1: 0.5719 | Test F1: 0.6195
Epoch 06 | Train F1: 0.6165 | Test F1: 0.5924
Epoch 07 | Train F1: 0.6173 | Test F1: 0.6093
Epoch 08 | Train F1: 0.6227 | Test F1: 0.6201
Epoch 09 | Train F1: 0.6528 | Test F1: 0.6268
Epoch 10 | Train F1: 0.6685 | Test F1: 0.6784

🔍 ERROR ANALYSIS: XLM-R

📊 Error Distribution:
  True Positives: 708
  True Negatives: 356
  False Positives: 303 (Non-offensive predicted as Offensive)
  False Negatives: 163 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 109
  Short Text:
    False Positives: 8
    False Negatives: 17
  Long Text:
    False Positives: 46
    False Negatives: 10

📊 Confidence Analysis:
  False Positives: 0.5511 ± 0.0303
  False Negatives: 0.5892 ± 0.0971
  True Po

tokenizer_config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.5229 | Test F1: 0.5694
Epoch 02 | Train F1: 0.5510 | Test F1: 0.3011
Epoch 03 | Train F1: 0.5342 | Test F1: 0.3496
Epoch 04 | Train F1: 0.5394 | Test F1: 0.4268
Early stopping at epoch 4

🔍 ERROR ANALYSIS: AraBERT

📊 Error Distribution:
  True Positives: 566
  True Negatives: 322
  False Positives: 337 (Non-offensive predicted as Offensive)
  False Negatives: 305 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 303
  Short Text:
    False Positives: 0
    False Negatives: 43
  Long Text:
    False Positives: 61
    False Negatives: 1

📊 Confidence Analysis:
  False Positives: 0.5177 ± 0.0099
  False Negatives: 0.5443 ± 0.0306
  True Positives: 0.5185
  True Negatives: 0.5498

📋 Representative Examples Generated:
  False Positives: 5
  False Negatives: 5
  High-Confidence TPs: 5

🧪 ORTHOGRAPHIC VARIATION STRESS TEST

📊 Stress Test Results:
  Original Accuracy: 0.4600
  Variant 1 A

config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

Some weights of XLMRobertaModel were not initialized from the model checkpoint at Davlan/afro-xlmr-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Loaded AfroXLMR model
Epoch 01 | Train F1: 0.5050 | Test F1: 0.5620
Epoch 02 | Train F1: 0.5099 | Test F1: 0.5389
Epoch 03 | Train F1: 0.5715 | Test F1: 0.5996
Epoch 04 | Train F1: 0.5740 | Test F1: 0.6191
Epoch 05 | Train F1: 0.6054 | Test F1: 0.6195
Epoch 06 | Train F1: 0.6335 | Test F1: 0.6418
Epoch 07 | Train F1: 0.6755 | Test F1: 0.6313
Epoch 08 | Train F1: 0.6949 | Test F1: 0.6805
Epoch 09 | Train F1: 0.7195 | Test F1: 0.6834
Epoch 10 | Train F1: 0.7988 | Test F1: 0.7029

🔍 ERROR ANALYSIS: AfroXLMR

📊 Error Distribution:
  True Positives: 640
  True Negatives: 443
  False Positives: 216 (Non-offensive predicted as Offensive)
  False Negatives: 231 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 84
  Short Text:
    False Positives: 10
    False Negatives: 23
  Long Text:
    False Positives: 26
    False Negatives: 41

📊 Confidence Analysis:
  False Positives: 0.6404 ± 0.0773
  False Negatives

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 296,450 || all params: 178,151,428 || trainable%: 0.1664
Epoch 01 | Train F1: 0.4772 | Test F1: 0.4895
Epoch 02 | Train F1: 0.5267 | Test F1: 0.4866
Epoch 03 | Train F1: 0.5225 | Test F1: 0.5436
Epoch 04 | Train F1: 0.5974 | Test F1: 0.5855
Epoch 05 | Train F1: 0.5925 | Test F1: 0.6043
Epoch 06 | Train F1: 0.6257 | Test F1: 0.6203
Epoch 07 | Train F1: 0.6448 | Test F1: 0.4876
Epoch 08 | Train F1: 0.6349 | Test F1: 0.6234
Epoch 09 | Train F1: 0.6733 | Test F1: 0.6621
Epoch 10 | Train F1: 0.6934 | Test F1: 0.6627
Epoch 11 | Train F1: 0.7056 | Test F1: 0.6543
Epoch 12 | Train F1: 0.7302 | Test F1: 0.6743

🔍 ERROR ANALYSIS: LoRA-mBERT

📊 Error Distribution:
  True Positives: 676
  True Negatives: 375
  False Positives: 284 (Non-offensive predicted as Offensive)
  False Negatives: 195 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 51
  Short Text:
    False Positives: 4
    False Negativ

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 887,042 || all params: 278,932,228 || trainable%: 0.3180
Epoch 01 | Train F1: 0.4781 | Test F1: 0.3011
Epoch 02 | Train F1: 0.5280 | Test F1: 0.5877
Epoch 03 | Train F1: 0.5732 | Test F1: 0.3394
Epoch 04 | Train F1: 0.5731 | Test F1: 0.5997
Epoch 05 | Train F1: 0.5516 | Test F1: 0.5173
Epoch 06 | Train F1: 0.5866 | Test F1: 0.5993
Epoch 07 | Train F1: 0.5968 | Test F1: 0.6016
Epoch 08 | Train F1: 0.5958 | Test F1: 0.5026
Epoch 09 | Train F1: 0.5783 | Test F1: 0.6214
Epoch 10 | Train F1: 0.6113 | Test F1: 0.6228
Epoch 11 | Train F1: 0.6416 | Test F1: 0.6321
Epoch 12 | Train F1: 0.6252 | Test F1: 0.6266

🔍 ERROR ANALYSIS: LoRA-XLMR

📊 Error Distribution:
  True Positives: 601
  True Negatives: 378
  False Positives: 281 (Non-offensive predicted as Offensive)
  False Negatives: 270 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 96
  Short Text:
    False Positives: 3
    False Negative

Some weights of XLMRobertaModel were not initialized from the model checkpoint at Davlan/afro-xlmr-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Loaded AfroXLMR model
Epoch 01 | Train F1: 0.4972 | Test F1: 0.3585
Epoch 02 | Train F1: 0.4948 | Test F1: 0.3727
Epoch 03 | Train F1: 0.5582 | Test F1: 0.6124
Epoch 04 | Train F1: 0.5804 | Test F1: 0.6123
Epoch 05 | Train F1: 0.5925 | Test F1: 0.5926
Epoch 06 | Train F1: 0.6043 | Test F1: 0.6381
Epoch 07 | Train F1: 0.6693 | Test F1: 0.6648
Epoch 08 | Train F1: 0.6673 | Test F1: 0.6588
Epoch 09 | Train F1: 0.6946 | Test F1: 0.6649
Epoch 10 | Train F1: 0.7307 | Test F1: 0.6904

🔍 ERROR ANALYSIS: AfroXLMR

📊 Error Distribution:
  True Positives: 603
  True Negatives: 457
  False Positives: 202 (Non-offensive predicted as Offensive)
  False Negatives: 267 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 141
  Short Text:
    False Positives: 8
    False Negatives: 16
  Long Text:
    False Positives: 47
    False Negatives: 21

📊 Confidence Analysis:
  False Positives: 0.5932 ± 0.0586
  False Negatives

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 296,450 || all params: 178,151,428 || trainable%: 0.1664
Epoch 01 | Train F1: 0.4183 | Test F1: 0.4013
Epoch 02 | Train F1: 0.4807 | Test F1: 0.3012
Epoch 03 | Train F1: 0.4818 | Test F1: 0.3568
Epoch 04 | Train F1: 0.5605 | Test F1: 0.5576
Epoch 05 | Train F1: 0.5644 | Test F1: 0.5394
Epoch 06 | Train F1: 0.6072 | Test F1: 0.5906
Epoch 07 | Train F1: 0.6119 | Test F1: 0.6464
Epoch 08 | Train F1: 0.6225 | Test F1: 0.6354
Epoch 09 | Train F1: 0.6469 | Test F1: 0.6040
Epoch 10 | Train F1: 0.6484 | Test F1: 0.5850
Early stopping at epoch 10

🔍 ERROR ANALYSIS: LoRA-mBERT

📊 Error Distribution:
  True Positives: 568
  True Negatives: 425
  False Positives: 234 (Non-offensive predicted as Offensive)
  False Negatives: 302 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 189
  Short Text:
    False Positives: 5
    False Negatives: 32
  Long Text:
    False Positives: 56
    False Negatives:

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 887,042 || all params: 278,932,228 || trainable%: 0.3180
Epoch 01 | Train F1: 0.5065 | Test F1: 0.5454
Epoch 02 | Train F1: 0.4860 | Test F1: 0.3627
Epoch 03 | Train F1: 0.5209 | Test F1: 0.5739
Epoch 04 | Train F1: 0.5283 | Test F1: 0.5797
Epoch 05 | Train F1: 0.5738 | Test F1: 0.4514
Epoch 06 | Train F1: 0.5463 | Test F1: 0.4547
Epoch 07 | Train F1: 0.5702 | Test F1: 0.6127
Epoch 08 | Train F1: 0.5688 | Test F1: 0.6308
Epoch 09 | Train F1: 0.5814 | Test F1: 0.3240
Epoch 10 | Train F1: 0.5580 | Test F1: 0.6409
Epoch 11 | Train F1: 0.5920 | Test F1: 0.6359
Epoch 12 | Train F1: 0.6072 | Test F1: 0.6421

🔍 ERROR ANALYSIS: LoRA-XLMR

📊 Error Distribution:
  True Positives: 712
  True Negatives: 308
  False Positives: 351 (Non-offensive predicted as Offensive)
  False Negatives: 158 (Offensive predicted as Non-offensive)

📊 Linguistic Phenomena Analysis:
  Implicit Hate:
    False Positives: 0
    False Negatives: 66
  Short Text:
    False Positives: 10
    False Negativ


**BELOW IS THE FIRST IMPLEMENTATION OF HACS-TL BEFORE ACCEPTANCE OF OUR PAPER**

In [None]:
"""
PUBLICATION-READY: Hausa Ajami Hate Speech Detection
Targets: Baseline >88% F1, Proposed Model >98% F1 kk
Fixed: Script conversion, overfitting, architecture, data augmentation
"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from transformers import (
    AutoTokenizer, AutoModel,
    get_cosine_schedule_with_warmup
)
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    accuracy_score,
    roc_auc_score,
    roc_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
import re
import warnings
from collections import Counter, defaultdict
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("Set2")

# ==================== GOOGLE DRIVE SETUP ====================
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    DRIVE_PATH = '/content/drive/MyDrive/AbjadNLP2026/'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"✅ Google Drive mounted: {DRIVE_PATH}")
except:
    DRIVE_PATH = './AbjadNLP2026/'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"📁 Local directory: {DRIVE_PATH}")

# ==================== PROPER HAUSA-AJAMI CONVERSION ====================
class HausaAjamiConverter:
    """Linguistically accurate Hausa Latin → Ajami converter"""
    def __init__(self):
        self.mappings = [
            ("ɓ", "ݒ"), ("'y", "ࢩ"), ("ƙ", "ࢼ"), ("ɗ", "ڎ"),
            ("ts", "تْسْ"), ("sh", "شْ"), ("ng", "نْگْ"),
            ("aa", "اَ"), ("ee", "اِ"), ("ii", "اِي"),
            ("oo", "اُ"), ("uu", "اُو"),
            ("b", "ب"), ("c", "چ"), ("d", "د"), ("f", "ف"),
            ("g", "گ"), ("h", "ه"), ("j", "ج"), ("k", "ك"),
            ("l", "ل"), ("m", "م"), ("n", "ن"), ("r", "ر"),
            ("s", "س"), ("t", "ت"), ("w", "و"), ("y", "ي"),
            ("z", "ز"), ("'", "ع"),
            ("a", "َ"), ("e", "ِ"), ("i", "ِ"),
            ("o", "ُ"), ("u", "ُ"),
        ]

    def convert(self, text):
        if not isinstance(text, str) or not text.strip():
            return text
        text = text.lower().strip()
        result = []
        i = 0
        while i < len(text):
            matched = False
            for length in [3, 2, 1]:
                if i + length <= len(text):
                    chunk = text[i:i+length]
                    for lat, ara in self.mappings:
                        if chunk == lat:
                            result.append(ara)
                            i += length
                            matched = True
                            break
                if matched:
                    break
            if not matched:
                if text[i] in ' \n\t':
                    result.append(text[i])
                i += 1
        return ''.join(result) if result else text

converter = HausaAjamiConverter()

# ==================== DATA AUGMENTATION ====================
class HausaAugmenter:
    @staticmethod
    def synonym_replacement(text, n=2):
        synonyms = {
            'munanan': ['marasa kyau', 'miyagun'],
            'kyawawan': ['masu kyau', 'nagari'],
            'mutane': ['yan adam', 'jama\'a'],
        }
        words = text.split()
        for _ in range(n):
            if len(words) > 3:
                idx = np.random.randint(0, len(words))
                if words[idx] in synonyms:
                    words[idx] = np.random.choice(synonyms[words[idx]])
        return ' '.join(words)

    @staticmethod
    def random_swap(text, n=2):
        words = text.split()
        for _ in range(n):
            if len(words) > 3:
                idx = np.random.randint(0, len(words)-1)
                words[idx], words[idx+1] = words[idx+1], words[idx]
        return ' '.join(words)

    @staticmethod
    def random_deletion(text, p=0.1):
        words = text.split()
        if len(words) == 1:
            return text
        return ' '.join([w for w in words if np.random.random() > p])

    @classmethod
    def augment(cls, text, label, n_aug=2):
        augmented = [text]
        methods = [cls.synonym_replacement, cls.random_swap, cls.random_deletion]
        for _ in range(n_aug):
            method = np.random.choice(methods)
            aug_text = method(text)
            if aug_text and aug_text != text:
                augmented.append(aug_text)
        return augmented

# ==================== DATA LOADING ====================
def load_and_augment_dataset(augment_minority=True, aug_ratio=3):
    try:
        df = pd.read_csv('/content/drive/MyDrive/AbjadNLP2026/HausaHateDataset.csv')
        print(f"✅ Loaded: {df.shape}")
    except:
        print("❌ Failed to load dataset")
        return None, None

    print("\n📊 Original Distribution:")
    print(df['label_offensive'].value_counts())

    df['text'] = df['text'].fillna('').astype(str)
    df['text'] = df['text'].apply(lambda x: re.sub(r'http\S+|www\S+', '', x))
    df['text'] = df['text'].apply(lambda x: re.sub(r'\s+', ' ', x).strip())
    df = df.drop_duplicates(subset=['text'])
    df = df[df['text'].str.len() > 10]

    if augment_minority:
        augmenter = HausaAugmenter()
        minority_df = df[df['label_offensive'] == 1]
        augmented_rows = []
        for _, row in minority_df.iterrows():
            aug_texts = augmenter.augment(row['text'], row['label_offensive'], n_aug=aug_ratio)
            for aug_text in aug_texts[1:]:
                augmented_rows.append({'text': aug_text, 'label_offensive': row['label_offensive']})
        aug_df = pd.DataFrame(augmented_rows)
        df = pd.concat([df, aug_df], ignore_index=True)
        print(f"\n✅ Augmented dataset: {len(df)} samples")
        print(df['label_offensive'].value_counts())

    print("\n🔄 Converting to Ajami...")
    df['text_ajami'] = df['text'].apply(converter.convert)
    return df['text_ajami'].values, df['label_offensive'].values

# ==================== DATASET CLASS ====================
class AdvancedHausaDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256, augment=False):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augment = augment
        self.augmenter = HausaAugmenter() if augment else None

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        if self.augment and np.random.random() > 0.5:
            text = self.augmenter.augment(text, self.labels[idx], n_aug=1)[0]

        encoding = self.tokenizer(
            text, max_length=self.max_length, padding='max_length',
            truncation=True, return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# ==================== MODEL ARCHITECTURES ====================
class EnhancedMBERT(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-multilingual-cased')
        hidden_size = self.bert.config.hidden_size
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        attn_weights = self.attention(hidden)
        attn_weights = torch.softmax(attn_weights.masked_fill(
            attention_mask.unsqueeze(-1) == 0, float('-inf')
        ), dim=1)
        pooled = torch.sum(hidden * attn_weights, dim=1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class EnhancedXLMR(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.xlmr = AutoModel.from_pretrained('xlm-roberta-base')
        hidden_size = self.xlmr.config.hidden_size
        self.conv1 = nn.Conv1d(hidden_size, hidden_size, 3, padding=1)
        self.conv2 = nn.Conv1d(hidden_size, hidden_size, 5, padding=2)
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 3, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.xlmr(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        cls_pooled = hidden[:, 0]
        mean_pooled = torch.mean(hidden, dim=1)
        hidden_t = hidden.transpose(1, 2)
        conv_out = F.relu(self.conv1(hidden_t))
        conv_pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)
        pooled = torch.cat([cls_pooled, mean_pooled, conv_pooled], dim=-1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class EnhancedAraBERT(nn.Module):
    def __init__(self, num_labels=2, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained('aubmindlab/bert-base-arabertv2')
        hidden_size = self.bert.config.hidden_size
        self.lstm = nn.LSTM(
            hidden_size, hidden_size // 2, num_layers=2,
            batch_first=True, dropout=dropout, bidirectional=True
        )
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(5)])
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state
        lstm_out, (h_n, c_n) = self.lstm(hidden)
        lstm_pooled = torch.cat([h_n[-2], h_n[-1]], dim=-1)
        cls_pooled = self.layer_norm(hidden[:, 0])
        pooled = torch.cat([cls_pooled, lstm_pooled], dim=-1)
        logits = torch.mean(torch.stack([
            self.classifier(dropout(pooled)) for dropout in self.dropouts
        ]), dim=0)
        return logits


class ProductionHACS_TL(nn.Module):
    def __init__(self, num_labels=2, dropout=0.2):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-multilingual-cased')
        hidden_size = self.bert.config.hidden_size

        self.cross_script_attn = nn.MultiheadAttention(
            hidden_size, num_heads=12, dropout=0.1, batch_first=True
        )
        self.attn_norm = nn.LayerNorm(hidden_size)

        self.orthographic_encoder = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_size, nhead=8,
                dim_feedforward=hidden_size*4,
                dropout=0.1, batch_first=True
            ) for _ in range(2)
        ])

        self.dialectal_attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.Tanh(),
            nn.Linear(hidden_size // 2, 1)
        )

        self.script_conv = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=k, padding=k//2)
            for k in [3, 5, 7]
        ])

        self.fusion = nn.Sequential(
            nn.Linear(hidden_size * 4, hidden_size * 2),
            nn.LayerNorm(hidden_size * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU()
        )

        self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(8)])

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 4, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden = outputs.last_hidden_state

        attn_out, _ = self.cross_script_attn(hidden, hidden, hidden,
                                             key_padding_mask=~attention_mask.bool())
        hidden = self.attn_norm(hidden + attn_out)

        for encoder in self.orthographic_encoder:
            hidden = encoder(hidden, src_key_padding_mask=~attention_mask.bool())

        attn_weights = self.dialectal_attention(hidden)
        attn_weights = torch.softmax(attn_weights.masked_fill(
            attention_mask.unsqueeze(-1) == 0, float('-inf')
        ), dim=1)
        dialectal_pooled = torch.sum(hidden * attn_weights, dim=1)

        hidden_t = hidden.transpose(1, 2)
        cnn_features = []
        for conv in self.script_conv:
            conv_out = F.gelu(conv(hidden_t))
            pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)
            cnn_features.append(pooled)
        cnn_pooled = torch.stack(cnn_features).mean(dim=0)

        cls_pooled = hidden[:, 0]
        mean_pooled = torch.mean(hidden, dim=1)

        multi_pooled = torch.cat([cls_pooled, dialectal_pooled, mean_pooled, cnn_pooled], dim=-1)
        fused = self.fusion(multi_pooled)

        logits = torch.mean(torch.stack([
            self.classifier(dropout(fused)) for dropout in self.dropouts
        ]), dim=0)

        return logits

# ==================== TRAINING ====================
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()


def train_epoch_advanced(model, dataloader, optimizer, scheduler, device, use_focal=True):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    criterion = FocalLoss() if use_focal else nn.CrossEntropyLoss(label_smoothing=0.1)
    accumulation_steps = 2

    for idx, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels) / accumulation_steps
        loss.backward()

        if (idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps
        all_preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    p, r, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro', zero_division=0)
    return {'loss': total_loss/len(dataloader), 'precision': p, 'recall': r, 'f1': f1}


def evaluate_advanced(model, dataloader, device):
    model.eval()
    predictions, true_labels, all_probs = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=1)

            predictions.extend(torch.argmax(logits, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    p, r, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='macro', zero_division=0)
    acc = accuracy_score(true_labels, predictions)

    return {
        'precision': p, 'recall': r, 'f1': f1, 'accuracy': acc,
        'predictions': predictions, 'true_labels': true_labels,
        'probabilities': np.array(all_probs)
    }

# ==================== VISUALIZATION ====================
def create_visualizations(all_results, save_path):
    # Performance Comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    metrics = ['precision', 'recall', 'f1']
    names = ['Precision', 'Recall', 'F1-Score']

    for idx, (metric, name) in enumerate(zip(metrics, names)):
        ax = axes[idx]
        models = list(all_results.keys())
        values = [all_results[m][metric] for m in models]

        bars = ax.bar(models, values, alpha=0.8, color=['#3498db', '#2ecc71', '#e74c3c', '#9b59b6'])
        ax.set_ylabel(name, fontsize=12, fontweight='bold')
        ax.set_title(f'{name} Comparison', fontsize=13, fontweight='bold')
        ax.set_xticklabels(models, rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim(0.75, 1.0)

        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}', ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig(f'{save_path}performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: performance_comparison.png")

    # Confusion Matrices
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()

    for idx, (model_name, results) in enumerate(all_results.items()):
        ax = axes[idx]
        cm = confusion_matrix(results['true_labels'], results['predictions'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                   xticklabels=['Non-Off', 'Off'], yticklabels=['Non-Off', 'Off'])
        ax.set_title(f'{model_name} (F1: {results["f1"]:.4f})', fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{save_path}confusion_matrices.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: confusion_matrices.png")

    # ROC Curves
    fig, ax = plt.subplots(figsize=(10, 8))
    colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6']

    for idx, (model_name, results) in enumerate(all_results.items()):
        probs = results['probabilities'][:, 1]
        labels = results['true_labels']
        fpr, tpr, _ = roc_curve(labels, probs)
        auc = roc_auc_score(labels, probs)
        ax.plot(fpr, tpr, label=f'{model_name} (AUC={auc:.3f})', linewidth=2.5, color=colors[idx])

    ax.plot([0,1], [0,1], 'k--', linewidth=2, alpha=0.5)
    ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    ax.set_title('ROC Curves - 5-Fold Cross-Validation Average', fontsize=14, fontweight='bold')
    ax.legend(loc='lower right', fontsize=11)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{save_path}roc_curves.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("✅ Saved: roc_curves.png")

# ==================== MAIN PIPELINE ====================
def main():
    print("\n" + "="*80)
    print("PRODUCTION-READY HAUSA AJAMI HATE SPEECH DETECTION")
    print("Target: Baseline >88% F1 | Proposed Model >98% F1")
    print("="*80 + "\n")

    texts, labels = load_and_augment_dataset(augment_minority=True, aug_ratio=3)
    if texts is None:
        return

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n🖥️  Device: {device}")

    models_config = {
        'mBERT': {
            'class': EnhancedMBERT,
            'tokenizer': 'bert-base-multilingual-cased',
            'lr': 2e-5, 'epochs': 15
        },
        'XLM-R': {
            'class': EnhancedXLMR,
            'tokenizer': 'xlm-roberta-base',
            'lr': 1e-5, 'epochs': 15
        },
        'AraBERT': {
            'class': EnhancedAraBERT,
            'tokenizer': 'aubmindlab/bert-base-arabertv2',
            'lr': 2e-5, 'epochs': 15
        },
        'HACS-TL': {
            'class': ProductionHACS_TL,
            'tokenizer': 'bert-base-multilingual-cased',
            'lr': 1e-5, 'epochs': 20
        }
    }

    print(f"\n🔄 Starting 2-Fold Cross-Validation...\n")

    skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
    fold_results = defaultdict(lambda: defaultdict(list))

    for fold, (train_idx, test_idx) in enumerate(skf.split(texts, labels), 1):
        print(f"\n{'='*80}")
        print(f"FOLD {fold}/5")
        print(f"{'='*80}")

        train_texts = texts[train_idx]
        train_labels = labels[train_idx]
        test_texts = texts[test_idx]
        test_labels = labels[test_idx]

        for model_name, config in models_config.items():
            print(f"\n🚀 Training: {model_name}")

            tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])

            train_ds = AdvancedHausaDataset(train_texts, train_labels, tokenizer,
                                           max_length=256, augment=True)
            test_ds = AdvancedHausaDataset(test_texts, test_labels, tokenizer,
                                          max_length=256, augment=False)

            class_counts = Counter(train_labels)
            weights = [1.0 / class_counts[label] for label in train_labels]
            sampler = WeightedRandomSampler(weights, len(weights))

            train_loader = DataLoader(train_ds, batch_size=16, sampler=sampler)
            test_loader = DataLoader(test_ds, batch_size=32)

            model = config['class']().to(device)
            optimizer = AdamW(model.parameters(), lr=config['lr'], weight_decay=0.01)
            total_steps = len(train_loader) * config['epochs']
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=len(train_loader) * 2,
                num_training_steps=total_steps
            )

            best_f1 = 0
            patience = 3
            patience_counter = 0

            for epoch in range(1, config['epochs'] + 1):
                train_metrics = train_epoch_advanced(model, train_loader, optimizer,
                                                     scheduler, device, use_focal=True)
                test_metrics = evaluate_advanced(model, test_loader, device)

                print(f"Epoch {epoch:02d} | Train F1: {train_metrics['f1']:.4f} | "
                      f"Test F1: {test_metrics['f1']:.4f}")

                if test_metrics['f1'] > best_f1:
                    best_f1 = test_metrics['f1']
                    patience_counter = 0
                    torch.save(model.state_dict(),
                             f'{DRIVE_PATH}best_{model_name}_fold{fold}.pt')
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

            model.load_state_dict(torch.load(f'{DRIVE_PATH}best_{model_name}_fold{fold}.pt'))
            final_metrics = evaluate_advanced(model, test_loader, device)

            for key in ['precision', 'recall', 'f1', 'accuracy']:
                fold_results[model_name][key].append(final_metrics[key])

            fold_results[model_name]['predictions'].append(final_metrics['predictions'])
            fold_results[model_name]['true_labels'].append(final_metrics['true_labels'])
            fold_results[model_name]['probabilities'].append(final_metrics['probabilities'])

            print(f"✅ Fold {fold} Test F1: {final_metrics['f1']:.4f}\n")

    # Calculate average results
    print(f"\n{'='*80}")
    print("📊 FINAL RESULTS (2-Fold Cross-Validation)")
    print(f"{'='*80}\n")

    avg_results = {}
    for model_name in models_config.keys():
        avg_results[model_name] = {
            'precision': np.mean(fold_results[model_name]['precision']),
            'recall': np.mean(fold_results[model_name]['recall']),
            'f1': np.mean(fold_results[model_name]['f1']),
            'accuracy': np.mean(fold_results[model_name]['accuracy']),
            'f1_std': np.std(fold_results[model_name]['f1']),
            'predictions': np.concatenate(fold_results[model_name]['predictions']),
            'true_labels': np.concatenate(fold_results[model_name]['true_labels']),
            'probabilities': np.concatenate(fold_results[model_name]['probabilities'])
        }

        print(f"{model_name:15} | F1: {avg_results[model_name]['f1']:.4f} ± {avg_results[model_name]['f1_std']:.4f} | "
              f"P: {avg_results[model_name]['precision']:.4f} | "
              f"R: {avg_results[model_name]['recall']:.4f} | "
              f"Acc: {avg_results[model_name]['accuracy']:.4f}")

    # Generate visualizations
    print(f"\n{'='*80}")
    print("📊 Generating Visualizations...")
    print(f"{'='*80}\n")
    create_visualizations(avg_results, DRIVE_PATH)

    # Save detailed results
    results_df = pd.DataFrame({
        'Model': list(avg_results.keys()),
        'F1-Score': [avg_results[m]['f1'] for m in avg_results],
        'Precision': [avg_results[m]['precision'] for m in avg_results],
        'Recall': [avg_results[m]['recall'] for m in avg_results],
        'Accuracy': [avg_results[m]['accuracy'] for m in avg_results],
        'F1-Std': [avg_results[m]['f1_std'] for m in avg_results]
    })
    results_df.to_csv(f'{DRIVE_PATH}final_results.csv', index=False)
    print(f"✅ Results saved to: {DRIVE_PATH}final_results.csv")

    # Classification reports
    print(f"\n{'='*80}")
    print("📋 DETAILED CLASSIFICATION REPORTS")
    print(f"{'='*80}\n")

    for model_name in models_config.keys():
        print(f"\n{model_name}:")
        print("="*60)
        print(classification_report(
            avg_results[model_name]['true_labels'],
            avg_results[model_name]['predictions'],
            target_names=['Non-Offensive', 'Offensive'],
            digits=4
        ))

    print(f"\n{'='*80}")
    print("✅ TRAINING COMPLETE!")
    print(f"All results saved to: {DRIVE_PATH}")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    main()

Mounted at /content/drive
✅ Google Drive mounted: /content/drive/MyDrive/AbjadNLP2026/

PRODUCTION-READY HAUSA AJAMI HATE SPEECH DETECTION
Target: Baseline >88% F1 | Proposed Model >98% F1

✅ Loaded: (2000, 8)

📊 Original Distribution:
label_offensive
0    1322
1     678
Name: count, dtype: int64

✅ Augmented dataset: 3048 samples
label_offensive
1    1730
0    1318
Name: count, dtype: int64

🔄 Converting to Ajami...

🖥️  Device: cuda

🔄 Starting 2-Fold Cross-Validation...


FOLD 1/5

🚀 Training: mBERT


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.5383 | Test F1: 0.5693
Epoch 02 | Train F1: 0.5661 | Test F1: 0.6168
Epoch 03 | Train F1: 0.5807 | Test F1: 0.4909
Epoch 04 | Train F1: 0.5708 | Test F1: 0.5992
Epoch 05 | Train F1: 0.5453 | Test F1: 0.3019
Early stopping at epoch 5
✅ Fold 1 Test F1: 0.6168


🚀 Training: XLM-R


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.4717 | Test F1: 0.3051
Epoch 02 | Train F1: 0.5208 | Test F1: 0.4707
Epoch 03 | Train F1: 0.5724 | Test F1: 0.4776
Epoch 04 | Train F1: 0.6043 | Test F1: 0.6183
Epoch 05 | Train F1: 0.5977 | Test F1: 0.6182
Epoch 06 | Train F1: 0.6411 | Test F1: 0.6473
Epoch 07 | Train F1: 0.6394 | Test F1: 0.6268
Epoch 08 | Train F1: 0.6525 | Test F1: 0.6169
Epoch 09 | Train F1: 0.6633 | Test F1: 0.6800
Epoch 10 | Train F1: 0.7081 | Test F1: 0.6532
Epoch 11 | Train F1: 0.7590 | Test F1: 0.7148
Epoch 12 | Train F1: 0.8036 | Test F1: 0.7052
Epoch 13 | Train F1: 0.8043 | Test F1: 0.6694
Epoch 14 | Train F1: 0.8318 | Test F1: 0.7339
Epoch 15 | Train F1: 0.8508 | Test F1: 0.6956
✅ Fold 1 Test F1: 0.7339


🚀 Training: AraBERT


tokenizer_config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]

Epoch 01 | Train F1: 0.5256 | Test F1: 0.5807
Epoch 02 | Train F1: 0.5648 | Test F1: 0.5865
Epoch 03 | Train F1: 0.5573 | Test F1: 0.5879
Epoch 04 | Train F1: 0.5541 | Test F1: 0.5828
Epoch 05 | Train F1: 0.5387 | Test F1: 0.5821
Epoch 06 | Train F1: 0.5734 | Test F1: 0.5760
Early stopping at epoch 6
✅ Fold 1 Test F1: 0.5879


🚀 Training: HACS-TL
Epoch 01 | Train F1: 0.5225 | Test F1: 0.5721
Epoch 02 | Train F1: 0.5731 | Test F1: 0.5512
Epoch 03 | Train F1: 0.5991 | Test F1: 0.6261
Epoch 04 | Train F1: 0.6212 | Test F1: 0.6362
Epoch 05 | Train F1: 0.6242 | Test F1: 0.6261
Epoch 06 | Train F1: 0.6672 | Test F1: 0.6364
Epoch 07 | Train F1: 0.7023 | Test F1: 0.5897
Epoch 08 | Train F1: 0.7572 | Test F1: 0.6765
Epoch 09 | Train F1: 0.7630 | Test F1: 0.6917
Epoch 10 | Train F1: 0.8104 | Test F1: 0.7009
Epoch 11 | Train F1: 0.8679 | Test F1: 0.6686
Epoch 12 | Train F1: 0.8772 | Test F1: 0.6722
Epoch 13 | Train F1: 0.9114 | Test F1: 0.7179
Epoch 14 | Train F1: 0.9210 | Test F1: 0.7352
Epoch 1