# MoE Training v3 - BALANCED

**Target: 93%+ recall WITH 80%+ precision**

Previous issues:
- Too aggressive → 100% recall but 48% precision (useless)
- Not aggressive enough → 85% recall (misses toxic)

**Solution:**
- Moderate class weight (2x)
- Find optimal threshold via F1 optimization
- Use threshold that gives best recall while keeping precision > 75%

In [None]:
!pip install transformers datasets accelerate scikit-learn -q

In [None]:
from google.colab import files
print('Upload train.jsonl, val.jsonl, test.jsonl')
uploaded = files.upload()

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, precision_recall_curve
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print(f'GPU: {torch.cuda.get_device_name(0)}')

MODEL_NAME = 'xlm-roberta-base'
print(f'Model: {MODEL_NAME}')

In [None]:
def load_jsonl(path):
    with open(path, 'r') as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl('train.jsonl')
val_data = load_jsonl('val.jsonl')
test_data = load_jsonl('test.jsonl')

print(f'Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}')
print(f'Train: toxic={sum(1 for d in train_data if d["label"]=="toxic")}, safe={sum(1 for d in train_data if d["label"]=="safe")}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print('Tokenizer loaded')

In [None]:
train_texts = [d['text'] for d in train_data]
train_labels = [0 if d['label'] == 'safe' else 1 for d in train_data]
val_texts = [d['text'] for d in val_data]
val_labels = [0 if d['label'] == 'safe' else 1 for d in val_data]

def tokenize_data(texts, labels):
    enc = tokenizer(texts, truncation=True, padding=True, max_length=256, return_tensors=None)
    return Dataset.from_dict({'input_ids': enc['input_ids'], 'attention_mask': enc['attention_mask'], 'labels': labels})

print('Tokenizing...')
train_dataset = tokenize_data(train_texts, train_labels)
val_dataset = tokenize_data(val_texts, val_labels)
print('Done!')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2, id2label={0: 'safe', 1: 'toxic'}, label2id={'safe': 0, 'toxic': 1}
)
print('Model loaded')

In [None]:
# MODERATE class weight - 2x for toxic (not 4x)
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=train_labels)
class_weights[1] *= 2.0  # 2x boost for toxic
print(f'Class weights: safe={class_weights[0]:.3f}, toxic={class_weights[1]:.3f}')

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to('cuda')

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)
        loss = loss_fn(outputs.logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
# Store predictions for threshold tuning later
VAL_PROBS = None
VAL_LABELS = None

def compute_metrics(pred):
    global VAL_PROBS, VAL_LABELS
    labels = pred.label_ids
    probs = torch.softmax(torch.tensor(pred.predictions), dim=-1)[:, 1].numpy()
    
    VAL_PROBS = probs
    VAL_LABELS = labels
    
    # Use default 0.5 threshold for training metrics
    preds = (probs >= 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    
    print(f'\n[thresh=0.5] TP={tp}, FP={fp}, TN={tn}, FN={fn}')
    
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall}

In [None]:
args = TrainingArguments(
    output_dir='./checkpoints',
    num_train_epochs=4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=2e-5,
    logging_steps=200,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',  # Optimize for F1 (balance)
    greater_is_better=True,
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,
)

trainer = WeightedTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print('='*60)
print('TRAINING')
print('='*60)
trainer.train()

In [None]:
# Final evaluation to get probabilities
results = trainer.evaluate()
print(f'\nResults at threshold=0.5:')
print(f'  Accuracy:  {results["eval_accuracy"]:.4f}')
print(f'  F1:        {results["eval_f1"]:.4f}')
print(f'  Precision: {results["eval_precision"]:.4f}')
print(f'  Recall:    {results["eval_recall"]:.4f}')

In [None]:
# THRESHOLD TUNING - find best threshold
print('\n' + '='*60)
print('THRESHOLD TUNING')
print('='*60)

best_threshold = 0.5
best_score = 0

print(f'{"Thresh":>8} {"Recall":>8} {"Precision":>10} {"F1":>8} {"Note"}')
print('-'*50)

for thresh in [0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]:
    preds = (VAL_PROBS >= thresh).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(VAL_LABELS, preds, average='binary', zero_division=0)
    
    note = ''
    if recall >= 0.93 and precision >= 0.75:
        note = '✅ GOOD'
        # Prefer higher recall if both are acceptable
        score = recall + (precision * 0.5)  # Weight recall more
        if score > best_score:
            best_score = score
            best_threshold = thresh
    elif recall >= 0.90:
        note = '⚠️ OK'
    
    print(f'{thresh:>8.2f} {recall:>8.2%} {precision:>10.2%} {f1:>8.2%} {note}')

print(f'\n→ Selected threshold: {best_threshold}')

In [None]:
# Final results with best threshold
preds = (VAL_PROBS >= best_threshold).astype(int)
precision, recall, f1, _ = precision_recall_fscore_support(VAL_LABELS, preds, average='binary')
tn, fp, fn, tp = confusion_matrix(VAL_LABELS, preds).ravel()

print('='*60)
print(f'FINAL RESULTS (threshold={best_threshold})')
print('='*60)
print(f'Recall:    {recall:.2%} {"✅" if recall >= 0.93 else "⚠️"}')
print(f'Precision: {precision:.2%} {"✅" if precision >= 0.75 else "⚠️"}')
print(f'F1:        {f1:.2%}')
print(f'\nConfusion Matrix:')
print(f'  TP={tp} (caught toxic)')
print(f'  FP={fp} (false alarms)')
print(f'  TN={tn} (correct safe)')
print(f'  FN={fn} (missed toxic) ← want this LOW')
print('='*60)

In [None]:
# Save model
import os
os.makedirs('moe_v3', exist_ok=True)

model.save_pretrained('moe_v3/stage1_binary')
tokenizer.save_pretrained('moe_v3/stage1_binary')

with open('moe_v3/threshold.txt', 'w') as f:
    f.write(str(best_threshold))

print(f'✅ Saved with threshold={best_threshold}')

In [None]:
# Test classifier
from transformers import pipeline

clf = pipeline('text-classification', model='moe_v3/stage1_binary', device=0, top_k=None)

def classify(text, threshold=None):
    if threshold is None:
        threshold = best_threshold
    result = clf(text, truncation=True)[0]
    scores = {r['label']: r['score'] for r in result}
    toxic_score = scores.get('toxic', 0)
    return {'toxic': toxic_score >= threshold, 'score': toxic_score}

print(f'Classifier ready (threshold={best_threshold})')

In [None]:
# Critical tests
tests = [
    # SAFE - must NOT flag these
    ('Hello, how are you?', False),
    ('namaste kaise ho', False),
    ('This is so frustrating', False),
    ('I hate Mondays', False),
    ('I hate how much I love you', False),
    ('yaar bahut mushkil hai ye', False),
    ('The weather is nice', False),
    ('My brain hurts', False),
    ('tera number aa gaya bhai', False),
    
    # TOXIC - must flag these
    ('I will kill you', True),
    ('bhenchod', True),
    ('madarchod', True),
    ('chutiya hai tu', True),
    ('go back to your country', True),
    ('you stupid idiot', True),
    ('fuck you bitch', True),
    ('tujhe maar dunga', True),
    ('I want to kill myself', True),
    ('you worthless piece of garbage', True),
]

print('='*70)
print(f'CRITICAL TESTS (threshold={best_threshold})')
print('='*70)

correct = 0
fn_list = []
fp_list = []

for text, expected in tests:
    r = classify(text)
    ok = r['toxic'] == expected
    if ok:
        correct += 1
    elif expected:
        fn_list.append((text, r['score']))
    else:
        fp_list.append((text, r['score']))
    
    icon = '✅' if ok else '❌'
    got = 'TOXIC' if r['toxic'] else 'SAFE'
    exp = 'TOXIC' if expected else 'SAFE'
    print(f"{icon} [{r['score']:.2f}] {got:5} exp:{exp:5} | {text[:35]}")

print(f'\nAccuracy: {correct}/{len(tests)} ({100*correct/len(tests):.0f}%)')

if fn_list:
    print(f'\n❌ FALSE NEGATIVES (missed toxic):')
    for t, s in fn_list:
        print(f'   [{s:.2f}] {t}')

if fp_list:
    print(f'\n⚠️ FALSE POSITIVES (wrongly flagged):')
    for t, s in fp_list:
        print(f'   [{s:.2f}] {t}')

In [None]:
# Test set evaluation
test_texts = [d['text'] for d in test_data]
test_labels_list = [1 if d['label'] == 'toxic' else 0 for d in test_data]
test_dataset = tokenize_data(test_texts, test_labels_list)

print('Evaluating on test set...')
pred_output = trainer.predict(test_dataset)
test_probs = torch.softmax(torch.tensor(pred_output.predictions), dim=-1)[:, 1].numpy()
test_preds = (test_probs >= best_threshold).astype(int)

precision, recall, f1, _ = precision_recall_fscore_support(test_labels_list, test_preds, average='binary')
tn, fp, fn, tp = confusion_matrix(test_labels_list, test_preds).ravel()

print(f'\nTest Results (threshold={best_threshold}):')
print(f'  Recall:    {recall:.2%}')
print(f'  Precision: {precision:.2%}')
print(f'  F1:        {f1:.2%}')
print(f'  TP={tp}, FP={fp}, TN={tn}, FN={fn}')

In [None]:
# Save config
config = {
    'model_name': MODEL_NAME,
    'threshold': best_threshold,
    'val_recall': float(recall),
    'val_precision': float(precision),
    'val_f1': float(f1),
}

with open('moe_v3/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print('✅ Config saved')

In [None]:
!zip -r moe_v3_balanced.zip moe_v3/
files.download('moe_v3_balanced.zip')