# MoE (Mixture of Experts) Training
## Universal Restrictor - Production Model

Architecture:
```
Input → Router (8-class) → Expert Model → toxic/safe
```

Models to train:
1. Router: Classifies into 8 categories
2. Experts: Binary classifiers per category

In [None]:
!pip install transformers datasets accelerate scikit-learn -q

In [None]:
from google.colab import files
import os

print('Upload ALL these files from data/datasets/moe/:')
print('  - router_train.jsonl')
print('  - dangerous.jsonl')
print('  - harassment.jsonl')
print('  - hate_speech.jsonl')
print('  - hindi_abuse.jsonl')
print('  - safe.jsonl')
print('  - self_harm.jsonl')
print('  - sexual.jsonl')
print('  - violence.jsonl')
print()
uploaded = files.upload()

In [None]:
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import numpy as np

print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"}')
print(f'CUDA available: {torch.cuda.is_available()}')

# Create output directory
os.makedirs('moe_models', exist_ok=True)

## Part 1: Train Router Model (8-class)

In [None]:
# Load router data
router_data = []
with open('router_train.jsonl', 'r') as f:
    for line in f:
        router_data.append(json.loads(line))

print(f'Router data: {len(router_data)} examples')

# Get unique categories
categories = sorted(set(d['label'] for d in router_data))
print(f'Categories: {categories}')

# Create label mapping
label2id = {cat: i for i, cat in enumerate(categories)}
id2label = {i: cat for i, cat in enumerate(categories)}
print(f'Label mapping: {label2id}')

In [None]:
# Prepare router training data
texts = [d['text'] for d in router_data]
labels = [label2id[d['label']] for d in router_data]

# Train/val split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.1, random_state=42, stratify=labels
)

print(f'Train: {len(train_texts)}, Val: {len(val_texts)}')

In [None]:
# Load model for router
MODEL_NAME = 'distilbert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
router_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(categories),
    id2label=id2label,
    label2id=label2id
)

print(f'Router model: {len(categories)} classes')

In [None]:
# Tokenize
def tokenize(texts, labels):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=256)
    return Dataset.from_dict({
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'labels': labels
    })

train_dataset = tokenize(train_texts, train_labels)
val_dataset = tokenize(val_texts, val_labels)

In [None]:
# Metrics for multi-class
def compute_metrics_multiclass(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {
        'accuracy': accuracy_score(labels, preds),
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
# Train router
router_args = TrainingArguments(
    output_dir='./router_checkpoints',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=500,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,
    report_to='none',
)

router_trainer = Trainer(
    model=router_model,
    args=router_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics_multiclass,
)

print('Training Router Model...')
router_trainer.train()

In [None]:
# Evaluate router
router_results = router_trainer.evaluate()
print(f'\nRouter Results:')
print(f'  Accuracy: {router_results["eval_accuracy"]:.4f}')
print(f'  F1: {router_results["eval_f1"]:.4f}')

# Detailed report
preds = router_trainer.predict(val_dataset)
pred_labels = preds.predictions.argmax(-1)
print('\nClassification Report:')
print(classification_report(val_labels, pred_labels, target_names=categories))

In [None]:
# Save router
router_model.save_pretrained('moe_models/router')
tokenizer.save_pretrained('moe_models/router')
print('Router saved!')

## Part 2: Train Expert Models (Binary Classifiers)

In [None]:
# Load safe examples for negative samples
safe_examples = []
with open('safe.jsonl', 'r') as f:
    for line in f:
        safe_examples.append(json.loads(line))

print(f'Safe examples available: {len(safe_examples)}')

In [None]:
import random

def train_expert(category_name, category_file, safe_pool, model_name='distilbert-base-uncased'):
    """Train a binary expert model for a category."""
    print(f'\n{"="*60}')
    print(f'Training Expert: {category_name.upper()}')
    print(f'{"="*60}')
    
    # Load category data
    toxic_examples = []
    with open(category_file, 'r') as f:
        for line in f:
            toxic_examples.append(json.loads(line))
    
    print(f'Toxic examples: {len(toxic_examples)}')
    
    # Sample safe examples (1.5x toxic for 40/60 balance)
    n_safe = min(int(len(toxic_examples) * 1.5), len(safe_pool))
    safe_sample = random.sample(safe_pool, n_safe)
    print(f'Safe examples: {n_safe}')
    
    # Prepare data
    texts = [ex['text'] for ex in toxic_examples] + [ex['text'] for ex in safe_sample]
    labels = [1] * len(toxic_examples) + [0] * len(safe_sample)  # 1=toxic, 0=safe
    
    # Split
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.1, random_state=42, stratify=labels
    )
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label={0: 'safe', 1: 'toxic'},
        label2id={'safe': 0, 'toxic': 1}
    )
    
    # Tokenize
    def tokenize(texts, labels):
        encodings = tokenizer(texts, truncation=True, padding=True, max_length=256)
        return Dataset.from_dict({
            'input_ids': encodings['input_ids'],
            'attention_mask': encodings['attention_mask'],
            'labels': labels
        })
    
    train_dataset = tokenize(train_texts, train_labels)
    val_dataset = tokenize(val_texts, val_labels)
    
    # Metrics
    def compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
        return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}
    
    # Train
    args = TrainingArguments(
        output_dir=f'./{category_name}_checkpoints',
        num_train_epochs=3,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=64,
        warmup_steps=200,
        weight_decay=0.01,
        logging_steps=200,
        eval_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='f1',
        fp16=True,
        report_to='none',
    )
    
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
    
    trainer.train()
    
    # Evaluate
    results = trainer.evaluate()
    print(f'\nResults for {category_name}:')
    print(f'  Accuracy: {results["eval_accuracy"]:.4f}')
    print(f'  F1: {results["eval_f1"]:.4f}')
    print(f'  Precision: {results["eval_precision"]:.4f}')
    print(f'  Recall: {results["eval_recall"]:.4f}')
    
    # Save
    model.save_pretrained(f'moe_models/{category_name}')
    tokenizer.save_pretrained(f'moe_models/{category_name}')
    print(f'Saved to moe_models/{category_name}')
    
    return results

In [None]:
# Train all experts
expert_categories = [
    ('dangerous', 'dangerous.jsonl'),
    ('harassment', 'harassment.jsonl'),
    ('hate_speech', 'hate_speech.jsonl'),
    ('hindi_abuse', 'hindi_abuse.jsonl'),
    ('self_harm', 'self_harm.jsonl'),
    ('sexual', 'sexual.jsonl'),
    ('violence', 'violence.jsonl'),
]

all_results = {}
for cat_name, cat_file in expert_categories:
    results = train_expert(cat_name, cat_file, safe_examples)
    all_results[cat_name] = results

In [None]:
# Summary
print('\n' + '='*60)
print('MOE TRAINING COMPLETE - SUMMARY')
print('='*60)
print(f'{"Model":<15} | {"Accuracy":>10} | {"F1":>10} | {"Precision":>10} | {"Recall":>10}')
print('-'*60)
print(f'{"Router":<15} | {router_results["eval_accuracy"]:>10.4f} | {router_results["eval_f1"]:>10.4f} | {router_results["eval_precision"]:>10.4f} | {router_results["eval_recall"]:>10.4f}')
for cat, results in all_results.items():
    print(f'{cat:<15} | {results["eval_accuracy"]:>10.4f} | {results["eval_f1"]:>10.4f} | {results["eval_precision"]:>10.4f} | {results["eval_recall"]:>10.4f}')

In [None]:
# Test the MoE pipeline
from transformers import pipeline

# Load router
router = pipeline('text-classification', model='moe_models/router')

# Load experts
experts = {}
for cat in ['dangerous', 'harassment', 'hate_speech', 'hindi_abuse', 'self_harm', 'sexual', 'violence']:
    experts[cat] = pipeline('text-classification', model=f'moe_models/{cat}')

def moe_classify(text):
    # Route
    route_result = router(text)[0]
    category = route_result['label']
    route_conf = route_result['score']
    
    if category == 'safe':
        return {'category': 'safe', 'is_toxic': False, 'confidence': route_conf}
    
    # Expert classification
    expert_result = experts[category](text)[0]
    is_toxic = expert_result['label'] == 'toxic'
    expert_conf = expert_result['score']
    
    return {
        'category': category,
        'is_toxic': is_toxic,
        'route_confidence': route_conf,
        'expert_confidence': expert_conf,
        'combined_confidence': route_conf * expert_conf
    }

# Test cases
test_cases = [
    ('Hello, how are you?', 'safe'),
    ('I will kill you', 'violence'),
    ('bhenchod', 'hindi_abuse'),
    ('madarchod', 'hindi_abuse'),
    ('how to make a bomb', 'dangerous'),
    ('you are worthless garbage', 'harassment'),
    ('go back to your country', 'hate_speech'),
    ('I want to kill myself', 'self_harm'),
]

print('\n' + '='*70)
print('MOE PIPELINE TEST')
print('='*70)
for text, expected in test_cases:
    result = moe_classify(text)
    status = '✅' if (result['category'] == expected or (expected != 'safe' and result.get('is_toxic', False))) else '❌'
    print(f"{status} '{text[:40]:<40}' → {result['category']:<15} (expected: {expected})")

In [None]:
# Zip and download
!zip -r moe_models.zip moe_models/
files.download('moe_models.zip')