# MoE Training v3 - Curated Dataset (Fixed)

**Fixes**:
- Cleaned 670 mislabeled samples
- Balanced dataset
- Higher toxic class weight (2.5x)
- Target: 95%+ recall

In [None]:
!pip install transformers datasets accelerate scikit-learn sentencepiece -q

In [None]:
from google.colab import files
import os

print('Upload from data/datasets/curated/final/:')
print('  - train.jsonl')
print('  - val.jsonl') 
print('  - test.jsonl')
print()
uploaded = files.upload()

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'CUDA: {torch.cuda.is_available()}')

os.makedirs('moe_v3', exist_ok=True)

MODEL_NAME = 'google/muril-base-cased'
print(f'Model: {MODEL_NAME}')

In [None]:
# Load data
def load_jsonl(path):
    data = []
    with open(path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

train_data = load_jsonl('train.jsonl')
val_data = load_jsonl('val.jsonl')
test_data = load_jsonl('test.jsonl')

print(f'Train: {len(train_data)}')
print(f'Val:   {len(val_data)}')
print(f'Test:  {len(test_data)}')

print('\nTrain distribution:')
for label, count in Counter(d['label'] for d in train_data).items():
    print(f'  {label}: {count}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f'Tokenizer loaded')

## Stage 1: Binary Classifier
**Goal: 95%+ recall on toxic**

In [None]:
# Prepare binary data
def prep_binary(data):
    return [{'text': d['text'], 'label': 0 if d['label'] == 'safe' else 1} for d in data]

train_binary = prep_binary(train_data)
val_binary = prep_binary(val_data)

train_texts = [d['text'] for d in train_binary]
train_labels = [d['label'] for d in train_binary]
val_texts = [d['text'] for d in val_binary]
val_labels = [d['label'] for d in val_binary]

print(f'Train: {len(train_texts)} (toxic: {sum(train_labels)}, safe: {len(train_labels)-sum(train_labels)})')
print(f'Val:   {len(val_texts)} (toxic: {sum(val_labels)}, safe: {len(val_labels)-sum(val_labels)})')

In [None]:
def tokenize_data(texts, labels):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=256, return_tensors=None)
    return Dataset.from_dict({'input_ids': encodings['input_ids'], 'attention_mask': encodings['attention_mask'], 'labels': labels})

print('Tokenizing...')
train_dataset = tokenize_data(train_texts, train_labels)
val_dataset = tokenize_data(val_texts, val_labels)
print(f'Done!')

In [None]:
binary_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2, id2label={0: 'safe', 1: 'toxic'}, label2id={'safe': 0, 'toxic': 1}
)
print('Binary model loaded!')

In [None]:
# AGGRESSIVE class weights for HIGH RECALL
# We want to catch ALL toxic content, even at cost of some false positives
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=train_labels)
class_weights[1] *= 2.5  # BOOST toxic weight significantly
print(f'Class weights: safe={class_weights[0]:.3f}, toxic={class_weights[1]:.3f}')

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to('cuda')

class HighRecallTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    
    # Confusion matrix
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    print(f'\nConfusion Matrix: TP={tp}, FP={fp}, TN={tn}, FN={fn}')
    print(f'False Negatives (missed toxic): {fn}')
    
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
binary_args = TrainingArguments(
    output_dir='./stage1_checkpoints',
    num_train_epochs=4,  # More epochs
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=2e-5,
    logging_steps=200,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='recall',
    greater_is_better=True,
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,
)

binary_trainer = HighRecallTrainer(
    model=binary_model,
    args=binary_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print('='*60)
print('TRAINING STAGE 1: Binary (toxic vs safe)')
print('Target: 95%+ recall')
print('='*60)
binary_trainer.train()

In [None]:
results1 = binary_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 1 RESULTS')
print(f'{"="*60}')
print(f'Accuracy:  {results1["eval_accuracy"]:.4f}')
print(f'F1:        {results1["eval_f1"]:.4f}')
print(f'Precision: {results1["eval_precision"]:.4f}')
print(f'Recall:    {results1["eval_recall"]:.4f}')

if results1['eval_recall'] >= 0.95:
    print('✅ EXCELLENT - 95%+ recall achieved!')
elif results1['eval_recall'] >= 0.90:
    print('⚠️ GOOD - but aim for 95%')
else:
    print('❌ NEEDS IMPROVEMENT - recall too low')

In [None]:
binary_model.save_pretrained('moe_v3/stage1_binary')
tokenizer.save_pretrained('moe_v3/stage1_binary')
print('✅ Stage 1 saved')

## Stage 2: Category Classifier

In [None]:
train_toxic = [d for d in train_data if d['label'] == 'toxic']
val_toxic = [d for d in val_data if d['label'] == 'toxic']

categories = ['harassment', 'hate_speech', 'violence', 'sexual', 'threat']
label2id_cat = {cat: i for i, cat in enumerate(categories)}
id2label_cat = {i: cat for i, cat in enumerate(categories)}

def map_cat(cat):
    return label2id_cat.get(cat, 0)  # Default to harassment

train_cat_texts = [d['text'] for d in train_toxic]
train_cat_labels = [map_cat(d['category']) for d in train_toxic]
val_cat_texts = [d['text'] for d in val_toxic]
val_cat_labels = [map_cat(d['category']) for d in val_toxic]

print(f'Train toxic: {len(train_cat_texts)}')
print(f'Val toxic: {len(val_cat_texts)}')
print(f'\nCategories:')
for cat, count in Counter(d['category'] for d in train_toxic).most_common():
    print(f'  {cat}: {count}')

In [None]:
print('Tokenizing Stage 2...')
train_cat_dataset = tokenize_data(train_cat_texts, train_cat_labels)
val_cat_dataset = tokenize_data(val_cat_texts, val_cat_labels)
print('Done!')

In [None]:
category_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=len(categories), id2label=id2label_cat, label2id=label2id_cat
)
print(f'Category model loaded ({len(categories)} classes)')

In [None]:
# Balance category weights
unique_labels = np.unique(train_cat_labels)
cat_weights = compute_class_weight('balanced', classes=unique_labels, y=train_cat_labels)

# Pad weights if needed
full_weights = np.ones(len(categories))
for i, label in enumerate(unique_labels):
    full_weights[label] = cat_weights[i]

print('Category weights:')
for i, cat in enumerate(categories):
    print(f'  {cat}: {full_weights[i]:.3f}')

cat_weights_tensor = torch.tensor(full_weights, dtype=torch.float32).to('cuda')

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        loss_fn = nn.CrossEntropyLoss(weight=cat_weights_tensor)
        loss = loss_fn(outputs.logits, labels)
        return (loss, outputs) if return_outputs else loss

def compute_metrics_cat(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
cat_args = TrainingArguments(
    output_dir='./stage2_checkpoints',
    num_train_epochs=4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=200,
    weight_decay=0.01,
    learning_rate=2e-5,
    logging_steps=200,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,
)

category_trainer = WeightedTrainer(
    model=category_model,
    args=cat_args,
    train_dataset=train_cat_dataset,
    eval_dataset=val_cat_dataset,
    compute_metrics=compute_metrics_cat,
)

print('='*60)
print('TRAINING STAGE 2: Category Classifier')
print('='*60)
category_trainer.train()

In [None]:
results2 = category_trainer.evaluate()
print(f'\n{"="*60}')
print('STAGE 2 RESULTS')
print(f'{"="*60}')
print(f'Accuracy:  {results2["eval_accuracy"]:.4f}')
print(f'F1:        {results2["eval_f1"]:.4f}')

preds = category_trainer.predict(val_cat_dataset)
pred_labels = preds.predictions.argmax(-1)
print('\nClassification Report:')
print(classification_report(val_cat_labels, pred_labels, target_names=categories, zero_division=0))

In [None]:
category_model.save_pretrained('moe_v3/stage2_category')
tokenizer.save_pretrained('moe_v3/stage2_category')
print('✅ Stage 2 saved')

## Test Pipeline

In [None]:
from transformers import pipeline

stage1 = pipeline('text-classification', model='moe_v3/stage1_binary', device=0)
stage2 = pipeline('text-classification', model='moe_v3/stage2_category', device=0)

def classify(text):
    r1 = stage1(text, truncation=True)[0]
    if r1['label'] == 'safe':
        return {'toxic': False, 'category': 'safe', 'conf': r1['score']}
    r2 = stage2(text, truncation=True)[0]
    return {'toxic': True, 'category': r2['label'], 'conf': r1['score'] * r2['score']}

print('Pipeline ready!')

In [None]:
# Critical test cases - MUST pass
tests = [
    # SAFE
    ('Hello, how are you?', False),
    ('The weather is nice today', False),
    ('namaste kaise ho', False),
    ('I went back to my country for vacation', False),
    ('This is so frustrating', False),
    ('I hate Mondays', False),
    ('This is really hard', False),
    ('I hate how much I love you', False),
    ('My brain hurts', False),
    ('yaar bahut mushkil hai ye', False),
    ('tera number aa gaya bhai', False),
    
    # TOXIC
    ('I will kill you', True),
    ('bhenchod', True),
    ('madarchod', True),
    ('chutiya hai tu', True),
    ('you are worthless garbage', True),
    ('go back to your country', True),
    ('all immigrants should be deported', True),
    ('tera time aa gaya hai tayyar reh', True),
    ('tujhe maar dunga', True),
    ('I want to kill myself', True),
    ('you stupid idiot', True),
    ('fuck you bitch', True),
]

print('='*70)
print('CRITICAL TEST CASES')
print('='*70)

correct = 0
false_negatives = []
false_positives = []

for text, should_be_toxic in tests:
    r = classify(text)
    is_correct = r['toxic'] == should_be_toxic
    if is_correct:
        correct += 1
    else:
        if should_be_toxic:
            false_negatives.append(text)
        else:
            false_positives.append(text)
    
    icon = '✅' if is_correct else '❌'
    exp = 'TOXIC' if should_be_toxic else 'SAFE'
    got = 'TOXIC' if r['toxic'] else 'SAFE'
    print(f"{icon} [{got:5}] exp:{exp:5} | {text[:40]}")

print(f'\nAccuracy: {correct}/{len(tests)} ({100*correct/len(tests):.0f}%)')

if false_negatives:
    print(f'\n❌ FALSE NEGATIVES (missed toxic):')
    for t in false_negatives:
        print(f'   - {t}')

if false_positives:
    print(f'\n⚠️ FALSE POSITIVES (wrongly flagged safe as toxic):')
    for t in false_positives:
        print(f'   - {t}')

In [None]:
# Test set evaluation
print('='*60)
print('TEST SET EVALUATION')
print('='*60)

test_texts = [d['text'] for d in test_data]
test_labels = [1 if d['label'] == 'toxic' else 0 for d in test_data]
test_dataset = tokenize_data(test_texts, test_labels)

test_results = binary_trainer.evaluate(test_dataset)
print(f'Test Accuracy:  {test_results["eval_accuracy"]:.4f}')
print(f'Test F1:        {test_results["eval_f1"]:.4f}')
print(f'Test Precision: {test_results["eval_precision"]:.4f}')
print(f'Test Recall:    {test_results["eval_recall"]:.4f}')

In [None]:
# Summary
print('\n' + '='*60)
print('FINAL SUMMARY')
print('='*60)
print(f'Model: {MODEL_NAME}')
print(f'Dataset: ~26K samples (cleaned & balanced)')
print(f'\nStage 1 (Binary):')
print(f'  Val Recall: {results1["eval_recall"]:.2%} {"✅" if results1["eval_recall"] >= 0.95 else "⚠️"}')
print(f'  Val F1:     {results1["eval_f1"]:.2%}')
print(f'\nStage 2 (Category):')
print(f'  Val F1:     {results2["eval_f1"]:.2%}')
print('='*60)

In [None]:
# Save config
config = {
    'model_name': MODEL_NAME,
    'categories': categories,
    'label2id_cat': label2id_cat,
    'stage1_results': {'accuracy': results1['eval_accuracy'], 'f1': results1['eval_f1'], 'recall': results1['eval_recall']},
    'stage2_results': {'accuracy': results2['eval_accuracy'], 'f1': results2['eval_f1']},
    'test_results': {'accuracy': test_results['eval_accuracy'], 'f1': test_results['eval_f1'], 'recall': test_results['eval_recall']}
}

with open('moe_v3/config.json', 'w') as f:
    json.dump(config, f, indent=2)
print('✅ Config saved')

In [None]:
!zip -r moe_v3.zip moe_v3/
files.download('moe_v3.zip')