Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.d/fix-numerical-guards.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Standardize Hard Concrete numerical guards: all three modules now use
`1e-6` as the uniform-sampling epsilon (previously `distributions.py` used
`1e-8`, which underflows at fp16), and clamp `qz_logits`/`log_alpha` to
`[-20, 20]` before sampling, deterministic gates, and penalty computation
so gradients don't vanish on saturating inputs.
17 changes: 13 additions & 4 deletions l0/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,26 @@ def _convert_sparse_to_torch(self, M_sparse: sp.spmatrix) -> torch.sparse.Tensor

return M_torch

# Bounds for `log_alpha` before scaling by `beta`. Values outside
# ``[-LOG_ALPHA_BOUND, LOG_ALPHA_BOUND]`` saturate the sigmoid and
# vanish gradients; clamping keeps sampling and deterministic gates
# well-defined under fp16/bf16.
LOG_ALPHA_BOUND: float = 20.0

def _sample_gates(self) -> torch.Tensor:
"""Sample gates using Hard Concrete distribution."""
eps = 1e-6
u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps)
s = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
u = torch.rand_like(log_alpha).clamp(eps, 1 - eps)
s = (torch.log(u) - torch.log(1 - u) + log_alpha) / self.beta
s = torch.sigmoid(s)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

def get_deterministic_gates(self) -> torch.Tensor:
"""Get deterministic gate values (for inference)."""
s = torch.sigmoid(self.log_alpha / self.beta)
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
s = torch.sigmoid(log_alpha / self.beta)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

Expand Down Expand Up @@ -263,7 +271,8 @@ def get_l0_penalty(self) -> torch.Tensor:
c = -self.beta * torch.log(
torch.tensor(-self.gamma / self.zeta, device=self.device)
)
pi = torch.sigmoid(self.log_alpha + c)
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
pi = torch.sigmoid(log_alpha + c)
return pi.sum()

def get_l2_penalty(self) -> torch.Tensor:
Expand Down
38 changes: 28 additions & 10 deletions l0/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def forward(self, input_shape: tuple[int, ...] | None = None) -> torch.Tensor:

return gates

# Bounds for `qz_logits` before sigmoid/temperature scaling. Values
# outside ``[-LOG_ALPHA_BOUND, LOG_ALPHA_BOUND]`` saturate the sigmoid
# and vanish gradients; clamping keeps sampling and deterministic gates
# well-defined in fp16/bf16 as well. See Louizos et al. 2017.
LOG_ALPHA_BOUND: float = 20.0

# Epsilon for uniform sampling; fp16-safe (``1e-8`` underflows to zero
# at fp16 and ``log(0) = -inf``).
_UNIFORM_EPS: float = 1e-6

def _sample_gates(self) -> torch.Tensor:
"""
Sample gates using the reparameterization trick.
Expand All @@ -105,12 +115,19 @@ def _sample_gates(self) -> torch.Tensor:
torch.Tensor
Sampled gate values in [0, 1]
"""
# Sample uniform noise (avoiding exact 0 and 1 for numerical stability)
u = torch.zeros_like(self.qz_logits).uniform_(1e-8, 1.0 - 1e-8)
# Clamp logits to a safe range so log-odds stay finite under fp16/bf16
# and gradients don't vanish at saturation.
logits = self.qz_logits.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)

# Sample uniform noise (avoiding exact 0 and 1 for numerical stability).
# ``1e-6`` matches calibration.py / sparse.py and is fp16-safe.
u = torch.zeros_like(logits).uniform_(
self._UNIFORM_EPS, 1.0 - self._UNIFORM_EPS
)

# Apply the concrete distribution transformation
# s = sigmoid((log(u) - log(1-u) + logits) / temperature)
s = torch.log(u) - torch.log(1 - u) + self.qz_logits
s = torch.log(u) - torch.log(1 - u) + logits
s = torch.sigmoid(s / self.temperature)

# Stretch and clamp to create hard concrete
Expand All @@ -133,8 +150,10 @@ def _deterministic_gates(self) -> torch.Tensor:
torch.Tensor
Deterministic gate values in [0, 1]
"""
# Clamp logits so eval output stays well-defined in fp16/bf16.
logits = self.qz_logits.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
# Mean of the binary concrete before stretch: sigmoid(logits / beta).
probs = torch.sigmoid(self.qz_logits / self.temperature)
probs = torch.sigmoid(logits / self.temperature)

# Apply stretching transformation
gates = probs * (self.zeta - self.gamma) + self.gamma
Expand All @@ -150,10 +169,10 @@ def get_penalty(self) -> torch.Tensor:
torch.Tensor
Expected number of non-zero gates
"""
# Clamp logits so sigmoid doesn't saturate and kill gradients.
logits = self.qz_logits.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
# Shift logits to account for hard concrete bounds
logits_shifted = self.qz_logits - self.temperature * math.log(
-self.gamma / self.zeta
)
logits_shifted = logits - self.temperature * math.log(-self.gamma / self.zeta)

# Probability that gate is active (non-zero)
prob_active = torch.sigmoid(logits_shifted)
Expand All @@ -169,9 +188,8 @@ def get_active_prob(self) -> torch.Tensor:
torch.Tensor
Probability of each gate being non-zero
"""
logits_shifted = self.qz_logits - self.temperature * math.log(
-self.gamma / self.zeta
)
logits = self.qz_logits.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
logits_shifted = logits - self.temperature * math.log(-self.gamma / self.zeta)
return torch.sigmoid(logits_shifted)

def get_sparsity(self) -> float:
Expand Down
17 changes: 13 additions & 4 deletions l0/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,26 @@ def _convert_sparse_to_torch(self, X_sparse: sp.spmatrix) -> torch.sparse.Tensor

return X_torch

# Bounds for `log_alpha` before scaling by `beta`. Values outside
# ``[-LOG_ALPHA_BOUND, LOG_ALPHA_BOUND]`` saturate the sigmoid and
# vanish gradients; clamping keeps sampling and deterministic gates
# well-defined under fp16/bf16.
LOG_ALPHA_BOUND: float = 20.0

def _sample_gates(self) -> torch.Tensor:
"""Sample gates using Hard Concrete distribution."""
eps = 1e-6
u = torch.rand_like(self.log_alpha).clamp(eps, 1 - eps)
X = (torch.log(u) - torch.log(1 - u) + self.log_alpha) / self.beta
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
u = torch.rand_like(log_alpha).clamp(eps, 1 - eps)
X = (torch.log(u) - torch.log(1 - u) + log_alpha) / self.beta
s = torch.sigmoid(X)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)

def get_deterministic_gates(self) -> torch.Tensor:
"""Get deterministic gate values (for inference)."""
X = self.log_alpha / self.beta
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
X = log_alpha / self.beta
s = torch.sigmoid(X)
s_bar = s * (self.zeta - self.gamma) + self.gamma
return s_bar.clamp(0, 1)
Expand Down Expand Up @@ -186,7 +194,8 @@ def get_l0_penalty(self) -> torch.Tensor:
c = -self.beta * torch.log(
torch.tensor(-self.gamma / self.zeta, device=self.device)
)
pi = torch.sigmoid(self.log_alpha + c)
log_alpha = self.log_alpha.clamp(-self.LOG_ALPHA_BOUND, self.LOG_ALPHA_BOUND)
pi = torch.sigmoid(log_alpha + c)
return pi.sum()

def get_sparsity(self) -> float:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,19 @@ def test_sparse_calibration_weights_exported(self):

assert "SparseCalibrationWeights" in l0.__all__
assert l0.SparseCalibrationWeights is SparseCalibrationWeights

def test_extreme_log_alpha_stays_finite(self):
"""Very large `log_alpha` must not corrupt gates or penalty."""
model = SparseCalibrationWeights(n_features=30, init_keep_prob=0.5)
with torch.no_grad():
model.log_alpha.fill_(1000.0)

det_gates = model.get_deterministic_gates()
sample_gates = model._sample_gates()
penalty = model.get_l0_penalty()

assert torch.isfinite(det_gates).all()
assert torch.isfinite(sample_gates).all()
assert torch.isfinite(penalty)
assert torch.all(det_gates >= 0) and torch.all(det_gates <= 1)
assert torch.all(sample_gates >= 0) and torch.all(sample_gates <= 1)
38 changes: 38 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,41 @@ def test_sparsity_stats_match_eval_activation(self):
# Both numbers come from the same temperature-aware distribution
# so they must be close (within sampling-free rounding slack).
assert abs(reported_sparsity - eval_sparsity) < 0.15

def test_extreme_logits_stay_finite(self):
"""Very large logits must not produce NaN/Inf gates.

Without the `LOG_ALPHA_BOUND` clamp, ``log(u)`` for ``u ~ 0`` plus
a huge positive logit can push the sigmoid argument past fp16/bf16
range; in fp16 it also drops to ``-inf`` and produces ``NaN``. We
simulate that failure mode at fp32 by pushing logits to 1e3.
"""
gate = HardConcrete(20, temperature=0.1, init_mean=0.5)
with torch.no_grad():
gate.qz_logits.fill_(1000.0)

gate.train()
for _ in range(50):
sampled = gate()
assert torch.isfinite(sampled).all()
assert torch.all(sampled >= 0) and torch.all(sampled <= 1)

gate.eval()
det = gate()
assert torch.isfinite(det).all()
assert torch.all(det >= 0) and torch.all(det <= 1)

penalty = gate.get_penalty()
assert torch.isfinite(penalty)

with torch.no_grad():
gate.qz_logits.fill_(-1000.0)
gate.train()
for _ in range(50):
assert torch.isfinite(gate()).all()
gate.eval()
assert torch.isfinite(gate()).all()

def test_uniform_eps_is_fp16_safe(self):
"""The uniform sample floor should be >= ``1e-6`` (fp16 underflow)."""
assert HardConcrete._UNIFORM_EPS >= 1e-6
18 changes: 18 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,21 @@ def test_seed_none_uses_global_rng(self):
torch.manual_seed(0)
b = SparseL0Linear(n_features=20, init_keep_prob=0.5)
torch.testing.assert_close(a.log_alpha.data, b.log_alpha.data)

def test_extreme_log_alpha_stays_finite(self):
"""Very large `log_alpha` must not corrupt gates or penalty."""
from l0.sparse import SparseL0Linear

model = SparseL0Linear(n_features=30, init_keep_prob=0.5)
with torch.no_grad():
model.log_alpha.fill_(1000.0)

det_gates = model.get_deterministic_gates()
sample_gates = model._sample_gates()
penalty = model.get_l0_penalty()

assert torch.isfinite(det_gates).all()
assert torch.isfinite(sample_gates).all()
assert torch.isfinite(penalty)
assert torch.all(det_gates >= 0) and torch.all(det_gates <= 1)
assert torch.all(sample_gates >= 0) and torch.all(sample_gates <= 1)
Loading