# Boundary Scorer Training

v1.8.1 - Binary + Aggressive Regularization

Train a neural model to predict semantic boundary strength using XLM-RoBERTa.

**Data**: 18,044 labeled boundaries from Gemini + Claude ensemble
**Model**: XLM-R-base with binary classification head
**Changes in v1.8.1**:
- **Anti-overfitting**: Freeze 10/12 layers, dropout 0.5, LR 1e-5
- Binary classification: No-Break (0-2) vs Break (3-6)
- Class weighting for imbalance

## 1. Setup

In [None]:
# Clone the repository
!git clone https://github.com/HBBobo/Intelligent-Chunking.git
%cd Intelligent-Chunking

In [None]:
# Install dependencies
!pip install -q transformers torch scipy scikit-learn tqdm

In [None]:
# Mount Google Drive (for saving models)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import random
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Add src to path for imports
sys.path.insert(0, '.')

# Import dp_chunk_document here so it's available throughout notebook
from src.training.evaluate import dp_chunk_document

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Load Data

In [None]:
# Data paths (from cloned repo)
BOUNDARIES_PATH = Path('data/processed/all_training_data.jsonl')
SENTENCES_DIR = Path('data/processed/sentences')

# Verify data exists
assert BOUNDARIES_PATH.exists(), f"Boundaries file not found: {BOUNDARIES_PATH}"
assert SENTENCES_DIR.exists(), f"Sentences dir not found: {SENTENCES_DIR}"

# Count data
with open(BOUNDARIES_PATH) as f:
    n_boundaries = sum(1 for _ in f)
n_docs = len(list(SENTENCES_DIR.glob('*.json')))

print(f'Boundaries: {n_boundaries}')
print(f'Documents: {n_docs}')

## 3. Dataset & DataLoader

In [None]:
# Import from our training module
from src.training.dataset import BoundaryDataset, get_doc_splits

In [None]:
# Configuration - v1.8.1 BINARY + AGGRESSIVE REGULARIZATION
MODEL_NAME = 'xlm-roberta-base'
CONTEXT_SIZE = 8    # Larger context (was 5)
MAX_LENGTH = 512
BATCH_SIZE = 16
LEARNING_RATE = 1e-5  # Reduced from 2e-5 to prevent overfitting
EPOCHS = 30
SEED = 42

# Model parameters - AGGRESSIVE REGULARIZATION
FREEZE_LAYERS = 10  # Freeze 10/12 layers (was 6) - only train last 2
DROPOUT = 0.5       # Increased from 0.3
NUM_CLASSES = 2     # Binary: No-Break (0) vs Break (1)

# Standard classification
USE_ORDINAL = False

# Class weighting to handle imbalance (~79% No-Break, ~21% Break)
USE_CLASS_WEIGHTS = True

# Early stopping
EARLY_STOPPING_PATIENCE = 5

# Undersampling - disabled since we're using class weights instead
USE_UNDERSAMPLING = False

# Score mapping: 7-class to binary
# No-Break (0): scores 0, 1, 2 - weak/no boundary, don't split
# Break (1): scores 3, 4, 5, 6 - meaningful boundary, split here
SCORE_MAP = {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1, 6: 1}

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"v1.8.1 - Binary + Aggressive Regularization")
print(f"=" * 50)
print(f"Anti-overfitting changes:")
print(f"  - Freeze 10/12 layers (was 6) - only train last 2")
print(f"  - Dropout 0.5 (was 0.3)")
print(f"  - Learning rate 1e-5 (was 2e-5)")
print(f"=" * 50)

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Split documents into train/val/test
train_ids, val_ids, test_ids = get_doc_splits(SENTENCES_DIR, seed=SEED)
print(f'Train: {len(train_ids)} docs, Val: {len(val_ids)} docs, Test: {len(test_ids)} docs')

In [None]:
# Create datasets with score remapping
class RemappedDataset(torch.utils.data.Dataset):
    """Wrapper that remaps 7-class scores to binary."""
    def __init__(self, base_dataset, score_map):
        self.base = base_dataset
        self.score_map = score_map
    
    def __len__(self):
        return len(self.base)
    
    def __getitem__(self, idx):
        item = self.base[idx]
        original_score = item['score'].item()
        remapped_score = self.score_map[original_score]
        return {
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'score': torch.tensor(remapped_score, dtype=torch.long)
        }

# Create base datasets
train_dataset_base = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(train_ids)
)
val_dataset_base = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(val_ids)
)
test_dataset_base = BoundaryDataset(
    BOUNDARIES_PATH, SENTENCES_DIR, tokenizer,
    context_size=CONTEXT_SIZE, doc_ids=set(test_ids)
)

# Wrap with score remapping
train_dataset = RemappedDataset(train_dataset_base, SCORE_MAP)
val_dataset = RemappedDataset(val_dataset_base, SCORE_MAP)
test_dataset = RemappedDataset(test_dataset_base, SCORE_MAP)

# Keep reference to sentences for DP chunking demo
test_dataset.sentences = test_dataset_base.sentences

print(f'Train samples: {len(train_dataset)}')
print(f'Val samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# Show binary distribution
from collections import Counter
train_scores = [train_dataset[i]['score'].item() for i in range(min(1000, len(train_dataset)))]
score_counts = Counter(train_scores)
class_names = {0: 'No-Break (0-2)', 1: 'Break (3-6)'}
print(f'\nBinary class distribution (sample):')
for cls in sorted(score_counts.keys()):
    pct = score_counts[cls] / len(train_scores) * 100
    print(f'  {class_names[cls]}: {score_counts[cls]:5d} ({pct:5.1f}%)')

In [None]:
# Undersample majority classes to balance the dataset
def undersample_dataset(dataset, target_ratio=0.5):
    """
    Undersample majority classes to balance the dataset.
    
    Args:
        dataset: The original dataset
        target_ratio: Target samples per class relative to median class count
                     (0.5 means cap at 2x median count)
    
    Returns:
        Subset of dataset with balanced classes
    """
    # Get all scores
    scores = [dataset[i]['score'].item() for i in range(len(dataset))]
    
    # Group indices by score
    score_indices = {}
    for idx, score in enumerate(scores):
        if score not in score_indices:
            score_indices[score] = []
        score_indices[score].append(idx)
    
    # Find median count
    counts = [len(v) for v in score_indices.values()]
    median_count = sorted(counts)[len(counts) // 2]
    target_count = int(median_count / target_ratio)
    
    print(f'Undersampling: median count = {median_count}, target cap = {target_count}')
    
    # Undersample classes above target
    balanced_indices = []
    for score in sorted(score_indices.keys()):
        indices = score_indices[score]
        if len(indices) > target_count:
            sampled = random.sample(indices, target_count)
            balanced_indices.extend(sampled)
            print(f'  Score {score}: {len(indices)} → {target_count} (undersampled)')
        else:
            balanced_indices.extend(indices)
            print(f'  Score {score}: {len(indices)} (kept all)')
    
    random.shuffle(balanced_indices)
    return torch.utils.data.Subset(dataset, balanced_indices)

# Apply undersampling if enabled
if USE_UNDERSAMPLING:
    print(f'\nApplying undersampling with ratio={UNDERSAMPLE_RATIO}...')
    train_dataset_balanced = undersample_dataset(train_dataset, target_ratio=UNDERSAMPLE_RATIO)
    print(f'\nBalanced train size: {len(train_dataset_balanced)} (was {len(train_dataset)})')
else:
    train_dataset_balanced = train_dataset
    print(f'\nUsing full train dataset: {len(train_dataset)}')

In [None]:
# Create data loaders (use balanced train dataset)
train_loader = DataLoader(train_dataset_balanced, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')

## 4. Model

In [None]:
# Import model from our training module
from src.training.model import BoundaryScorer

In [None]:
# Initialize model with CORN ordinal head
model = BoundaryScorer(
    MODEL_NAME,
    freeze_layers=FREEZE_LAYERS,
    dropout=DROPOUT,
    num_classes=NUM_CLASSES,
    ordinal=USE_ORDINAL  # NEW: enables CORN mode
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Frozen layers: {FREEZE_LAYERS}/12')
print(f'Ordinal mode: {model.ordinal}')
print(f'Output dimension: {NUM_CLASSES - 1 if USE_ORDINAL else NUM_CLASSES}')

## 5. Training

In [None]:
# Import evaluation function
from src.training.evaluate import evaluate as evaluate_model_fn

def evaluate_model(model, loader):
    """Wrapper for our evaluate function."""
    metrics = evaluate_model_fn(model, loader, device)
    return metrics

In [None]:
# Setup loss function and optimizer
from src.training.trainer import CORNLoss, FocalLoss, compute_class_weights

# Compute class weights from training data (for binary)
if USE_CLASS_WEIGHTS:
    print("Computing class weights from training data...")
    # Count classes in training set
    class_counts = {0: 0, 1: 0}
    for i in range(len(train_dataset)):
        score = train_dataset[i]['score'].item()
        class_counts[score] += 1
    
    total = sum(class_counts.values())
    # Inverse frequency weighting: weight = total / (num_classes * count)
    class_weights = torch.tensor([
        total / (NUM_CLASSES * class_counts[c]) for c in range(NUM_CLASSES)
    ], dtype=torch.float32)
    
    print(f"Class distribution:")
    class_labels = ['No-Break', 'Break']
    for c in range(NUM_CLASSES):
        pct = class_counts[c] / total * 100
        print(f"  {class_labels[c]}: {class_counts[c]:5d} ({pct:5.1f}%) → weight {class_weights[c]:.2f}")
    
    class_weights = class_weights.to(device)
else:
    class_weights = None

# Setup optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(total_steps * 0.1),
    num_training_steps=total_steps
)

# Use weighted CrossEntropyLoss for binary classification
criterion = nn.CrossEntropyLoss(weight=class_weights)
print(f'\nUsing weighted CrossEntropyLoss (binary classification)')
print(f'  Classes: No-Break (0-2), Break (3-6)')
if class_weights is not None:
    print(f'  Weights: {class_weights.cpu().tolist()}')

In [None]:
# Training loop with early stopping and model checkpointing
from src.training.trainer import corn_expected_value

# Create checkpoint directory
CHECKPOINT_DIR = '/content/drive/MyDrive/ChunkingNN/checkpoints'
!mkdir -p "{CHECKPOINT_DIR}"

history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
best_val_loss = float('inf')
best_val_accuracy = 0
best_epoch = 0
epochs_without_improvement = 0

print(f"Training with early stopping (patience={EARLY_STOPPING_PATIENCE})")
print(f"Using LOSS-BASED early stopping (more sensitive than accuracy)")
print(f"Best model will be saved to: {CHECKPOINT_DIR}/best_model.pt")
print("=" * 60)

for epoch in range(EPOCHS):
    model.train()
    train_losses = []

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for batch in pbar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        scores = batch['score'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, scores)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        train_losses.append(loss.item())
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    # Validation
    model.eval()
    val_preds, val_targets = [], []
    val_losses = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            scores = batch['score'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, scores)
            val_losses.append(loss.item())
            
            # Get predictions (argmax for classification)
            pred = logits.argmax(dim=-1)
            val_preds.extend(pred.cpu().numpy())
            val_targets.extend(scores.cpu().numpy())
    
    val_preds = np.array(val_preds)
    val_targets = np.array(val_targets)
    
    val_accuracy = (val_preds == val_targets).mean()
    val_loss = np.mean(val_losses)

    avg_train_loss = np.mean(train_losses)
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_accuracy)

    # Check for improvement using LOSS (not accuracy)
    improved = val_loss < best_val_loss
    if improved:
        best_val_loss = val_loss
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        epochs_without_improvement = 0
        
        # SAVE BEST MODEL TO DISK
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': val_accuracy,
            'val_loss': val_loss,
        }, f'{CHECKPOINT_DIR}/best_model.pt')
        
        status = "✓ NEW BEST (saved)"
    else:
        epochs_without_improvement += 1
        status = f"no improvement ({epochs_without_improvement}/{EARLY_STOPPING_PATIENCE})"

    # Show binary prediction counts
    pred_counts = [np.sum(val_preds == c) for c in range(NUM_CLASSES)]
    pred_dist = f"[NoBrk:{pred_counts[0]} Brk:{pred_counts[1]}]"
    
    print(f"\nEpoch {epoch+1}: train_loss={avg_train_loss:.4f}, "
          f"val_loss={val_loss:.4f}, val_acc={val_accuracy:.4f} {pred_dist} - {status}")

    # Early stopping
    if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
        print(f"\n{'='*60}")
        print(f"EARLY STOPPING at epoch {epoch+1}")
        print(f"Best model was at epoch {best_epoch} with val_loss={best_val_loss:.4f}")
        print(f"{'='*60}")
        break

# Load best model from checkpoint
print(f"\nLoading best model from epoch {best_epoch}...")
checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pt', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Restored best model (val_loss={checkpoint['val_loss']:.4f}, val_acc={checkpoint['val_accuracy']:.4f})")

## 6. Evaluation

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('Loss')

axes[1].plot(history['val_accuracy'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Accuracy (Binary)')
axes[1].axhline(y=0.5, color='r', linestyle='--', label='Random (50%)')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Final evaluation on test set
model.eval()
preds, targets = [], []
probs_all = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        logits = model(input_ids, attention_mask)
        
        pred = logits.argmax(dim=-1)
        probs = torch.softmax(logits, dim=-1)
        
        preds.extend(pred.cpu().numpy())
        targets.extend(batch['score'].numpy())
        probs_all.extend(probs.cpu().numpy())

preds = np.array(preds)
targets = np.array(targets)
probs_all = np.array(probs_all)

# Metrics
accuracy = (preds == targets).mean()

# Per-class accuracy
class_names = ['No-Break (0-2)', 'Break (3-6)']
print('Test Set Results (Binary Classification):')
print(f"  Overall Accuracy: {accuracy:.4f}")
print(f"\nPer-class accuracy:")
for cls in range(NUM_CLASSES):
    mask = targets == cls
    if mask.sum() > 0:
        cls_acc = (preds[mask] == targets[mask]).mean()
        print(f"  {class_names[cls]}: {cls_acc:.4f} ({mask.sum()} samples)")

# Precision, Recall, F1 for Break class
from sklearn.metrics import precision_recall_fscore_support
precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary', pos_label=1)
print(f"\nBreak class metrics:")
print(f"  Precision: {precision:.4f}")
print(f"  Recall: {recall:.4f}")
print(f"  F1 Score: {f1:.4f}")

# Confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(targets, preds)
print(f"\nConfusion Matrix:")
print(f"                   Pred No-Break  Pred Break")
for i, name in enumerate(class_names):
    print(f"  True {name:15s}: {cm[i, 0]:7d}      {cm[i, 1]:7d}")

In [None]:
# Visualization for binary classification
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Class distribution comparison
class_labels = ['No-Break', 'Break']
x = np.arange(len(class_labels))
width = 0.35

target_counts = [np.sum(targets == i) for i in range(2)]
pred_counts = [np.sum(preds == i) for i in range(2)]

axes[0].bar(x - width/2, target_counts, width, label='Teacher', alpha=0.8)
axes[0].bar(x + width/2, pred_counts, width, label='Model', alpha=0.8)
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Count')
axes[0].set_xticks(x)
axes[0].set_xticklabels(class_labels)
axes[0].legend()
axes[0].set_title('Class Distribution')

# Confusion matrix heatmap
import matplotlib.pyplot as plt
im = axes[1].imshow(cm, cmap='Blues')
axes[1].set_xticks(range(2))
axes[1].set_yticks(range(2))
axes[1].set_xticklabels(class_labels)
axes[1].set_yticklabels(class_labels)
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix')

# Add text annotations
for i in range(2):
    for j in range(2):
        axes[1].text(j, i, str(cm[i, j]), ha='center', va='center', 
                    color='white' if cm[i, j] > cm.max()/2 else 'black')

plt.colorbar(im, ax=axes[1])
plt.tight_layout()
plt.show()

## 7. DP Chunking Demo

In [None]:
# dp_chunk_document is already imported in Cell 5

In [None]:
# Demo on a test document - convert binary predictions to scores for DP chunking
from src.training.evaluate import dp_chunk_tunable, ChunkingParams, CHUNKING_PRESETS

demo_doc_id = test_ids[0] if test_ids else list(test_dataset.sentences.keys())[0]
demo_sents = test_dataset.sentences[demo_doc_id]

print(f'Document: {demo_doc_id}')
print(f'Sentences: {len(demo_sents)}')

# Get predictions for this document
model.eval()
demo_scores = []

# Map binary predictions to approximate scores for DP chunking
# No-Break (0) → 1 (low score, don't split)
# Break (1) → 5 (high score, split here)
CLASS_TO_SCORE = {0: 1.0, 1: 5.0}

for i in range(len(demo_sents) - 1):
    left = demo_sents[max(0, i - CONTEXT_SIZE + 1):i + 1]
    right = demo_sents[i + 1:min(len(demo_sents), i + 1 + CONTEXT_SIZE)]
    text = ' '.join(left) + f' {tokenizer.sep_token} ' + ' '.join(right)

    encoding = tokenizer(text, max_length=MAX_LENGTH, truncation=True,
                         padding='max_length', return_tensors='pt')

    with torch.no_grad():
        logits = model(
            encoding['input_ids'].to(device),
            encoding['attention_mask'].to(device)
        )
        # Get class prediction
        pred_class = logits.argmax(dim=-1).item()
        # Convert to score for DP chunking
        demo_scores.append(CLASS_TO_SCORE[pred_class])

demo_scores = np.array(demo_scores)
print(f'\nBinary predictions mapped to scores (first 20): {demo_scores[:20].astype(int)}')
print(f'Unique values: {sorted(set(demo_scores))}')

# Show class distribution
no_break_count = np.sum(demo_scores == 1.0)
break_count = np.sum(demo_scores == 5.0)
print(f'Class distribution: No-Break={no_break_count}, Break={break_count}')

In [None]:
# Compare different chunking presets
print("=" * 70)
print("CHUNKING PRESET COMPARISON")
print("=" * 70)

for preset_name, params in CHUNKING_PRESETS.items():
    chunks = dp_chunk_tunable(demo_sents, demo_scores.tolist(), params=params)
    chunk_sizes = [end - start for start, end in chunks]
    
    print(f'\n{preset_name.upper()} preset:')
    print(f'  target_chunk_size={params.target_chunk_size}, target_coherency={params.target_coherency}')
    print(f'  → {len(chunks)} chunks, sizes: {chunk_sizes[:10]}{"..." if len(chunk_sizes) > 10 else ""}')
    print(f'  → avg size: {np.mean(chunk_sizes):.1f}, min: {min(chunk_sizes)}, max: {max(chunk_sizes)}')

# Use "balanced" preset for detailed display
print("\n" + "=" * 70)
print("DETAILED VIEW (balanced preset)")
print("=" * 70)

chunks = dp_chunk_tunable(demo_sents, demo_scores.tolist(), preset="balanced")

for i, (start, end) in enumerate(chunks[:5]):
    print(f'\n{"="*60}')
    print(f'CHUNK {i+1} (sentences {start+1}-{end})')
    print('='*60)
    for j in range(start, min(end, start + 5)):
        sent = demo_sents[j][:80] + '...' if len(demo_sents[j]) > 80 else demo_sents[j]
        print(f'  [{j+1}] {sent}')
    if end - start > 5:
        print(f'  ... ({end - start - 5} more sentences)')
    if end - 1 < len(demo_scores):
        print(f'  -- SPLIT (score: {demo_scores[end-1]:.1f}) --')

if len(chunks) > 5:
    print(f'\n... and {len(chunks) - 5} more chunks')

## 8. Save Model

In [None]:
# Save final model to Google Drive
SAVE_PATH = '/content/drive/MyDrive/ChunkingNN/models/boundary_scorer_v1.8.1'
!mkdir -p "{SAVE_PATH}"

# Save model weights
torch.save(model.state_dict(), f'{SAVE_PATH}/model.pt')

# Save tokenizer
tokenizer.save_pretrained(SAVE_PATH)

# Save config
config = {
    'model_name': MODEL_NAME,
    'context_size': CONTEXT_SIZE,
    'max_length': MAX_LENGTH,
    'freeze_layers': FREEZE_LAYERS,
    'dropout': DROPOUT,
    'num_classes': NUM_CLASSES,
    'ordinal': USE_ORDINAL,
    'use_class_weights': USE_CLASS_WEIGHTS,
    'score_map': SCORE_MAP,
    'class_to_score': CLASS_TO_SCORE,
    'test_accuracy': float(accuracy),
    'test_precision': float(precision),
    'test_recall': float(recall),
    'test_f1': float(f1),
    'best_epoch': best_epoch,
    'best_val_loss': float(best_val_loss)
}
with open(f'{SAVE_PATH}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Model saved to {SAVE_PATH}')
print(f'Config: {config}')

In [None]:
# To load later:
# from src.training.model import BoundaryScorer
# model = BoundaryScorer('xlm-roberta-base', freeze_layers=6)
# model.load_state_dict(torch.load(f'{SAVE_PATH}/model.pt'))
# model.eval()