# MoE Training v3 - HIGH RECALL Edition

**Target: 95%+ recall**

**Changes:**
- Focal Loss (focuses on hard examples)
- 4x toxic class weight
- Threshold tuning (0.3 instead of 0.5)
- XLM-RoBERTa (better multilingual)

In [None]:
!pip install transformers datasets accelerate scikit-learn -q

In [None]:
from google.colab import files
print('Upload train.jsonl, val.jsonl, test.jsonl')
uploaded = files.upload()

In [None]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print(f'GPU: {torch.cuda.get_device_name(0)}')

# Try XLM-RoBERTa first (better multilingual), fallback to MuRIL
MODEL_NAME = 'xlm-roberta-base'
print(f'Model: {MODEL_NAME}')

In [None]:
def load_jsonl(path):
    with open(path, 'r') as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl('train.jsonl')
val_data = load_jsonl('val.jsonl')
test_data = load_jsonl('test.jsonl')

print(f'Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}')
print(f'Train toxic: {sum(1 for d in train_data if d["label"]=="toxic")}')
print(f'Train safe: {sum(1 for d in train_data if d["label"]=="safe")}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print('Tokenizer loaded')

In [None]:
# Prepare data
train_texts = [d['text'] for d in train_data]
train_labels = [0 if d['label'] == 'safe' else 1 for d in train_data]
val_texts = [d['text'] for d in val_data]
val_labels = [0 if d['label'] == 'safe' else 1 for d in val_data]

def tokenize_data(texts, labels):
    enc = tokenizer(texts, truncation=True, padding=True, max_length=256, return_tensors=None)
    return Dataset.from_dict({'input_ids': enc['input_ids'], 'attention_mask': enc['attention_mask'], 'labels': labels})

print('Tokenizing...')
train_dataset = tokenize_data(train_texts, train_labels)
val_dataset = tokenize_data(val_texts, val_labels)
print('Done!')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2, id2label={0: 'safe', 1: 'toxic'}, label2id={'safe': 0, 'toxic': 1}
)
print('Model loaded')

In [None]:
# FOCAL LOSS - focuses on hard examples
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Class weights
        self.gamma = gamma  # Focusing parameter
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)  # Probability of correct class
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

# Very aggressive class weights: 4x for toxic
class_weights = compute_class_weight('balanced', classes=np.array([0, 1]), y=train_labels)
class_weights[1] *= 4.0  # 4x boost for toxic
print(f'Class weights: safe={class_weights[0]:.3f}, toxic={class_weights[1]:.3f}')

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to('cuda')
focal_loss = FocalLoss(alpha=weights_tensor, gamma=2.0)

class FocalTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop('labels')
        outputs = model(**inputs)
        loss = focal_loss(outputs.logits, labels)
        return (loss, outputs) if return_outputs else loss

In [None]:
# Custom metrics with threshold tuning
BEST_THRESHOLD = 0.5

def compute_metrics(pred):
    global BEST_THRESHOLD
    labels = pred.label_ids
    probs = torch.softmax(torch.tensor(pred.predictions), dim=-1)[:, 1].numpy()
    
    # Try different thresholds
    best_recall = 0
    best_thresh = 0.5
    for thresh in [0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
        preds = (probs >= thresh).astype(int)
        _, recall, _, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
        if recall > best_recall:
            best_recall = recall
            best_thresh = thresh
    
    BEST_THRESHOLD = best_thresh
    preds = (probs >= best_thresh).astype(int)
    
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    
    print(f'\nThreshold: {best_thresh}, TP={tp}, FP={fp}, TN={tn}, FN={fn}')
    
    return {'accuracy': accuracy_score(labels, preds), 'f1': f1, 'precision': precision, 'recall': recall, 'threshold': best_thresh}

In [None]:
args = TrainingArguments(
    output_dir='./checkpoints',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=2e-5,
    logging_steps=200,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='recall',
    greater_is_better=True,
    fp16=True,
    report_to='none',
    gradient_accumulation_steps=2,
)

trainer = FocalTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print('='*60)
print('TRAINING - Target: 95%+ Recall')
print('='*60)
trainer.train()

In [None]:
results = trainer.evaluate()
print(f'\n{"="*60}')
print('RESULTS')
print(f'{"="*60}')
print(f'Accuracy:  {results["eval_accuracy"]:.4f}')
print(f'F1:        {results["eval_f1"]:.4f}')
print(f'Precision: {results["eval_precision"]:.4f}')
print(f'Recall:    {results["eval_recall"]:.4f}')
print(f'Threshold: {results.get("eval_threshold", BEST_THRESHOLD)}')

if results['eval_recall'] >= 0.95:
    print('\n✅ EXCELLENT - 95%+ recall!')
elif results['eval_recall'] >= 0.90:
    print('\n⚠️ GOOD - close to target')
else:
    print('\n❌ Needs more work')

In [None]:
# Save with threshold info
import os
os.makedirs('moe_v3', exist_ok=True)

model.save_pretrained('moe_v3/stage1_binary')
tokenizer.save_pretrained('moe_v3/stage1_binary')

# Save threshold
with open('moe_v3/threshold.txt', 'w') as f:
    f.write(str(BEST_THRESHOLD))

print(f'✅ Saved with threshold={BEST_THRESHOLD}')

In [None]:
# Test with custom threshold
from transformers import pipeline

clf = pipeline('text-classification', model='moe_v3/stage1_binary', device=0, top_k=None)

def classify(text, threshold=None):
    if threshold is None:
        threshold = BEST_THRESHOLD
    result = clf(text, truncation=True)[0]
    scores = {r['label']: r['score'] for r in result}
    toxic_score = scores.get('toxic', 0)
    return {'toxic': toxic_score >= threshold, 'score': toxic_score, 'threshold': threshold}

print(f'Classifier ready (threshold={BEST_THRESHOLD})')

In [None]:
# Critical tests
tests = [
    ('Hello, how are you?', False),
    ('namaste kaise ho', False),
    ('This is so frustrating', False),
    ('I hate Mondays', False),
    ('I hate how much I love you', False),
    ('yaar bahut mushkil hai ye', False),
    ('I will kill you', True),
    ('bhenchod', True),
    ('madarchod', True),
    ('chutiya hai tu', True),
    ('go back to your country', True),
    ('you stupid idiot', True),
    ('fuck you bitch', True),
    ('tujhe maar dunga', True),
]

print('='*70)
print(f'TESTS (threshold={BEST_THRESHOLD})')
print('='*70)

correct = 0
fn = []
fp = []

for text, expected in tests:
    r = classify(text)
    ok = r['toxic'] == expected
    if ok:
        correct += 1
    elif expected:
        fn.append((text, r['score']))
    else:
        fp.append((text, r['score']))
    
    icon = '✅' if ok else '❌'
    print(f"{icon} [{r['score']:.2f}] exp:{'TOXIC' if expected else 'SAFE':5} | {text[:40]}")

print(f'\nAccuracy: {correct}/{len(tests)}')

if fn:
    print(f'\n❌ FALSE NEGATIVES (missed toxic):')
    for t, s in fn:
        print(f'   [{s:.2f}] {t}')

if fp:
    print(f'\n⚠️ FALSE POSITIVES:')
    for t, s in fp:
        print(f'   [{s:.2f}] {t}')

In [None]:
# Test on held-out test set
test_texts = [d['text'] for d in test_data]
test_labels = [1 if d['label'] == 'toxic' else 0 for d in test_data]

print('Evaluating on test set...')
test_dataset = tokenize_data(test_texts, test_labels)
test_results = trainer.evaluate(test_dataset)

print(f'\nTest Results:')
print(f'  Recall:    {test_results["eval_recall"]:.4f}')
print(f'  Precision: {test_results["eval_precision"]:.4f}')
print(f'  F1:        {test_results["eval_f1"]:.4f}')

In [None]:
# Save config
config = {
    'model_name': MODEL_NAME,
    'threshold': BEST_THRESHOLD,
    'val_recall': results['eval_recall'],
    'val_f1': results['eval_f1'],
    'test_recall': test_results['eval_recall'],
    'test_f1': test_results['eval_f1']
}

with open('moe_v3/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print('Config saved')

In [None]:
# Download
!zip -r moe_v3_highrecall.zip moe_v3/
files.download('moe_v3_highrecall.zip')