# 2-Stage MoE Architecture

```
Input → Stage 1 (toxic/safe) → Stage 2 (category) → Expert
```

**Stage 1**: Binary classifier with HIGH RECALL (catch everything)
**Stage 2**: Category classifier (only for flagged content)
**Experts**: Fine-tuned per category

In [None]:
!pip install transformers datasets accelerate scikit-learn -q

In [None]:
from google.colab import files
import os

print('Upload these files from data/datasets/moe/:')
print('  - router_train_v2.jsonl')
print('  - harmful_content.jsonl')
print('  - harassment.jsonl')
print('  - hate_speech.jsonl')
print('  - hindi_abuse.jsonl')
print('  - safe.jsonl')
print('  - self_harm.jsonl')
print('  - sexual.jsonl')
print()
uploaded = files.upload()

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.utils.class_weight import compute_class_weight
import os

print(f'GPU: {torch.cuda.get_device_name(0)}')
os.makedirs('moe_2stage', exist_ok=True)

## Stage 1: Binary Classifier (toxic vs safe)
Goal: **95%+ recall** - catch ALL toxic content

In [None]:
# Load all data
router_data = []
with open('router_train_v2.jsonl', 'r') as f:
    for line in f:
        router_data.append(json.loads(line))

# Convert to binary
binary_data = []
for d in router_data:
    binary_data.append({
        'text': d['text'],
        'label': 0 if d['label'] == 'safe' else 1
    })

safe_count = sum(1 for d in binary_data if d['label'] == 0)
toxic_count = sum(1 for d in binary_data if d['label'] == 1)

print(f'Total: {len(binary_data)}')
print(f'Safe: {safe_count} ({100*safe_count/len(binary_data):.1f}%)')
print(f'Toxic: {toxic_count} ({100*toxic_count/len(binary_data):.1f}%)')

In [None]:
# Prepare binary data
texts = [d['text'] for d in binary_data]
labels = [d['label'] for d in binary_data]

train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.1, random_state=42, stratify=labels
)

print(f'Train: {len(train_texts)}, Val: {len(val_texts)}')

In [None]:
# Load model
MODEL_NAME = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

binary_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: 'safe', 1: 'toxic'},
    label2id={'safe': 0, 'toxic': 1}
)

In [None]:
# Tokenize
def tokenize(texts, labels):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=256)
    return Dataset.from_dict({
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'labels': labels
    })

train_dataset = tokenize(train_texts, train_labels)
val_dataset = tokenize(val_texts, val_labels)

In [None]:
# Class weights - BOOST TOXIC for high recall
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=train_labels)
class_weights[1] *= 2.0  # Double toxic weight for high recall
print(f'Class weights: safe={class_weights[0]:.3f}, toxic={class_weights[1]:.3f}')

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to('cuda')

class HighRecallTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics_binary(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
# Train Stage 1
args = TrainingArguments(
    output_dir='./stage1_checkpoints',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=500,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='recall',
    fp16=True,
    report_to='none',
)

binary_trainer = HighRecallTrainer(
    model=binary_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics_binary,
)

print('Training Stage 1 (binary - high recall)...')
binary_trainer.train()

In [None]:
# Evaluate Stage 1
results = binary_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 1 RESULTS (Binary: toxic vs safe)')
print(f'{"="*60}')
print(f'Accuracy:  {results["eval_accuracy"]:.4f}')
print(f'F1:        {results["eval_f1"]:.4f}')
print(f'Precision: {results["eval_precision"]:.4f}')
print(f'Recall:    {results["eval_recall"]:.4f} {"✅" if results["eval_recall"] >= 0.95 else "⚠️"}')
print(f'{"="*60}')

if results['eval_recall'] < 0.95:
    print('\n⚠️ Recall < 95%. Consider increasing toxic weight.')

In [None]:
# Save Stage 1
binary_model.save_pretrained('moe_2stage/stage1_binary')
tokenizer.save_pretrained('moe_2stage/stage1_binary')
print('Stage 1 saved!')

## Stage 2: Category Classifier
Only classifies content flagged as toxic by Stage 1

In [None]:
# Prepare category data (toxic only)
category_data = [d for d in router_data if d['label'] != 'safe']

print(f'Toxic examples for Stage 2: {len(category_data)}')

# Categories
categories = sorted(set(d['label'] for d in category_data))
label2id_cat = {cat: i for i, cat in enumerate(categories)}
id2label_cat = {i: cat for i, cat in enumerate(categories)}

print(f'Categories ({len(categories)}): {categories}')

from collections import Counter
cat_counts = Counter(d['label'] for d in category_data)
for cat, count in cat_counts.most_common():
    print(f'  {cat}: {count}')

In [None]:
# Prepare Stage 2 data
cat_texts = [d['text'] for d in category_data]
cat_labels = [label2id_cat[d['label']] for d in category_data]

train_cat_texts, val_cat_texts, train_cat_labels, val_cat_labels = train_test_split(
    cat_texts, cat_labels, test_size=0.1, random_state=42, stratify=cat_labels
)

print(f'Train: {len(train_cat_texts)}, Val: {len(val_cat_texts)}')

In [None]:
# Load model for Stage 2
category_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(categories),
    id2label=id2label_cat,
    label2id=label2id_cat
)

# Tokenize
train_cat_dataset = tokenize(train_cat_texts, train_cat_labels)
val_cat_dataset = tokenize(val_cat_texts, val_cat_labels)

In [None]:
# Class weights for categories
cat_weights = compute_class_weight('balanced', classes=np.unique(train_cat_labels), y=train_cat_labels)
print('Category weights:')
for i, cat in id2label_cat.items():
    print(f'  {cat}: {cat_weights[i]:.3f}')

cat_weights_tensor = torch.tensor(cat_weights, dtype=torch.float32).to('cuda')

class WeightedCategoryTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = nn.CrossEntropyLoss(weight=cat_weights_tensor)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics_cat(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
# Train Stage 2
cat_args = TrainingArguments(
    output_dir='./stage2_checkpoints',
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=300,
    weight_decay=0.01,
    logging_steps=300,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,
    report_to='none',
)

category_trainer = WeightedCategoryTrainer(
    model=category_model,
    args=cat_args,
    train_dataset=train_cat_dataset,
    eval_dataset=val_cat_dataset,
    compute_metrics=compute_metrics_cat,
)

print('Training Stage 2 (category classifier)...')
category_trainer.train()

In [None]:
# Evaluate Stage 2
cat_results = category_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 2 RESULTS (Category Classifier)')
print(f'{"="*60}')
print(f'Accuracy:  {cat_results["eval_accuracy"]:.4f}')
print(f'F1:        {cat_results["eval_f1"]:.4f}')
print(f'Precision: {cat_results["eval_precision"]:.4f}')
print(f'Recall:    {cat_results["eval_recall"]:.4f}')

# Detailed report
preds = category_trainer.predict(val_cat_dataset)
pred_labels = preds.predictions.argmax(-1)
print('\nClassification Report:')
print(classification_report(val_cat_labels, pred_labels, target_names=categories))

In [None]:
# Save Stage 2
category_model.save_pretrained('moe_2stage/stage2_category')
tokenizer.save_pretrained('moe_2stage/stage2_category')
print('Stage 2 saved!')

## Test Complete 2-Stage Pipeline

In [None]:
from transformers import pipeline

# Load both stages
stage1 = pipeline('text-classification', model='moe_2stage/stage1_binary', device=0)
stage2 = pipeline('text-classification', model='moe_2stage/stage2_category', device=0)

def classify_2stage(text):
    # Stage 1: toxic or safe?
    s1_result = stage1(text)[0]
    s1_label = s1_result['label']
    s1_conf = s1_result['score']
    
    if s1_label == 'safe':
        return {'is_toxic': False, 'category': 'safe', 'confidence': s1_conf}
    
    # Stage 2: which category?
    s2_result = stage2(text)[0]
    s2_label = s2_result['label']
    s2_conf = s2_result['score']
    
    return {
        'is_toxic': True,
        'category': s2_label,
        'stage1_conf': s1_conf,
        'stage2_conf': s2_conf,
        'combined_conf': s1_conf * s2_conf
    }

print('2-Stage pipeline ready!')

In [None]:
# Critical test cases
test_cases = [
    # Safe
    ('Hello, how are you?', 'safe', False),
    ('The weather is nice today', 'safe', False),
    ('namaste kaise ho', 'safe', False),
    ('I went back to my country for vacation', 'safe', False),
    
    # Harmful content
    ('I will kill you', 'harmful_content', True),
    ('how to make a bomb', 'harmful_content', True),
    ('how to make an explosive device', 'harmful_content', True),
    ('I will shoot you', 'harmful_content', True),
    
    # Hindi abuse
    ('bhenchod', 'hindi_abuse', True),
    ('madarchod', 'hindi_abuse', True),
    ('chutiya', 'hindi_abuse', True),
    ('teri maa ki', 'hindi_abuse', True),
    
    # Harassment
    ('you are worthless garbage', 'harassment', True),
    ('you are so stupid', 'harassment', True),
    ('nobody likes you', 'harassment', True),
    ('your kind disgusts me', 'harassment', True),
    
    # Hate speech
    ('go back to your country', 'hate_speech', True),
    ('all immigrants should be deported', 'hate_speech', True),
    ('people like you should not exist', 'hate_speech', True),
    
    # Self harm
    ('I want to kill myself', 'self_harm', True),
    ('I want to end my life', 'self_harm', True),
    ('kys', 'self_harm', True),
    
    # Sexual
    ('send nudes', 'sexual', True),
    ('I want to see you naked', 'sexual', True),
]

print('='*80)
print('2-STAGE MOE CRITICAL TEST CASES')
print('='*80)

toxic_correct = 0
toxic_total = 0
category_correct = 0
category_total = 0
wrong = []

for text, expected_cat, expected_toxic in test_cases:
    result = classify_2stage(text)
    is_toxic = result['is_toxic']
    category = result['category']
    
    # Check toxic detection
    if expected_toxic:
        toxic_total += 1
        if is_toxic:
            toxic_correct += 1
    
    # Check category (only if toxic)
    if expected_toxic and is_toxic:
        category_total += 1
        if category == expected_cat:
            category_correct += 1
    
    # Status
    toxic_status = '✅' if is_toxic == expected_toxic else '❌'
    cat_status = '✅' if category == expected_cat else '❌'
    
    if is_toxic != expected_toxic or category != expected_cat:
        wrong.append((text, expected_cat, category, is_toxic))
    
    conf = result.get('combined_conf', result.get('confidence', 0))
    print(f'{toxic_status}{cat_status} {category:<18} | exp: {expected_cat:<18} | {text[:35]}')

print()
print('='*80)
print('RESULTS')
print('='*80)
print(f'Toxic Detection: {toxic_correct}/{toxic_total} ({100*toxic_correct/toxic_total:.0f}%) - Target: 95%+')
print(f'Category Accuracy: {category_correct}/{category_total} ({100*category_correct/category_total:.0f}%) - Target: 80%+')
print('='*80)

if wrong:
    print(f'\n❌ WRONG ({len(wrong)}):')
    for text, exp, got, is_toxic in wrong:
        print(f'   "{text[:40]}" → {got} (exp: {exp}, toxic={is_toxic})')

In [None]:
# Download
!zip -r moe_2stage.zip moe_2stage/
files.download('moe_2stage.zip')