# 2-Stage MoE with MuRIL

**Model**: `google/muril-base-cased` - Best for Indian languages

```
Input → Stage 1 (toxic/safe) → Stage 2 (category)
```

**Stage 1**: Binary classifier with HIGH RECALL
**Stage 2**: Category classifier for toxic content

In [None]:
!pip install transformers datasets accelerate scikit-learn sentencepiece -q

In [None]:
from google.colab import files
import os

print('Upload these files from data/datasets/moe/:')
print('  - router_train_v2.jsonl')
print()
uploaded = files.upload()

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings('ignore')

print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'CUDA: {torch.cuda.is_available()}')

os.makedirs('moe_muril', exist_ok=True)

# MuRIL - Best for Indian languages
MODEL_NAME = 'google/muril-base-cased'
print(f'Model: {MODEL_NAME}')

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f'Tokenizer loaded: {MODEL_NAME}')

---
## Stage 1: Binary Classifier (toxic vs safe)
**Goal: 95%+ recall** - catch ALL toxic content

In [None]:
# Load all data
router_data = []
with open('router_train_v2.jsonl', 'r') as f:
    for line in f:
        router_data.append(json.loads(line))

# Convert to binary
binary_data = []
for d in router_data:
    binary_data.append({
        'text': d['text'],
        'label': 0 if d['label'] == 'safe' else 1  # 0=safe, 1=toxic
    })

safe_count = sum(1 for d in binary_data if d['label'] == 0)
toxic_count = sum(1 for d in binary_data if d['label'] == 1)

print(f'Total: {len(binary_data)}')
print(f'Safe: {safe_count} ({100*safe_count/len(binary_data):.1f}%)')
print(f'Toxic: {toxic_count} ({100*toxic_count/len(binary_data):.1f}%)')

In [None]:
# Prepare binary data
texts = [d['text'] for d in binary_data]
labels = [d['label'] for d in binary_data]

train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.1, random_state=42, stratify=labels
)

print(f'Train: {len(train_texts)}')
print(f'Val: {len(val_texts)}')

In [None]:
# Tokenize function
def tokenize_data(texts, labels):
    encodings = tokenizer(
        texts, 
        truncation=True, 
        padding=True, 
        max_length=256,
        return_tensors=None
    )
    return Dataset.from_dict({
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'labels': labels
    })

print('Tokenizing Stage 1 data...')
train_dataset = tokenize_data(train_texts, train_labels)
val_dataset = tokenize_data(val_texts, val_labels)
print(f'Done! Train: {len(train_dataset)}, Val: {len(val_dataset)}')

In [None]:
# Load binary model
binary_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: 'safe', 1: 'toxic'},
    label2id={'safe': 0, 'toxic': 1}
)
print('Binary model loaded!')

In [None]:
# Class weights - BOOST TOXIC for high recall
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=train_labels)
class_weights[1] *= 2.0  # Double toxic weight
print(f'Class weights: safe={class_weights[0]:.3f}, toxic={class_weights[1]:.3f}')

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to('cuda')

class HighRecallTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics_binary(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    return {
        'accuracy': accuracy_score(labels, preds), 
        'f1': f1, 
        'precision': precision, 
        'recall': recall
    }

In [None]:
# Train Stage 1
binary_args = TrainingArguments(
    output_dir='./stage1_checkpoints',
    num_train_epochs=3,
    per_device_train_batch_size=16,  # Smaller batch for MuRIL
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=500,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='recall',  # Optimize for RECALL
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,  # Effective batch size = 32
)

binary_trainer = HighRecallTrainer(
    model=binary_model,
    args=binary_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics_binary,
)

print('='*60)
print('TRAINING STAGE 1: Binary (toxic vs safe)')
print('='*60)
binary_trainer.train()

In [None]:
# Evaluate Stage 1
results1 = binary_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 1 RESULTS')
print(f'{"="*60}')
print(f'Accuracy:  {results1["eval_accuracy"]:.4f}')
print(f'F1:        {results1["eval_f1"]:.4f}')
print(f'Precision: {results1["eval_precision"]:.4f}')
print(f'Recall:    {results1["eval_recall"]:.4f} {"✅ GOOD" if results1["eval_recall"] >= 0.93 else "⚠️ LOW"}')
print(f'{"="*60}')

In [None]:
# Save Stage 1
binary_model.save_pretrained('moe_muril/stage1_binary')
tokenizer.save_pretrained('moe_muril/stage1_binary')
print('✅ Stage 1 saved to moe_muril/stage1_binary')

---
## Stage 2: Category Classifier
Classifies toxic content into categories

In [None]:
# Prepare category data (toxic only)
category_data = [d for d in router_data if d['label'] != 'safe']

print(f'Toxic examples: {len(category_data)}')

# Get categories
categories = sorted(set(d['label'] for d in category_data))
label2id_cat = {cat: i for i, cat in enumerate(categories)}
id2label_cat = {i: cat for i, cat in enumerate(categories)}

print(f'\nCategories ({len(categories)}):')
from collections import Counter
for cat, count in Counter(d['label'] for d in category_data).most_common():
    print(f'  {cat}: {count}')

In [None]:
# Prepare Stage 2 data
cat_texts = [d['text'] for d in category_data]
cat_labels = [label2id_cat[d['label']] for d in category_data]

train_cat_texts, val_cat_texts, train_cat_labels, val_cat_labels = train_test_split(
    cat_texts, cat_labels, test_size=0.1, random_state=42, stratify=cat_labels
)

print(f'Train: {len(train_cat_texts)}')
print(f'Val: {len(val_cat_texts)}')

In [None]:
# Tokenize Stage 2
print('Tokenizing Stage 2 data...')
train_cat_dataset = tokenize_data(train_cat_texts, train_cat_labels)
val_cat_dataset = tokenize_data(val_cat_texts, val_cat_labels)
print('Done!')

In [None]:
# Load category model (fresh)
category_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(categories),
    id2label=id2label_cat,
    label2id=label2id_cat
)
print(f'Category model loaded! ({len(categories)} classes)')

In [None]:
# Category class weights
cat_weights = compute_class_weight('balanced', classes=np.unique(train_cat_labels), y=train_cat_labels)
print('Category weights:')
for i, cat in id2label_cat.items():
    print(f'  {cat}: {cat_weights[i]:.3f}')

cat_weights_tensor = torch.tensor(cat_weights, dtype=torch.float32).to('cuda')

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = nn.CrossEntropyLoss(weight=cat_weights_tensor)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics_cat(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
# Train Stage 2
cat_args = TrainingArguments(
    output_dir='./stage2_checkpoints',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=300,
    weight_decay=0.01,
    logging_steps=300,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,
)

category_trainer = WeightedTrainer(
    model=category_model,
    args=cat_args,
    train_dataset=train_cat_dataset,
    eval_dataset=val_cat_dataset,
    compute_metrics=compute_metrics_cat,
)

print('='*60)
print('TRAINING STAGE 2: Category Classifier')
print('='*60)
category_trainer.train()

In [None]:
# Evaluate Stage 2
results2 = category_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 2 RESULTS')
print(f'{"="*60}')
print(f'Accuracy:  {results2["eval_accuracy"]:.4f}')
print(f'F1:        {results2["eval_f1"]:.4f}')
print(f'Precision: {results2["eval_precision"]:.4f}')
print(f'Recall:    {results2["eval_recall"]:.4f}')

# Detailed report
preds = category_trainer.predict(val_cat_dataset)
pred_labels = preds.predictions.argmax(-1)
print('\nClassification Report:')
print(classification_report(val_cat_labels, pred_labels, target_names=categories))

In [None]:
# Save Stage 2
category_model.save_pretrained('moe_muril/stage2_category')
tokenizer.save_pretrained('moe_muril/stage2_category')
print('✅ Stage 2 saved to moe_muril/stage2_category')

---
## Test Complete Pipeline

In [None]:
from transformers import pipeline

# Load both stages
stage1 = pipeline('text-classification', model='moe_muril/stage1_binary', device=0)
stage2 = pipeline('text-classification', model='moe_muril/stage2_category', device=0)

def classify(text):
    # Stage 1
    r1 = stage1(text)[0]
    if r1['label'] == 'safe':
        return {'toxic': False, 'category': 'safe', 'conf': r1['score']}
    
    # Stage 2
    r2 = stage2(text)[0]
    return {
        'toxic': True, 
        'category': r2['label'], 
        'conf': r1['score'] * r2['score']
    }

print('Pipeline ready!')

In [None]:
# Critical tests
tests = [
    # Safe
    ('Hello, how are you?', 'safe', False),
    ('The weather is nice', 'safe', False),
    ('namaste kaise ho', 'safe', False),
    ('I went back to my country for vacation', 'safe', False),
    
    # Harmful
    ('I will kill you', 'harmful_content', True),
    ('how to make a bomb', 'harmful_content', True),
    ('how to make an explosive device', 'harmful_content', True),
    ('I will shoot you', 'harmful_content', True),
    
    # Hindi
    ('bhenchod', 'hindi_abuse', True),
    ('madarchod', 'hindi_abuse', True),
    ('chutiya', 'hindi_abuse', True),
    ('teri maa ki', 'hindi_abuse', True),
    ('tu bekaar insaan hai', 'hindi_abuse', True),
    
    # Harassment
    ('you are worthless garbage', 'harassment', True),
    ('you are so stupid', 'harassment', True),
    ('nobody likes you', 'harassment', True),
    ('your kind disgusts me', 'harassment', True),
    
    # Hate
    ('go back to your country', 'hate_speech', True),
    ('all immigrants should be deported', 'hate_speech', True),
    ('people like you should not exist', 'hate_speech', True),
    
    # Self-harm
    ('I want to kill myself', 'self_harm', True),
    ('I want to end my life', 'self_harm', True),
    ('kys', 'self_harm', True),
    
    # Sexual
    ('send nudes', 'sexual', True),
    ('I want to see you naked', 'sexual', True),
]

print('='*80)
print('CRITICAL TEST CASES')
print('='*80)

toxic_ok = 0
toxic_total = 0
cat_ok = 0
cat_total = 0

for text, exp_cat, exp_toxic in tests:
    r = classify(text)
    
    # Toxic detection
    if exp_toxic:
        toxic_total += 1
        if r['toxic']: toxic_ok += 1
    
    # Category
    if exp_toxic and r['toxic']:
        cat_total += 1
        if r['category'] == exp_cat: cat_ok += 1
    
    t_ok = '✅' if r['toxic'] == exp_toxic else '❌'
    c_ok = '✅' if r['category'] == exp_cat else '❌'
    
    print(f"{t_ok}{c_ok} {r['category']:<18} | exp: {exp_cat:<18} | {text[:35]}")

print()
print('='*80)
print(f'TOXIC DETECTION: {toxic_ok}/{toxic_total} ({100*toxic_ok/toxic_total:.0f}%) - Target: 95%+')
print(f'CATEGORY:        {cat_ok}/{cat_total} ({100*cat_ok/cat_total:.0f}%) - Target: 80%+')
print('='*80)

In [None]:
# Summary
print('\n' + '='*60)
print('FINAL SUMMARY')
print('='*60)
print(f'Model: {MODEL_NAME}')
print(f'\nStage 1 (Binary):')
print(f'  Recall: {results1["eval_recall"]:.2%}')
print(f'  F1:     {results1["eval_f1"]:.2%}')
print(f'\nStage 2 (Category):')
print(f'  F1:     {results2["eval_f1"]:.2%}')
print('='*60)

In [None]:
# Download
!zip -r moe_muril.zip moe_muril/
files.download('moe_muril.zip')