# ModernBERT Hallucination Grader Training
**Optimization: ModernBERT Hallucination Detector**

This notebook fine-tunes `answerdotai/ModernBERT-base` for binary hallucination detection.

## Target Metrics
- Latency: <15ms on T4
- Context: 8k tokens
- Task: Binary classification (0=hallucinated, 1=faithful)

In [None]:
# Install dependencies
!pip install -q transformers torch accelerate datasets

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
import json
from tqdm import tqdm

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Load Dataset

In [None]:
# Load training data generated for the grader
def load_jsonl(path):
    data = []
    with open(path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Upload your hallucination_dataset.jsonl to Colab first
dataset = load_jsonl('hallucination_dataset.jsonl')
print(f'Loaded {len(dataset)} samples')

# Train/Val split (80/20)
split_idx = int(len(dataset) * 0.8)
train_data = dataset[:split_idx]
val_data = dataset[split_idx:]
print(f'Train: {len(train_data)}, Val: {len(val_data)}')

## 2. Dataset & Model Definition

In [None]:
class HallucinationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(item['label'], dtype=torch.long)
        }


class ModernBERTClassifier(nn.Module):
    """ModernBERT with classification head for hallucination detection."""
    
    def __init__(self, model_name='answerdotai/ModernBERT-base', num_labels=2):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_labels)  # ModernBERT hidden size = 768
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Use [CLS] token representation
        pooled = outputs.last_hidden_state[:, 0, :]
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        return logits

## 3. Training Loop (Custom PyTorch - NOT Trainer API)

In [None]:
# Initialize
MODEL_NAME = 'answerdotai/ModernBERT-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = ModernBERTClassifier(MODEL_NAME).to(device)

# Dataloaders
train_dataset = HallucinationDataset(train_data, tokenizer)
val_dataset = HallucinationDataset(val_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# Optimizer & Loss
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Training'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total


def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    return total_loss / len(loader), correct / total

In [None]:
# Training Loop - 5 Epochs
EPOCHS = 5
best_val_acc = 0

for epoch in range(EPOCHS):
    print(f'\n=== Epoch {epoch+1}/{EPOCHS} ===')
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'guardrail_v1.pt')
        print(f'âœ“ Saved best model (val_acc: {val_acc:.4f})')

print(f'\nâœ“ Training complete! Best Val Accuracy: {best_val_acc:.4f}')

## 4. Inference Speed Test

In [None]:
import time

# Load best model
model.load_state_dict(torch.load('guardrail_v1.pt'))
model.eval()

# Test inference speed
test_text = 'Context: The company reported 15% growth. Answer: Revenue increased by 15%.'
inputs = tokenizer(test_text, return_tensors='pt', truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Warmup
for _ in range(10):
    with torch.no_grad():
        _ = model(inputs['input_ids'], inputs['attention_mask'])

# Benchmark
times = []
for _ in range(100):
    start = time.perf_counter()
    with torch.no_grad():
        logits = model(inputs['input_ids'], inputs['attention_mask'])
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    times.append((time.perf_counter() - start) * 1000)

print(f'Average inference time: {sum(times)/len(times):.2f}ms')
print(f'Target: <15ms âœ“' if sum(times)/len(times) < 15 else 'Target: <15ms âœ—')

## 5. Export Model
Download `guardrail_v1.pt` and place in `models/` folder.

In [None]:
from google.colab import files
files.download('guardrail_v1.pt')