diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29..cff8a7c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,8 @@ +- bump: minor + changes: + added: + - Group-wise loss averaging for calibration to balance contributions from targets with different cardinalities + - Improved training output with meaningful error percentages and sparsity statistics + changed: + - Simplified active weight detection in SparseCalibrationWeights (removed threshold parameter) + - Enhanced verbose output during calibration training to show relative errors and sparsity percentage \ No newline at end of file diff --git a/l0/calibration.py b/l0/calibration.py index 3df84eb..fb59cdc 100644 --- a/l0/calibration.py +++ b/l0/calibration.py @@ -218,17 +218,12 @@ def get_sparsity(self) -> float: """ with torch.no_grad(): gates = self.get_deterministic_gates() - return (gates < 0.01).float().mean().item() + return (gates == 0).float().mean().item() - def get_active_weights(self, threshold: float = 0.01) -> dict: + def get_active_weights(self) -> dict: """ Get indices and values of active (non-zero) weights. - Parameters - ---------- - threshold : float - Gate values below this are considered zero - Returns ------- dict @@ -236,8 +231,7 @@ def get_active_weights(self, threshold: float = 0.01) -> dict: """ with torch.no_grad(): weights = self.get_weights(deterministic=True) - gates = self.get_deterministic_gates() - active_mask = gates > threshold + active_mask = weights > 0 return { "indices": torch.where(active_mask)[0], @@ -256,6 +250,7 @@ def fit( loss_type: str = "mse", verbose: bool = False, verbose_freq: int = 100, + target_groups: np.ndarray | None = None, ) -> "SparseCalibrationWeights": """ Fit calibration weights using gradient descent. @@ -280,6 +275,10 @@ def fit( Whether to print progress verbose_freq : int How often to print progress + target_groups : numpy.ndarray, optional + Array of group IDs for each target. Targets in the same group + will be averaged together so each group contributes equally to loss. + If None, all targets are treated independently. Returns ------- @@ -292,6 +291,27 @@ def fit( # Convert M to torch sparse (will be cached) M_torch = self._convert_sparse_to_torch(M) + # Compute group weights for loss averaging + if target_groups is not None: + # Convert to tensor + target_groups = torch.tensor( + target_groups, dtype=torch.long, device=self.device + ) + + # Calculate group weights: 1 / group_size for each target + unique_groups = torch.unique(target_groups) + group_weights = torch.zeros_like(y) + + for group_id in unique_groups: + group_mask = target_groups == group_id + group_size = group_mask.sum().item() + # Each target in the group gets weight 1/group_size + # so the group's total contribution is 1 + group_weights[group_mask] = 1.0 / group_size + else: + # No grouping - all targets weighted equally + group_weights = torch.ones_like(y) + # Initialize weights nn.init.normal_(self.log_weight, 0, 0.5) @@ -303,15 +323,25 @@ def fit( # Forward pass y_pred = self.forward(M, deterministic=False) - # Compute loss + # Compute loss with group weighting if loss_type == "relative": # Relative error: (y - y_pred)^2 / (y + 1)^2 # Adding 1 to avoid division by zero relative_errors = (y - y_pred) / (y + 1) - data_loss = relative_errors.pow(2).mean() + # Apply group weights and then average + weighted_squared_errors = ( + relative_errors.pow(2) * group_weights + ) + data_loss = ( + weighted_squared_errors.sum() + ) # Sum because weights already normalize else: - # Standard MSE - data_loss = (y - y_pred).pow(2).mean() + # Standard MSE with group weighting + squared_errors = (y - y_pred).pow(2) + weighted_squared_errors = squared_errors * group_weights + data_loss = ( + weighted_squared_errors.sum() + ) # Sum because weights already normalize l0_loss = self.get_l0_penalty() loss = data_loss + lambda_l0 * l0_loss @@ -331,18 +361,61 @@ def fit( with torch.no_grad(): active_info = self.get_active_weights() weights = self.get_weights(deterministic=True) - # Compute MSE for monitoring even if using relative loss - mse = (y - y_pred).pow(2).mean().item() - print( - f"Epoch {epoch+1:4d}: " - f"loss={loss.item():.4f}, " - f"data_loss={data_loss.item():.4f}, " - f"mse={mse:.4f}, " - f"l0={l0_loss.item():.2f}, " - f"active={active_info['count']}, " - f"mean_weight={weights[weights > 0.01].mean().item() if (weights > 0.01).any() else 0:.3f}" + active_weights = weights[weights > 0] + + # Compute relative errors for meaningful output + y_det = self.forward(M, deterministic=True) + if loss_type == "relative": + rel_errors = torch.abs((y - y_det) / (y + 1)) + else: + # For MSE, show relative errors anyway for interpretability + rel_errors = torch.abs((y - y_det) / (y + 1)) + + # For reporting, we can show both overall and group-averaged errors + mean_rel_err = rel_errors.mean().item() + max_rel_err = rel_errors.max().item() + + # Compute mean group loss if groups are used + if target_groups is not None: + # Calculate mean loss per group + group_losses = [] + for group_id in torch.unique(target_groups): + group_mask = target_groups == group_id + group_mean_err = ( + rel_errors[group_mask].mean().item() + ) + group_losses.append(group_mean_err) + mean_group_loss = np.mean(group_losses) + else: + mean_group_loss = mean_rel_err + + # Calculate sparsity percentage + sparsity_pct = 100 * ( + 1 - active_info["count"] / self.n_features ) + # Calculate components of the actual loss being minimized + actual_data_loss = data_loss.item() + actual_l0_loss = l0_loss.item() + actual_total_loss = loss.item() + + if target_groups is not None: + print( + f"Epoch {epoch+1:4d}: " + f"mean_group_loss={mean_group_loss:.1%}, " + f"max_error={max_rel_err:.1%}, " + f"total_loss={actual_total_loss:.3f}, " + f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" + ) + else: + print( + f"Epoch {epoch+1:4d}: " + f"mean_error={mean_rel_err:.1%}, " + f"max_error={max_rel_err:.1%}, " + f"total_loss={actual_total_loss:.3f}, " + f"active={active_info['count']:4d}/{self.n_features} ({sparsity_pct:.1f}% sparse)" + ) + return self def predict(self, M: sp.spmatrix) -> torch.Tensor: diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 3b49d3c..f0ee9c3 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -36,6 +36,7 @@ def test_sparse_ground_truth_relative_loss(self): N_active = 1000 # 50% sparsity np.random.seed(42) + torch.manual_seed(42) # Generate data with sparse ground truth M_dense = np.random.lognormal(mean=1.5, sigma=0.25, size=(Q, N)) @@ -62,7 +63,7 @@ def test_sparse_ground_truth_relative_loss(self): model.fit( M=M, y=y, - lambda_l0=0.00015, # Tuned for ~50% sparsity with relative loss + lambda_l0=0.0005, # Tuned for ~50% sparsity with relative loss lambda_l2=1e-6, lr=0.2, epochs=2000, @@ -88,6 +89,7 @@ def test_relative_vs_mse_loss(self): N = 500 np.random.seed(123) + torch.manual_seed(123) # Large-scale data M = sp.random(Q, N, density=0.5, format="csr") @@ -136,6 +138,9 @@ def test_sparsity_control(self): Q = 50 N = 200 + np.random.seed(123) + torch.manual_seed(123) + M = sp.random(Q, N, density=0.3, format="csr") y = np.random.randn(Q) + 10 @@ -149,7 +154,7 @@ def test_sparsity_control(self): y, lambda_l0=lambda_l0, lr=0.1, - epochs=500, + epochs=2000, loss_type="relative", verbose=False, ) @@ -174,7 +179,7 @@ def test_get_active_weights(self): model.fit(M, y, lambda_l0=0.01, epochs=100, verbose=False) - active_info = model.get_active_weights(threshold=0.01) + active_info = model.get_active_weights() assert "indices" in active_info assert "values" in active_info @@ -191,6 +196,9 @@ def test_deterministic_inference(self): N = 50 Q = 10 + np.random.seed(123) + torch.manual_seed(123) + M = sp.random(Q, N, density=0.5, format="csr") y = np.random.randn(Q) @@ -211,6 +219,9 @@ def test_l2_regularization(self): N = 100 Q = 20 + np.random.seed(123) + torch.manual_seed(123) + M = sp.random(Q, N, density=0.3, format="csr") y = np.random.randn(Q) * 100 # Large scale @@ -234,3 +245,179 @@ def test_l2_regularization(self): assert ( weights_with_l2.max() <= weights_no_l2.max() * 2.0 ), "L2 should prevent extreme weights" + + def test_group_wise_averaging(self): + """Test that group-wise averaging balances loss contributions.""" + N = 100 # features (households) + + # Create targets with different cardinalities: + # - 3 singleton targets (like national targets) + # - 18 targets in one group (like age bins for one state) + # - 18 targets in another group (like age bins for another state) + Q = 3 + 18 + 18 # 39 total targets + + np.random.seed(42) + torch.manual_seed(42) + + # Create matrix with varying scales + M = sp.random(Q, N, density=0.3, format="csr") + + # Create target values with different scales + # Singletons: large values (billions scale) + y_singletons = np.array([1e9, 5e8, 2e9]) + # Groups: smaller values (thousands scale) + y_group1 = np.random.uniform(1e3, 1e6, size=18) + y_group2 = np.random.uniform(1e3, 1e6, size=18) + y = np.concatenate([y_singletons, y_group1, y_group2]) + + # Create target groups + # Groups 0, 1, 2: singletons (each national target) + # Group 3: all 18 targets from first age group + # Group 4: all 18 targets from second age group + target_groups = np.array( + [0, 1, 2] # 3 singletons + + [3] * 18 # Group 3 + + [4] * 18 # Group 4 + ) + + # Train WITHOUT grouping (baseline) + model_no_groups = SparseCalibrationWeights(n_features=N) + model_no_groups.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=None, # No grouping + ) + + # Train WITH grouping + model_with_groups = SparseCalibrationWeights(n_features=N) + model_with_groups.fit( + M, + y, + lambda_l0=0.0001, + lr=0.1, + epochs=500, + loss_type="relative", + verbose=False, + target_groups=target_groups, + ) + + # Compute errors by group + with torch.no_grad(): + y_pred_no_groups = model_no_groups.predict(M).cpu().numpy() + y_pred_with_groups = model_with_groups.predict(M).cpu().numpy() + + # Relative errors + rel_err_no_groups = np.abs((y - y_pred_no_groups) / (y + 1)) + rel_err_with_groups = np.abs((y - y_pred_with_groups) / (y + 1)) + + # Average errors by group + singleton_err_no_groups = rel_err_no_groups[:3].mean() + group3_err_no_groups = rel_err_no_groups[3:21].mean() + group4_err_no_groups = rel_err_no_groups[21:].mean() + + singleton_err_with_groups = rel_err_with_groups[:3].mean() + group3_err_with_groups = rel_err_with_groups[3:21].mean() + group4_err_with_groups = rel_err_with_groups[21:].mean() + + # With grouping, singleton errors should be much better + # (they're not dominated by the 36 histogram targets) + assert singleton_err_with_groups < singleton_err_no_groups * 1.5, ( + f"Grouping should improve singleton accuracy: " + f"{singleton_err_with_groups:.4f} vs {singleton_err_no_groups:.4f}" + ) + + # All groups should have relatively balanced errors with grouping + all_group_errors = [ + singleton_err_with_groups, + group3_err_with_groups, + group4_err_with_groups, + ] + max_err = max(all_group_errors) + min_err = min(all_group_errors) + + # Errors should be within an order of magnitude of each other + assert max_err < min_err * 10, ( + f"Group errors should be balanced: " + f"min={min_err:.4f}, max={max_err:.4f}" + ) + + def test_group_wise_averaging_edge_cases(self): + """Test edge cases for group-wise averaging.""" + N = 50 + Q = 10 + + np.random.seed(42) + torch.manual_seed(42) + + M = sp.random(Q, N, density=0.3, format="csr") + y = np.random.uniform(100, 1000, size=Q) + + model = SparseCalibrationWeights(n_features=N) + + # Test 1: All targets in one group (should behave like no grouping) + target_groups_single = np.zeros(Q, dtype=int) + model.fit( + M, + y, + lambda_l0=0.00001, # Lower penalty for better convergence + epochs=2000, # Plenty of epochs + lr=0.2, # Higher learning rate + loss_type="relative", + verbose=False, + target_groups=target_groups_single, + ) + + with torch.no_grad(): + y_pred = model.predict(M).cpu().numpy() + rel_err = np.mean(np.abs((y - y_pred) / (y + 1))) + assert ( + rel_err < 0.5 + ), f"Single group should still converge, got {rel_err:.4f}" + + # Test 2: Each target in its own group (like all singletons) + target_groups_all_singleton = np.arange(Q) + model_new = SparseCalibrationWeights(n_features=N) + model_new.fit( + M, + y, + lambda_l0=0.00001, + epochs=2000, + lr=0.2, + loss_type="relative", + verbose=False, + target_groups=target_groups_all_singleton, + ) + + with torch.no_grad(): + y_pred = model_new.predict(M).cpu().numpy() + rel_err = np.mean(np.abs((y - y_pred) / (y + 1))) + assert ( + rel_err < 0.5 + ), f"All singleton groups should converge, got {rel_err:.4f}" + + # Test 3: Unbalanced groups (1 huge group, several small) + target_groups_unbalanced = np.array([0] * 7 + [1, 2, 3]) + model_unbalanced = SparseCalibrationWeights(n_features=N) + model_unbalanced.fit( + M, + y, + lambda_l0=0.00001, + epochs=2000, + lr=0.2, + loss_type="relative", + verbose=False, + target_groups=target_groups_unbalanced, + ) + + with torch.no_grad(): + y_pred = model_unbalanced.predict(M).cpu().numpy() + # Check that small groups aren't ignored + small_group_errors = np.abs((y[7:] - y_pred[7:]) / (y[7:] + 1)) + assert ( + np.mean(small_group_errors) < 0.5 + ), "Small groups should not be ignored"