In [None]:
! pip install indic_transliteration

In [None]:
"""
FINAL: GRU Model with K-Fold Cross-Validation + Integrated Preprocessing
IMPROVEMENTS:
- Integrated preprocessing pipeline ✓
- Per-fold training plots ✓
- Confusion matrices for validation and test ✓
- Complete visualization suite ✓
"""
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from gensim.models import Word2Vec
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import random
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# PREPROCESSING UTILITIES (from preprocessing.py)
# ============================================================================
import re
import emoji
import regex

try:
    from indic_transliteration import sanscript
    from indic_transliteration.sanscript import transliterate
    TRANSLITERATION_AVAILABLE = True
except ImportError:
    TRANSLITERATION_AVAILABLE = False
    print("Warning: indic_transliteration not available. Romanization disabled.")

# Nepali stopwords (Devanagari)
NEPALI_STOPWORDS = set([
    "र", "मा", "कि", "भने", "त", "छ", "हो", "लाई", "ले",
    "गरेको", "गर्छ", "गर्छन्", "हुन्", "गरे", "न", "नभएको",
    "को", "का", "की", "ने", "पनि", "नै", "थियो", "थिए"
])

# Dirghikaran normalization
DIRGHIKARAN_MAP = {
    "उ": "ऊ", "इ": "ई", "ऋ": "रि", "ए": "ऐ", "अ": "आ",
    "\u200d": "", "\u200c": "",
    "।": ".", "॥": ".",
    "ि": "ी", "ु": "ू"
}

_roman_stopwords_cache = None


def is_devanagari(text: str) -> bool:
    """Return True if more than 50% of letters are Devanagari."""
    if not isinstance(text, str) or not text.strip():
        return False

    dev_chars = len(regex.findall(r'\p{Devanagari}', text))
    total_chars = len(regex.findall(r'\p{L}', text))
    return total_chars > 0 and (dev_chars / total_chars) > 0.5


def devanagari_to_roman(text: str) -> str:
    """Convert Devanagari text to Roman (ITRANS)."""
    if not TRANSLITERATION_AVAILABLE:
        return text
    try:
        return transliterate(text, sanscript.DEVANAGARI, sanscript.ITRANS)
    except Exception:
        return text


def normalize_dirghikaran(text: str) -> str:
    """Normalize orthographic variants in Devanagari."""
    for src, tgt in DIRGHIKARAN_MAP.items():
        text = text.replace(src, tgt)
    return text


def clean_text(text: str) -> str:
    """
    Aggressive cleaning for ML/GRU:
    - lowercase
    - remove URLs, mentions, hashtags
    - remove emojis
    - remove digits and punctuation
    """
    if not isinstance(text, str):
        return ""

    text = text.lower()
    text = re.sub(r"http\S+|www\S+|https\S+", "", text)
    text = re.sub(r"@\w+|#\w+", "", text)
    text = emoji.replace_emoji(text, replace="")
    text = re.sub(r"\d+", "", text)
    text = re.sub(r"[^\w\s\u0900-\u097F]", "", text)
    text = re.sub(r"\s+", " ", text).strip()

    return text


def remove_stopwords_devanagari(text: str) -> str:
    words = text.split()
    return " ".join(w for w in words if w not in NEPALI_STOPWORDS)


def remove_stopwords_roman(text: str) -> str:
    global _roman_stopwords_cache

    if _roman_stopwords_cache is None:
        _roman_stopwords_cache = {
            devanagari_to_roman(w) for w in NEPALI_STOPWORDS
        }

    words = text.split()
    return " ".join(w for w in words if w not in _roman_stopwords_cache)


def preprocess_for_ml_gru(text: str) -> str:
    """
    ML / GRU preprocessing pipeline:
    1. Aggressive cleaning
    2. Dirghikaran normalization (if Devanagari)
    3. Stopword removal
    4. Romanization
    """
    if not isinstance(text, str):
        return ""

    text = clean_text(text)

    if is_devanagari(text):
        text = normalize_dirghikaran(text)
        text = remove_stopwords_devanagari(text)
        text = devanagari_to_roman(text)
    else:
        text = remove_stopwords_roman(text)

    return text


def batch_preprocess(texts):
    """Batch preprocessing for ML / GRU."""
    return [preprocess_for_ml_gru(t) for t in texts]


def apply_preprocessing_to_dataframe(df, text_column=None):
    """
    Apply preprocessing pipeline to a dataframe and create tokens column.
    
    Args:
        df: DataFrame with text data
        text_column: Name of the column containing text (auto-detected if None)
    
    Returns:
        DataFrame with 'processed_text' and 'tokens' columns
    """
    print(f"\nApplying preprocessing to {len(df)} samples...")
    
    # Auto-detect text column if not specified
    if text_column is None:
        # Check common column names
        possible_columns = ['text', 'Text', 'comment', 'Comment', 'content', 'Content', 
                          'tweet', 'Tweet', 'post', 'Post', 'message', 'Message']
        
        for col in possible_columns:
            if col in df.columns:
                text_column = col
                print(f"Auto-detected text column: '{text_column}'")
                break
        
        # If still not found, use the first column that looks like text
        if text_column is None:
            for col in df.columns:
                if df[col].dtype == 'object' and col.lower() not in ['label', 'class', 'category', 'id']:
                    text_column = col
                    print(f"Using column: '{text_column}'")
                    break
        
        if text_column is None:
            raise ValueError(f"Could not find text column. Available columns: {df.columns.tolist()}")
    
    # Verify column exists
    if text_column not in df.columns:
        raise ValueError(f"Column '{text_column}' not found. Available columns: {df.columns.tolist()}")
    
    # Apply preprocessing
    df['processed_text'] = df[text_column].apply(preprocess_for_ml_gru)
    
    # Tokenize (simple whitespace tokenization)
    df['tokens'] = df['processed_text'].apply(lambda x: x.split() if x else [])
    
    # Remove empty samples
    original_len = len(df)
    df = df[df['tokens'].apply(len) > 0].reset_index(drop=True)
    
    if len(df) < original_len:
        print(f"Removed {original_len - len(df)} empty samples after preprocessing")
    
    print(f"✓ Preprocessing complete. Final dataset size: {len(df)}")
    
    return df

# ============================================================================
# DATASET DEFINITION
# ============================================================================
class HateSpeechDataset(Dataset):
    def __init__(self, input_ids, labels, augment=False):
        self.input_ids = input_ids
        self.labels = labels
        self.augment = augment
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        input_ids = self.input_ids[idx].copy()
        if self.augment and random.random() < 0.15:
            mask = np.random.random(len(input_ids)) > 0.1
            input_ids = [t if m else 0 for t, m in zip(input_ids, mask)]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# ============================================================================
# GRU CLASSIFIER
# ============================================================================
class OptimizedGRUClassifier(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim=96, output_dim=4, dropout=0.5):
        super(OptimizedGRUClassifier, self).__init__()
        num_embeddings, embedding_dim = embedding_matrix.shape
        
        self.embedding = nn.Embedding.from_pretrained(
            torch.FloatTensor(embedding_matrix), freeze=False
        )
        self.embedding_dropout = nn.Dropout(0.3)
        
        self.gru = nn.GRU(
            embedding_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        embedded = self.embedding_dropout(self.embedding(x))
        _, hidden = self.gru(embedded)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        hidden = self.dropout(hidden)
        out = self.fc1(hidden)
        out = self.relu(out)
        out = self.dropout2(out)
        logits = self.fc2(out)
        return logits

# ============================================================================
# UTILITIES
# ============================================================================
def encode_and_pad(tokens, word2idx, max_len=40):
    indices = [word2idx.get(tok, 0) for tok in tokens[:max_len]]
    if len(indices) < max_len:
        indices += [0] * (max_len - len(indices))
    return indices

def train_epoch(model, dataloader, optimizer, device, class_weights):
    model.train()
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    total_loss, all_preds, all_labels = 0, [], []
    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / len(dataloader)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    return avg_loss, f1

def evaluate(model, dataloader, device, class_weights):
    model.eval()
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    total_loss, all_preds, all_labels = 0, [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            logits = model(input_ids)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / len(dataloader)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    return avg_loss, f1, all_preds, all_labels

# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================
def plot_confusion_matrix_custom(y_true, y_pred, labels, save_path, title="Confusion Matrix"):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Proportion'})
    plt.title(title, fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Confusion matrix saved: {os.path.basename(save_path)}")


def plot_fold_history(history, fold_num, save_dir):
    """Plot training curves for a single fold"""
    os.makedirs(save_dir, exist_ok=True)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_loss'], label='Train Loss', 
                linewidth=2, color='#1f77b4', marker='o', markersize=3)
    axes[0].plot(epochs, history['val_loss'], label='Val Loss', 
                linewidth=2, color='#ff7f0e', marker='s', markersize=3)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title(f'Fold {fold_num + 1} - Training and Validation Loss', 
                     fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # F1 plot
    axes[1].plot(epochs, history['train_f1'], label='Train F1', 
                linewidth=2, color='#1f77b4', marker='o', markersize=3)
    axes[1].plot(epochs, history['val_f1'], label='Val F1', 
                linewidth=2, color='#ff7f0e', marker='s', markersize=3)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('F1 Score', fontsize=12)
    axes[1].set_title(f'Fold {fold_num + 1} - Training and Validation F1', 
                     fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_path = os.path.join(save_dir, f'fold_{fold_num + 1}_history.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Fold {fold_num + 1} training curves saved")


def plot_cv_summary(fold_scores, all_histories, save_dir):
    """Plots K-Fold summary charts"""
    os.makedirs(save_dir, exist_ok=True)
    
    # Bar chart of fold scores
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(1, len(fold_scores) + 1), fold_scores, color="#1f77b4", edgecolor='navy')
    plt.axhline(np.mean(fold_scores), color="red", linestyle="--", linewidth=2,
                label=f"Mean: {np.mean(fold_scores):.4f}")
    
    # Add value labels on bars
    for i, (bar, score) in enumerate(zip(bars, fold_scores)):
        plt.text(bar.get_x() + bar.get_width()/2, score + 0.01, 
                f'{score:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    plt.xlabel("Fold", fontsize=12)
    plt.ylabel("Validation F1", fontsize=12)
    plt.title("Cross-Validation: Best F1 Score per Fold", fontsize=14, fontweight="bold")
    plt.legend(fontsize=11)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "cv_f1_summary.png"), dpi=300)
    plt.close()
    print("✓ CV summary (F1 bar chart) saved")

    # Mean curves across folds
    max_epochs = max(len(h["train_loss"]) for h in all_histories)
    
    def pad_history(hist_list, key):
        padded = []
        for h in hist_list:
            arr = np.array(h[key])
            if len(arr) < max_epochs:
                arr = np.pad(arr, (0, max_epochs - len(arr)), constant_values=np.nan)
            padded.append(arr)
        return np.array(padded)
    
    mean_train_loss = np.nanmean(pad_history(all_histories, "train_loss"), axis=0)
    mean_val_loss = np.nanmean(pad_history(all_histories, "val_loss"), axis=0)
    mean_train_f1 = np.nanmean(pad_history(all_histories, "train_f1"), axis=0)
    mean_val_f1 = np.nanmean(pad_history(all_histories, "val_f1"), axis=0)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, max_epochs + 1)
    
    # Average loss
    axes[0].plot(epochs, mean_train_loss, label="Train Loss", linewidth=2, color='#1f77b4')
    axes[0].plot(epochs, mean_val_loss, label="Val Loss", linewidth=2, color='#ff7f0e')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title("Average Train/Val Loss Across Folds", fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(alpha=0.3)
    
    # Average F1
    axes[1].plot(epochs, mean_train_f1, label="Train F1", linewidth=2, color='#1f77b4')
    axes[1].plot(epochs, mean_val_f1, label="Val F1", linewidth=2, color='#ff7f0e')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('F1 Score', fontsize=12)
    axes[1].set_title("Average Train/Val F1 Across Folds", fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "cv_training_summary.png"), dpi=300)
    plt.close()
    print("✓ CV training summary (average curves) saved")


def plot_final_training_history(final_history, save_dir):
    """Plot final model train/val loss and F1 curves."""
    os.makedirs(save_dir, exist_ok=True)
    epochs = range(1, len(final_history['train_loss']) + 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss
    axes[0].plot(epochs, final_history['train_loss'], label="Train Loss", 
                linewidth=2, color='#1f77b4', marker='o', markersize=4)
    axes[0].plot(epochs, final_history['val_loss'], label="Val Loss", 
                linewidth=2, color='#ff7f0e', marker='s', markersize=4, linestyle='--')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title("Final Model - Training and Validation Loss", fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(alpha=0.3)
    
    # F1
    axes[1].plot(epochs, final_history['train_f1'], label="Train F1", 
                linewidth=2, color='#2ca02c', marker='o', markersize=4)
    axes[1].plot(epochs, final_history['val_f1'], label="Val F1", 
                linewidth=2, color='#d62728', marker='s', markersize=4, linestyle='--')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('F1 Score', fontsize=12)
    axes[1].set_title("Final Model - Training and Validation F1", fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "gru_final_training_history.png"), dpi=300)
    plt.close()
    print("✓ Final model training curves saved")

# ============================================================================
# K-FOLD TRAINING FUNCTION
# ============================================================================
def train_single_fold(train_idx, val_idx, train_df, embedding_matrix, vocab, le, 
                     device, fold_num, save_dir):
    """Train one CV fold"""
    print(f"\n{'='*60}")
    print(f" Fold {fold_num + 1}/5")
    print(f"{'='*60}")
    
    fold_train_df = train_df.iloc[train_idx].copy()
    fold_val_df = train_df.iloc[val_idx].copy()
    
    y_train = le.transform(fold_train_df['Label_Multiclass'])
    y_val = le.transform(fold_val_df['Label_Multiclass'])
    
    train_ds = HateSpeechDataset(fold_train_df['input_ids'].tolist(), y_train, augment=True)
    val_ds = HateSpeechDataset(fold_val_df['input_ids'].tolist(), y_val)
    
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=128)
    
    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = np.clip(class_weights, 0.5, 4.0)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    model = OptimizedGRUClassifier(embedding_matrix, hidden_dim=96, 
                                   output_dim=len(le.classes_), dropout=0.5).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
    
    best_val_f1, patience, patience_counter = 0, 5, 0
    hist = {'train_loss': [], 'val_loss': [], 'train_f1': [], 'val_f1': []}

    for epoch in range(50):
        tr_loss, tr_f1 = train_epoch(model, train_loader, optimizer, device, class_weights)
        val_loss, val_f1, val_preds, val_labels = evaluate(model, val_loader, device, class_weights)
        scheduler.step()
        
        hist['train_loss'].append(tr_loss)
        hist['val_loss'].append(val_loss)
        hist['train_f1'].append(tr_f1)
        hist['val_f1'].append(val_f1)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:2d} | Train F1: {tr_f1:.4f} | Val F1: {val_f1:.4f} | Gap: {tr_f1-val_f1:.4f}")
        
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            best_val_preds = val_preds
            best_val_labels = val_labels
            torch.save({'model_state_dict': model.state_dict()}, 
                      os.path.join(save_dir, f'gru_fold_{fold_num}.pt'))
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break
    
    print(f"Fold {fold_num + 1} Best Val F1: {best_val_f1:.4f}")
    
    # Plot fold history
    plot_fold_history(hist, fold_num, save_dir)
    
    # Plot fold confusion matrix
    cm_path = os.path.join(save_dir, f'fold_{fold_num + 1}_confusion_matrix.png')
    plot_confusion_matrix_custom(best_val_labels, best_val_preds, le.classes_, 
                                cm_path, title=f"Fold {fold_num + 1} - Validation Confusion Matrix")
    
    return best_val_f1, hist

# ============================================================================
# MAIN CROSS-VALIDATION TRAINING
# ============================================================================
def train_with_cross_validation(train_df, val_df, test_df, n_splits=5, 
                                text_column=None, save_dir='models/saved_models'):
    """
    Main training function with integrated preprocessing
    
    Args:
        train_df: Training dataframe
        val_df: Validation dataframe
        test_df: Test dataframe
        n_splits: Number of CV folds
        text_column: Name of column containing text data (auto-detected if None)
        save_dir: Directory to save models and plots
    """
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n{'='*70}")
    print(" GRU MODEL WITH 5-FOLD CROSS-VALIDATION + PREPROCESSING")
    print(f"{'='*70}")
    print(f"Using device: {device}\n")

    # ---- PREPROCESSING STEP ----
    print("="*70)
    print(" STEP 1: PREPROCESSING")
    print("="*70)
    
    # Show available columns
    print(f"Available columns in train_df: {train_df.columns.tolist()}")
    
    # Apply preprocessing to all datasets
    train_df = apply_preprocessing_to_dataframe(train_df, text_column)
    val_df = apply_preprocessing_to_dataframe(val_df, text_column)
    test_df = apply_preprocessing_to_dataframe(test_df, text_column)
    
    print(f"\nFinal dataset sizes:")
    print(f"  Train: {len(train_df)}")
    print(f"  Val:   {len(val_df)}")
    print(f"  Test:  {len(test_df)}")

    # ---- WORD2VEC EMBEDDINGS ----
    print("\n" + "="*70)
    print(" STEP 2: BUILDING WORD2VEC EMBEDDINGS")
    print("="*70)
    
    all_tokens = train_df['tokens'].tolist()
    print(f"Training Word2Vec on {len(all_tokens)} samples...")
    
    w2v = Word2Vec(sentences=all_tokens, vector_size=100, window=5, 
                   min_count=1, workers=4, epochs=10, sg=1)
    
    vocab = {w: i+1 for i, w in enumerate(w2v.wv.index_to_key)}
    vocab['<PAD>'] = 0
    
    emb_matrix = np.zeros((len(vocab), 100))
    for w, i in vocab.items():
        if w in w2v.wv:
            emb_matrix[i] = w2v.wv[w]
    
    print(f"✓ Vocabulary size: {len(vocab)}")
    print(f"✓ Embedding matrix shape: {emb_matrix.shape}")

    # ---- ENCODE SEQUENCES ----
    print("\nEncoding sequences...")
    for df in [train_df, val_df, test_df]:
        df['input_ids'] = df['tokens'].apply(lambda x: encode_and_pad(x, vocab, 40))
    print("✓ Sequences encoded")

    # ---- LABEL ENCODING ----
    le = LabelEncoder()
    le.fit(train_df['Label_Multiclass'])
    print(f"\n✓ Classes: {le.classes_}\n")

    # ---- K-FOLD CROSS-VALIDATION ----
    print("="*70)
    print(" STEP 3: 5-FOLD CROSS-VALIDATION")
    print("="*70)
    
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_scores, histories = [], []
    
    for fold, (tr_idx, v_idx) in enumerate(skf.split(train_df, train_df['Label_Multiclass'])):
        best_f1, hist = train_single_fold(tr_idx, v_idx, train_df, emb_matrix, vocab, 
                                          le, device, fold, save_dir)
        fold_scores.append(best_f1)
        histories.append(hist)
    
    # CV Summary
    print("\n" + "="*70)
    print(" CROSS-VALIDATION RESULTS")
    print("="*70)
    for i, score in enumerate(fold_scores):
        print(f"Fold {i+1}: {score:.4f}")
    
    mean_f1, std_f1 = np.mean(fold_scores), np.std(fold_scores)
    print(f"\nMean Val F1: {mean_f1:.4f} ± {std_f1:.4f}")
    print("="*70 + "\n")
    
    plot_cv_summary(fold_scores, histories, save_dir)

    # ---- FINAL TRAINING WITH VALIDATION MONITORING ----
    print("="*70)
    print(" STEP 4: TRAINING FINAL MODEL")
    print("="*70 + "\n")
    
    tr_full, val_hold = train_test_split(train_df, test_size=0.1, 
                                         stratify=train_df['Label_Multiclass'], 
                                         random_state=42)
    y_tr = le.transform(tr_full['Label_Multiclass'])
    y_val_hold = le.transform(val_hold['Label_Multiclass'])
    
    tr_loader = DataLoader(HateSpeechDataset(tr_full['input_ids'].tolist(), y_tr, augment=True), 
                          batch_size=64, shuffle=True)
    val_hold_loader = DataLoader(HateSpeechDataset(val_hold['input_ids'].tolist(), y_val_hold), 
                                 batch_size=64)
    
    weights = compute_class_weight('balanced', classes=np.unique(y_tr), y=y_tr)
    weights = torch.tensor(weights, dtype=torch.float).to(device)

    final_model = OptimizedGRUClassifier(emb_matrix, output_dim=len(le.classes_)).to(device)
    opt = optim.AdamW(final_model.parameters(), lr=1e-4, weight_decay=1e-3)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', patience=2, factor=0.5)
    
    best_val, patience, counter = 0, 3, 0
    final_hist = {'train_loss': [], 'train_f1': [], 'val_loss': [], 'val_f1': []}

    for ep in range(15):
        tr_loss, tr_f1 = train_epoch(final_model, tr_loader, opt, device, weights)
        val_loss, val_f1, val_preds, val_labels = evaluate(final_model, val_hold_loader, device, weights)
        sched.step(val_loss)
        
        final_hist['train_loss'].append(tr_loss)
        final_hist['train_f1'].append(tr_f1)
        final_hist['val_loss'].append(val_loss)
        final_hist['val_f1'].append(val_f1)
        
        if (ep + 1) % 5 == 0:
            print(f"Epoch {ep+1:2d} | Train F1: {tr_f1:.4f} | Val F1: {val_f1:.4f} | Gap: {tr_f1-val_f1:.4f}")
        
        if val_f1 > best_val:
            best_val = val_f1
            counter = 0
            best_val_preds = val_preds
            best_val_labels = val_labels
            torch.save({'model_state_dict': final_model.state_dict(), 'vocab': vocab, 
                       'label_encoder': le}, os.path.join(save_dir, 'gru_final_model.pt'))
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {ep + 1}")
                break

    # Final model plots
    plot_final_training_history(final_hist, save_dir)
    
    # Validation confusion matrix
    val_cm_path = os.path.join(save_dir, 'final_validation_confusion_matrix.png')
    plot_confusion_matrix_custom(best_val_labels, best_val_preds, le.classes_,
                                val_cm_path, title="Final Model - Validation Confusion Matrix")
    
    # ---- TEST SET EVALUATION ----
    print("\n" + "="*70)
    print(" STEP 5: EVALUATING ON TEST SET")
    print("="*70 + "\n")
    
    y_test = le.transform(test_df['Label_Multiclass'])
    test_loader = DataLoader(HateSpeechDataset(test_df['input_ids'].tolist(), y_test), 
                            batch_size=64)
    
    test_loss, test_f1, test_preds, test_labels = evaluate(final_model, test_loader, device, weights)
    
    # Test confusion matrix
    test_cm_path = os.path.join(save_dir, 'final_test_confusion_matrix.png')
    plot_confusion_matrix_custom(test_labels, test_preds, le.classes_,
                                test_cm_path, title="Final Model - Test Set Confusion Matrix")
    
    # ---- FINAL SUMMARY ----
    print("\n" + "="*70)
    print(" FINAL SUMMARY")
    print("="*70)
    print(f"Cross-Validation F1: {mean_f1:.4f} ± {std_f1:.4f}")
    print(f"Final Validation F1: {best_val:.4f}")
    print(f"Test Set F1:         {test_f1:.4f}")
    print(f"\nGeneralization: {'✓ Good' if abs(mean_f1 - test_f1) < 0.05 else '⚠ Check variance'}")
    print("="*70 + "\n")
    
    print("✓ All visualizations saved to:", save_dir)
    print("  - fold_X_history.png (5 files)")
    print("  - fold_X_confusion_matrix.png (5 files)")
    print("  - cv_f1_summary.png")
    print("  - cv_training_summary.png")
    print("  - gru_final_training_history.png")
    print("  - final_validation_confusion_matrix.png")
    print("  - final_test_confusion_matrix.png")

# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    # Adjust paths as needed
    train_path = "/kaggle/input/nepalihatee/train_final.json"
    val_path = "/kaggle/input/nepalihatee/val_final.json"
    test_path = "/kaggle/input/nepalihatee/test.json"

    train_df = pd.read_json(train_path)
    val_df = pd.read_json(val_path)
    test_df = pd.read_json(test_path)

    print(f"Data loaded: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")
    
    # Run training with preprocessing
    # text_column will be auto-detected if set to None
    train_with_cross_validation(
        train_df, val_df, test_df, 
        n_splits=5,
        text_column=None,  # Auto-detect the text column
        save_dir='models/saved_models'
    )