# Boundary Scorer Training

v1.5.0 - Middle Ground (9 frozen, 0.4 dropout)

Train a neural model to predict semantic boundary scores (0-6) using XLM-RoBERTa.

**Data**: 16,311 labeled boundaries from Gemini teacher
**Model**: XLM-R-base with CORN ordinal head
**Method**: CORN - respects ordinal structure
**Changes in v1.5.0**:
- Freeze 9/12 layers (3 trainable encoder layers) - middle ground
- Dropout 0.4 - moderate regularization
- Previous results: 6 frozen → 0.30, 11 frozen → 0.16

## 1. Setup

In [None]:
# Clone the repository
!git clone https://github.com/HBBobo/Intelligent-Chunking.git
%cd Intelligent-Chunking

In [None]:
# Install dependencies
!pip install -q transformers torch scipy scikit-learn tqdm

In [None]:
# Mount Google Drive (for saving models)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import random
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Add src to path for imports
sys.path.insert(0, '.')

# Import dp_chunk_document here so it's available throughout notebook
from src.training.evaluate import dp_chunk_document

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Load Data

In [None]:
# Data paths (from cloned repo)
BOUNDARIES_PATH = Path('data/processed/all_training_data.jsonl')
SENTENCES_DIR = Path('data/processed/sentences')

# Verify data exists
assert BOUNDARIES_PATH.exists(), f"Boundaries file not found: {BOUNDARIES_PATH}"
assert SENTENCES_DIR.exists(), f"Sentences dir not found: {SENTENCES_DIR}"

# Count data
with open(BOUNDARIES_PATH) as f:
    n_boundaries = sum(1 for _ in f)
n_docs = len(list(SENTENCES_DIR.glob('*.json')))

print(f'Boundaries: {n_boundaries}')
print(f'Documents: {n_docs}')

## 3. Dataset & DataLoader

In [None]:
# Import from our training module
from src.training.dataset import BoundaryDataset, get_doc_splits

In [None]:
# Configuration - v1.5.0 MIDDLE GROUND
MODEL_NAME = 'xlm-roberta-base'
CONTEXT_SIZE = 5
MAX_LENGTH = 512
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
EPOCHS = 30  # Extended - will use early stopping
SEED = 42

# Model parameters - MIDDLE GROUND
FREEZE_LAYERS = 9   # 3 trainable encoder layers (was 6→0.30, 11→0.16)
DROPOUT = 0.4       # Moderate dropout (was 0.3, then 0.5)
NUM_CLASSES = 7     # Scores 0-6

# CORN Ordinal Regression
USE_ORDINAL = True  # Use CORN instead of classification

# Early stopping
EARLY_STOPPING_PATIENCE = 5

# Undersampling - keep balanced dataset
USE_UNDERSAMPLING = True
UNDERSAMPLE_RATIO = 0.4

# Set seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"v1.5.0 - Middle Ground")
print(f"=" * 50)
print(f"Previous results:")
print(f"  - FREEZE_LAYERS=6, DROPOUT=0.3 → Pearson ~0.30")
print(f"  - FREEZE_LAYERS=11, DROPOUT=0.5 → Pearson ~0.16")
print(f"=" * 50)
print(f"This run:")
print(f"  - FREEZE_LAYERS: {FREEZE_LAYERS} (3 trainable layers)")
print(f"  - DROPOUT: {DROPOUT}")
print(f"  - USE_ORDINAL: {USE_ORDINAL} (CORN)")
print(f"  - EPOCHS: {EPOCHS} (patience={EARLY_STOPPING_PATIENCE})")

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Split documents into train/val/test
train_ids, val_ids, test_ids = get_doc_splits(SENTENCES_DIR, seed=SEED)
print(f'Train: {len(train_ids)} docs, Val: {len(val_ids)} docs, Test: {len(test_ids)} docs')

In [None]:
# Create datasets
train_dataset = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(train_ids)
)
val_dataset = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(val_ids)
)
test_dataset = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(test_ids)
)

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# Show class distribution before balancing
from collections import Counter
train_scores = [train_dataset[i]['score'].item() for i in range(len(train_dataset))]
score_counts = Counter(train_scores)
print(f'\nClass distribution (before balancing):')
for score in sorted(score_counts.keys()):
    pct = score_counts[score] / len(train_scores) * 100
    print(f'  Score {score}: {score_counts[score]:5d} ({pct:5.1f}%)')

In [None]:
# Undersample majority classes to balance the dataset
def undersample_dataset(dataset, target_ratio=0.5):
    """
    Undersample majority classes to balance the dataset.
    
    Args:
        dataset: The original dataset
        target_ratio: Target samples per class relative to median class count
                     (0.5 means cap at 2x median count)
    
    Returns:
        Subset of dataset with balanced classes
    """
    # Get all scores
    scores = [dataset[i]['score'].item() for i in range(len(dataset))]
    
    # Group indices by score
    score_indices = {}
    for idx, score in enumerate(scores):
        if score not in score_indices:
            score_indices[score] = []
        score_indices[score].append(idx)
    
    # Find median count
    counts = [len(v) for v in score_indices.values()]
    median_count = sorted(counts)[len(counts) // 2]
    target_count = int(median_count / target_ratio)
    
    print(f'Undersampling: median count = {median_count}, target cap = {target_count}')
    
    # Undersample classes above target
    balanced_indices = []
    for score in sorted(score_indices.keys()):
        indices = score_indices[score]
        if len(indices) > target_count:
            sampled = random.sample(indices, target_count)
            balanced_indices.extend(sampled)
            print(f'  Score {score}: {len(indices)} → {target_count} (undersampled)')
        else:
            balanced_indices.extend(indices)
            print(f'  Score {score}: {len(indices)} (kept all)')
    
    random.shuffle(balanced_indices)
    return torch.utils.data.Subset(dataset, balanced_indices)

# Apply undersampling if enabled
if USE_UNDERSAMPLING:
    print(f'\nApplying undersampling with ratio={UNDERSAMPLE_RATIO}...')
    train_dataset_balanced = undersample_dataset(train_dataset, target_ratio=UNDERSAMPLE_RATIO)
    print(f'\nBalanced train size: {len(train_dataset_balanced)} (was {len(train_dataset)})')
else:
    train_dataset_balanced = train_dataset
    print(f'\nUsing full train dataset: {len(train_dataset)}')

In [None]:
# Create data loaders (use balanced train dataset)
train_loader = DataLoader(train_dataset_balanced, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')

## 4. Model

In [None]:
# Import model from our training module
from src.training.model import BoundaryScorer

In [None]:
# Initialize model with CORN ordinal head
model = BoundaryScorer(
    MODEL_NAME,
    freeze_layers=FREEZE_LAYERS,
    dropout=DROPOUT,
    num_classes=NUM_CLASSES,
    ordinal=USE_ORDINAL  # NEW: enables CORN mode
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Frozen layers: {FREEZE_LAYERS}/12')
print(f'Ordinal mode: {model.ordinal}')
print(f'Output dimension: {NUM_CLASSES - 1 if USE_ORDINAL else NUM_CLASSES}')

## 5. Training

In [None]:
# Import evaluation function
from src.training.evaluate import evaluate as evaluate_model_fn

def evaluate_model(model, loader):
    """Wrapper for our evaluate function."""
    metrics = evaluate_model_fn(model, loader, device)
    return metrics

In [None]:
# Setup loss function and optimizer
from src.training.trainer import CORNLoss, FocalLoss, compute_class_weights

# Setup optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(total_steps * 0.1),
    num_training_steps=total_steps
)

# Use CORN loss for ordinal regression
if USE_ORDINAL:
    criterion = CORNLoss(num_classes=NUM_CLASSES)
    print(f'Using CORNLoss (ordinal regression)')
    print(f'  - Converts to {NUM_CLASSES - 1} binary tasks: "Is score > k?"')
    print(f'  - Respects ordinal structure: 2→3 error < 0→3 error')
else:
    class_weights = compute_class_weights(BOUNDARIES_PATH, NUM_CLASSES)
    criterion = FocalLoss(alpha=class_weights.to(device), gamma=2.0)
    print(f'Using FocalLoss (classification)')

In [None]:
# Training loop with early stopping and model checkpointing
from src.training.trainer import corn_expected_value

# Create checkpoint directory
CHECKPOINT_DIR = '/content/drive/MyDrive/ChunkingNN/checkpoints'
!mkdir -p "{CHECKPOINT_DIR}"

history = {'train_loss': [], 'val_loss': [], 'val_pearson': []}
best_val_pearson = -float('inf')
best_val_loss = float('inf')
best_epoch = 0
epochs_without_improvement = 0

print(f"Training with early stopping (patience={EARLY_STOPPING_PATIENCE})")
print(f"Best model will be saved to: {CHECKPOINT_DIR}/best_model.pt")
print("=" * 60)

for epoch in range(EPOCHS):
    model.train()
    train_losses = []

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        scores = batch['score'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, scores)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        train_losses.append(loss.item())
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    # Validation
    model.eval()
    val_preds, val_targets = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            logits = model(input_ids, attention_mask)
            
            # Get expected value predictions
            if USE_ORDINAL:
                pred = corn_expected_value(logits)
            else:
                probs = torch.softmax(logits, dim=-1)
                classes = torch.arange(NUM_CLASSES, device=device).float()
                pred = (probs * classes).sum(dim=-1)
            
            val_preds.extend(pred.cpu().numpy())
            val_targets.extend(batch['score'].numpy())
    
    val_preds = np.array(val_preds)
    val_targets = np.array(val_targets).astype(float)
    
    val_pearson = pearsonr(val_preds, val_targets)[0]
    val_mse = mean_squared_error(val_targets, val_preds)

    avg_train_loss = np.mean(train_losses)
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(val_mse)
    history['val_pearson'].append(val_pearson)

    # Check for improvement
    improved = val_pearson > best_val_pearson
    if improved:
        best_val_pearson = val_pearson
        best_val_loss = val_mse
        best_epoch = epoch + 1
        epochs_without_improvement = 0
        
        # SAVE BEST MODEL TO DISK
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_pearson': val_pearson,
            'val_mse': val_mse,
        }, f'{CHECKPOINT_DIR}/best_model.pt')
        
        status = "✓ NEW BEST (saved)"
    else:
        epochs_without_improvement += 1
        status = f"no improvement ({epochs_without_improvement}/{EARLY_STOPPING_PATIENCE})"

    print(f"\nEpoch {epoch+1}: train_loss={avg_train_loss:.4f}, "
          f"val_mse={val_mse:.4f}, val_pearson={val_pearson:.4f} - {status}")

    # Early stopping
    if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
        print(f"\n{'='*60}")
        print(f"EARLY STOPPING at epoch {epoch+1}")
        print(f"Best model was at epoch {best_epoch} with val_pearson={best_val_pearson:.4f}")
        print(f"{'='*60}")
        break

# Load best model from checkpoint (PyTorch 2.6+ compatibility)
print(f"\nLoading best model from epoch {best_epoch}...")
checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pt', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Restored best model (val_pearson={checkpoint['val_pearson']:.4f})")

## 6. Evaluation

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].legend()
axes[0].set_title('Loss')

axes[1].plot(history['val_pearson'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Pearson Correlation')
axes[1].set_title('Validation Correlation')

plt.tight_layout()
plt.show()

In [None]:
# Final evaluation on test set
model.eval()
preds, targets = [], []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        logits = model(input_ids, attention_mask)
        
        # Get expected value prediction
        if USE_ORDINAL:
            pred = corn_expected_value(logits)
        else:
            probs = torch.softmax(logits, dim=-1)
            classes = torch.arange(NUM_CLASSES, device=device).float()
            pred = (probs * classes).sum(dim=-1)
        
        preds.extend(pred.cpu().numpy())
        targets.extend(batch['score'].numpy())

preds = np.array(preds)
targets = np.array(targets).astype(float)

test_metrics = {
    'pearson': pearsonr(preds, targets)[0],
    'spearman': spearmanr(preds, targets)[0],
    'mse': mean_squared_error(targets, preds),
    'mae': mean_absolute_error(targets, preds)
}

print('Test Set Results:')
print(f"  Pearson correlation: {test_metrics['pearson']:.4f}")
print(f"  Spearman correlation: {test_metrics['spearman']:.4f}")
print(f"  MSE: {test_metrics['mse']:.4f}")
print(f"  MAE: {test_metrics['mae']:.4f}")

In [None]:
# Histogram comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

bins = np.arange(0, 7, 0.5)

axes[0].hist(targets, bins=bins, alpha=0.7, label='Teacher')
axes[0].hist(preds, bins=bins, alpha=0.7, label='Model')
axes[0].set_xlabel('Score')
axes[0].set_ylabel('Count')
axes[0].legend()
axes[0].set_title('Score Distribution')

# Scatter plot
axes[1].scatter(targets, preds, alpha=0.3)
axes[1].plot([0, 6], [0, 6], 'r--', label='Perfect')
axes[1].set_xlabel('Teacher Score')
axes[1].set_ylabel('Model Score')
axes[1].set_title('Prediction vs Target')
axes[1].legend()

plt.tight_layout()
plt.show()

## 7. DP Chunking Demo

In [None]:
# dp_chunk_document is already imported in Cell 5

In [None]:
# Demo on a test document with TUNABLE chunking
from src.training.evaluate import dp_chunk_tunable, ChunkingParams, CHUNKING_PRESETS

demo_doc_id = test_ids[0] if test_ids else list(test_dataset.sentences.keys())[0]
demo_sents = test_dataset.sentences[demo_doc_id]

print(f'Document: {demo_doc_id}')
print(f'Sentences: {len(demo_sents)}')

# Get predictions for this document
model.eval()
demo_scores = []

for i in range(len(demo_sents) - 1):
    left = demo_sents[max(0, i - CONTEXT_SIZE + 1):i + 1]
    right = demo_sents[i + 1:min(len(demo_sents), i + 1 + CONTEXT_SIZE)]
    text = ' '.join(left) + f' {tokenizer.sep_token} ' + ' '.join(right)

    encoding = tokenizer(text, max_length=MAX_LENGTH, truncation=True,
                         padding='max_length', return_tensors='pt')

    with torch.no_grad():
        logits = model(
            encoding['input_ids'].to(device),
            encoding['attention_mask'].to(device)
        )
        # Use expected value for smoother scores
        if USE_ORDINAL:
            pred = corn_expected_value(logits)
        else:
            probs = torch.softmax(logits, dim=-1)
            classes = torch.arange(NUM_CLASSES, device=device).float()
            pred = (probs * classes).sum(dim=-1)
        demo_scores.append(pred.item())

demo_scores = np.array(demo_scores)
print(f'\nPredicted scores (first 20): {demo_scores[:20].round(1)}')
print(f'Score range: {demo_scores.min():.1f} - {demo_scores.max():.1f}')

# Show available presets
print(f'\nAvailable presets: {list(CHUNKING_PRESETS.keys())}')

In [None]:
# Compare different chunking presets
print("=" * 70)
print("CHUNKING PRESET COMPARISON")
print("=" * 70)

for preset_name, params in CHUNKING_PRESETS.items():
    chunks = dp_chunk_tunable(demo_sents, demo_scores.tolist(), params=params)
    chunk_sizes = [end - start for start, end in chunks]
    
    print(f'\n{preset_name.upper()} preset:')
    print(f'  target_chunk_size={params.target_chunk_size}, target_coherency={params.target_coherency}')
    print(f'  → {len(chunks)} chunks, sizes: {chunk_sizes[:10]}{"..." if len(chunk_sizes) > 10 else ""}')
    print(f'  → avg size: {np.mean(chunk_sizes):.1f}, min: {min(chunk_sizes)}, max: {max(chunk_sizes)}')

# Use "balanced" preset for detailed display
print("\n" + "=" * 70)
print("DETAILED VIEW (balanced preset)")
print("=" * 70)

chunks = dp_chunk_tunable(demo_sents, demo_scores.tolist(), preset="balanced")

for i, (start, end) in enumerate(chunks[:5]):
    print(f'\n{"="*60}')
    print(f'CHUNK {i+1} (sentences {start+1}-{end})')
    print('='*60)
    for j in range(start, min(end, start + 5)):
        sent = demo_sents[j][:80] + '...' if len(demo_sents[j]) > 80 else demo_sents[j]
        print(f'  [{j+1}] {sent}')
    if end - start > 5:
        print(f'  ... ({end - start - 5} more sentences)')
    if end - 1 < len(demo_scores):
        print(f'  -- SPLIT (score: {demo_scores[end-1]:.1f}) --')

if len(chunks) > 5:
    print(f'\n... and {len(chunks) - 5} more chunks')

## 8. Save Model

In [None]:
# Save final model to Google Drive
SAVE_PATH = '/content/drive/MyDrive/ChunkingNN/models/boundary_scorer_v1.5'
!mkdir -p "{SAVE_PATH}"

# Save model weights
torch.save(model.state_dict(), f'{SAVE_PATH}/model.pt')

# Save tokenizer
tokenizer.save_pretrained(SAVE_PATH)

# Save config
config = {
    'model_name': MODEL_NAME,
    'context_size': CONTEXT_SIZE,
    'max_length': MAX_LENGTH,
    'freeze_layers': FREEZE_LAYERS,
    'dropout': DROPOUT,
    'num_classes': NUM_CLASSES,
    'ordinal': USE_ORDINAL,
    'test_pearson': float(test_metrics['pearson']),
    'test_spearman': float(test_metrics['spearman']),
    'test_mse': float(test_metrics['mse']),
    'best_epoch': best_epoch
}
with open(f'{SAVE_PATH}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Model saved to {SAVE_PATH}')
print(f'Config: {config}')

In [None]:
# To load later:
# from src.training.model import BoundaryScorer
# model = BoundaryScorer('xlm-roberta-base', freeze_layers=6)
# model.load_state_dict(torch.load(f'{SAVE_PATH}/model.pt'))
# model.eval()