Skip to content

Add Refinement Loss Decomposition & TS-Refinement Metrics #12

@jxudata

Description

@jxudata

Add Refinement Loss Decomposition & TS-Refinement Metrics

Background

Based on the paper "Rethinking Early Stopping: Refine, Then Calibrate", this issue proposes adding metrics and tools to support the "Refine, Then Calibrate" training paradigm.

Key Insight from the Paper

Standard early stopping based on validation loss is suboptimal because it forces a compromise between two conflicting objectives:

  • Refinement Error: Model's discriminative ability (separating classes)
  • Calibration Error: Alignment of confidence scores with true probabilities

These errors are minimized at different points during training. The paper proposes:

  1. Train longer to minimize refinement error (ignore rising validation loss)
  2. Apply post-hoc calibration (e.g., Temperature Scaling) to fix calibration

Relevance to Splinator

Splinator is already a post-hoc calibration library — it's the "Then Calibrate" part! This feature adds:

  1. Metrics to decompose loss into refinement vs calibration components
  2. TS-Refinement metric for better early stopping decisions
  3. Temperature Scaling as a simple calibrator option

Variational Decomposition Method

The paper uses a variational approach rather than binning-based ECE. This is more stable and provides a direct estimate of how much "extra" loss is caused purely by poor probability scaling.

Formula

Term Formula Meaning
Total Loss L(y, p) Standard validation cross-entropy
Refinement Error L(y, TS(p)) Loss after optimal temperature scaling
Calibration Error L(y, p) - L(y, TS(p)) "Potential risk reduction" from recalibration

Where:

  • TS(p) = σ(logit(p) / T*) — Temperature-scaled probability
  • T* — Optimal temperature that minimizes NLL on calibration set

Step-by-Step Calculation

  1. Get predictions: Obtain model predictions p on validation set
  2. Fit Temperature: Find scalar T* that minimizes L(y, σ(logit(p) / T))
  3. Compute Refinement: The minimal loss L(y, TS(p)) is the Refinement Error
  4. Compute Calibration: L(y, p) - L(y, TS(p)) is the Calibration Error

Why Variational > Binning?

Approach Issues
Binning-based ECE Biased, bin-count dependent, inconsistent in multi-class
Variational (TS-based) Stable, directly measures "fixable" miscalibration, consistent

Proposed Implementation

1. New Metrics in metrics.py

def find_optimal_temperature(y_true, y_pred, init_temp=1.0, max_iter=50):
    """
    Find the optimal temperature for temperature scaling.
    
    Solves: T* = argmin_T L(y, σ(logit(p) / T))
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    y_pred : array-like of shape (n_samples,)
        Predicted probabilities.
    init_temp : float, default=1.0
        Initial temperature value for optimization.
    max_iter : int, default=50
        Maximum iterations for L-BFGS-B optimization.
        
    Returns
    -------
    temperature : float
        Optimal temperature that minimizes NLL.
    """


def apply_temperature_scaling(y_pred, temperature):
    """
    Apply temperature scaling to predicted probabilities.
    
    calibrated = σ(logit(p) / T)
    
    Parameters
    ----------
    y_pred : array-like of shape (n_samples,)
        Predicted probabilities.
    temperature : float
        Temperature parameter (T > 1 softens, T < 1 sharpens).
        
    Returns
    -------
    calibrated : array-like of shape (n_samples,)
        Temperature-scaled probabilities.
    """


def ts_refinement_loss(y_true, y_pred):
    """
    Refinement Error: Cross-entropy AFTER optimal temperature scaling.
    
    This is the irreducible loss given perfect calibration — it measures
    the model's fundamental discriminative ability. Use this as the
    early stopping criterion instead of raw validation loss.
    
    Formula: L(y, TS(p)) where TS applies optimal temperature scaling
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    y_pred : array-like of shape (n_samples,)
        Predicted probabilities.
        
    Returns
    -------
    refinement_loss : float
        Cross-entropy after optimal temperature scaling.
        
    Examples
    --------
    >>> from splinator.metrics import ts_refinement_loss
    >>> # Use as early stopping criterion
    >>> for epoch in range(max_epochs):
    ...     model.train_one_epoch()
    ...     val_preds = model.predict_proba(X_val)[:, 1]
    ...     ts_loss = ts_refinement_loss(y_val, val_preds)
    ...     if ts_loss < best_ts_loss:
    ...         best_ts_loss = ts_loss
    ...         save_checkpoint(model)
    """


def calibration_loss(y_true, y_pred):
    """
    Calibration Error: The "potential risk reduction" from recalibration.
    
    This measures how much loss is due purely to miscalibrated probabilities,
    which can be fixed by post-hoc calibration.
    
    Formula: L(y, p) - L(y, TS(p))
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    y_pred : array-like of shape (n_samples,)
        Predicted probabilities.
        
    Returns
    -------
    calibration_loss : float
        The reducible loss component (fixable by temperature scaling).
    """


def loss_decomposition(y_true, y_pred):
    """
    Full variational decomposition of log-loss into calibration and refinement.
    
    Uses temperature scaling to separate:
    - Refinement: irreducible loss (discriminative ability)
    - Calibration: reducible loss (fixable miscalibration)
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True binary labels.
    y_pred : array-like of shape (n_samples,)
        Predicted probabilities.
        
    Returns
    -------
    decomposition : dict
        'total_loss': float - standard log loss L(y, p)
        'refinement_loss': float - loss after TS, L(y, TS(p))
        'calibration_loss': float - reducible loss, L(y, p) - L(y, TS(p))
        'optimal_temperature': float - the fitted T*
        'calibration_fraction': float - fraction of loss from calibration
        
    Examples
    --------
    >>> from splinator.metrics import loss_decomposition
    >>> decomp = loss_decomposition(y_val, model.predict_proba(X_val)[:, 1])
    >>> print(f"Total Loss: {decomp['total_loss']:.4f}")
    >>> print(f"  Refinement: {decomp['refinement_loss']:.4f}")
    >>> print(f"  Calibration: {decomp['calibration_loss']:.4f} ({decomp['calibration_fraction']:.1%})")
    >>> print(f"  Optimal T: {decomp['optimal_temperature']:.3f}")
    """

2. sklearn Scorers

from sklearn.metrics import make_scorer

# For use with GridSearchCV, cross_val_score, etc.
# Negated because sklearn expects higher = better
ts_refinement_scorer = make_scorer(
    lambda y, p: -ts_refinement_loss(y, p),
    greater_is_better=True,
    needs_proba=True,
    response_method='predict_proba'
)

calibration_loss_scorer = make_scorer(
    lambda y, p: -calibration_loss(y, p),
    greater_is_better=True,
    needs_proba=True,
    response_method='predict_proba'
)

3. TemperatureScaling Estimator in estimators.py

class TemperatureScaling(RegressorMixin, TransformerMixin, BaseEstimator):
    """
    Temperature Scaling post-hoc calibrator.
    
    Learns a single temperature parameter T that rescales logits:
        calibrated_prob = σ(logit(p) / T)
    
    T > 1 softens probabilities (less confident)
    T < 1 sharpens probabilities (more confident)
    T = 1 leaves probabilities unchanged
    
    This is the simplest post-hoc calibration method, from Guo et al.
    "On Calibration of Modern Neural Networks" (ICML 2017).
    
    Parameters
    ----------
    init_temperature : float, default=1.0
        Initial temperature value for optimization.
    max_iter : int, default=50
        Maximum iterations for L-BFGS-B optimization.
        
    Attributes
    ----------
    temperature_ : float
        Learned temperature parameter.
    n_features_in_ : int
        Number of features seen during fit.
    
    Examples
    --------
    >>> from splinator import TemperatureScaling
    >>> ts = TemperatureScaling()
    >>> ts.fit(val_probs.reshape(-1, 1), y_val)
    >>> calibrated = ts.predict(test_probs.reshape(-1, 1))
    >>> print(f"Optimal temperature: {ts.temperature_:.3f}")
    
    Notes
    -----
    - Input X should be predicted probabilities, shape (n_samples,) or (n_samples, 1)
    - Works in sklearn pipelines
    - For multi-class, input should be logits of shape (n_samples, n_classes)
    """
    
    def __init__(self, init_temperature=1.0, max_iter=50):
        self.init_temperature = init_temperature
        self.max_iter = max_iter
    
    def fit(self, X, y):
        """
        Fit temperature parameter by minimizing NLL on calibration set.
        
        Parameters
        ----------
        X : array-like of shape (n_samples,) or (n_samples, 1)
            Predicted probabilities to calibrate.
        y : array-like of shape (n_samples,)
            True labels.
            
        Returns
        -------
        self : object
            Fitted estimator.
        """
        # Implementation:
        # 1. Convert probabilities to logits
        # 2. Optimize T to minimize NLL using L-BFGS-B
        # 3. Store temperature_
        return self
    
    def transform(self, X):
        """Apply temperature scaling to probabilities."""
        return self.predict(X)
    
    def predict(self, X):
        """
        Return calibrated probabilities.
        
        Parameters
        ----------
        X : array-like of shape (n_samples,) or (n_samples, 1)
            Predicted probabilities to calibrate.
            
        Returns
        -------
        calibrated : array-like of shape (n_samples,)
            Temperature-scaled probabilities.
        """
        # logits = logit(X)
        # return sigmoid(logits / self.temperature_)
        pass
    
    @property
    def is_fitted(self):
        return hasattr(self, 'temperature_')

Usage Examples

Example 1: Early Stopping with TS-Refinement

from splinator.metrics import ts_refinement_loss
import copy

best_ts_loss = float('inf')
patience_counter = 0
patience = 10

for epoch in range(max_epochs):
    model.partial_fit(X_train, y_train)
    
    val_probs = model.predict_proba(X_val)[:, 1]
    ts_loss = ts_refinement_loss(y_val, val_probs)
    
    if ts_loss < best_ts_loss:
        best_ts_loss = ts_loss
        patience_counter = 0
        best_model = copy.deepcopy(model)
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

# Apply final calibration to the best model
from splinator import TemperatureScaling
calibrator = TemperatureScaling()
calibrator.fit(best_model.predict_proba(X_val)[:, 1], y_val)

Example 2: GridSearchCV with TS-Refinement Scorer

from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import GradientBoostingClassifier
from splinator.metrics import ts_refinement_scorer

# Find best hyperparameters using TS-refinement as the metric
param_grid = {'n_estimators': [50, 100, 200], 'max_depth': [3, 5, 7]}

grid_search = GridSearchCV(
    GradientBoostingClassifier(),
    param_grid,
    scoring=ts_refinement_scorer,  # Use TS-refinement instead of neg_log_loss!
    cv=5
)
grid_search.fit(X, y)
print(f"Best params: {grid_search.best_params_}")

Example 3: Analyze Loss Decomposition During Training

from splinator.metrics import loss_decomposition

history = {'epoch': [], 'total': [], 'refinement': [], 'calibration': [], 'temperature': []}

for epoch in range(max_epochs):
    model.partial_fit(X_train, y_train)
    
    val_probs = model.predict_proba(X_val)[:, 1]
    decomp = loss_decomposition(y_val, val_probs)
    
    history['epoch'].append(epoch)
    history['total'].append(decomp['total_loss'])
    history['refinement'].append(decomp['refinement_loss'])
    history['calibration'].append(decomp['calibration_loss'])
    history['temperature'].append(decomp['optimal_temperature'])
    
    print(f"Epoch {epoch}: Total={decomp['total_loss']:.4f}, "
          f"Ref={decomp['refinement_loss']:.4f}, "
          f"Cal={decomp['calibration_loss']:.4f} ({decomp['calibration_fraction']:.1%})")

# Plot to visualize the divergence between refinement and total loss
import matplotlib.pyplot as plt
plt.plot(history['epoch'], history['total'], label='Total Loss')
plt.plot(history['epoch'], history['refinement'], label='Refinement Loss')
plt.plot(history['epoch'], history['calibration'], label='Calibration Loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Decomposition During Training')

Example 4: Pipeline with Temperature Scaling

from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from splinator import TemperatureScaling

pipe = Pipeline([
    ('clf', RandomForestClassifier()),
    ('calibrator', TemperatureScaling())
])
pipe.fit(X_train, y_train)
calibrated_probs = pipe.predict(X_test)

File Changes

File Changes
src/splinator/metrics.py Add find_optimal_temperature, apply_temperature_scaling, ts_refinement_loss, calibration_loss, loss_decomposition, and sklearn scorers
src/splinator/estimators.py Add TemperatureScaling class
src/splinator/__init__.py Export new functions and classes
tests/test_metrics.py Add tests for new metrics
tests/test_temperature_scaling.py Add tests for TemperatureScaling

Acceptance Criteria

  • find_optimal_temperature() finds T that minimizes NLL
  • apply_temperature_scaling() correctly applies σ(logit(p) / T)
  • ts_refinement_loss() returns loss after optimal temperature scaling
  • calibration_loss() returns total_loss - refinement_loss
  • loss_decomposition() returns complete decomposition with all components
  • total_loss ≈ refinement_loss + calibration_loss (numerical precision)
  • TemperatureScaling is sklearn-compatible (fit/transform/predict)
  • TemperatureScaling works in sklearn pipelines
  • All sklearn scorers work with GridSearchCV and cross_val_score
  • Unit tests pass with >90% coverage on new code
  • Documentation with examples

References

Labels

enhancement, metrics, calibration, sklearn-integration

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions