-
Notifications
You must be signed in to change notification settings - Fork 4
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Add Refinement Loss Decomposition & TS-Refinement Metrics
Background
Based on the paper "Rethinking Early Stopping: Refine, Then Calibrate", this issue proposes adding metrics and tools to support the "Refine, Then Calibrate" training paradigm.
Key Insight from the Paper
Standard early stopping based on validation loss is suboptimal because it forces a compromise between two conflicting objectives:
- Refinement Error: Model's discriminative ability (separating classes)
- Calibration Error: Alignment of confidence scores with true probabilities
These errors are minimized at different points during training. The paper proposes:
- Train longer to minimize refinement error (ignore rising validation loss)
- Apply post-hoc calibration (e.g., Temperature Scaling) to fix calibration
Relevance to Splinator
Splinator is already a post-hoc calibration library — it's the "Then Calibrate" part! This feature adds:
- Metrics to decompose loss into refinement vs calibration components
- TS-Refinement metric for better early stopping decisions
- Temperature Scaling as a simple calibrator option
Variational Decomposition Method
The paper uses a variational approach rather than binning-based ECE. This is more stable and provides a direct estimate of how much "extra" loss is caused purely by poor probability scaling.
Formula
| Term | Formula | Meaning |
|---|---|---|
| Total Loss | L(y, p) |
Standard validation cross-entropy |
| Refinement Error | L(y, TS(p)) |
Loss after optimal temperature scaling |
| Calibration Error | L(y, p) - L(y, TS(p)) |
"Potential risk reduction" from recalibration |
Where:
TS(p) = σ(logit(p) / T*)— Temperature-scaled probabilityT*— Optimal temperature that minimizes NLL on calibration set
Step-by-Step Calculation
- Get predictions: Obtain model predictions
pon validation set - Fit Temperature: Find scalar
T*that minimizesL(y, σ(logit(p) / T)) - Compute Refinement: The minimal loss
L(y, TS(p))is the Refinement Error - Compute Calibration:
L(y, p) - L(y, TS(p))is the Calibration Error
Why Variational > Binning?
| Approach | Issues |
|---|---|
| Binning-based ECE | Biased, bin-count dependent, inconsistent in multi-class |
| Variational (TS-based) | Stable, directly measures "fixable" miscalibration, consistent |
Proposed Implementation
1. New Metrics in metrics.py
def find_optimal_temperature(y_true, y_pred, init_temp=1.0, max_iter=50):
"""
Find the optimal temperature for temperature scaling.
Solves: T* = argmin_T L(y, σ(logit(p) / T))
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels.
y_pred : array-like of shape (n_samples,)
Predicted probabilities.
init_temp : float, default=1.0
Initial temperature value for optimization.
max_iter : int, default=50
Maximum iterations for L-BFGS-B optimization.
Returns
-------
temperature : float
Optimal temperature that minimizes NLL.
"""
def apply_temperature_scaling(y_pred, temperature):
"""
Apply temperature scaling to predicted probabilities.
calibrated = σ(logit(p) / T)
Parameters
----------
y_pred : array-like of shape (n_samples,)
Predicted probabilities.
temperature : float
Temperature parameter (T > 1 softens, T < 1 sharpens).
Returns
-------
calibrated : array-like of shape (n_samples,)
Temperature-scaled probabilities.
"""
def ts_refinement_loss(y_true, y_pred):
"""
Refinement Error: Cross-entropy AFTER optimal temperature scaling.
This is the irreducible loss given perfect calibration — it measures
the model's fundamental discriminative ability. Use this as the
early stopping criterion instead of raw validation loss.
Formula: L(y, TS(p)) where TS applies optimal temperature scaling
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels.
y_pred : array-like of shape (n_samples,)
Predicted probabilities.
Returns
-------
refinement_loss : float
Cross-entropy after optimal temperature scaling.
Examples
--------
>>> from splinator.metrics import ts_refinement_loss
>>> # Use as early stopping criterion
>>> for epoch in range(max_epochs):
... model.train_one_epoch()
... val_preds = model.predict_proba(X_val)[:, 1]
... ts_loss = ts_refinement_loss(y_val, val_preds)
... if ts_loss < best_ts_loss:
... best_ts_loss = ts_loss
... save_checkpoint(model)
"""
def calibration_loss(y_true, y_pred):
"""
Calibration Error: The "potential risk reduction" from recalibration.
This measures how much loss is due purely to miscalibrated probabilities,
which can be fixed by post-hoc calibration.
Formula: L(y, p) - L(y, TS(p))
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels.
y_pred : array-like of shape (n_samples,)
Predicted probabilities.
Returns
-------
calibration_loss : float
The reducible loss component (fixable by temperature scaling).
"""
def loss_decomposition(y_true, y_pred):
"""
Full variational decomposition of log-loss into calibration and refinement.
Uses temperature scaling to separate:
- Refinement: irreducible loss (discriminative ability)
- Calibration: reducible loss (fixable miscalibration)
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels.
y_pred : array-like of shape (n_samples,)
Predicted probabilities.
Returns
-------
decomposition : dict
'total_loss': float - standard log loss L(y, p)
'refinement_loss': float - loss after TS, L(y, TS(p))
'calibration_loss': float - reducible loss, L(y, p) - L(y, TS(p))
'optimal_temperature': float - the fitted T*
'calibration_fraction': float - fraction of loss from calibration
Examples
--------
>>> from splinator.metrics import loss_decomposition
>>> decomp = loss_decomposition(y_val, model.predict_proba(X_val)[:, 1])
>>> print(f"Total Loss: {decomp['total_loss']:.4f}")
>>> print(f" Refinement: {decomp['refinement_loss']:.4f}")
>>> print(f" Calibration: {decomp['calibration_loss']:.4f} ({decomp['calibration_fraction']:.1%})")
>>> print(f" Optimal T: {decomp['optimal_temperature']:.3f}")
"""2. sklearn Scorers
from sklearn.metrics import make_scorer
# For use with GridSearchCV, cross_val_score, etc.
# Negated because sklearn expects higher = better
ts_refinement_scorer = make_scorer(
lambda y, p: -ts_refinement_loss(y, p),
greater_is_better=True,
needs_proba=True,
response_method='predict_proba'
)
calibration_loss_scorer = make_scorer(
lambda y, p: -calibration_loss(y, p),
greater_is_better=True,
needs_proba=True,
response_method='predict_proba'
)3. TemperatureScaling Estimator in estimators.py
class TemperatureScaling(RegressorMixin, TransformerMixin, BaseEstimator):
"""
Temperature Scaling post-hoc calibrator.
Learns a single temperature parameter T that rescales logits:
calibrated_prob = σ(logit(p) / T)
T > 1 softens probabilities (less confident)
T < 1 sharpens probabilities (more confident)
T = 1 leaves probabilities unchanged
This is the simplest post-hoc calibration method, from Guo et al.
"On Calibration of Modern Neural Networks" (ICML 2017).
Parameters
----------
init_temperature : float, default=1.0
Initial temperature value for optimization.
max_iter : int, default=50
Maximum iterations for L-BFGS-B optimization.
Attributes
----------
temperature_ : float
Learned temperature parameter.
n_features_in_ : int
Number of features seen during fit.
Examples
--------
>>> from splinator import TemperatureScaling
>>> ts = TemperatureScaling()
>>> ts.fit(val_probs.reshape(-1, 1), y_val)
>>> calibrated = ts.predict(test_probs.reshape(-1, 1))
>>> print(f"Optimal temperature: {ts.temperature_:.3f}")
Notes
-----
- Input X should be predicted probabilities, shape (n_samples,) or (n_samples, 1)
- Works in sklearn pipelines
- For multi-class, input should be logits of shape (n_samples, n_classes)
"""
def __init__(self, init_temperature=1.0, max_iter=50):
self.init_temperature = init_temperature
self.max_iter = max_iter
def fit(self, X, y):
"""
Fit temperature parameter by minimizing NLL on calibration set.
Parameters
----------
X : array-like of shape (n_samples,) or (n_samples, 1)
Predicted probabilities to calibrate.
y : array-like of shape (n_samples,)
True labels.
Returns
-------
self : object
Fitted estimator.
"""
# Implementation:
# 1. Convert probabilities to logits
# 2. Optimize T to minimize NLL using L-BFGS-B
# 3. Store temperature_
return self
def transform(self, X):
"""Apply temperature scaling to probabilities."""
return self.predict(X)
def predict(self, X):
"""
Return calibrated probabilities.
Parameters
----------
X : array-like of shape (n_samples,) or (n_samples, 1)
Predicted probabilities to calibrate.
Returns
-------
calibrated : array-like of shape (n_samples,)
Temperature-scaled probabilities.
"""
# logits = logit(X)
# return sigmoid(logits / self.temperature_)
pass
@property
def is_fitted(self):
return hasattr(self, 'temperature_')Usage Examples
Example 1: Early Stopping with TS-Refinement
from splinator.metrics import ts_refinement_loss
import copy
best_ts_loss = float('inf')
patience_counter = 0
patience = 10
for epoch in range(max_epochs):
model.partial_fit(X_train, y_train)
val_probs = model.predict_proba(X_val)[:, 1]
ts_loss = ts_refinement_loss(y_val, val_probs)
if ts_loss < best_ts_loss:
best_ts_loss = ts_loss
patience_counter = 0
best_model = copy.deepcopy(model)
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
# Apply final calibration to the best model
from splinator import TemperatureScaling
calibrator = TemperatureScaling()
calibrator.fit(best_model.predict_proba(X_val)[:, 1], y_val)Example 2: GridSearchCV with TS-Refinement Scorer
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import GradientBoostingClassifier
from splinator.metrics import ts_refinement_scorer
# Find best hyperparameters using TS-refinement as the metric
param_grid = {'n_estimators': [50, 100, 200], 'max_depth': [3, 5, 7]}
grid_search = GridSearchCV(
GradientBoostingClassifier(),
param_grid,
scoring=ts_refinement_scorer, # Use TS-refinement instead of neg_log_loss!
cv=5
)
grid_search.fit(X, y)
print(f"Best params: {grid_search.best_params_}")Example 3: Analyze Loss Decomposition During Training
from splinator.metrics import loss_decomposition
history = {'epoch': [], 'total': [], 'refinement': [], 'calibration': [], 'temperature': []}
for epoch in range(max_epochs):
model.partial_fit(X_train, y_train)
val_probs = model.predict_proba(X_val)[:, 1]
decomp = loss_decomposition(y_val, val_probs)
history['epoch'].append(epoch)
history['total'].append(decomp['total_loss'])
history['refinement'].append(decomp['refinement_loss'])
history['calibration'].append(decomp['calibration_loss'])
history['temperature'].append(decomp['optimal_temperature'])
print(f"Epoch {epoch}: Total={decomp['total_loss']:.4f}, "
f"Ref={decomp['refinement_loss']:.4f}, "
f"Cal={decomp['calibration_loss']:.4f} ({decomp['calibration_fraction']:.1%})")
# Plot to visualize the divergence between refinement and total loss
import matplotlib.pyplot as plt
plt.plot(history['epoch'], history['total'], label='Total Loss')
plt.plot(history['epoch'], history['refinement'], label='Refinement Loss')
plt.plot(history['epoch'], history['calibration'], label='Calibration Loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Decomposition During Training')Example 4: Pipeline with Temperature Scaling
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from splinator import TemperatureScaling
pipe = Pipeline([
('clf', RandomForestClassifier()),
('calibrator', TemperatureScaling())
])
pipe.fit(X_train, y_train)
calibrated_probs = pipe.predict(X_test)File Changes
| File | Changes |
|---|---|
src/splinator/metrics.py |
Add find_optimal_temperature, apply_temperature_scaling, ts_refinement_loss, calibration_loss, loss_decomposition, and sklearn scorers |
src/splinator/estimators.py |
Add TemperatureScaling class |
src/splinator/__init__.py |
Export new functions and classes |
tests/test_metrics.py |
Add tests for new metrics |
tests/test_temperature_scaling.py |
Add tests for TemperatureScaling |
Acceptance Criteria
-
find_optimal_temperature()finds T that minimizes NLL -
apply_temperature_scaling()correctly applies σ(logit(p) / T) -
ts_refinement_loss()returns loss after optimal temperature scaling -
calibration_loss()returnstotal_loss - refinement_loss -
loss_decomposition()returns complete decomposition with all components -
total_loss ≈ refinement_loss + calibration_loss(numerical precision) -
TemperatureScalingis sklearn-compatible (fit/transform/predict) -
TemperatureScalingworks in sklearn pipelines - All sklearn scorers work with
GridSearchCVandcross_val_score - Unit tests pass with >90% coverage on new code
- Documentation with examples
References
- Paper: "Rethinking Early Stopping: Refine, Then Calibrate"
- Guo et al. "On Calibration of Modern Neural Networks" (ICML 2017)
- sklearn calibration docs: https://scikit-learn.org/stable/modules/calibration.html
Labels
enhancement, metrics, calibration, sklearn-integration
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request